Skip to content

Commit cfc06af

Browse files
committed
add rampup batch checkpoint support
1 parent 77d3bbd commit cfc06af

File tree

10 files changed

+215
-68
lines changed

10 files changed

+215
-68
lines changed

src/MaxText/configs/types.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1378,6 +1378,11 @@ class DerivedValues(BaseModel):
13781378
description="The total size of context parallelism, derived from ICI and DCN values.",
13791379
)
13801380

1381+
num_target_devices: None | int = Field(
1382+
None,
1383+
description="The number of devices computed from topology in train_compile or jax.devices() in train",
1384+
)
1385+
13811386
global_batch_size_to_train_on: None | int = Field(
13821387
None,
13831388
description="The total batch size for training across all devices. Derived from `per_device_batch_size` and data"
@@ -1640,9 +1645,9 @@ def get_num_target_devices():
16401645
else:
16411646
return len(jax.devices())
16421647

1643-
num_devices = 1 # Default for validation when JAX is not initialized
1648+
self.num_target_devices = 1 # Default for validation when JAX is not initialized
16441649
try:
1645-
num_devices = get_num_target_devices()
1650+
self.num_target_devices = get_num_target_devices()
16461651
except (RuntimeError, IndexError):
16471652
logger.warning("JAX device system not available for config validation. Assuming 1 device.")
16481653

@@ -1679,15 +1684,20 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
16791684
self.global_batch_size_to_train_on,
16801685
self.micro_batch_size_to_train_on,
16811686
) = calculate_global_batch_sizes(
1682-
self.per_device_batch_size, self.expansion_factor_real_data, num_devices, self.gradient_accumulation_steps
1687+
self.per_device_batch_size,
1688+
self.expansion_factor_real_data,
1689+
self.num_target_devices,
1690+
self.gradient_accumulation_steps,
16831691
)
16841692

16851693
# Calculate final evaluation batch sizes.
16861694
(
16871695
self.global_batch_size_to_load_eval,
16881696
self.global_batch_size_to_eval_on,
16891697
self.micro_batch_size_to_eval_on,
1690-
) = calculate_global_batch_sizes(self.eval_per_device_batch_size, self.expansion_factor_real_data, num_devices, 1)
1698+
) = calculate_global_batch_sizes(
1699+
self.eval_per_device_batch_size, self.expansion_factor_real_data, self.num_target_devices, 1
1700+
)
16911701

