Skip to content

Commit 8ac007f

Browse files
More typing changes for google import (#335)
This patch contains a couple more typing changes that are needed to make the google import work with the pytype checks. I have validated that after this patch all of the tests run successfully internally.
1 parent 0d094dc commit 8ac007f

File tree

3 files changed

+16
-12
lines changed

3 files changed

+16
-12
lines changed

gematria/model/python/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ gematria_py_library(
104104
"//gematria/basic_block/python:basic_block",
105105
"//gematria/basic_block/python:throughput",
106106
"//gematria/utils/python:timer",
107+
"@graph_nets_repo//:graph_nets",
107108
],
108109
)
109110

gematria/model/python/model_base.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,12 +45,15 @@
4545
import scipy.stats
4646
import tensorflow as tf
4747
import tf_slim.evaluation
48+
import graph_nets
4849

4950
# The type used for TensorFlow feed_dict objects. The type we use here is
5051
# simpler than what is actually accepted by TensorFlow, but the typing should be
5152
# sufficient for our use. Moreover, since TensorFlow and NumPy do not provide
5253
# type annotations, both the key and the value are reduced to typing.Any.
53-
FeedDict = MutableMapping[str, Union[np.ndarray, tf.Tensor]]
54+
FeedDict = MutableMapping[
55+
str, Union[np.ndarray, tf.Tensor, graph_nets.graphs.GraphsTuple]
56+
]
5457

5558
# A throughput value used as a placeholder in the expected output tensors for
5659
# masked expected outputs.
@@ -82,7 +85,7 @@ def __init__(
8285
self,
8386
error_tensor: tf.Tensor,
8487
checkpoint_dir: str,
85-
global_step: int,
88+
global_step: tf.Tensor,
8689
max_to_keep: int = 15,
8790
):
8891
"""Initializes the hook.
@@ -581,7 +584,7 @@ def _make_spearman_correlations(
581584
)
582585
return tf.concat(task_correlations, axis=0)
583586

584-
def _clip_if_not_none(self, grad: Optional[tf.Tensor]) -> tf.Tensor:
587+
def _clip_if_not_none(self, grad: Optional[tf.Tensor]) -> tf.Tensor | None:
585588
if grad is None:
586589
return grad
587590
return tf.clip_by_norm(grad, self._grad_clip_norm)

gematria/model/python/training.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -70,17 +70,17 @@ class TrainingEpochStats:
7070
[len(this.percentile_ranks), num_tasks].
7171
"""
7272

73-
epoch: int
73+
epoch: tf.Tensor
7474
loss: float
7575
percentile_ranks: Sequence[int]
76-
absolute_mse: np.ndarray
77-
relative_mae: np.ndarray
78-
relative_mse: np.ndarray
79-
absolute_error_percentiles: np.ndarray
80-
relative_error_percentiles: np.ndarray
81-
absolute_delta_mse: Optional[np.ndarray] = None
82-
absolute_delta_mae: Optional[np.ndarray] = None
83-
absolute_delta_error_percentiles: Optional[np.ndarray] = None
76+
absolute_mse: tf.Tensor
77+
relative_mae: tf.Tensor
78+
relative_mse: tf.Tensor
79+
absolute_error_percentiles: tf.Tensor
80+
relative_error_percentiles: tf.Tensor
81+
absolute_delta_mse: Optional[tf.Tensor] = None
82+
absolute_delta_mae: Optional[tf.Tensor] = None
83+
absolute_delta_error_percentiles: Optional[tf.Tensor] = None
8484

8585
def __post_init__(self) -> None:
8686
"""Validate and finalize the initialization of TrainingEpochStats.

0 commit comments

Comments
 (0)