diff --git a/xtuner/v1/config/fsdp.py b/xtuner/v1/config/fsdp.py index b7c58ac76..b49376570 100644 --- a/xtuner/v1/config/fsdp.py +++ b/xtuner/v1/config/fsdp.py @@ -22,6 +22,7 @@ class FSDPConfig(BaseModel): # TODO: (caoweihan) Convert `torch.dtype` to `Annotated` for compatibility with cyclopts param_dtype: Annotated[torch.dtype, Parameter(help="Data type for model parameters")] = torch.bfloat16 reduce_dtype: Annotated[torch.dtype, Parameter(help="Data type for reduction operations")] = torch.bfloat16 + lm_head_fp32: Annotated[bool, Parameter(help="Use float32 for language model head")] = False torch_compile: Annotated[bool, Parameter(help="Enable model compilation for faster inference")] = False mesh_prefix: Annotated[str, Parameter(help="Prefix for device mesh configuration in distributed training")] = ( "default" diff --git a/xtuner/v1/float8/float8_handler.py b/xtuner/v1/float8/float8_handler.py index f7b0dee01..749771ca4 100644 --- a/xtuner/v1/float8/float8_handler.py +++ b/xtuner/v1/float8/float8_handler.py @@ -36,8 +36,8 @@ def default_grouped_linear_filter_fn(mod: nn.Module, fqn: str): # handler 要跟 Engine 一一对应? class Float8Handler: - scaling_granularity_gemm: ScalingGranularity - scaling_granularity_grouped_gemm: ScalingGranularity + scaling_granularity_gemm: Optional[ScalingGranularity] + scaling_granularity_grouped_gemm: Optional[ScalingGranularity] fsdp_mesh: Optional[DeviceMesh] = None tilewise_reduce_mesh_devided_64: Optional[DeviceMesh] = None tilewise_reduce_mesh_mapping: Dict[Tuple[int, int], DeviceMesh] = {} @@ -61,12 +61,12 @@ def __init__( ) return - assert scaling_granularity_gemm in (ScalingGranularity.TILEWISE, ScalingGranularity.TENSORWISE), ( - "scaling_granularity_gemm must be TILEWISE or TENSORWISE." - ) - assert scaling_granularity_grouped_gemm in (ScalingGranularity.TILEWISE, ScalingGranularity.TENSORWISE), ( - "scaling_granularity_grouped_gemm must be TILEWISE or TENSORWISE." - ) + assert scaling_granularity_gemm in (ScalingGranularity.TILEWISE, ScalingGranularity.TENSORWISE) or ( + scaling_granularity_gemm is None + ), "scaling_granularity_gemm must be TILEWISE or TENSORWISE." + assert scaling_granularity_grouped_gemm in (ScalingGranularity.TILEWISE, ScalingGranularity.TENSORWISE) or ( + scaling_granularity_grouped_gemm is None + ), "scaling_granularity_grouped_gemm must be TILEWISE or TENSORWISE." self.scaling_granularity_gemm = scaling_granularity_gemm self.scaling_granularity_grouped_gemm = scaling_granularity_grouped_gemm diff --git a/xtuner/v1/float8/fsdp_utils.py b/xtuner/v1/float8/fsdp_utils.py index 644be106e..bedfff4fe 100644 --- a/xtuner/v1/float8/fsdp_utils.py +++ b/xtuner/v1/float8/fsdp_utils.py @@ -159,7 +159,7 @@ def precompute_tilewise_float8_scale_for_fsdp( assert reduce_mesh.ndim == 1, ( f"Currently only reduce_mesh.ndim should equal to 1, got reduce_mesh.ndim = {reduce_mesh.ndim} for local_shape {local_shape}" ) - weights_same_shape_stack = torch.stack(weights_same_shape, dim=0) # type: ignore + weights_same_shape_stack = torch.stack(weights_same_shape, dim=0).bfloat16().float() # type: ignore if dim >= 128 and dim % 128 == 64: assert reduce_mesh_devided_64 is not None, ( f"reduce_mesh_devided_64 should not be None for local_shape {local_shape}." @@ -382,7 +382,7 @@ def fsdp_pre_all_gather(self, mesh): assert self._precomputed_scale is not None if self._tensor.shape[0] >= 128 and self._tensor.shape[0] % 128 == 64: w_fp8_data = cast_to_per_block_fp8_devided_64_with_scales( - tensor=self._tensor, + tensor=self._tensor.bfloat16().float(), scales=self._precomputed_scale, fsdp_mesh=mesh, block_size=128, @@ -390,7 +390,7 @@ def fsdp_pre_all_gather(self, mesh): ) else: w_fp8_data = cast_to_per_block_fp8_with_scales( - tensor=self._tensor, + tensor=self._tensor.bfloat16().float(), scales=self._precomputed_scale, block_size=128, float8_dtype=self._dtype, diff --git a/xtuner/v1/model/base.py b/xtuner/v1/model/base.py index e6dd7b230..adc7c0b07 100644 --- a/xtuner/v1/model/base.py +++ b/xtuner/v1/model/base.py @@ -677,7 +677,7 @@ def _get_hf_params( for param, load_spec in params: local_tensor = param._local_tensor if isinstance(param, DTensor) else param - local_tensor = local_tensor.bfloat16() + local_tensor = local_tensor.to(dtype) tensor_size = self._get_tensor_size(param, dtype) if safetensor_size + tensor_size > bucket_size and tensor_list: hf_params, name_list = _get_hf_params(tensor_list, name_list) @@ -719,7 +719,7 @@ def _get_same_hf_param( buffer_name_list.append(load_spec.hf_keys[0]) continue local_tensor = param._local_tensor if isinstance(param, DTensor) else param - local_tensor = local_tensor.bfloat16() + local_tensor = local_tensor.to(dtype) tensor_size = self._get_tensor_size(param, dtype) if safetensor_size + tensor_size > bucket_size and tensor_list: if self.fsdp_mesh is not None: diff --git a/xtuner/v1/model/dense/dense.py b/xtuner/v1/model/dense/dense.py index 02298350b..bb092cb55 100644 --- a/xtuner/v1/model/dense/dense.py +++ b/xtuner/v1/model/dense/dense.py @@ -204,6 +204,10 @@ def fully_shard( mp_policy = MixedPrecisionPolicy( param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype ) + if self.fsdp_config.lm_head_fp32: + lm_head_mp_policy = MixedPrecisionPolicy(param_dtype=torch.float32, reduce_dtype=torch.float32) + else: + lm_head_mp_policy = mp_policy num_recompute_layers = int(self.config.num_hidden_layers * self.fsdp_config.recompute_ratio) generator = torch.Generator() @@ -256,7 +260,7 @@ def fully_shard( fully_shard( self.lm_head, mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, - mp_policy=mp_policy, + mp_policy=lm_head_mp_policy, reshard_after_forward=self.fsdp_config.reshard_after_forward, offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, ) diff --git a/xtuner/v1/model/moe/moe.py b/xtuner/v1/model/moe/moe.py index 2f9abe8c4..35fe0b324 100644 --- a/xtuner/v1/model/moe/moe.py +++ b/xtuner/v1/model/moe/moe.py @@ -721,6 +721,10 @@ def fully_shard( mp_policy = MixedPrecisionPolicy( param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype ) + if self.fsdp_config.lm_head_fp32: + lm_head_mp_policy = MixedPrecisionPolicy(param_dtype=torch.float32, reduce_dtype=torch.float32) + else: + lm_head_mp_policy = mp_policy num_recompute_layers = int(self.config.num_hidden_layers * self.fsdp_config.recompute_ratio) for layer_idx, layer in tqdm(self.layers.items(), desc="[FSDP Sharding]"): @@ -766,7 +770,7 @@ def fully_shard( fully_shard( self.lm_head, mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh, - mp_policy=mp_policy, + mp_policy=lm_head_mp_policy, reshard_after_forward=self.fsdp_config.reshard_after_forward, offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None, ) diff --git a/xtuner/v1/module/linear/linear.py b/xtuner/v1/module/linear/linear.py index 6c0e85a70..89fef6c18 100644 --- a/xtuner/v1/module/linear/linear.py +++ b/xtuner/v1/module/linear/linear.py @@ -33,7 +33,7 @@ def build_linear( float8_cfg=None, ) -> nn.Module: """Build a linear layer with optional float8 support.""" - if float8_cfg is None: + if float8_cfg is None or float8_cfg.scaling_granularity_gemm is None: return _Linear(in_features, out_features, bias=bias, device=device, dtype=dtype) elif float8_cfg.scaling_granularity_gemm is ScalingGranularity.TILEWISE: return TileWiseFloat8Linear(in_features, out_features, bias=bias, device=device, dtype=dtype) diff --git a/xtuner/v1/ray/config/worker.py b/xtuner/v1/ray/config/worker.py index f394b019c..a51dcc9c2 100644 --- a/xtuner/v1/ray/config/worker.py +++ b/xtuner/v1/ray/config/worker.py @@ -228,6 +228,20 @@ class RolloutConfig(BaseModel): help="Maximum number of retries per sample before marking it as failed.", ), ] = 1 + max_prefill_token_num: Annotated[ + Optional[int], + Parameter( + group=infer_group, + help="Maximum number of prefill token.", + ), + ] = None + router_n_groups: Annotated[ + Optional[int], + Parameter( + group=infer_group, + help="router_n_groups.", + ), + ] = None worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir" _logged_server_urls_per_engine: bool = PrivateAttr(default=False) diff --git a/xtuner/v1/ray/dataflow/flow.py b/xtuner/v1/ray/dataflow/flow.py index aa14c5122..ee341fcf2 100644 --- a/xtuner/v1/ray/dataflow/flow.py +++ b/xtuner/v1/ray/dataflow/flow.py @@ -109,7 +109,7 @@ def __init__( self.env = env self.config = dataflow_cfg replay_buffer_cfg.worker_log_dir = self.config.worker_log_dir - self.replay_buffer = ReplayBuffer.remote(replay_buffer_cfg) # type: ignore[attr-defined] + self.replay_buffer = ReplayBuffer(replay_buffer_cfg) # type: ignore[attr-defined] self.env_controller = environment self.finished_samples_count = 0 self.skipped_sample_count = 0 @@ -168,7 +168,7 @@ def _reset_internal_states( @ray_method def get_train_dataset_length(self): """Gets the length of the training dataset from the replay buffer.""" - return ray.get(self.replay_buffer.get_train_dataset_length.remote()) + return self.replay_buffer.get_train_dataset_length() @ray_method async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowItem]] = None): @@ -192,7 +192,7 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte # step 1: sample # TODO(@duanyanhui): More fine-grained control over group data generation: # Pass n to the inference engine to ensure that the same data is processed by the same server, improving efficiency. - group_data_items = await self.replay_buffer.sample.remote( # type: ignore[attr-defined] + group_data_items = self.replay_buffer.sample( # type: ignore[attr-defined] self.env, self.enable_partial_rollout, self.config.prompt_repeat_k ) assert len(group_data_items) > 0, "Sampled empty group data items from replay buffer." @@ -206,12 +206,12 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte group_state = determine_group_state(group_data_items) self.logger.debug(f"Determined replay state for {action_id}: {group_state}") if group_state == RolloutState.COMPLETED: - group_data_items = await self.replay_buffer.post_processor.remote(group_data_items) # type: ignore[attr-defined] + group_data_items = self.replay_buffer.post_processor(group_data_items) # type: ignore[attr-defined] if len(group_data_items) > 0: - await self.replay_buffer.add.remote(group_data_items) # type: ignore[attr-defined] + self.replay_buffer.add(group_data_items) # type: ignore[attr-defined] self.logger.debug(f"Worker task completed successfully for {action_id}.") elif group_state == RolloutState.ABORTED: - await self.replay_buffer.add.remote(group_data_items) # type: ignore[attr-defined] + self.replay_buffer.add(group_data_items) # type: ignore[attr-defined] self.logger.debug(f"Adding aborted sample {action_id} to aborted storage") elif group_state == RolloutState.SKIPPED: self.skipped_sample_count += 1 @@ -267,7 +267,7 @@ async def concurrent_task_runner(self): waiting_tasks.add(task) _, pending_tasks = await asyncio.wait(waiting_tasks, timeout=0.1, return_when=asyncio.FIRST_COMPLETED) - self.finished_samples_count = ray.get(self.replay_buffer.get_finished_samples.remote()) + self.finished_samples_count = self.replay_buffer.get_finished_samples() waiting_tasks = pending_tasks pbar.n = self.finished_samples_count @@ -349,25 +349,25 @@ async def run( if resume: assert resume_path, "Resuming is enabled but no resume path is provided." self.logger.info(f"Resuming replay buffer from {resume_path}") - await self.replay_buffer.resume_storage.remote(resume_path) + self.replay_buffer.resume_storage(resume_path) await self.concurrent_task_runner() if dump: assert dump_path, "Dumping is enabled but no dump path is provided." self.logger.info(f"Dump replay buffer from {dump_path}") - await self.replay_buffer.dump_storage.remote(dump_path) + self.replay_buffer.dump_storage(dump_path) - return await self.replay_buffer.get_samples.remote(self.target_batch_size) # type: ignore[attr-defined] + return self.replay_buffer.get_samples(self.target_batch_size) # type: ignore[attr-defined] def logging_replaybuffer_state(self): - ray.get(self.replay_buffer.print.remote()) + self.replay_buffer.print() def get_replaybuffer_status(self): - return ray.get(self.replay_buffer.status.remote()) + return self.replay_buffer.status() def clear_replaybuffer(self): - return ray.get(self.replay_buffer.clear.remote()) + return self.replay_buffer.clear() async def _send_abort_request(self, client, url, timeout): worker_url = f"{url}/abort_request" @@ -386,7 +386,7 @@ def save(self, save_path: Path | str): Args: save_path (str): The path to the checkpoint file to save to. """ - ray.get(self.replay_buffer.save.remote(save_path)) + self.replay_buffer.save(save_path) def resume(self, resume_path: Path | str): """Resumes the replay buffer from the specified path. @@ -394,7 +394,7 @@ def resume(self, resume_path: Path | str): Args: resume_path (str): The path to the checkpoint file to resume from. """ - ray.get(self.replay_buffer.resume.remote(resume_path)) + self.replay_buffer.resume(resume_path) DataFlow = ray.remote(RawDataFlow) diff --git a/xtuner/v1/ray/dataflow/replay_buffer.py b/xtuner/v1/ray/dataflow/replay_buffer.py index 98a8e7524..dd3c84d9b 100644 --- a/xtuner/v1/ray/dataflow/replay_buffer.py +++ b/xtuner/v1/ray/dataflow/replay_buffer.py @@ -509,7 +509,6 @@ def resume(self, file_path: Path | str): self.print() -@ray.remote class ReplayBuffer: """A Ray actor that manages experience replay for reinforcement learning.""" diff --git a/xtuner/v1/ray/rollout/lmdeploy.py b/xtuner/v1/ray/rollout/lmdeploy.py index 1db75cc27..829b710b2 100644 --- a/xtuner/v1/ray/rollout/lmdeploy.py +++ b/xtuner/v1/ray/rollout/lmdeploy.py @@ -234,9 +234,13 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace: lmdeploy_config_kwargs["uvicorn_log_level"] = lmdeploy_config_kwargs.pop("uvicorn_log_level", "ERROR") lmdeploy_config_kwargs["tm_log_level"] = lmdeploy_config_kwargs.pop("tm_log_level", "ERROR") - extra_engine_config = {} + extra_engine_config: dict[str, Any] = {} if backend == "pytorch" and self.config.enable_return_routed_experts: extra_engine_config["enable_return_routed_experts"] = True + if backend == "pytorch" and self.config.router_n_groups: + extra_engine_config["hf_overrides"] = dict(router_n_groups=self.config.router_n_groups) + if backend == "pytorch" and self.config.max_prefill_token_num: + extra_engine_config["max_prefill_token_num"] = self.config.max_prefill_token_num dp_rank = 0 if backend == "pytorch": diff --git a/xtuner/v1/rl/base/controller.py b/xtuner/v1/rl/base/controller.py index e5c2a7088..e3ec1a205 100644 --- a/xtuner/v1/rl/base/controller.py +++ b/xtuner/v1/rl/base/controller.py @@ -79,7 +79,6 @@ def _packing(self, data_batches, pack_max_length, language_cfg): if data_batches[0]["seq_ctx"].rollout_routed_experts is not None: assert language_cfg is not None has_rollout_routed_experts = True - n_routed_experts = language_cfg.n_routed_experts for pack_info in pack_infos: indices = pack_info["indices"] @@ -119,7 +118,12 @@ def _packing(self, data_batches, pack_max_length, language_cfg): pad_seq_ctx.position_ids = torch.cat(_position_ids_list, dim=-1) if has_rollout_routed_experts: - pad_rand_index = torch.randint(low=0, high=n_routed_experts, size=(pad_len, 1, 1)) + pad_rand_index = torch.randint( + low=0, + high=1, + size=(pad_len, 1, 1), # add dummy data, true data will be initialized in train worker.fit + ) + pad_seq_ctx.rollout_routed_experts = pad_rand_index pad_seq_ctx.rollout_routed_experts = pad_rand_index seq_ctx_list.append(pad_seq_ctx) diff --git a/xtuner/v1/rl/base/rollout_is.py b/xtuner/v1/rl/base/rollout_is.py index 42a34b6ae..bfb3190ab 100644 --- a/xtuner/v1/rl/base/rollout_is.py +++ b/xtuner/v1/rl/base/rollout_is.py @@ -53,24 +53,14 @@ class RolloutImportanceSampling(BaseModel): rollout_is_mask_threshold: Optional[Tuple[float, float]] = None rollout_is_veto_threshold: Optional[Tuple[float, float]] = None - def compute_rollout_importance_weights_and_metrics( + def compute_rollout_importance_weights( self, old_log_prob: torch.Tensor, rollout_log_prob: torch.Tensor, num_tokens: torch.Tensor, response_mask: torch.Tensor, - ) -> tuple[Optional[torch.Tensor], torch.Tensor, dict[str, Any], dict[str, Any]]: - mismatch_metrics = compute_mismatch_metrics( - old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=response_mask - ) - mismatch_metrics_scalar = {} - for key, value in mismatch_metrics.items(): - if isinstance(value, torch.Tensor): - mismatch_metrics_scalar[f"mismatch/{key}"] = value.item() - else: - mismatch_metrics_scalar[f"mismatch/{key}"] = value - - rollout_is_weights, modified_response_mask, metrics_scalar = compute_rollout_importance_weights( + ) -> tuple[Optional[torch.Tensor], torch.Tensor, dict[str, Any]]: + return compute_rollout_importance_weights( old_log_prob, rollout_log_prob, num_tokens, @@ -81,7 +71,6 @@ def compute_rollout_importance_weights_and_metrics( rollout_is_mask_threshold=self.rollout_is_mask_threshold, rollout_is_veto_threshold=self.rollout_is_veto_threshold, ) - return rollout_is_weights, modified_response_mask, mismatch_metrics_scalar, metrics_scalar def compute_rollout_importance_weights( @@ -150,7 +139,7 @@ def compute_rollout_importance_weights( metrics: Dict of IS and mismatch metrics, all scalars with "mismatch/" prefix """ if rollout_is_threshold is None: - return None, response_mask, {} + return None, response_mask, compute_mismatch_metrics(old_log_prob, rollout_log_prob, response_mask) assert rollout_is_mode in ["truncate", "mask", "both"], ( f"Invalid rollout_is_mode: {rollout_is_mode}. Must be 'truncate', 'mask', or 'both'." @@ -302,6 +291,12 @@ def compute_rollout_importance_weights( # This is different from rejection - padding must be zeroed regardless of mode rollout_is_weights = rollout_is_weights * response_mask + # Compute mismatch metrics (KL, PPL, etc.) and merge with IS metrics + mismatch_metrics = compute_mismatch_metrics( + old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=response_mask + ) + metrics.update(mismatch_metrics) + # Convert all tensor metrics to scalars for logging # Note: No need to detach since old_log_prob and rollout_log_prob are computed with torch.no_grad() metrics_scalar = {} diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index 5e2f07941..e174b5ffc 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -350,27 +350,29 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): if rollout_routed_experts is not None: if isinstance(rollout_routed_experts, list): # list[n,l,e] - out_rollout_routed_expert = [] + rollout_routed_experts_list = [] for rollout_routed_expert in rollout_routed_experts: - if isinstance(rollout_routed_expert, torch.Tensor): - rollout_routed_experts_tensor = torch.randint( - low=0, - high=language_cfg.n_routed_experts, - size=( - rollout_routed_expert.size(0), - language_cfg.num_hidden_layers, - language_cfg.num_experts_per_tok, - ), - ) - out_rollout_routed_expert.append(rollout_routed_experts_tensor) - else: - rollout_routed_expert_refs = rollout_routed_expert - rollout_routed_expert = ray.get(rollout_routed_expert_refs) + if not isinstance(rollout_routed_expert, torch.Tensor): + rollout_routed_expert_ref = rollout_routed_expert + rollout_routed_expert = ray.get(rollout_routed_expert_ref) # free obj store explicitly - ray._private.internal_api.free(rollout_routed_expert_refs) - out_rollout_routed_expert.append(torch.as_tensor(rollout_routed_expert, dtype=torch.long)) - - seq_ctx.rollout_routed_experts = torch.cat(out_rollout_routed_expert, dim=0) # max_len,l,e + ray._private.internal_api.free(rollout_routed_expert_ref) + if not isinstance(rollout_routed_expert[0], torch.Tensor): + rollout_routed_expert = torch.as_tensor(rollout_routed_expert, dtype=torch.long) + else: + if rollout_routed_expert.size(dim=-1) == 1: + pad_len = rollout_routed_expert.size(dim=0) + rollout_routed_expert = torch.randint( + low=0, + high=language_cfg.n_routed_experts, + size=( + pad_len, + language_cfg.num_hidden_layers, + language_cfg.num_experts_per_tok, + ), + ) + rollout_routed_experts_list.append(rollout_routed_expert) + seq_ctx.rollout_routed_experts = torch.cat(rollout_routed_experts_list, dim=0) # max_len,l,e else: assert isinstance(rollout_routed_experts, torch.Tensor), ( f"padding experts should be a dummy tensor, bug got {type(rollout_routed_experts)}" @@ -417,6 +419,14 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): dist.all_reduce(global_grad_tokens, op=dist.ReduceOp.SUM) # old logprobs are inplaced updated in compute_actor_logprobs + if isinstance(self.config.model_cfg, BaseComposeConfig): + if self._engine.llm_float8_handler is not None and self._engine.llm_float8_handler.enabled: # type: ignore [attr-defined] + self._engine.llm_float8_handler.precompute_float8_dynamic_scale_for_fsdp( # type: ignore [attr-defined] + self._engine.model.language_model + ) + else: + if self._engine.float8_handler is not None and self._engine.float8_handler.enabled: + self._engine.float8_handler.precompute_float8_dynamic_scale_for_fsdp(self._engine.model) loss_ctx_input_list = self.compute_actor_logprobs(seq_ctx_list, loss_ctx_input_list) sum_entropy: torch.Tensor | None = None sum_rollout_entropy: torch.Tensor | None = None @@ -426,7 +436,6 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): ) all_rollout_is_metrics = [] - all_mismatch_metrics = [] for i, loss_ctx_input in enumerate(loss_ctx_input_list): mask = loss_ctx_input.shifted_labels != -100 entropy = -(cast(torch.Tensor, loss_ctx_input.old_logprobs) * mask).sum() @@ -446,8 +455,8 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): cu_seq_lens = seq_ctx_list[i].cu_seq_lens_q num_tokens = cu_seq_lens[1:] - cu_seq_lens[:-1] - rollout_is_weights, rollout_is_mask, mismatch_metrics, rollout_is_metrics = ( - loss_cfg.rollout_is.compute_rollout_importance_weights_and_metrics( + rollout_is_weights, rollout_is_mask, rollout_is_metrics = ( + loss_cfg.rollout_is.compute_rollout_importance_weights( old_log_prob=loss_ctx_input.old_logprobs, rollout_log_prob=rollout_logprobs_list[i], num_tokens=num_tokens, @@ -457,7 +466,6 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): loss_ctx_input.shifted_labels[~rollout_is_mask.bool()] = -100 # update loss mask loss_ctx_input.is_weights = rollout_is_weights all_rollout_is_metrics.append(rollout_is_metrics) - all_mismatch_metrics.append(mismatch_metrics) logger_msg = f"Rollout {rollout_idx}: " @@ -472,11 +480,6 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int): avg_rollout_entropy = sum_rollout_entropy / global_grad_tokens if global_grad_tokens > 0 else 0 logger_msg += f", avg rollout entropy: {avg_rollout_entropy:.4f}" - if len(all_mismatch_metrics) > 0: - mismatch_metrics = merge_rollout_is_metrics(all_mismatch_metrics, DEVICE) - if len(mismatch_metrics) > 0: - logger_msg += f"\n rollout mismatch metrics:\n{json.dumps(mismatch_metrics, indent=4)}" - if len(all_rollout_is_metrics) > 0: rollout_is_metrics = merge_rollout_is_metrics(all_rollout_is_metrics, DEVICE) if len(rollout_is_metrics) > 0: @@ -623,24 +626,21 @@ def update_weights(self): if self.rollout_cfg_info.get("backend") == "turbomind": self._update_weights_by_layer() else: - self._update_weights_hf_generator() + if isinstance(self.config.model_cfg, BaseComposeConfig): + self._update_weights_hf_generator(submodule="vision_tower", final_update=False) + self._update_weights_hf_generator(submodule="multi_modal_projector", final_update=False) + self._update_weights_hf_generator(submodule="language_model", final_update=True) + else: + self._update_weights_hf_generator() - def _update_weights_hf_generator(self): + def _update_weights_hf_generator(self, submodule=None, final_update=False): """Update the model weights.""" self.endpoints["update_weights"] = "update_weights" assert self.rollout_device_mesh is not None model = self._engine.model DEVICE_MODULE.empty_cache() - - if isinstance(model.config, BaseComposeConfig): - dtype = torch.bfloat16 - else: - if (model.config.float8_cfg is not None) and (model.config.float8_cfg.enable_float8): - dtype = torch.float8_e4m3fn - else: - dtype = torch.bfloat16 - + dtype = torch.bfloat16 bucket_size = int(self.config.update_weight_bucket_size_in_gb * 1024**3) same_gen = model._get_same_hf_param( model._group_param_by_load_spec(LoadEnum.SAME), dtype=dtype, device=DEVICE, bucket_size=bucket_size @@ -663,7 +663,10 @@ def _update_weights_hf_generator(self): # We can all gather them to get full fused param but it would lead to a larger memory usage. # So we broadcast the part fused param from each ep rank in ep_group sequentially, # and update the part of the fused param sequentially to reduce memory usage. - ep_mesh: DeviceMesh = model.ep_mesh + if isinstance(model.config, BaseComposeConfig): + ep_mesh: DeviceMesh = model.language_model.ep_mesh + else: + ep_mesh: DeviceMesh = model.ep_mesh ep_group = ep_mesh.get_group() global_rank = dist.get_rank() for src_global_rank in dist.get_process_group_ranks(ep_group): @@ -690,7 +693,7 @@ def _update_weights_hf_generator(self): self.request_update_params(state_dict, finished=False) del state_dict, name_list, param_list - if self.rollout_cfg_info["backend"] == "pytorch": + if self.rollout_cfg_info["backend"] == "pytorch" and final_update: self.request_update_params({}, finished=True) dist.barrier() diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index ef2d37ecb..eb2a608ac 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -626,7 +626,7 @@ def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_inf } if "routed_experts" in group[i].env.rollout.extra_info: - routed_experts = group[i].env.rollout.extra_info["routed_experts"] # n,layer*expert + routed_experts = group[i].env.rollout.extra_info.pop("routed_experts") # n,layer*expert seq_ctx.rollout_routed_experts = routed_experts # n,layer,expert data_batches.append(data_dict)