16921702
# Calculate ramp-up batch size parameters if enabled.
16931703
if self.enable_rampup_batch_size:
@@ -1696,7 +1706,10 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
16961706
_,
16971707
_,
16981708
) = calculate_global_batch_sizes(
1699-
self.per_device_batch_size_start, self.expansion_factor_real_data, num_devices, self.gradient_accumulation_steps
1709+
self.per_device_batch_size_start,
1710+
self.expansion_factor_real_data,
1711+
self.num_target_devices,
1712+
self.gradient_accumulation_steps,
17001713
)
17011714
(
17021715
self.global_batch_size_to_load_increment,
@@ -1705,7 +1718,7 @@ def calculate_global_batch_sizes(per_device_batch_size, expansion_factor, num_de
17051718
) = calculate_global_batch_sizes(
17061719
self.per_device_batch_size_increment,
17071720
self.expansion_factor_real_data,
1708-
num_devices,
1721+
self.num_target_devices,
17091722
self.gradient_accumulation_steps,
17101723
)
17111724
diff_batch_size = self.global_batch_size_to_load - self.global_batch_size_to_load_start

src/MaxText/data_loader.py

Lines changed: 31 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from jax.experimental import checkify
2121

2222
from MaxText import exceptions
23-
from MaxText import max_logging
23+
from MaxText.sharding import get_input_data_sharding, maybe_shard_with_name
2424
from MaxText.utils.goodput_utils import (
2525
GoodputEvent,
2626
maybe_record_goodput,
@@ -42,15 +42,16 @@ def __init__(self, config, mesh, data_iterator, goodput_recorder):
4242
else:
4343
self.data_iterator = data_iterator
4444
self.last_batch = None
45+
self.input_data_shardings = get_input_data_sharding(config, mesh)
4546

4647
def update_data_iterator(self):
4748
"""Update to the next data iterator in the list, if applicable."""
4849
if hasattr(self, "data_iterator_list"):
4950
self.data_iterator_index = (self.data_iterator_index + 1) % len(self.data_iterator_list)
5051
self.data_iterator = self.data_iterator_list[self.data_iterator_index]
5152

52-
def load_next_batch(self):
53-
"""Loads the next batch. Can keep reusing the same batch for performance reasons."""
53+
def load_next_batch_pre_sharding(self):
54+
"""Loads the next batch w/o sharding. Can keep reusing the same batch for performance reasons."""
5455
with maybe_record_goodput(self.goodput_recorder, GoodputEvent.DATA_LOADING):
5556
try:
5657
if self.config.reuse_example_batch and self.last_batch:
@@ -67,6 +68,14 @@ def load_next_batch(self):
6768
raise exceptions.StopTraining(f"`load_next_batch()` failed with {type(e)} exception: ({e}).")
6869
return self.last_batch
6970

71+
def load_next_batch(self, *args, **kwargs):
72+
"""Loads the next batch with sharding hint"""
73+
return maybe_shard_with_name(
74+
self.load_next_batch_pre_sharding(),
75+
self.input_data_shardings,
76+
self.config.shard_mode,
77+
)
78+
7079
def check_example_batch(self):
7180
if self.config.max_checkify:
7281
jittable_f = checkify.checkify(lambda x: checkify.check(jnp.any(x > -1), "Batch contains bad synthetic data!"))
@@ -90,22 +99,11 @@ def __init__(self, config, mesh, data_iterator, goodput_recorder):
9099
# Call parent constructor
91100
super().__init__(config, mesh, data_iterator, goodput_recorder)
92101

93-
# Get ramp-up parameters from config, with safe defaults
94-
self.global_batch_size_end = config.global_batch_size_to_load
95-
self.global_batch_size_start = config.global_batch_size_to_load_start
96-
self.increment = config.global_batch_size_to_load_increment
97-
self.samples_per_increment = config.rampup_samples_per_increment_to_load
98-
99-
# Check if ramp-up is active
100-
self.rampup_active = self.global_batch_size_start < self.global_batch_size_end
101-
102-
# State for tracking ramp-up
103-
self.accum_samples = 0
104-
self.global_batch_size_current = self.global_batch_size_start
102+
self.rampup_active = True
105103
self.batch_buffer = None
106104
self.buffer_start = 0
107105

108-
def load_next_batch(self):
106+
def load_next_batch(self, *args, rampup_manager=None, **kwargs):
109107
"""
110108
Updates the batch size based on the schedule and then loads the next
111109
batch using the parent method.
@@ -114,68 +112,56 @@ def load_next_batch(self):
114112
if not self.rampup_active:
115113
return super().load_next_batch()
116114

117-
# If in rampup phase, we use batch buffer to save data
118-
# Check if it's time to increment the batch size
119-
is_time_to_increment = self.accum_samples >= self.samples_per_increment
120-
121-
if is_time_to_increment:
122-
# Update current batch size and refresh accumulate samples
123-
max_logging.log(
124-
f"Global batch size increments from {self.global_batch_size_current}"
125-
f" to {self.global_batch_size_current + self.increment}"
126-
)
127-
self.global_batch_size_current += self.increment
128-
self.accum_samples = 0
129-
self.rampup_active = self.global_batch_size_current < self.global_batch_size_end
130-
131-
self.accum_samples += self.global_batch_size_current
132-
slice_start, slice_end = self.buffer_start, self.buffer_start + self.global_batch_size_current
115+
slice_start, slice_end = self.buffer_start, self.buffer_start + rampup_manager.global_batch_size_current
133116

134-
# Load new batch if batch_buffer is None or slice overpast the buffer end
117+
# Load new batch if batch_buffer is None
135118
if self.batch_buffer is None:
136-
self.batch_buffer = super().load_next_batch()
137-
slice_start, slice_end = 0, self.global_batch_size_current
119+
self.batch_buffer = super().load_next_batch_pre_sharding()
120+
slice_start, slice_end = 0, rampup_manager.global_batch_size_current
138121

139-
if slice_end > self.global_batch_size_end:
140-
old_buffer, self.batch_buffer = self.batch_buffer, super().load_next_batch()
122+
# If the slice end overpast batch end we collect new batch data
123+
if slice_end > rampup_manager.global_batch_size_end:
124+
old_buffer, self.batch_buffer = self.batch_buffer, super().load_next_batch_pre_sharding()
141125

142126
# self.global_batch_size_end is batch_buffer size
143127
def _slice_and_concat(old_data, new_data):
144128
sliced_old_data = jax.lax.dynamic_slice_in_dim(
145129
old_data,
146130
slice_start,
147-
self.global_batch_size_end - slice_start,
131+
rampup_manager.global_batch_size_end - slice_start,
148132
axis=0,
149133
)
150134
sliced_new_data = jax.lax.dynamic_slice_in_dim(
151135
new_data,
152136
0,
153-
slice_end - self.global_batch_size_end,
137+
slice_end - rampup_manager.global_batch_size_end,
154138
axis=0,
155139
)
156140
return jax.lax.concatenate((sliced_old_data, sliced_new_data), dimension=0)
157141

158-
self.buffer_start = slice_end - self.global_batch_size_end
159-
return jax.tree.map(_slice_and_concat, old_buffer, self.batch_buffer)
142+
self.buffer_start = slice_end - rampup_manager.global_batch_size_end
143+
output = jax.tree.map(_slice_and_concat, old_buffer, self.batch_buffer)
160144
else:
161145

162146
def _slice(data):
163147
return jax.lax.dynamic_slice_in_dim(
164148
data,
165149
slice_start,
166-
self.global_batch_size_current,
150+
rampup_manager.global_batch_size_current,
167151
axis=0,
168152
)
169153

170154
self.buffer_start = slice_end
171-
return jax.tree.map(_slice, self.batch_buffer)
155+
output = jax.tree.map(_slice, self.batch_buffer)
156+
self.rampup_active = rampup_manager.update()
157+
return maybe_shard_with_name(output, self.input_data_shardings, self.config.shard_mode)
172158

173159

174-
def create_dataloader(config, mesh, data_iterator, goodput_recorder):
160+
def create_dataloader(config, mesh, data_iterator, goodput_recorder, rampup_manager):
175161
"""
176162
Create the dataloader
177163
"""
178-
if config.enable_rampup_batch_size:
164+
if rampup_manager and rampup_manager.num_accum_samples < config.global_rampup_samples:
179165
return RampUpDataLoader(config, mesh, data_iterator, goodput_recorder)
180166
else:
181167
return DataLoader(config, mesh, data_iterator, goodput_recorder)

src/MaxText/elastic_train.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,8 @@ def elastic_handler(
124124
learning_rate_schedule,
125125
data_iterator,
126126
_,
127+
_,
128+
_,
127129
state,
128130
) = setup_train_loop(config, recorder, elastic_manager.good_devices)
129131

@@ -178,6 +180,8 @@ def train_loop(config, elastic_manager, recorder, state=None):
178180
learning_rate_schedule,
179181
data_iterator,
180182
_,
183+
_,
184+
_,
181185
state,
182186
) = setup_train_loop(config, recorder)
183187

src/MaxText/rampup_batch.py

Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# pytype: disable=unsupported-operands
16+
"""Module to save batch size managing classes."""
17+
18+
import math
19+
20+
21+
class RampupBatchManager:
22+
"""
23+
A stateful class tracking current batch size given train step
24+
"""
25+
26+
def __init__(self, config, step_num):
27+
self._verify_inputs(config)
28+
self._init_values(config)
29+
self.num_accum_samples = 0
30+
31+
# Compute the number of samples already used given recovered step num
32+
self._recover_states(step_num)
33+
34+
def _verify_inputs(self, config):
35+
"""Verify the rampup batch related inputs."""
36+
diff_batch_size = config.per_device_batch_size - config.per_device_batch_size_start
37+
if diff_batch_size <= 0:
38+
raise ValueError(
39+
"per_device_batch_size must be greater than per_device_batch_size_start. "
40+
f"get batch size is {config.per_device_batch_size} and "
41+
f"batch size start is {config.per_device_batch_size_start}."
42+
)
43+
if diff_batch_size % config.per_device_batch_size_increment:
44+
raise ValueError(
45+
"Expect rampup batch size change divisible by batch size increment."
46+
f"Got per_device_batch_size={config.per_device_batch_size} and "
47+
f"per_device_batch_size_start={config.per_device_batch_size_start}."
48+
)
49+
50+
def _init_values(self, config):
51+
"""Initialize rampup batch related parameters"""
52+
diff_batch_size = config.per_device_batch_size - config.per_device_batch_size_start
53+
num_increments = diff_batch_size // config.per_device_batch_size_increment
54+
self.samples_per_increment = config.global_rampup_samples / num_increments
55+
num_devices = int(config.num_target_devices)
56+
self.global_batch_size_end = int(num_devices * config.per_device_batch_size)
57+
self.global_batch_size_start = int(num_devices * config.per_device_batch_size_start)
58+
self.increment = int(num_devices * config.per_device_batch_size_increment)
59+
self.global_rampup_samples = config.global_rampup_samples
60+
self.global_batch_size_current = self.global_batch_size_start
61+
self.total_rampup_steps = self._compute_total_rampup_steps(config)
62+
self.total_used_samples = 0
63+
64+
def _compute_total_rampup_steps(self, config):
65+
"""Compute total number of rampup steps"""
66+
batch_size_start = config.per_device_batch_size_start
67+
batch_size_end = config.per_device_batch_size
68+
batch_size_increment = config.per_device_batch_size_increment
69+
diff_batch_size = batch_size_end - batch_size_start
70+
num_increments = diff_batch_size // batch_size_increment
71+
rampup_samples = config.global_rampup_samples / config.num_target_devices
72+
rampup_samples_per_increment = rampup_samples / num_increments
73+
total_rampup_steps = 0
74+
current_batch_size = batch_size_start
75+
76+
while current_batch_size < batch_size_end:
77+
steps_for_this_stage = math.ceil(rampup_samples_per_increment / current_batch_size)
78+
total_rampup_steps += steps_for_this_stage
79+
current_batch_size += batch_size_increment
80+
return total_rampup_steps
81+
82+
def _recover_states(self, step_num):
83+
"""Recover the number of samples already used"""
84+
if step_num < 0:
85+
return
86+
for _ in range(step_num):
87+
_ = self.update()
88+
return
89+
90+
def update(self):
91+
"""Update values when load_batch is called"""
92+
self.total_used_samples += self.global_batch_size_current
93+
self.num_accum_samples += self.global_batch_size_current
94+
# Check if it's time to increment the batch size
95+
is_time_to_increment = self.num_accum_samples >= self.samples_per_increment
96+
if is_time_to_increment:
97+
self.global_batch_size_current = min(self.increment + self.global_batch_size_current, self.global_batch_size_end)
98+
self.num_accum_samples = 0
99+
# return whether rampup phase is active or not
100+
return self.global_batch_size_current < self.global_batch_size_end
101+
102+
103+
def create_rampup_manager(config, checkpoint_manager):
104+
if not config.enable_rampup_batch_size:
105+
return None
106+
107+
# Current step default as -1 if no checkpoint exists
108+
current_step = -1
109+
if checkpoint_manager and checkpoint_manager.latest_step():
110+
current_step = checkpoint_manager.latest_step()
111+
112+
return RampupBatchManager(config, current_step)

src/MaxText/sft_trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,8 @@ def train_loop(config, recorder, state=None):
6565
mesh,
6666
learning_rate_schedule,
6767
data_iterator,
68+
_,
69+
_,
6870
eval_data_iterator,
6971
state,
7072
) = setup_train_loop(config, recorder)

src/MaxText/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
from MaxText import sharding
5454
from MaxText.layers.multi_token_prediction import calculate_mtp_acceptance_rate, calculate_mtp_loss
5555
from MaxText.common_types import ShardMode
56-
from MaxText.data_loader import create_dataloader
5756
from MaxText.globals import EPS
5857
from MaxText.metric_logger import MetricLogger
5958
from MaxText.utils import gcs_utils
@@ -377,6 +376,8 @@ def train_loop(config, recorder, state=None):
377376
mesh,
378377
learning_rate_schedule,
379378
data_iterator,
379+
data_loader,
380+
rampup_manager,
380381
eval_data_iterator,
381382
state,
382383
) = train_utils.setup_train_loop(config, recorder)
@@ -412,7 +413,6 @@ def train_loop(config, recorder, state=None):
412413

413414
start_step = get_first_step(state) # this is the start_step for training
414415
prof = profiler.Profiler(config, offset_step=start_step)
415-
data_loader = create_dataloader(config, mesh, data_iterator, recorder)
416416
metric_logger = MetricLogger(config=config, learning_rate_schedule=learning_rate_schedule)
417417

418418
# Write train config params, num model params, and XLA flags to tensorboard
@@ -424,7 +424,7 @@ def train_loop(config, recorder, state=None):
424424
prof.maybe_activate_profiler(step, state)
425425

426426
with jax.profiler.StepTraceAnnotation("train", step_num=step):
427-
example_batch = data_loader.load_next_batch()
427+
example_batch = data_loader.load_next_batch(rampup_manager=rampup_manager)
428428
# Reshard data from loaded sharding to performant activation sharding
429429
example_batch = sharding.maybe_shard_with_name(
430430
example_batch,

0 commit comments

Comments
 (0)