2020from jax .experimental import checkify
2121
2222from MaxText import exceptions
23- from MaxText import max_logging
23+ from MaxText . sharding import get_input_data_sharding , maybe_shard_with_name
2424from MaxText .utils .goodput_utils import (
2525 GoodputEvent ,
2626 maybe_record_goodput ,
@@ -42,15 +42,16 @@ def __init__(self, config, mesh, data_iterator, goodput_recorder):
4242 else :
4343 self .data_iterator = data_iterator
4444 self .last_batch = None
45+ self .input_data_shardings = get_input_data_sharding (config , mesh )
4546
4647 def update_data_iterator (self ):
4748 """Update to the next data iterator in the list, if applicable."""
4849 if hasattr (self , "data_iterator_list" ):
4950 self .data_iterator_index = (self .data_iterator_index + 1 ) % len (self .data_iterator_list )
5051 self .data_iterator = self .data_iterator_list [self .data_iterator_index ]
5152
52- def load_next_batch (self ):
53- """Loads the next batch. Can keep reusing the same batch for performance reasons."""
53+ def load_next_batch_pre_sharding (self ):
54+ """Loads the next batch w/o sharding . Can keep reusing the same batch for performance reasons."""
5455 with maybe_record_goodput (self .goodput_recorder , GoodputEvent .DATA_LOADING ):
5556 try :
5657 if self .config .reuse_example_batch and self .last_batch :
@@ -67,6 +68,14 @@ def load_next_batch(self):
6768 raise exceptions .StopTraining (f"`load_next_batch()` failed with { type (e )} exception: ({ e } )." )
6869 return self .last_batch
6970
71+ def load_next_batch (self , * args , ** kwargs ):
72+ """Loads the next batch with sharding hint"""
73+ return maybe_shard_with_name (
74+ self .load_next_batch_pre_sharding (),
75+ self .input_data_shardings ,
76+ self .config .shard_mode ,
77+ )
78+
7079 def check_example_batch (self ):
7180 if self .config .max_checkify :
7281 jittable_f = checkify .checkify (lambda x : checkify .check (jnp .any (x > - 1 ), "Batch contains bad synthetic data!" ))
@@ -90,22 +99,11 @@ def __init__(self, config, mesh, data_iterator, goodput_recorder):
9099 # Call parent constructor
91100 super ().__init__ (config , mesh , data_iterator , goodput_recorder )
92101
93- # Get ramp-up parameters from config, with safe defaults
94- self .global_batch_size_end = config .global_batch_size_to_load
95- self .global_batch_size_start = config .global_batch_size_to_load_start
96- self .increment = config .global_batch_size_to_load_increment
97- self .samples_per_increment = config .rampup_samples_per_increment_to_load
98-
99- # Check if ramp-up is active
100- self .rampup_active = self .global_batch_size_start < self .global_batch_size_end
101-
102- # State for tracking ramp-up
103- self .accum_samples = 0
104- self .global_batch_size_current = self .global_batch_size_start
102+ self .rampup_active = True
105103 self .batch_buffer = None
106104 self .buffer_start = 0
107105
108- def load_next_batch (self ):
106+ def load_next_batch (self , * args , rampup_manager = None , ** kwargs ):
109107 """
110108 Updates the batch size based on the schedule and then loads the next
111109 batch using the parent method.
@@ -114,68 +112,56 @@ def load_next_batch(self):
114112 if not self .rampup_active :
115113 return super ().load_next_batch ()
116114
117- # If in rampup phase, we use batch buffer to save data
118- # Check if it's time to increment the batch size
119- is_time_to_increment = self .accum_samples >= self .samples_per_increment
120-
121- if is_time_to_increment :
122- # Update current batch size and refresh accumulate samples
123- max_logging .log (
124- f"Global batch size increments from { self .global_batch_size_current } "
125- f" to { self .global_batch_size_current + self .increment } "
126- )
127- self .global_batch_size_current += self .increment
128- self .accum_samples = 0
129- self .rampup_active = self .global_batch_size_current < self .global_batch_size_end
130-
131- self .accum_samples += self .global_batch_size_current
132- slice_start , slice_end = self .buffer_start , self .buffer_start + self .global_batch_size_current
115+ slice_start , slice_end = self .buffer_start , self .buffer_start + rampup_manager .global_batch_size_current
133116
134- # Load new batch if batch_buffer is None or slice overpast the buffer end
117+ # Load new batch if batch_buffer is None
135118 if self .batch_buffer is None :
136- self .batch_buffer = super ().load_next_batch ()
137- slice_start , slice_end = 0 , self .global_batch_size_current
119+ self .batch_buffer = super ().load_next_batch_pre_sharding ()
120+ slice_start , slice_end = 0 , rampup_manager .global_batch_size_current
138121
139- if slice_end > self .global_batch_size_end :
140- old_buffer , self .batch_buffer = self .batch_buffer , super ().load_next_batch ()
122+ # If the slice end overpast batch end we collect new batch data
123+ if slice_end > rampup_manager .global_batch_size_end :
124+ old_buffer , self .batch_buffer = self .batch_buffer , super ().load_next_batch_pre_sharding ()
141125
142126 # self.global_batch_size_end is batch_buffer size
143127 def _slice_and_concat (old_data , new_data ):
144128 sliced_old_data = jax .lax .dynamic_slice_in_dim (
145129 old_data ,
146130 slice_start ,
147- self .global_batch_size_end - slice_start ,
131+ rampup_manager .global_batch_size_end - slice_start ,
148132 axis = 0 ,
149133 )
150134 sliced_new_data = jax .lax .dynamic_slice_in_dim (
151135 new_data ,
152136 0 ,
153- slice_end - self .global_batch_size_end ,
137+ slice_end - rampup_manager .global_batch_size_end ,
154138 axis = 0 ,
155139 )
156140 return jax .lax .concatenate ((sliced_old_data , sliced_new_data ), dimension = 0 )
157141
158- self .buffer_start = slice_end - self .global_batch_size_end
159- return jax .tree .map (_slice_and_concat , old_buffer , self .batch_buffer )
142+ self .buffer_start = slice_end - rampup_manager .global_batch_size_end
143+ output = jax .tree .map (_slice_and_concat , old_buffer , self .batch_buffer )
160144 else :
161145
162146 def _slice (data ):
163147 return jax .lax .dynamic_slice_in_dim (
164148 data ,
165149 slice_start ,
166- self .global_batch_size_current ,
150+ rampup_manager .global_batch_size_current ,
167151 axis = 0 ,
168152 )
169153
170154 self .buffer_start = slice_end
171- return jax .tree .map (_slice , self .batch_buffer )
155+ output = jax .tree .map (_slice , self .batch_buffer )
156+ self .rampup_active = rampup_manager .update ()
157+ return maybe_shard_with_name (output , self .input_data_shardings , self .config .shard_mode )
172158
173159
174- def create_dataloader (config , mesh , data_iterator , goodput_recorder ):
160+ def create_dataloader (config , mesh , data_iterator , goodput_recorder , rampup_manager ):
175161 """
176162 Create the dataloader
177163 """
178- if config .enable_rampup_batch_size :
164+ if rampup_manager and rampup_manager . num_accum_samples < config .global_rampup_samples :
179165 return RampUpDataLoader (config , mesh , data_iterator , goodput_recorder )
180166 else :
181167 return DataLoader (config , mesh , data_iterator , goodput_recorder )
0 commit comments