Skip to content

Commit 4325980

Browse files
Merge pull request #2785 from AI-Hypercomputer:aireen/fix_interleave
PiperOrigin-RevId: 841811659
2 parents 2226c7d + 0ba5c52 commit 4325980

File tree

2 files changed

+59
-27
lines changed

2 files changed

+59
-27
lines changed

src/MaxText/configs/base.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -599,8 +599,10 @@ grain_train_mixture_config_path: '' # Path to a JSON file specifying the mixture
599599
grain_file_type: 'arrayrecord' # arrayrecord or parquet
600600
grain_worker_count: 1 # Set to -1 to enable auto-tuning: automatically determines optimal worker count. See https://google-grain.readthedocs.io/en/latest/_autosummary/grain.experimental.pick_performance_config.html
601601
grain_per_worker_buffer_size: 1
602-
# num_threads and prefetch_buffer_size are per-worker per-dataset. Used in ReadOptions (https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#per-worker-readoptions)
602+
# num_threads and prefetch_buffer_size are per-worker per-dataset.
603+
# When using array_records, they are used in ReadOptions (https://google-grain.readthedocs.io/en/latest/tutorials/data_loader_tutorial.html#per-worker-readoptions)
603604
# The default value matches that in the Grain package. If mixing multiple data sources, consider lowering these values to reduce memory usage.
605+
# When using parquet, grain_num_threads is the number of files to read and interleave in parallel
604606
grain_num_threads: 16
605607
grain_prefetch_buffer_size: 500
606608
grain_worker_count_eval: 1

src/MaxText/input_pipeline/_grain_data_processing.py

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,30 @@ def find_data_files(data_file_pattern):
4747
return data_files
4848

4949

50+
def _apply_mapdataset_transforms(
51+
dataset,
52+
shuffle,
53+
shuffle_seed,
54+
num_epoch,
55+
dataloading_host_index,
56+
dataloading_host_count,
57+
grain_num_threads,
58+
grain_prefetch_buffer_size,
59+
):
60+
"""Apply standard shuffle, repeat, shard, and iter conversion transforms."""
61+
if shuffle:
62+
dataset = dataset.shuffle(seed=shuffle_seed)
63+
dataset = dataset.repeat(num_epoch)
64+
dataset = dataset[dataloading_host_index::dataloading_host_count] # sharding
65+
dataset = dataset.to_iter_dataset(
66+
read_options=grain.ReadOptions(
67+
num_threads=grain_num_threads,
68+
prefetch_buffer_size=grain_prefetch_buffer_size,
69+
)
70+
)
71+
return dataset
72+
73+
5074
def get_datasets(
5175
data_file_pattern,
5276
data_file_type,
@@ -84,12 +108,16 @@ def create_dataset_from_pattern(pattern):
84108
datasets_dict = dict(zip(mixture_config.keys(), dataset_list))
85109

86110
for name, ds in datasets_dict.items():
87-
if shuffle:
88-
ds = ds.shuffle(seed=shuffle_seed)
89-
ds = ds.repeat(num_epoch)
90-
ds = ds[dataloading_host_index::dataloading_host_count] # sharding
91-
ds = ds.to_iter_dataset()
92-
datasets_dict[name] = ds
111+
datasets_dict[name] = _apply_mapdataset_transforms(
112+
ds,
113+
shuffle,
114+
shuffle_seed,
115+
num_epoch,
116+
dataloading_host_index,
117+
dataloading_host_count,
118+
grain_num_threads,
119+
grain_prefetch_buffer_size,
120+
)
93121

94122
# Normalize weights
95123
total_weight = sum(weights)
@@ -111,15 +139,15 @@ def create_dataset_from_pattern(pattern):
111139

112140
# Apply shuffle, repeat, sharding, and conversion to IterDataset to each dataset before mixing
113141
for d, _ in enumerate(dataset_list):
114-
if shuffle:
115-
dataset_list[d] = dataset_list[d].shuffle(seed=shuffle_seed)
116-
dataset_list[d] = dataset_list[d].repeat(num_epoch)
117-
dataset_list[d] = dataset_list[d][dataloading_host_index::dataloading_host_count] # sharding
118-
dataset_list[d] = dataset_list[d].to_iter_dataset(
119-
read_options=grain.ReadOptions(
120-
num_threads=grain_num_threads,
121-
prefetch_buffer_size=grain_prefetch_buffer_size,
122-
)
142+
dataset_list[d] = _apply_mapdataset_transforms(
143+
dataset_list[d],
144+
shuffle,
145+
shuffle_seed,
146+
num_epoch,
147+
dataloading_host_index,
148+
dataloading_host_count,
149+
grain_num_threads,
150+
grain_prefetch_buffer_size,
123151
)
124152
# Use IterDataset.mix instead of MapDataset.mix in order to have per-mixture component checkpoints
125153
# for supporting changing the mixture after checkpointing
@@ -128,15 +156,15 @@ def create_dataset_from_pattern(pattern):
128156
else:
129157
# Single pattern case - no need for parallelization
130158
dataset = create_dataset_from_pattern(data_file_pattern)
131-
if shuffle:
132-
dataset = dataset.shuffle(seed=shuffle_seed)
133-
dataset = dataset.repeat(num_epoch)
134-
dataset = dataset[dataloading_host_index::dataloading_host_count] # sharding
135-
dataset = dataset.to_iter_dataset(
136-
read_options=grain.ReadOptions(
137-
num_threads=grain_num_threads,
138-
prefetch_buffer_size=grain_prefetch_buffer_size,
139-
)
159+
dataset = _apply_mapdataset_transforms(
160+
dataset,
161+
shuffle,
162+
shuffle_seed,
163+
num_epoch,
164+
dataloading_host_index,
165+
dataloading_host_count,
166+
grain_num_threads,
167+
grain_prefetch_buffer_size,
140168
)
141169
return dataset
142170
elif data_file_type == "parquet":
@@ -152,8 +180,10 @@ def create_dataset_from_pattern(pattern):
152180
f"Please lower grain_worker_count or increase file shard count."
153181
)
154182
dataset = dataset.map(grain.experimental.ParquetIterDataset)
155-
dataset = grain.experimental.InterleaveIterDataset(dataset, cycle_length=len(dataset))
156-
dataset = grain.experimental.WindowShuffleIterDataset(dataset, window_size=100, seed=shuffle_seed)
183+
cycle_length = min(len(dataset) // num_epoch, grain_num_threads)
184+
dataset = grain.experimental.InterleaveIterDataset(dataset, cycle_length=cycle_length)
185+
if shuffle:
186+
dataset = grain.experimental.WindowShuffleIterDataset(dataset, window_size=100, seed=shuffle_seed)
157187
return dataset
158188
else:
159189
raise ValueError(f"grain pipeline supports (arrayrecord, parquet) as grain_file_type, but got {data_file_type}")

0 commit comments

Comments
 (0)