@@ -36,8 +36,8 @@ def default_grouped_linear_filter_fn(mod: nn.Module, fqn: str):
3636
3737# handler 要跟 Engine 一一对应?
3838class Float8Handler :
39- scaling_granularity_gemm : ScalingGranularity
40- scaling_granularity_grouped_gemm : ScalingGranularity
39+ scaling_granularity_gemm : Optional [ ScalingGranularity ]
40+ scaling_granularity_grouped_gemm : Optional [ ScalingGranularity ]
4141 fsdp_mesh : Optional [DeviceMesh ] = None
4242 tilewise_reduce_mesh_devided_64 : Optional [DeviceMesh ] = None
4343 tilewise_reduce_mesh_mapping : Dict [Tuple [int , int ], DeviceMesh ] = {}
@@ -64,9 +64,9 @@ def __init__(
6464 assert scaling_granularity_gemm in (ScalingGranularity .TILEWISE , ScalingGranularity .TENSORWISE ) or (
6565 scaling_granularity_gemm is None
6666 ), "scaling_granularity_gemm must be TILEWISE or TENSORWISE."
67- assert scaling_granularity_grouped_gemm in (ScalingGranularity .TILEWISE , ScalingGranularity .TENSORWISE ), (
68- " scaling_granularity_grouped_gemm must be TILEWISE or TENSORWISE."
69- )
67+ assert scaling_granularity_grouped_gemm in (ScalingGranularity .TILEWISE , ScalingGranularity .TENSORWISE ) or (
68+ scaling_granularity_grouped_gemm is None
69+ ), "scaling_granularity_grouped_gemm must be TILEWISE or TENSORWISE."
7070
7171 self .scaling_granularity_gemm = scaling_granularity_gemm
7272 self .scaling_granularity_grouped_gemm = scaling_granularity_grouped_gemm
0 commit comments