|
60 | 60 | INVALID_THROUGHPUT_VALUE = -1 |
61 | 61 |
|
62 | 62 | _BASIC_BLOCK_INDEX_TF_DTYPE = tf.dtypes.int32 |
63 | | -_BASIC_BLOCK_INDEX_NUMPY_DTYPE = _BASIC_BLOCK_INDEX_TF_DTYPE.as_numpy_dtype() |
64 | 63 |
|
65 | 64 | # A type variable that is either a basic block or block with throughput. The |
66 | 65 | # advantage over typing.Union is that in each context, the typevar represents |
@@ -755,20 +754,20 @@ def _finalize_batch(self, include_expected_outputs: bool) -> FeedDict: |
755 | 754 | """ |
756 | 755 | schedule = self._make_batch_feed_dict() |
757 | 756 | if self._create_delta_block_index: |
758 | | - schedule['delta_block_index'] = np.array( |
759 | | - self._batch_delta_block_index, dtype=_BASIC_BLOCK_INDEX_NUMPY_DTYPE |
| 757 | + schedule['delta_block_index'] = tf.constant( |
| 758 | + self._batch_delta_block_index, dtype=_BASIC_BLOCK_INDEX_TF_DTYPE |
760 | 759 | ) |
761 | 760 | if include_expected_outputs: |
762 | | - schedule['expected_outputs'] = np.reshape( |
763 | | - np.array(self._batch_expected_outputs, dtype=self.numpy_dtype), |
| 761 | + schedule['expected_outputs'] = tf.reshape( |
| 762 | + tf.constant(self._batch_expected_outputs, dtype=self.dtype), |
764 | 763 | [-1, self.num_tasks], |
765 | 764 | ) |
766 | | - schedule['output_mask'] = np.array(self._batch_mask, dtype=bool) |
| 765 | + schedule['output_mask'] = tf.constant( |
| 766 | + self._batch_mask, dtype=tf.dtypes.bool |
| 767 | + ) |
767 | 768 | if self._use_deltas: |
768 | | - schedule['expected_outputs_deltas'] = np.reshape( |
769 | | - np.array( |
770 | | - self._batch_expected_outputs_deltas, dtype=self.numpy_dtype |
771 | | - ), |
| 769 | + schedule['expected_outputs_deltas'] = tf.reshape( |
| 770 | + tf.constant(self._batch_expected_outputs_deltas, dtype=self.dtype), |
772 | 771 | [-1, self.num_tasks], |
773 | 772 | ) |
774 | 773 |
|
@@ -1304,16 +1303,16 @@ def _compute_loss(self, schedule: FeedDict) -> loss_utils.LossComputation: |
1304 | 1303 | output = self(schedule, train=True) |
1305 | 1304 | loss = loss_utils.LossComputation( |
1306 | 1305 | output['output'], |
1307 | | - tf.constant(schedule['expected_outputs']), |
1308 | | - tf.constant(schedule['output_mask']), |
| 1306 | + schedule['expected_outputs'], |
| 1307 | + schedule['output_mask'], |
1309 | 1308 | percentile_ranks=self._collected_percentile_ranks, |
1310 | 1309 | dtype=self.dtype, |
1311 | 1310 | ) |
1312 | 1311 |
|
1313 | 1312 | if self._use_deltas: |
1314 | 1313 | self._delta_loss = loss_utils.LossComputation( |
1315 | 1314 | output['output_deltas'], |
1316 | | - tf.constant(schedule['expected_outputs_deltas']), |
| 1315 | + schedule['expected_outputs_deltas'], |
1317 | 1316 | output['output_mask_deltas'], |
1318 | 1317 | percentile_ranks=self._collected_percentile_ranks, |
1319 | 1318 | dtype=self.dtype, |
|
0 commit comments