Skip to content

Commit b63ce5c

Browse files
Merge pull request #2645 from AI-Hypercomputer:xfgu-micro-batch
PiperOrigin-RevId: 830690451
2 parents 188c426 + fed1495 commit b63ce5c

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

src/MaxText/configs/rl.yml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,9 @@ batch_size: 1
8181
# Increase `batch_size` and `MAX_STEPS` for better results.
8282
# num_batches: 3738
8383
num_batches: 4 # 200
84+
# A batch can be split into multiple micro batches for memory management
85+
# and/or async sampling and training.
86+
micro_batch_size: -1
8487
# Keep `num_test_batches` low so that evaluation runs quickly. It can be
8588
# increased to a max. of 330 (if batch size is 4).
8689
num_test_batches: 5 # 200

src/MaxText/rl/train_rl.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@
6161
from tunix.rl import rl_cluster as rl_cluster_lib
6262
from tunix.rl.rollout import base_rollout
6363
from tunix.rl.grpo.grpo_learner import GrpoConfig, GrpoLearner
64-
from tunix.sft import metrics_logger
64+
from tunix.sft import metrics_logger, profiler
6565

6666

6767
from transformers import AutoTokenizer
@@ -201,6 +201,8 @@ def rl_train(tmvp_config):
201201
os.makedirs(train_data_dir)
202202
if not os.path.exists(test_data_dir):
203203
os.makedirs(test_data_dir)
204+
if not os.path.exists(tmvp_config.tensorboard_dir):
205+
os.makedirs(tmvp_config.tensorboard_dir)
204206

205207
# Create model tokenizer
206208
model_tokenizer = AutoTokenizer.from_pretrained(tmvp_config.tokenizer_path)
@@ -271,16 +273,24 @@ def rl_train(tmvp_config):
271273
save_interval_steps=tmvp_config.checkpoint_period, max_to_keep=tmvp_config.max_num_checkpoints_to_keep
272274
)
273275

276+
# Set up micro batching
277+
micro_batch_size = None if tmvp_config.micro_batch_size == -1 else tmvp_config.micro_batch_size
278+
274279
# Setup metrics logging
275280
max_logging.log(f"TensorBoard logs directory: {tmvp_config.tensorboard_dir}")
276281
# Metrics logger
277282
metrics_logging_options = metrics_logger.MetricsLoggerOptions(
278283
log_dir=tmvp_config.tensorboard_dir, flush_every_n_steps=tmvp_config.log_period
279284
)
280285

281-
# Profiler configurations
282-
# TODO: xfgu@: add profiling
283286
profiler_options = None
287+
if tmvp_config.profiler == "xplane":
288+
profiler_options = profiler.ProfilerOptions(
289+
log_dir=tmvp_config.tensorboard_dir,
290+
skip_first_n_steps=tmvp_config.skip_first_n_steps_for_profiler,
291+
profiler_steps=tmvp_config.profiler_steps,
292+
set_profile_options=False,
293+
)
284294

285295
# RL Cluster config
286296
# Note that we use vLLM as the rollout engine.
@@ -297,11 +307,15 @@ def rl_train(tmvp_config):
297307
actor_optimizer=optimizer,
298308
eval_every_n_steps=tmvp_config.eval_interval,
299309
max_steps=max_train_steps,
300-
# metrics logging
310+
# Micro batching
311+
mini_batch_size=tmvp_config.batch_size,
312+
train_micro_batch_size=micro_batch_size,
313+
rollout_micro_batch_size=micro_batch_size,
314+
# Metrics logging
301315
metrics_logging_options=metrics_logging_options,
302-
# profiling
316+
# Profiling
303317
profiler_options=profiler_options,
304-
# checkpoint saving
318+
# Checkpoint saving
305319
checkpoint_root_directory=ckpt_dir,
306320
checkpointing_options=checkpointing_options,
307321
),

0 commit comments

Comments
 (0)