Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xtuner/v1/config/fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class FSDPConfig(BaseModel):
# TODO: (caoweihan) Convert `torch.dtype` to `Annotated` for compatibility with cyclopts
param_dtype: Annotated[torch.dtype, Parameter(help="Data type for model parameters")] = torch.bfloat16
reduce_dtype: Annotated[torch.dtype, Parameter(help="Data type for reduction operations")] = torch.bfloat16
lm_head_fp32: Annotated[bool, Parameter(help="Use float32 for language model head")] = False
torch_compile: Annotated[bool, Parameter(help="Enable model compilation for faster inference")] = False
mesh_prefix: Annotated[str, Parameter(help="Prefix for device mesh configuration in distributed training")] = (
"default"
Expand Down
16 changes: 8 additions & 8 deletions xtuner/v1/float8/float8_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,8 @@ def default_grouped_linear_filter_fn(mod: nn.Module, fqn: str):

# handler 要跟 Engine 一一对应?
class Float8Handler:
scaling_granularity_gemm: ScalingGranularity
scaling_granularity_grouped_gemm: ScalingGranularity
scaling_granularity_gemm: Optional[ScalingGranularity]
scaling_granularity_grouped_gemm: Optional[ScalingGranularity]
fsdp_mesh: Optional[DeviceMesh] = None
tilewise_reduce_mesh_devided_64: Optional[DeviceMesh] = None
tilewise_reduce_mesh_mapping: Dict[Tuple[int, int], DeviceMesh] = {}
Expand All @@ -61,12 +61,12 @@ def __init__(
)
return

assert scaling_granularity_gemm in (ScalingGranularity.TILEWISE, ScalingGranularity.TENSORWISE), (
"scaling_granularity_gemm must be TILEWISE or TENSORWISE."
)
assert scaling_granularity_grouped_gemm in (ScalingGranularity.TILEWISE, ScalingGranularity.TENSORWISE), (
"scaling_granularity_grouped_gemm must be TILEWISE or TENSORWISE."
)
assert scaling_granularity_gemm in (ScalingGranularity.TILEWISE, ScalingGranularity.TENSORWISE) or (
scaling_granularity_gemm is None
), "scaling_granularity_gemm must be TILEWISE or TENSORWISE."
assert scaling_granularity_grouped_gemm in (ScalingGranularity.TILEWISE, ScalingGranularity.TENSORWISE) or (
scaling_granularity_grouped_gemm is None
), "scaling_granularity_grouped_gemm must be TILEWISE or TENSORWISE."

self.scaling_granularity_gemm = scaling_granularity_gemm
self.scaling_granularity_grouped_gemm = scaling_granularity_grouped_gemm
Expand Down
6 changes: 3 additions & 3 deletions xtuner/v1/float8/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def precompute_tilewise_float8_scale_for_fsdp(
assert reduce_mesh.ndim == 1, (
f"Currently only reduce_mesh.ndim should equal to 1, got reduce_mesh.ndim = {reduce_mesh.ndim} for local_shape {local_shape}"
)
weights_same_shape_stack = torch.stack(weights_same_shape, dim=0) # type: ignore
weights_same_shape_stack = torch.stack(weights_same_shape, dim=0).bfloat16().float() # type: ignore
if dim >= 128 and dim % 128 == 64:
assert reduce_mesh_devided_64 is not None, (
f"reduce_mesh_devided_64 should not be None for local_shape {local_shape}."
Expand Down Expand Up @@ -382,15 +382,15 @@ def fsdp_pre_all_gather(self, mesh):
assert self._precomputed_scale is not None
if self._tensor.shape[0] >= 128 and self._tensor.shape[0] % 128 == 64:
w_fp8_data = cast_to_per_block_fp8_devided_64_with_scales(
tensor=self._tensor,
tensor=self._tensor.bfloat16().float(),
scales=self._precomputed_scale,
fsdp_mesh=mesh,
block_size=128,
float8_dtype=self._dtype,
)
else:
w_fp8_data = cast_to_per_block_fp8_with_scales(
tensor=self._tensor,
tensor=self._tensor.bfloat16().float(),
scales=self._precomputed_scale,
block_size=128,
float8_dtype=self._dtype,
Expand Down
4 changes: 2 additions & 2 deletions xtuner/v1/model/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,7 +677,7 @@ def _get_hf_params(

for param, load_spec in params:
local_tensor = param._local_tensor if isinstance(param, DTensor) else param
local_tensor = local_tensor.bfloat16()
local_tensor = local_tensor.to(dtype)
tensor_size = self._get_tensor_size(param, dtype)
if safetensor_size + tensor_size > bucket_size and tensor_list:
hf_params, name_list = _get_hf_params(tensor_list, name_list)
Expand Down Expand Up @@ -719,7 +719,7 @@ def _get_same_hf_param(
buffer_name_list.append(load_spec.hf_keys[0])
continue
local_tensor = param._local_tensor if isinstance(param, DTensor) else param
local_tensor = local_tensor.bfloat16()
local_tensor = local_tensor.to(dtype)
tensor_size = self._get_tensor_size(param, dtype)
if safetensor_size + tensor_size > bucket_size and tensor_list:
if self.fsdp_mesh is not None:
Expand Down
6 changes: 5 additions & 1 deletion xtuner/v1/model/dense/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,10 @@ def fully_shard(
mp_policy = MixedPrecisionPolicy(
param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype
)
if self.fsdp_config.lm_head_fp32:
lm_head_mp_policy = MixedPrecisionPolicy(param_dtype=torch.float32, reduce_dtype=torch.float32)
else:
lm_head_mp_policy = mp_policy
num_recompute_layers = int(self.config.num_hidden_layers * self.fsdp_config.recompute_ratio)

generator = torch.Generator()
Expand Down Expand Up @@ -256,7 +260,7 @@ def fully_shard(
fully_shard(
self.lm_head,
mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh,
mp_policy=mp_policy,
mp_policy=lm_head_mp_policy,
reshard_after_forward=self.fsdp_config.reshard_after_forward,
offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None,
)
Expand Down
6 changes: 5 additions & 1 deletion xtuner/v1/model/moe/moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -721,6 +721,10 @@ def fully_shard(
mp_policy = MixedPrecisionPolicy(
param_dtype=self.fsdp_config.param_dtype, reduce_dtype=fsdp_config.reduce_dtype
)
if self.fsdp_config.lm_head_fp32:
lm_head_mp_policy = MixedPrecisionPolicy(param_dtype=torch.float32, reduce_dtype=torch.float32)
else:
lm_head_mp_policy = mp_policy
num_recompute_layers = int(self.config.num_hidden_layers * self.fsdp_config.recompute_ratio)

for layer_idx, layer in tqdm(self.layers.items(), desc="[FSDP Sharding]"):
Expand Down Expand Up @@ -766,7 +770,7 @@ def fully_shard(
fully_shard(
self.lm_head,
mesh=self.fsdp_mesh if self.hsdp_mesh is None else self.hsdp_mesh,
mp_policy=mp_policy,
mp_policy=lm_head_mp_policy,
reshard_after_forward=self.fsdp_config.reshard_after_forward,
offload_policy=CPUOffloadPolicy() if self.fsdp_config.cpu_offload else None,
)
Expand Down
2 changes: 1 addition & 1 deletion xtuner/v1/module/linear/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def build_linear(
float8_cfg=None,
) -> nn.Module:
"""Build a linear layer with optional float8 support."""
if float8_cfg is None:
if float8_cfg is None or float8_cfg.scaling_granularity_gemm is None:
return _Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
elif float8_cfg.scaling_granularity_gemm is ScalingGranularity.TILEWISE:
return TileWiseFloat8Linear(in_features, out_features, bias=bias, device=device, dtype=dtype)
Expand Down
14 changes: 14 additions & 0 deletions xtuner/v1/ray/config/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,20 @@ class RolloutConfig(BaseModel):
help="Maximum number of retries per sample before marking it as failed.",
),
] = 1
max_prefill_token_num: Annotated[
Optional[int],
Parameter(
group=infer_group,
help="Maximum number of prefill token.",
),
] = None
router_n_groups: Annotated[
Optional[int],
Parameter(
group=infer_group,
help="router_n_groups.",
),
] = None
worker_log_dir: Annotated[Path, Parameter(help="Directory to save worker logs.")] = Path.cwd() / "work_dir"
_logged_server_urls_per_engine: bool = PrivateAttr(default=False)

Expand Down
30 changes: 15 additions & 15 deletions xtuner/v1/ray/dataflow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def __init__(
self.env = env
self.config = dataflow_cfg
replay_buffer_cfg.worker_log_dir = self.config.worker_log_dir
self.replay_buffer = ReplayBuffer.remote(replay_buffer_cfg) # type: ignore[attr-defined]
self.replay_buffer = ReplayBuffer(replay_buffer_cfg) # type: ignore[attr-defined]
self.env_controller = environment
self.finished_samples_count = 0
self.skipped_sample_count = 0
Expand Down Expand Up @@ -168,7 +168,7 @@ def _reset_internal_states(
@ray_method
def get_train_dataset_length(self):
"""Gets the length of the training dataset from the replay buffer."""
return ray.get(self.replay_buffer.get_train_dataset_length.remote())
return self.replay_buffer.get_train_dataset_length()

@ray_method
async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowItem]] = None):
Expand All @@ -192,7 +192,7 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
# step 1: sample
# TODO(@duanyanhui): More fine-grained control over group data generation:
# Pass n to the inference engine to ensure that the same data is processed by the same server, improving efficiency.
group_data_items = await self.replay_buffer.sample.remote( # type: ignore[attr-defined]
group_data_items = self.replay_buffer.sample( # type: ignore[attr-defined]
self.env, self.enable_partial_rollout, self.config.prompt_repeat_k
)
assert len(group_data_items) > 0, "Sampled empty group data items from replay buffer."
Expand All @@ -206,12 +206,12 @@ async def worker_task(self, group_samples_for_retry: Optional[List[RLDataFlowIte
group_state = determine_group_state(group_data_items)
self.logger.debug(f"Determined replay state for {action_id}: {group_state}")
if group_state == RolloutState.COMPLETED:
group_data_items = await self.replay_buffer.post_processor.remote(group_data_items) # type: ignore[attr-defined]
group_data_items = self.replay_buffer.post_processor(group_data_items) # type: ignore[attr-defined]
if len(group_data_items) > 0:
await self.replay_buffer.add.remote(group_data_items) # type: ignore[attr-defined]
self.replay_buffer.add(group_data_items) # type: ignore[attr-defined]
self.logger.debug(f"Worker task completed successfully for {action_id}.")
elif group_state == RolloutState.ABORTED:
await self.replay_buffer.add.remote(group_data_items) # type: ignore[attr-defined]
self.replay_buffer.add(group_data_items) # type: ignore[attr-defined]
self.logger.debug(f"Adding aborted sample {action_id} to aborted storage")
elif group_state == RolloutState.SKIPPED:
self.skipped_sample_count += 1
Expand Down Expand Up @@ -267,7 +267,7 @@ async def concurrent_task_runner(self):
waiting_tasks.add(task)

_, pending_tasks = await asyncio.wait(waiting_tasks, timeout=0.1, return_when=asyncio.FIRST_COMPLETED)
self.finished_samples_count = ray.get(self.replay_buffer.get_finished_samples.remote())
self.finished_samples_count = self.replay_buffer.get_finished_samples()
waiting_tasks = pending_tasks

pbar.n = self.finished_samples_count
Expand Down Expand Up @@ -349,25 +349,25 @@ async def run(
if resume:
assert resume_path, "Resuming is enabled but no resume path is provided."
self.logger.info(f"Resuming replay buffer from {resume_path}")
await self.replay_buffer.resume_storage.remote(resume_path)
self.replay_buffer.resume_storage(resume_path)

await self.concurrent_task_runner()

if dump:
assert dump_path, "Dumping is enabled but no dump path is provided."
self.logger.info(f"Dump replay buffer from {dump_path}")
await self.replay_buffer.dump_storage.remote(dump_path)
self.replay_buffer.dump_storage(dump_path)

return await self.replay_buffer.get_samples.remote(self.target_batch_size) # type: ignore[attr-defined]
return self.replay_buffer.get_samples(self.target_batch_size) # type: ignore[attr-defined]

def logging_replaybuffer_state(self):
ray.get(self.replay_buffer.print.remote())
self.replay_buffer.print()

def get_replaybuffer_status(self):
return ray.get(self.replay_buffer.status.remote())
return self.replay_buffer.status()

def clear_replaybuffer(self):
return ray.get(self.replay_buffer.clear.remote())
return self.replay_buffer.clear()

async def _send_abort_request(self, client, url, timeout):
worker_url = f"{url}/abort_request"
Expand All @@ -386,15 +386,15 @@ def save(self, save_path: Path | str):
Args:
save_path (str): The path to the checkpoint file to save to.
"""
ray.get(self.replay_buffer.save.remote(save_path))
self.replay_buffer.save(save_path)

def resume(self, resume_path: Path | str):
"""Resumes the replay buffer from the specified path.

Args:
resume_path (str): The path to the checkpoint file to resume from.
"""
ray.get(self.replay_buffer.resume.remote(resume_path))
self.replay_buffer.resume(resume_path)


DataFlow = ray.remote(RawDataFlow)
Expand Down
1 change: 0 additions & 1 deletion xtuner/v1/ray/dataflow/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,6 @@ def resume(self, file_path: Path | str):
self.print()


@ray.remote
class ReplayBuffer:
"""A Ray actor that manages experience replay for reinforcement
learning."""
Expand Down
6 changes: 5 additions & 1 deletion xtuner/v1/ray/rollout/lmdeploy.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,9 +234,13 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace:
lmdeploy_config_kwargs["uvicorn_log_level"] = lmdeploy_config_kwargs.pop("uvicorn_log_level", "ERROR")
lmdeploy_config_kwargs["tm_log_level"] = lmdeploy_config_kwargs.pop("tm_log_level", "ERROR")

extra_engine_config = {}
extra_engine_config: dict[str, Any] = {}
if backend == "pytorch" and self.config.enable_return_routed_experts:
extra_engine_config["enable_return_routed_experts"] = True
if backend == "pytorch" and self.config.router_n_groups:
extra_engine_config["hf_overrides"] = dict(router_n_groups=self.config.router_n_groups)
if backend == "pytorch" and self.config.max_prefill_token_num:
extra_engine_config["max_prefill_token_num"] = self.config.max_prefill_token_num

dp_rank = 0
if backend == "pytorch":
Expand Down
8 changes: 6 additions & 2 deletions xtuner/v1/rl/base/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def _packing(self, data_batches, pack_max_length, language_cfg):
if data_batches[0]["seq_ctx"].rollout_routed_experts is not None:
assert language_cfg is not None
has_rollout_routed_experts = True
n_routed_experts = language_cfg.n_routed_experts

for pack_info in pack_infos:
indices = pack_info["indices"]
Expand Down Expand Up @@ -119,7 +118,12 @@ def _packing(self, data_batches, pack_max_length, language_cfg):
pad_seq_ctx.position_ids = torch.cat(_position_ids_list, dim=-1)

if has_rollout_routed_experts:
pad_rand_index = torch.randint(low=0, high=n_routed_experts, size=(pad_len, 1, 1))
pad_rand_index = torch.randint(
low=0,
high=1,
size=(pad_len, 1, 1), # add dummy data, true data will be initialized in train worker.fit
)
pad_seq_ctx.rollout_routed_experts = pad_rand_index
pad_seq_ctx.rollout_routed_experts = pad_rand_index

seq_ctx_list.append(pad_seq_ctx)
Expand Down
25 changes: 10 additions & 15 deletions xtuner/v1/rl/base/rollout_is.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,24 +53,14 @@ class RolloutImportanceSampling(BaseModel):
rollout_is_mask_threshold: Optional[Tuple[float, float]] = None
rollout_is_veto_threshold: Optional[Tuple[float, float]] = None

def compute_rollout_importance_weights_and_metrics(
def compute_rollout_importance_weights(
self,
old_log_prob: torch.Tensor,
rollout_log_prob: torch.Tensor,
num_tokens: torch.Tensor,
response_mask: torch.Tensor,
) -> tuple[Optional[torch.Tensor], torch.Tensor, dict[str, Any], dict[str, Any]]:
mismatch_metrics = compute_mismatch_metrics(
old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=response_mask
)
mismatch_metrics_scalar = {}
for key, value in mismatch_metrics.items():
if isinstance(value, torch.Tensor):
mismatch_metrics_scalar[f"mismatch/{key}"] = value.item()
else:
mismatch_metrics_scalar[f"mismatch/{key}"] = value

rollout_is_weights, modified_response_mask, metrics_scalar = compute_rollout_importance_weights(
) -> tuple[Optional[torch.Tensor], torch.Tensor, dict[str, Any]]:
return compute_rollout_importance_weights(
old_log_prob,
rollout_log_prob,
num_tokens,
Expand All @@ -81,7 +71,6 @@ def compute_rollout_importance_weights_and_metrics(
rollout_is_mask_threshold=self.rollout_is_mask_threshold,
rollout_is_veto_threshold=self.rollout_is_veto_threshold,
)
return rollout_is_weights, modified_response_mask, mismatch_metrics_scalar, metrics_scalar


def compute_rollout_importance_weights(
Expand Down Expand Up @@ -150,7 +139,7 @@ def compute_rollout_importance_weights(
metrics: Dict of IS and mismatch metrics, all scalars with "mismatch/" prefix
"""
if rollout_is_threshold is None:
return None, response_mask, {}
return None, response_mask, compute_mismatch_metrics(old_log_prob, rollout_log_prob, response_mask)

assert rollout_is_mode in ["truncate", "mask", "both"], (
f"Invalid rollout_is_mode: {rollout_is_mode}. Must be 'truncate', 'mask', or 'both'."
Expand Down Expand Up @@ -302,6 +291,12 @@ def compute_rollout_importance_weights(
# This is different from rejection - padding must be zeroed regardless of mode
rollout_is_weights = rollout_is_weights * response_mask

# Compute mismatch metrics (KL, PPL, etc.) and merge with IS metrics
mismatch_metrics = compute_mismatch_metrics(
old_log_prob=old_log_prob, rollout_log_prob=rollout_log_prob, response_mask=response_mask
)
metrics.update(mismatch_metrics)

# Convert all tensor metrics to scalars for logging
# Note: No need to detach since old_log_prob and rollout_log_prob are computed with torch.no_grad()
metrics_scalar = {}
Expand Down
Loading
Loading