diff --git a/tests/ray/test_evaluator.py b/tests/ray/test_evaluator.py index 61611e1f8..7915d96b9 100644 --- a/tests/ray/test_evaluator.py +++ b/tests/ray/test_evaluator.py @@ -94,27 +94,17 @@ def custom_compute_metric(samples): evaluator_cfg = EvaluatorConfig( dataset_cfg=self.eval_dataset_cfg, tokenizer=self.tokenizer, - max_concurrent=1, - eval_sample_ratio=0.004, # generate 5 samples - compute_metric_func=None, - sample_params=self.sample_params, - worker_log_dir=self.worker_log_dir - ) - evaluator = Evaluator.remote(evaluator_cfg, self.test_env) - correctness = ray.get(evaluator.run.remote()) - custom_evaluator_cfg = EvaluatorConfig( - dataset_cfg=self.eval_dataset_cfg, - tokenizer=self.tokenizer, - max_concurrent=1, + max_concurrent=16, eval_sample_ratio=0.004, # generate 5 samples compute_metric_func=custom_compute_metric, sample_params=self.sample_params, worker_log_dir=self.worker_log_dir ) - custom_evaluator = Evaluator.remote(custom_evaluator_cfg, self.test_env) - custom_correctness = ray.get(custom_evaluator.run.remote()) - self.assertEqual(correctness['accuracy'], custom_correctness['custom_accuracy']) - ray.get(self.test_env.shutdown.remote()) + evaluator = Evaluator.remote(evaluator_cfg, self.test_env) + try: + ray.get(evaluator.run.remote()) + except Exception as e: + self.fail(f"evaluator.run.remote() raised an exception: {e}") if __name__ == '__main__': unittest.main() diff --git a/tests/ray/test_grpo_train.py b/tests/ray/test_grpo_train.py deleted file mode 100644 index 720d334e5..000000000 --- a/tests/ray/test_grpo_train.py +++ /dev/null @@ -1,133 +0,0 @@ -import os -import torch -import json -import time -import unittest -from transformers import AutoTokenizer -import shutil -import tempfile - -import ray -from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers -from xtuner.v1.data_proto.sequence_context import SequenceContext -from xtuner.v1.config import ( - AdamWConfig, - FSDPConfig, - LRConfig, -) -from xtuner.v1.model.moe.moe import BalancingLossConfig, ZLossConfig -# from xtuner.v1.rl.grpo.config import WorkerConfig, LossConfig -from xtuner.v1.rl.base import WorkerConfig, TrainingController, TrainingWorker as BaseTrainingWorker -from xtuner.v1.rl.grpo.loss import GRPOLossConfig as LossConfig -from xtuner.v1.model.dense.qwen3 import Qwen3Dense8BConfig - - -# Qwen3 30B A3 -QWEN3_PATH = os.environ["QWEN3_PATH"] - -class TestGRPOTrain(unittest.TestCase): - def setUp(self): - ray.init(num_cpus=80, ignore_reinit_error=True) - - resources = AcceleratorResourcesConfig( - accelerator="GPU", - num_accelerators_per_worker=1, - num_cpus_per_worker=8, - num_workers=8, - cpu_memory_per_worker=16 * 1024**3, # 16 GB - ) - - pg = AutoAcceleratorWorkers.build_placement_group(resources) - self.pg = pg - - self.temp_dir = tempfile.mkdtemp() - tokenizer = AutoTokenizer.from_pretrained(QWEN3_PATH, trust_remote_code=True) - self.tokenizer = tokenizer - self.prompt_repeat_k = 8 - file = './tests/ray/rollout_output.jsonl' - with open(file, 'r') as f: - data = [json.loads(line) for line in f] - data_groups = [data[i:i + self.prompt_repeat_k] for i in range(0, len(data), self.prompt_repeat_k)] - data_groups = data_groups[:8] - data_batches = [] - for group in data_groups: - prompt_ids = tokenizer(group[0]['prompt'], return_tensors='pt')['input_ids'].flatten().tolist() - rewards = [item['reward'] for item in group] - rewards = torch.tensor(rewards, dtype=torch.float32) - advantages = (rewards - rewards.mean(0)) / (rewards.std(0) + 1e-8) - - for i in range(self.prompt_repeat_k): - item = group[i] - response_ids = tokenizer(item['response'], return_tensors='pt')['input_ids'].flatten().tolist() - input_ids = prompt_ids + response_ids - shifted_labels = [-100] * (len(prompt_ids) - 1) + response_ids + [-100] - input_ids = torch.tensor(input_ids, dtype=torch.int64).unsqueeze(0) - shifted_labels = torch.tensor(shifted_labels, dtype=torch.int64).unsqueeze(0) - data_batches.append( - dict( - seq_ctx=SequenceContext.from_input_ids((input_ids, ), device="cpu"), - shifted_labels=shifted_labels, - advantage=advantages[i].item(), - ) - ) - self.data_batches = data_batches - - def tearDown(self): - shutil.rmtree(self.temp_dir) - ray.shutdown() - - def build_train_controller(self): - model_cfg = Qwen3Dense8BConfig() - optim_cfg: AdamWConfig = AdamWConfig(lr=5e-7, foreach=False) - fsdp_cfg: FSDPConfig = FSDPConfig( - torch_compile=True, - cpu_offload=False, - ep_size=1, - - ) - lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=5e-7) - worker_cfg: WorkerConfig = WorkerConfig( - model_cfg=model_cfg, - optim_cfg=optim_cfg, - loss_cfg=LossConfig( - policy_loss_cfg=dict( - cliprange_high=0.28, - cliprange_low=0.2, - loss_type="vanilla", - ), - ignore_idx=-100, - use_kl_loss=True, - kl_loss_coef=0.001, - kl_loss_type="low_var_kl", - mode="eager"), - lr_cfg=lr_cfg, - fsdp_cfg=fsdp_cfg, - load_from=QWEN3_PATH, - sp_size=1, - pack_max_length=8192, - ) - - TrainingWorker = ray.remote( - runtime_env={ - "env_vars": { - "RAY_EXPERIMENTAL_NOSET_CUDA_VISIBLE_DEVICES": "1", - "RAY_EXPERIMENTAL_NOSET_ASCEND_RT_VISIBLE_DEVICES": "1", - } - }, - )(BaseTrainingWorker) - train_workers, _ = AutoAcceleratorWorkers.from_placement_group( - TrainingWorker, worker_cfg, self.pg - ) - futures = [ worker.test_all_reduce.remote() for worker in train_workers ] - print(ray.get(futures)) - train_controller = TrainingController.remote( - workers=train_workers, - ) - ray.get(train_controller.__ray_ready__.remote()) - return train_controller - - def test_grpo_train_and_save(self): - train_controller = self.build_train_controller() - ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=0)) - save_path = os.path.join(self.temp_dir, "hf_test") - ray.get(train_controller.save_hf.remote(str(save_path))) diff --git a/tests/ray/test_rl_trainer.py b/tests/ray/test_rl_trainer.py new file mode 100644 index 000000000..930cf7b14 --- /dev/null +++ b/tests/ray/test_rl_trainer.py @@ -0,0 +1,250 @@ +import os +import tempfile +import unittest +from pathlib import Path + +import ray +import torch + +from transformers import AutoTokenizer +from xtuner.v1.config import ( + AdamWConfig, + FSDPConfig, + LRConfig, +) +from xtuner.v1.data_proto.rl_data import SampleParams +from xtuner.v1.datasets import RLTokenizeFnConfig +from xtuner.v1.datasets.config import DataloaderConfig, DatasetConfig +from xtuner.v1.model import get_model_config_from_hf +from xtuner.v1.ray.base import AcceleratorResourcesConfig, CPUResourcesConfig +from xtuner.v1.ray.config.worker import RolloutConfig +from xtuner.v1.ray.dataflow import DataFlowConfig, ReplayBufferConfig +from xtuner.v1.ray.judger.controller import JudgerConfig +from xtuner.v1.rl.base import WorkerConfig +from xtuner.v1.rl.grpo import GRPOLossConfig +from xtuner.v1.train.rl_trainer import RLTrainer, RLTrainerConfig + + +MODEL_PATH = os.environ["ROLLOUT_MODEL_PATH"] +TRAIN_DATA_PATH = os.environ["ROLLOUT_DATA_PATH"] +TEST_DATA_PATH = os.environ["ROLLOUT_TEST_DATA_PATH"] +resource_map = { + "npu": "NPU", + "cuda": "GPU", +} + + +class TestRLTrainer(unittest.TestCase): + @classmethod + def setUpClass(cls): + os.environ["XTUNER_USE_FA3"] = "1" + + @classmethod + def tearDownClass(cls): + del os.environ["XTUNER_USE_FA3"] + + def init_traine_worker_config(self, train_optimizer_steps, pack_max_length): + model_cfg = get_model_config_from_hf(Path(MODEL_PATH)) + optim_cfg = AdamWConfig(lr=1e-6, betas=(0.9, 0.999), max_grad_norm=1.0, weight_decay=0.1, foreach=False) + loss_cfg = GRPOLossConfig( + policy_loss_cfg=dict( + cliprange_high=0.28, + cliprange_low=0.2, + loss_type="vanilla", + clip_ratio_c=10.0, + log_prob_diff_min=-20.0, + log_prob_diff_max=20.0, + ), + ignore_idx=-100, + use_kl_loss=False, + kl_loss_coef=0.0, + kl_loss_type="low_var_kl", + mode="chunk", + chunk_size=512, + ) + lr_cfg = LRConfig(lr_type="constant", warmup_ratio=0, lr_min=1e-6) + fsdp_cfg = FSDPConfig(torch_compile=False, cpu_offload=False, ep_size=1) + train_worker_cfg: WorkerConfig = WorkerConfig( + model_cfg=model_cfg, + load_from=MODEL_PATH, + optim_cfg=optim_cfg, + loss_cfg=loss_cfg, + lr_cfg=lr_cfg, + fsdp_cfg=fsdp_cfg, + sp_size=1, + optimizer_steps=train_optimizer_steps, + pack_max_length=pack_max_length, + ) + return train_worker_cfg + + def init_replay_buffer_config(self, max_prompt_length): + train_dataset_cfg = [ + { + "dataset": DatasetConfig(name="gsm8k", anno_path=TRAIN_DATA_PATH, sample_ratio=1.0), + "tokenize_fn": RLTokenizeFnConfig(max_length=max_prompt_length), + }, + ] + dataloader_cfg = DataloaderConfig( + collator="fake_collator", + pack_level="none", + group_by_length=False, + ) + tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH, trust_remote_code=True) + replay_buffer_cfg = ReplayBufferConfig( + dataset_cfg=train_dataset_cfg, + dataloader_cfg=dataloader_cfg, + tokenizer=tokenizer, + worker_log_dir=self.worker_log_dir, + ) + return replay_buffer_cfg + + def init_resources_config(self, num_workers, num_cpus_per_worker, cpu_memory_per_worker): + resources = AcceleratorResourcesConfig( + accelerator=resource_map[torch.accelerator.current_accelerator().type], + num_workers=num_workers, + num_cpus_per_worker=num_cpus_per_worker, + cpu_memory_per_worker=cpu_memory_per_worker, + ) + return resources + + def init_cpu_resources_config(self, num_cpus_per_worker, cpu_memory_per_worker): + cpu_resources = CPUResourcesConfig( + num_cpus_per_worker=num_cpus_per_worker, + cpu_memory_per_worker=cpu_memory_per_worker, + ) + return cpu_resources + + def init_rollout_config(self, max_prompt_length, max_response_length): + rollout_config = RolloutConfig( + env="test_rl_trainer", + model_path=MODEL_PATH, + worker_log_dir=self.worker_log_dir, + rollout_max_batch_size_per_instance=1024, + context_length=max_response_length + max_prompt_length, + ) + return rollout_config + + def init_dataflow_config(self, max_response_length, global_batch_size, prompt_repeat_k, enable_partial_rollout): + sample_params = SampleParams( + max_tokens=max_response_length, + ) + dataflow_config = DataFlowConfig( + env="test_rl_trainer", + global_batch_size=global_batch_size, + prompt_repeat_k=prompt_repeat_k, + worker_log_dir=self.worker_log_dir, + sample_params=sample_params, + enable_partial_rollout=enable_partial_rollout, + max_concurrent=1024, + ) + return dataflow_config + + def init_judger_config(self): + from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig + + gsm8k_judger_config = GSM8KJudgerConfig(judger_name="openai/gsm8k") + judger_cfg = JudgerConfig(reward_judger_configs=[gsm8k_judger_config], worker_log_dir=self.worker_log_dir) + return judger_cfg + + def init_multi_judger_config(self): + from xtuner.v1.ray.judger.gsm8k import GSM8KJudgerConfig + + # 支持一个GSM8KJudgerConfig创建多个实例 + gsm8k_judger_config_1 = GSM8KJudgerConfig(judger_name="openai/gsm8k-1") + gsm8k_judger_config_2 = GSM8KJudgerConfig(judger_name="openai/gsm8k-2") + judger_cfg = JudgerConfig( + reward_judger_configs=[gsm8k_judger_config_1, gsm8k_judger_config_2], + worker_log_dir=self.worker_log_dir, + ) + return judger_cfg + + def setUp(self): + ray.init(num_cpus=80, ignore_reinit_error=True) + self.temp_dir = tempfile.TemporaryDirectory() + self.worker_log_dir = os.path.join(self.temp_dir.name, "work_dirs") + + train_optimizer_steps = 2 + pack_max_length = 32768 + max_prompt_length = 2048 + max_response_length = 1024 + global_batch_size = 4 + prompt_repeat_k = 4 + enable_partial_rollout = False + + self.train_worker_cfg = self.init_traine_worker_config(train_optimizer_steps, pack_max_length) + self.replay_buffer_cfg = self.init_replay_buffer_config(max_prompt_length) + self.resources_cfg = self.init_resources_config( + num_workers=8, num_cpus_per_worker=8, cpu_memory_per_worker=8 * 1024**3 + ) + self.cpu_resources = self.init_cpu_resources_config(num_cpus_per_worker=1, cpu_memory_per_worker=1 * 1024**3) + self.rollout_config = self.init_rollout_config( + max_response_length=max_response_length, max_prompt_length=max_prompt_length + ) + self.dataflow_config = self.init_dataflow_config( + max_response_length=max_response_length, + global_batch_size=global_batch_size, + prompt_repeat_k=prompt_repeat_k, + enable_partial_rollout=enable_partial_rollout, + ) + self.judger_config = self.init_judger_config() + + def tearDown(self): + self.temp_dir.cleanup() + ray.shutdown() + + def test_rl_trainer(self): + multi_judger_config = self.init_multi_judger_config() + cpu_resources = self.init_cpu_resources_config(num_cpus_per_worker=2, cpu_memory_per_worker=2 * 1024**3) + trainer_config = RLTrainerConfig( + load_from=MODEL_PATH, + resources=self.resources_cfg, + cpu_resources=cpu_resources, + rollout_config=self.rollout_config, + dataflow_config=self.dataflow_config, + judger_config=multi_judger_config, + replay_buffer_config=self.replay_buffer_cfg, + train_worker_config=self.train_worker_cfg, + work_dir=self.worker_log_dir, + tokenizer_path=MODEL_PATH, + total_epochs=1, + rollout_steps=1, + ) + trainer = RLTrainer.from_config(trainer_config) + self.assertIsNotNone(trainer, "Trainer should be created successfully") + try: + trainer.fit() + except Exception as e: + self.fail(f"trainer.fit() raised unexpected exception: {e}") + # assure all writers are closed before checking log files + del trainer + log_files = list(Path(self.worker_log_dir).rglob("*.log")) + self.assertGreater(len(log_files), 0, "Should generate log files") + trajectory_files = list(Path(self.worker_log_dir).rglob("*_trajectory.jsonl")) + self.assertGreater(len(trajectory_files), 0, "Should generate trajectory files") + + def test_judger_cpu_pg_creation_with_error(self): + """Test RLTrainer judger_cpu_pg creation.""" + multi_judger_config = self.init_multi_judger_config() + # error resource with multi-judger + cpu_resources = self.init_cpu_resources_config(num_cpus_per_worker=1, cpu_memory_per_worker=1 * 1024**3) + trainer_config = RLTrainerConfig( + load_from=MODEL_PATH, + resources=self.resources_cfg, + cpu_resources=cpu_resources, + rollout_config=self.rollout_config, + dataflow_config=self.dataflow_config, + judger_config=multi_judger_config, + replay_buffer_config=self.replay_buffer_cfg, + train_worker_config=self.train_worker_cfg, + work_dir=self.worker_log_dir, + tokenizer_path=MODEL_PATH, + total_epochs=1, + rollout_steps=1, + ) + with self.assertRaises(AssertionError) as cm: + trainer = RLTrainer.from_config(trainer_config) + + print(f"Expected AssertionError caught: {cm.exception}") + +if __name__ == "__main__": + unittest.main() diff --git a/xtuner/v1/ray/dataflow/replay_buffer.py b/xtuner/v1/ray/dataflow/replay_buffer.py index 98a8e7524..cccf06005 100644 --- a/xtuner/v1/ray/dataflow/replay_buffer.py +++ b/xtuner/v1/ray/dataflow/replay_buffer.py @@ -8,7 +8,7 @@ import ray import torch from cyclopts import Parameter -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from ray import ObjectRef from typing_extensions import Annotated @@ -181,10 +181,12 @@ class ReplayBufferConfig(BaseModel): tokenizer: Annotated[ Union[PreTrainedTokenizer, PreTrainedTokenizerFast, str], + Field(exclude=True), Parameter(help="The tokenizer for processing text data, e.g., for partial rollouts."), ] postprocessor_func: Annotated[ Optional[Callable], + Field(exclude=True), Parameter(help="An optional function to filter or modify data groups after they are generated."), ] = None replay_ratio: Annotated[ diff --git a/xtuner/v1/ray/evaluator.py b/xtuner/v1/ray/evaluator.py index 798788565..f82d2900b 100644 --- a/xtuner/v1/ray/evaluator.py +++ b/xtuner/v1/ray/evaluator.py @@ -4,7 +4,7 @@ import ray from cyclopts import Parameter -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from ray.actor import ActorProxy from tqdm.auto import tqdm from typing_extensions import Annotated @@ -85,6 +85,7 @@ class EvaluatorConfig(BaseModel): tokenizer: Annotated[ Union[PreTrainedTokenizer, PreTrainedTokenizerFast, str], + Field(exclude=True), Parameter(help="Tokenizer for text processing."), ] max_concurrent: Annotated[ @@ -103,6 +104,7 @@ class EvaluatorConfig(BaseModel): evaluate_step: Annotated[int, Parameter(help="Step interval for evaluation.")] = 1 compute_metric_func: Annotated[ Optional[Callable], + Field(exclude=True), Parameter(help="An optional function to filter or modify data groups after they are generated."), ] = None sample_params: Annotated[ diff --git a/xtuner/v1/ray/judger/controller.py b/xtuner/v1/ray/judger/controller.py index 5369f0e79..01b4ada6e 100644 --- a/xtuner/v1/ray/judger/controller.py +++ b/xtuner/v1/ray/judger/controller.py @@ -5,12 +5,14 @@ import ray from cyclopts import Parameter -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, computed_field from ray.util.placement_group import PlacementGroup, placement_group from typing_extensions import Annotated from xtuner.v1.data_proto.rl_data import RLDataFlowItem, RLJudgerResponseItem +from .native import NativeJudgerConfig + PG_READY_TIMEOUT = 30 @@ -70,12 +72,33 @@ class JudgerConfig(BaseModel): bool, Parameter(help="Whether to enable weighted reward calculation on multi judgers.") ] = False reward_judger_configs: Annotated[ - List[BaseModel], + List[NativeJudgerConfig], Parameter(help="A custom Python function for computing reward given model output and label."), ] = [] judger_timeout: Annotated[float, Parameter(help="Timeout for each judger request in seconds.")] = 1200.0 worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir" + @computed_field + def total_bundles_needed(self) -> list[dict]: + judger_total_bundles = [ + {"CPU": cfg.num_cpus_per_actor, "memory": cfg.num_cpus_per_actor * 1024**3} + for cfg in self.reward_judger_configs + for _ in range(cfg.num_ray_actors) + ] + return judger_total_bundles + + @computed_field + def total_cpus_needed(self) -> int: + judger_total_cpus = sum(cfg.num_cpus_per_actor * cfg.num_ray_actors for cfg in self.reward_judger_configs) + return judger_total_cpus + + @computed_field + def total_memory_needed(self) -> int: + judger_total_memory = sum( + cfg.num_cpus_per_actor * 1024**3 * cfg.num_ray_actors for cfg in self.reward_judger_configs + ) + return judger_total_memory + @ray.remote class JudgerController: diff --git a/xtuner/v1/ray/judger/native.py b/xtuner/v1/ray/judger/native.py index 2d76caf12..a4940659a 100644 --- a/xtuner/v1/ray/judger/native.py +++ b/xtuner/v1/ray/judger/native.py @@ -3,7 +3,7 @@ import httpx import ray -from pydantic import BaseModel, ConfigDict +from pydantic import BaseModel, ConfigDict, Field from ray.util.placement_group import PlacementGroup from xtuner.v1.data_proto.rl_data import RLDataFlowItem, RLJudgerResponseItem @@ -11,12 +11,33 @@ class NativeJudgerConfig(BaseModel): - """Base configuration class for judgers.""" + """Configuration class for NativeJudger. + + This class defines the configuration options for initializing a NativeJudger, + including resource allocation (number of Ray actors and CPUs per actor), + reward function or remote judging service, optional pre/post-processing functions, + request timeout, and any extra information needed for judging. + + Attributes: + judger_name (str): Name identifier for the judger. + num_ray_actors (int): Number of Ray actor instances to launch. + num_cpus_per_actor (int): Number of CPUs allocated per actor. + reward_func (Optional[Callable]): Local reward function for judging. + Exactly one of reward_func or remote_url must be provided. + remote_url (Optional[str]): Remote service URL for judging. + Exactly one of reward_func or remote_url must be provided. + preprocess_func (Optional[Callable]): Function to preprocess input data before judging. + postprocess_func (Optional[Callable]): Function to postprocess the judging result. + request_timeout (float): Timeout (in seconds) for remote requests. + extra_info (dict): Additional information to be passed to the judger or reward function. + """ model_config = ConfigDict(arbitrary_types_allowed=True, extra="forbid") judger_name: str num_ray_actors: int = 1 - reward_func: Optional[Callable] = None + num_cpus_per_actor: int = 1 + cpu_memory_per_actor: int = 1024**3 + reward_func: Optional[Callable] = Field(default=None, exclude=True) remote_url: Optional[str] = None preprocess_func: Optional[Callable] = None postprocess_func: Optional[Callable] = None @@ -40,7 +61,13 @@ def build_actor(self, pg: PlacementGroup, start_bundle_idx: int) -> List[ray.act workers_list = [] for idx in range(self.num_ray_actors): bundle_idx = start_bundle_idx + idx - pg_options = {"num_cpus": pg.bundle_specs[bundle_idx].get("CPU", 1)} + pg_options = {"num_cpus": self.num_cpus_per_actor, "memory": self.cpu_memory_per_actor} + assert pg.bundle_specs[bundle_idx].get("CPU", 1) >= self.num_cpus_per_actor, ( + f"Placement group bundle {bundle_idx} does not have enough CPU resources." + ) + assert pg.bundle_specs[bundle_idx].get("memory", 0) >= self.cpu_memory_per_actor, ( + f"Placement group bundle {bundle_idx} does not have enough memory resources." + ) worker = ( ray.remote(NativeJudger) .options( diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index 3c507f8fa..ce41dc739 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -11,7 +11,7 @@ from mmengine import load from mmengine.dist import get_rank from mmengine.runner import set_random_seed -from pydantic import BaseModel, ConfigDict, field_serializer, model_validator +from pydantic import BaseModel, ConfigDict, model_validator from ray.util.placement_group import placement_group from typing_extensions import Self @@ -20,7 +20,7 @@ from xtuner.v1.data_proto.rl_data import is_valid_for_training from xtuner.v1.data_proto.sequence_context import SequenceContext from xtuner.v1.patch import patch_default_save_plan -from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers, AutoCPUWorkers, CPUResourcesConfig +from xtuner.v1.ray.base import AcceleratorResourcesConfig, AutoAcceleratorWorkers, CPUResourcesConfig from xtuner.v1.ray.config.worker import RolloutConfig from xtuner.v1.ray.dataflow import DataFlow, DataFlowConfig, DataFlowProxy, ReplayBufferConfig from xtuner.v1.ray.environment import SingleTurnEnvironment, SingleTurnEnvironmentProxy @@ -97,21 +97,6 @@ def _convert_work_dir(self): self.work_dir = Path.cwd() return self - @field_serializer("replay_buffer_config") - def serialize_replay_buffer_cfg(self, replay_buffer_config: ReplayBufferConfig) -> str: - return replay_buffer_config.model_dump(include={"replay_ratio", "replay_weights"}) - - @field_serializer("evaluator_config") - def serialize_evaluator_cfg(self, evaluator_config: EvaluatorConfig) -> str: - if evaluator_config: - return evaluator_config.model_dump(exclude={"tokenizer", "dataset_cfg", "compute_metric_func"}) - else: - return "" - - @field_serializer("judger_config") - def serialize_judger_config(self, judger_config: JudgerConfig) -> str: - return judger_config.model_dump(exclude={"tokenizer", "reward_func"}) - def get_train_seq_ctx( input_ids: torch.LongTensor, multimodal_train_info: dict | None = None, len_response_ids: int = 0 @@ -318,11 +303,23 @@ def __init__( self._enable_evaluate = evaluator_config.enable_evaluate self._enable_initial_evaluate = evaluator_config.enable_initial_evaluate self._pg = AutoAcceleratorWorkers.build_placement_group(resources) - if cpu_resources is None: - self._cpu_pg = placement_group(bundles=[{"CPU": 1, "memory": 1024**3}], strategy="PACK") - ray.get(self._cpu_pg.ready(), timeout=PG_READY_TIMEOUT) - else: - self._cpu_pg = AutoCPUWorkers.build_placement_group(cpu_resources) + + if cpu_resources is not None: + # NOTE: Here we only check CPU and memory for judger actors because only judger actors use CPU resources currently. + assert judger_config.total_cpus_needed <= cpu_resources.num_cpus_per_worker * cpu_resources.num_workers, ( + f"Not enough CPU resources for judger actors, " + f"required {judger_config.total_cpus_needed}, but got {cpu_resources.num_cpus_per_worker * cpu_resources.num_workers}." + ) + assert ( + judger_config.total_memory_needed <= cpu_resources.cpu_memory_per_worker * cpu_resources.num_workers + ), ( + f"Not enough memory resources for judger actors, " + f"required {judger_config.total_memory_needed}, but got {cpu_resources.cpu_memory_per_worker * cpu_resources.num_workers}." + ) + + self._judger_cpu_pg = placement_group(bundles=judger_config.total_bundles_needed, strategy="SPREAD") + ray.get(self._judger_cpu_pg.ready(), timeout=PG_READY_TIMEOUT) + # We need to build train controller first, and then build rollout dataflow to make # inference engines know how much memory they can utilize. self._train_controller = self._build_train_controller(train_worker_cfg) @@ -399,6 +396,12 @@ def __init__( self._writer = TensorboardWriter(log_dir / "tb") + def __del__(self): + if hasattr(self, "_writer") and self._writer is not None: + self._writer.close() + if hasattr(self, "_rollout_env_controller"): + ray.get(self._rollout_env_controller.shutdown.remote()) + def _resolve_load_checkpoint_cfg( self, auto_resume: bool, load_checkpoint_cfg: LoadCheckpointConfig ) -> LoadCheckpointConfig: @@ -444,6 +447,8 @@ def from_config(cls, config: RLTrainerConfig) -> Self: skip_checkpoint_validation=config.skip_checkpoint_validation, seed=config.seed, debug=config.debug, + debug_rollout=config.debug_rollout, + rollout_steps=config.rollout_steps, trainer_cfg=config, ) return self @@ -455,7 +460,7 @@ def _build_rollout_dataflow( judger_cfg: JudgerConfig, replay_buffer_config: ReplayBufferConfig, ) -> tuple[SingleTurnEnvironmentProxy, DataFlowProxy]: - env = SingleTurnEnvironment.remote("grpo", self._pg, rollout_cfg, self._cpu_pg, judger_cfg) + env = SingleTurnEnvironment.remote("grpo", self._pg, rollout_cfg, self._judger_cpu_pg, judger_cfg) flow = DataFlow.remote("grpo", dataflow_cfg, replay_buffer_config, env) return env, flow