Skip to content

Commit cf083b5

Browse files
RecML authorsrecml authors
authored andcommitted
[Efficient LM] Support cache() for TFDatasetFactory and add 'array_record' in its docstring.
PiperOrigin-RevId: 748762181
1 parent 6593c25 commit cf083b5

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

recml/core/data/tf_dataset_factory.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
115115
Optionally, a sequence of such strings can be provided to create an evenly
116116
distributed mixture of datasets. This or `input_path` must be set.
117117
file_format: The file format of the input files. Must be one of 'tfrecord',
118-
'recordio', 'sstable'. Defaults to recordio.
118+
'recordio', 'sstable', 'array_record'. Defaults to recordio.
119119
global_batch_size: The global batch size across all replicas.
120120
drop_remainder: Whether the last batch should be dropped in the case it has
121121
fewer than `global_batch_size` elements.
@@ -211,6 +211,7 @@ class TFDatasetFactory(types.Factory[tf.data.Dataset]):
211211
infinitely repeated
212212
"""
213213

214+
cache_reading: bool = False
214215
input_path: str | Sequence[str] = ""
215216
tfds_source: str | Sequence[str] = ""
216217
file_format: FileFormat = FileFormat.RECORDIO
@@ -555,7 +556,10 @@ def _maybe_apply_tf_data_service(
555556
def make(self) -> tf.data.Dataset:
556557
"""Creates a `tf.data.Dataset` instance with all dataset ops applied."""
557558
# Create an examples dataset.
558-
dataset = self._create_dataset()
559+
if self.cache_reading:
560+
dataset = self._create_dataset().cache()
561+
else:
562+
dataset = self._create_dataset()
559563
# Shuffle and repeat the dataset.
560564
dataset = self._maybe_shuffle_and_repeat(dataset)
561565
# Batch and parse the examples dataset.

0 commit comments

Comments
 (0)