Skip to content

Commit 0832708

Browse files
Remove model_ref in token_graph_builder_model
With the port to TF2 I added a hack to get access to instruction annotations and the node mask within the call function of TokenGraphBuilderModelNodeEmbed. This was necessary due to TF2 preferring eager mode so we could not just pass tensor references around. This patch makes everything more canonically TF2 by moving the data around through standard dataflow (with some slight complexity to ensure things are getting passed around to the right places) rather than passing references around. Either one should theoretically work with tf.function annotations depending upon scope, but this will definitely work. Reviewers: orodley, virajbshah, ondrasej Reviewed By: ondrasej Pull Request: #342
1 parent cdf8b1f commit 0832708

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

gematria/granite/python/gnn_model_base.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ class GraphNetworkLayer(tf.Module):
6161
features computed by the layer. When both a residual connection and layer
6262
normalization are used, the layer normalization op is inserted after the
6363
residual connection.
64+
extra_node_inputs: Names of extra feed_dict members that should be passed
65+
to the node model.
6466
"""
6567

6668
# NOTE(ondrasej): This should be one of the classes defined in
@@ -74,6 +76,7 @@ class GraphNetworkLayer(tf.Module):
7476
edges_output_size: Sequence[int] | None = None
7577
nodes_output_size: Sequence[int] | None = None
7678
globals_output_size: Sequence[int] | None = None
79+
extra_node_inputs: Sequence[str] | None = None
7780

7881

7982
class GnnModelBase(model_base.ModelBase):
@@ -376,7 +379,15 @@ def _execute_graph_network(self, feed_dict) -> graph_nets.graphs.GraphsTuple:
376379
)
377380
for iteration in range(num_iterations):
378381
residual_input = graphs_tuple
379-
graphs_tuple = layer.module(graphs_tuple)
382+
extra_node_args = {}
383+
if layer.extra_node_inputs is not None:
384+
for extra_node_arg_name in layer.extra_node_inputs:
385+
extra_node_args[extra_node_arg_name] = feed_dict[
386+
extra_node_arg_name
387+
]
388+
graphs_tuple = layer.module(
389+
graphs_tuple, node_model_kwargs=extra_node_args
390+
)
380391
if use_residual_connections:
381392
residual_op_name_base = (
382393
f'residual_connection_{layer_index}_{iteration}'

gematria/granite/python/token_graph_builder_model.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,6 @@ def _create_graph_network_modules(
345345
vocab_size=len(self._token_list),
346346
common_embed_dim=self._common_node_embedding_size,
347347
num_annotations=self._num_annotations,
348-
model_ref=self,
349348
initializer=embedding_initializer,
350349
),
351350
global_model_fn=functools.partial(
@@ -364,6 +363,10 @@ def _create_graph_network_modules(
364363
num_iterations=1,
365364
layer_normalization=options.EnableFeature.NEVER,
366365
residual_connection=options.EnableFeature.NEVER,
366+
extra_node_inputs=(
367+
'instruction_node_mask',
368+
'instruction_annotations',
369+
),
367370
),
368371
gnn_model_base.GraphNetworkLayer(
369372
module=graph_nets.modules.GraphNetwork(
@@ -410,7 +413,6 @@ def __init__(
410413
self,
411414
common_embed_dim,
412415
num_annotations,
413-
model_ref,
414416
**kwargs,
415417
) -> None:
416418
"""Initializes node embeddings.
@@ -420,11 +422,8 @@ def __init__(
420422
embedding vectors. The remainder of the vector is filled with
421423
instruction annotation.
422424
num_annotations: The number of annotations per instruction.
423-
model_ref: A reference to the model to get the instruction node mask and
424-
instruction annotations.
425425
kwargs: Additional arguments to be passed to the internal `snt.Embed`s.
426426
"""
427-
self._model_ref = model_ref
428427

429428
# The first `embed_dim - num_annotations` embedding values for all nodes.
430429
self._common_embed = snt.Embed(
@@ -446,6 +445,8 @@ def __init__(
446445
def __call__(
447446
self,
448447
inputs,
448+
instruction_node_mask,
449+
instruction_annotations,
449450
):
450451
if not self._extra_embed:
451452
return self._common_embed(inputs)
@@ -458,10 +459,8 @@ def __call__(
458459
common_embeddings,
459460
tf.tensor_scatter_nd_update(
460461
extra_embeddings,
461-
indices=tf.where(
462-
self._model_ref._instruction_node_mask,
463-
),
464-
updates=self._model_ref._instruction_annotations,
462+
indices=tf.where(instruction_node_mask),
463+
updates=instruction_annotations,
465464
),
466465
],
467466
axis=1,

0 commit comments

Comments
 (0)