@@ -42,10 +42,10 @@ class ModelUtils:
4242 def create_global_steps ():
4343 """Creates TF ops to track and increment global training step."""
4444 global_step = tf .Variable (
45- 0 , name = "global_step" , trainable = False , dtype = tf .int32
45+ 0 , name = "global_step" , trainable = False , dtype = tf .int64
4646 )
4747 steps_to_increment = tf .placeholder (
48- shape = [], dtype = tf .int32 , name = "steps_to_increment"
48+ shape = [], dtype = tf .int64 , name = "steps_to_increment"
4949 )
5050 increment_step = tf .assign (global_step , tf .add (global_step , steps_to_increment ))
5151 return global_step , increment_step , steps_to_increment
@@ -195,7 +195,7 @@ def create_normalizer(vector_obs: tf.Tensor) -> NormalizerTensors:
195195 "normalization_steps" ,
196196 [],
197197 trainable = False ,
198- dtype = tf .int32 ,
198+ dtype = tf .int64 ,
199199 initializer = tf .zeros_initializer (),
200200 )
201201 running_mean = tf .get_variable (
@@ -244,7 +244,7 @@ def create_normalizer_update(
244244 # Based on Welford's algorithm for running mean and standard deviation, for batch updates. Discussion here:
245245 # https://stackoverflow.com/questions/56402955/whats-the-formula-for-welfords-algorithm-for-variance-std-with-batch-updates
246246 steps_increment = tf .shape (vector_input )[0 ]
247- total_new_steps = tf .add (steps , steps_increment )
247+ total_new_steps = tf .add (steps , tf . cast ( steps_increment , dtype = tf . int64 ) )
248248
249249 # Compute the incremental update and divide by the number of new steps.
250250 input_to_old_mean = tf .subtract (vector_input , running_mean )
0 commit comments