Skip to content

Commit ab53efa

Browse files
patnotzrecml authors
authored andcommitted
Add batch_number to AbstractInputBatch.
This CL adds the batch_number attribute to the AbstractInputBatch base class. The number should be unique and incremental, but can be reset to 0 on restart or between epochs. PiperOrigin-RevId: 782134551
1 parent eb58583 commit ab53efa

File tree

1 file changed

+3
-0
lines changed

1 file changed

+3
-0
lines changed

recml/layers/linen/sparsecore.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,8 @@ def _to_np(x: Any) -> np.ndarray:
328328
if weights[key] is not None:
329329
weights[key] = np.reshape(weights[key], (-1, 1))
330330

331+
# TODO(patn): Find the step number. DO NOT SUBMIT
332+
batch_number = 0
331333
csr_inputs, _ = embedding.preprocess_sparse_dense_matmul_input(
332334
features=features,
333335
features_weights=weights,
@@ -337,6 +339,7 @@ def _to_np(x: Any) -> np.ndarray:
337339
num_sc_per_device=self.sparsecore_config.num_sc_per_device,
338340
sharding_strategy=self.sparsecore_config.sharding_strategy,
339341
allow_id_dropping=False,
342+
batch_number=batch_number,
340343
)
341344

342345
processed_inputs = {

0 commit comments

Comments
 (0)