6161from tunix .rl import rl_cluster as rl_cluster_lib
6262from tunix .rl .rollout import base_rollout
6363from tunix .rl .grpo .grpo_learner import GrpoConfig , GrpoLearner
64- from tunix .sft import metrics_logger
64+ from tunix .sft import metrics_logger , profiler
6565
6666
6767from transformers import AutoTokenizer
@@ -201,6 +201,8 @@ def rl_train(tmvp_config):
201201 os .makedirs (train_data_dir )
202202 if not os .path .exists (test_data_dir ):
203203 os .makedirs (test_data_dir )
204+ if not os .path .exists (tmvp_config .tensorboard_dir ):
205+ os .makedirs (tmvp_config .tensorboard_dir )
204206
205207 # Create model tokenizer
206208 model_tokenizer = AutoTokenizer .from_pretrained (tmvp_config .tokenizer_path )
@@ -271,16 +273,24 @@ def rl_train(tmvp_config):
271273 save_interval_steps = tmvp_config .checkpoint_period , max_to_keep = tmvp_config .max_num_checkpoints_to_keep
272274 )
273275
276+ # Set up micro batching
277+ micro_batch_size = None if tmvp_config .micro_batch_size == - 1 else tmvp_config .micro_batch_size
278+
274279 # Setup metrics logging
275280 max_logging .log (f"TensorBoard logs directory: { tmvp_config .tensorboard_dir } " )
276281 # Metrics logger
277282 metrics_logging_options = metrics_logger .MetricsLoggerOptions (
278283 log_dir = tmvp_config .tensorboard_dir , flush_every_n_steps = tmvp_config .log_period
279284 )
280285
281- # Profiler configurations
282- # TODO: xfgu@: add profiling
283286 profiler_options = None
287+ if tmvp_config .profiler == "xplane" :
288+ profiler_options = profiler .ProfilerOptions (
289+ log_dir = tmvp_config .tensorboard_dir ,
290+ skip_first_n_steps = tmvp_config .skip_first_n_steps_for_profiler ,
291+ profiler_steps = tmvp_config .profiler_steps ,
292+ set_profile_options = False ,
293+ )
284294
285295 # RL Cluster config
286296 # Note that we use vLLM as the rollout engine.
@@ -297,11 +307,15 @@ def rl_train(tmvp_config):
297307 actor_optimizer = optimizer ,
298308 eval_every_n_steps = tmvp_config .eval_interval ,
299309 max_steps = max_train_steps ,
300- # metrics logging
310+ # Micro batching
311+ mini_batch_size = tmvp_config .batch_size ,
312+ train_micro_batch_size = micro_batch_size ,
313+ rollout_micro_batch_size = micro_batch_size ,
314+ # Metrics logging
301315 metrics_logging_options = metrics_logging_options ,
302- # profiling
316+ # Profiling
303317 profiler_options = profiler_options ,
304- # checkpoint saving
318+ # Checkpoint saving
305319 checkpoint_root_directory = ckpt_dir ,
306320 checkpointing_options = checkpointing_options ,
307321 ),
0 commit comments