File tree Expand file tree Collapse file tree 3 files changed +16
-12
lines changed Expand file tree Collapse file tree 3 files changed +16
-12
lines changed Original file line number Diff line number Diff line change @@ -104,6 +104,7 @@ gematria_py_library(
104104 "//gematria/basic_block/python:basic_block" ,
105105 "//gematria/basic_block/python:throughput" ,
106106 "//gematria/utils/python:timer" ,
107+ "@graph_nets_repo//:graph_nets" ,
107108 ],
108109)
109110
Original file line number Diff line number Diff line change 4545import scipy .stats
4646import tensorflow as tf
4747import tf_slim .evaluation
48+ import graph_nets
4849
4950# The type used for TensorFlow feed_dict objects. The type we use here is
5051# simpler than what is actually accepted by TensorFlow, but the typing should be
5152# sufficient for our use. Moreover, since TensorFlow and NumPy do not provide
5253# type annotations, both the key and the value are reduced to typing.Any.
53- FeedDict = MutableMapping [str , Union [np .ndarray , tf .Tensor ]]
54+ FeedDict = MutableMapping [
55+ str , Union [np .ndarray , tf .Tensor , graph_nets .graphs .GraphsTuple ]
56+ ]
5457
5558# A throughput value used as a placeholder in the expected output tensors for
5659# masked expected outputs.
@@ -82,7 +85,7 @@ def __init__(
8285 self ,
8386 error_tensor : tf .Tensor ,
8487 checkpoint_dir : str ,
85- global_step : int ,
88+ global_step : tf . Tensor ,
8689 max_to_keep : int = 15 ,
8790 ):
8891 """Initializes the hook.
@@ -581,7 +584,7 @@ def _make_spearman_correlations(
581584 )
582585 return tf .concat (task_correlations , axis = 0 )
583586
584- def _clip_if_not_none (self , grad : Optional [tf .Tensor ]) -> tf .Tensor :
587+ def _clip_if_not_none (self , grad : Optional [tf .Tensor ]) -> tf .Tensor | None :
585588 if grad is None :
586589 return grad
587590 return tf .clip_by_norm (grad , self ._grad_clip_norm )
Original file line number Diff line number Diff line change @@ -70,17 +70,17 @@ class TrainingEpochStats:
7070 [len(this.percentile_ranks), num_tasks].
7171 """
7272
73- epoch : int
73+ epoch : tf . Tensor
7474 loss : float
7575 percentile_ranks : Sequence [int ]
76- absolute_mse : np . ndarray
77- relative_mae : np . ndarray
78- relative_mse : np . ndarray
79- absolute_error_percentiles : np . ndarray
80- relative_error_percentiles : np . ndarray
81- absolute_delta_mse : Optional [np . ndarray ] = None
82- absolute_delta_mae : Optional [np . ndarray ] = None
83- absolute_delta_error_percentiles : Optional [np . ndarray ] = None
76+ absolute_mse : tf . Tensor
77+ relative_mae : tf . Tensor
78+ relative_mse : tf . Tensor
79+ absolute_error_percentiles : tf . Tensor
80+ relative_error_percentiles : tf . Tensor
81+ absolute_delta_mse : Optional [tf . Tensor ] = None
82+ absolute_delta_mae : Optional [tf . Tensor ] = None
83+ absolute_delta_error_percentiles : Optional [tf . Tensor ] = None
8484
8585 def __post_init__ (self ) -> None :
8686 """Validate and finalize the initialization of TrainingEpochStats.
You can’t perform that action at this time.
0 commit comments