Skip to content

Commit 7459ac8

Browse files
committed
fix mypy check
1 parent be56a6a commit 7459ac8

File tree

3 files changed

+8
-8
lines changed

3 files changed

+8
-8
lines changed

xtuner/v1/float8/float8_handler.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ def default_grouped_linear_filter_fn(mod: nn.Module, fqn: str):
3636

3737
# handler 要跟 Engine 一一对应?
3838
class Float8Handler:
39-
scaling_granularity_gemm: ScalingGranularity
40-
scaling_granularity_grouped_gemm: ScalingGranularity
39+
scaling_granularity_gemm: Optional[ScalingGranularity]
40+
scaling_granularity_grouped_gemm: Optional[ScalingGranularity]
4141
fsdp_mesh: Optional[DeviceMesh] = None
4242
tilewise_reduce_mesh_devided_64: Optional[DeviceMesh] = None
4343
tilewise_reduce_mesh_mapping: Dict[Tuple[int, int], DeviceMesh] = {}
@@ -64,9 +64,9 @@ def __init__(
6464
assert scaling_granularity_gemm in (ScalingGranularity.TILEWISE, ScalingGranularity.TENSORWISE) or (
6565
scaling_granularity_gemm is None
6666
), "scaling_granularity_gemm must be TILEWISE or TENSORWISE."
67-
assert scaling_granularity_grouped_gemm in (ScalingGranularity.TILEWISE, ScalingGranularity.TENSORWISE), (
68-
"scaling_granularity_grouped_gemm must be TILEWISE or TENSORWISE."
69-
)
67+
assert scaling_granularity_grouped_gemm in (ScalingGranularity.TILEWISE, ScalingGranularity.TENSORWISE) or (
68+
scaling_granularity_grouped_gemm is None
69+
), "scaling_granularity_grouped_gemm must be TILEWISE or TENSORWISE."
7070

7171
self.scaling_granularity_gemm = scaling_granularity_gemm
7272
self.scaling_granularity_grouped_gemm = scaling_granularity_grouped_gemm

xtuner/v1/ray/rollout/lmdeploy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def _transform_rollout_config_to_server_configs(self) -> Namespace:
234234
lmdeploy_config_kwargs["uvicorn_log_level"] = lmdeploy_config_kwargs.pop("uvicorn_log_level", "ERROR")
235235
lmdeploy_config_kwargs["tm_log_level"] = lmdeploy_config_kwargs.pop("tm_log_level", "ERROR")
236236

237-
extra_engine_config = {}
237+
extra_engine_config: dict[str, Any] = {}
238238
if backend == "pytorch" and self.config.enable_return_routed_experts:
239239
extra_engine_config["enable_return_routed_experts"] = True
240240
if backend == "pytorch" and self.config.router_n_groups:

xtuner/v1/rl/base/worker.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,8 +410,8 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int):
410410

411411
# old logprobs are inplaced updated in compute_actor_logprobs
412412
if isinstance(self.config.model_cfg, BaseComposeConfig):
413-
if self._engine.llm_float8_handler is not None and self._engine.llm_float8_handler.enabled:
414-
self._engine.llm_float8_handler.precompute_float8_dynamic_scale_for_fsdp(
413+
if self._engine.llm_float8_handler is not None and self._engine.llm_float8_handler.enabled: # type: ignore [attr-defined]
414+
self._engine.llm_float8_handler.precompute_float8_dynamic_scale_for_fsdp( # type: ignore [attr-defined]
415415
self._engine.model.language_model
416416
)
417417
else:

0 commit comments

Comments
 (0)