Skip to content

Commit bce4600

Browse files
Misc fixes for google import (#333)
This patch contains some misc fixes most related to type checking that enable us to import gematria into google after the TF2 migration. Most are pretty standard. We disable an attribute error on run_continuous_evaluation mostly to silence it for now as tf_slim is not really compatible with TF2 and needs to be rewritten.
1 parent 7199a13 commit bce4600

File tree

3 files changed

+17
-9
lines changed

3 files changed

+17
-9
lines changed

gematria/granite/python/gnn_model_base.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ class GraphNetworkLayer:
9090
num_iterations: Optional[int]
9191
layer_normalization: options.EnableFeature
9292
residual_connection: options.EnableFeature
93-
edges_output_size: Sequence[int] = None
94-
nodes_output_size: Sequence[int] = None
95-
globals_output_size: Sequence[int] = None
93+
edges_output_size: Sequence[int] | None = None
94+
nodes_output_size: Sequence[int] | None = None
95+
globals_output_size: Sequence[int] | None = None
9696

9797

9898
class GnnModelBase(model_base.ModelBase):
@@ -445,7 +445,9 @@ def _execute_graph_network(self, feed_dict) -> graph_nets.graphs.GraphsTuple:
445445
return graphs_tuple
446446

447447
@abc.abstractmethod
448-
def _execute_readout_network(self, graph_tuple) -> tf.Tensor:
448+
def _execute_readout_network(
449+
self, graph_tuple, feed_dict: model_base.FeedDict
450+
) -> tf.Tensor:
449451
"""Creates a readout part of the network.
450452
451453
Creates TensorFlow ops that take the output of the graph network and

gematria/granite/python/graph_builder_model_base.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,14 +208,18 @@ def _start_batch(self) -> None:
208208
# @Override
209209
def _make_batch_feed_dict(self) -> model_base.FeedDict:
210210
feed_dict = super()._make_batch_feed_dict()
211-
feed_dict['instruction_node_mask'] = np.array(
211+
feed_dict['instruction_node_mask'] = tf.constant(
212212
self._batch_graph_builder.instruction_node_mask, dtype=bool
213213
)
214-
self._instruction_node_mask = feed_dict['instruction_node_mask']
215-
feed_dict['instruction_annotations'] = (
214+
self._instruction_node_mask = tf.constant(
215+
feed_dict['instruction_node_mask']
216+
)
217+
feed_dict['instruction_annotations'] = tf.constant(
216218
self._batch_graph_builder.instruction_annotations
217219
)
218-
self._instruction_annotations = feed_dict['instruction_annotations']
220+
self._instruction_annotations = tf.constant(
221+
feed_dict['instruction_annotations']
222+
)
219223
return feed_dict
220224

221225
# @Override

gematria/model/python/model_base.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1113,8 +1113,10 @@ def run_continuous_evaluation(
11131113
tf_slim.evaluation.StopAfterNEvalsHook(1),
11141114
tf_slim.evaluation.SummaryAtEndHook(summary_dir, feed_dict=schedule),
11151115
# Save the models with the best MAPE.
1116+
# We disable attribute error detection on self._loss because it is
1117+
# nullable and pytype expects there to be a check here.
11161118
SaveBestCheckpoint(
1117-
error_tensor=self._loss.mean_absolute_percentage_error,
1119+
error_tensor=self._loss.mean_absolute_percentage_error, # pytype: disable=attribute-error
11181120
checkpoint_dir=os.path.join(summary_dir, 'best_models'),
11191121
global_step=self.global_step,
11201122
),

0 commit comments

Comments
 (0)