Skip to content

Commit 2894e8c

Browse files
Use tensors instead of numpy arrays in feed dicts
This makes it easier to use tf.function on a training step as we are now passing tensors around rather than converting to tensors later on. Reviewers: virajbshah, ondrasej, orodley Reviewed By: ondrasej Pull Request: #344
1 parent 0832708 commit 2894e8c

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

gematria/model/python/model_base.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@
6060
INVALID_THROUGHPUT_VALUE = -1
6161

6262
_BASIC_BLOCK_INDEX_TF_DTYPE = tf.dtypes.int32
63-
_BASIC_BLOCK_INDEX_NUMPY_DTYPE = _BASIC_BLOCK_INDEX_TF_DTYPE.as_numpy_dtype()
6463

6564
# A type variable that is either a basic block or block with throughput. The
6665
# advantage over typing.Union is that in each context, the typevar represents
@@ -755,20 +754,20 @@ def _finalize_batch(self, include_expected_outputs: bool) -> FeedDict:
755754
"""
756755
schedule = self._make_batch_feed_dict()
757756
if self._create_delta_block_index:
758-
schedule['delta_block_index'] = np.array(
759-
self._batch_delta_block_index, dtype=_BASIC_BLOCK_INDEX_NUMPY_DTYPE
757+
schedule['delta_block_index'] = tf.constant(
758+
self._batch_delta_block_index, dtype=_BASIC_BLOCK_INDEX_TF_DTYPE
760759
)
761760
if include_expected_outputs:
762-
schedule['expected_outputs'] = np.reshape(
763-
np.array(self._batch_expected_outputs, dtype=self.numpy_dtype),
761+
schedule['expected_outputs'] = tf.reshape(
762+
tf.constant(self._batch_expected_outputs, dtype=self.dtype),
764763
[-1, self.num_tasks],
765764
)
766-
schedule['output_mask'] = np.array(self._batch_mask, dtype=bool)
765+
schedule['output_mask'] = tf.constant(
766+
self._batch_mask, dtype=tf.dtypes.bool
767+
)
767768
if self._use_deltas:
768-
schedule['expected_outputs_deltas'] = np.reshape(
769-
np.array(
770-
self._batch_expected_outputs_deltas, dtype=self.numpy_dtype
771-
),
769+
schedule['expected_outputs_deltas'] = tf.reshape(
770+
tf.constant(self._batch_expected_outputs_deltas, dtype=self.dtype),
772771
[-1, self.num_tasks],
773772
)
774773

@@ -1304,16 +1303,16 @@ def _compute_loss(self, schedule: FeedDict) -> loss_utils.LossComputation:
13041303
output = self(schedule, train=True)
13051304
loss = loss_utils.LossComputation(
13061305
output['output'],
1307-
tf.constant(schedule['expected_outputs']),
1308-
tf.constant(schedule['output_mask']),
1306+
schedule['expected_outputs'],
1307+
schedule['output_mask'],
13091308
percentile_ranks=self._collected_percentile_ranks,
13101309
dtype=self.dtype,
13111310
)
13121311

13131312
if self._use_deltas:
13141313
self._delta_loss = loss_utils.LossComputation(
13151314
output['output_deltas'],
1316-
tf.constant(schedule['expected_outputs_deltas']),
1315+
schedule['expected_outputs_deltas'],
13171316
output['output_mask_deltas'],
13181317
percentile_ranks=self._collected_percentile_ranks,
13191318
dtype=self.dtype,

0 commit comments

Comments
 (0)