@@ -1613,10 +1613,13 @@ def _extract_op_args(node):
16131613 "is_gated_mlp" ,
16141614 )
16151615
1616- def _stack (param_list , dim = 0 ):
1617- return torch .stack (
1618- [get_param_or_buffer (element .target ) for element in param_list ], dim = dim
1619- ).contiguous ()
1616+ def _stack (param_list , dim = 0 , device = None , dtype = None ):
1617+ if param_list :
1618+ return torch .stack (
1619+ [get_param_or_buffer (element .target ) for element in param_list ], dim = dim
1620+ ).contiguous ()
1621+ else :
1622+ return torch .empty (0 , device = device , dtype = dtype )
16201623
16211624 def _prepare_args_cutlass_format_nvfp4 ():
16221625 if is_gated_mlp :
@@ -1627,9 +1630,15 @@ def _prepare_args_cutlass_format_nvfp4():
16271630 fc1_act_scale = torch .cat (
16281631 [w3_input_scale_stacked , w1_input_scale_stacked ], dim = 1
16291632 ).contiguous ()
1633+ fc1_alpha_stacked = torch .cat ([w3_alpha_stacked , w1_alpha_stacked ], dim = 1 ).contiguous ()
1634+ fc1_weight_blockscale_fp8_stacked = torch .cat (
1635+ [w3_weight_blockscale_fp8_stacked , w1_weight_blockscale_fp8_stacked ], dim = 1
1636+ ).contiguous ()
16301637 else :
16311638 fc1_expert_weights = w1_stacked
16321639 fc1_act_scale = w1_input_scale_stacked
1640+ fc1_alpha_stacked = w1_alpha_stacked
1641+ fc1_weight_blockscale_fp8_stacked = w1_weight_blockscale_fp8_stacked
16331642
16341643 fc2_expert_weights = w2_stacked
16351644 fc2_act_scale = w2_input_scale_stacked
@@ -1651,11 +1660,13 @@ def _prepare_args_cutlass_format_nvfp4():
16511660 weight_dtype = torch .float8_e4m3fn
16521661 _register_parameter (gm , new_key_fc1_expert_weights , fc1_expert_weights .to (weight_dtype ))
16531662 _register_parameter (gm , new_key_fc2_expert_weights , fc2_expert_weights .to (weight_dtype ))
1654- _register_parameter (gm , new_key_fc1_weight_blockscale_fp8 , w1_weight_blockscale_fp8_stacked )
1663+ _register_parameter (
1664+ gm , new_key_fc1_weight_blockscale_fp8 , fc1_weight_blockscale_fp8_stacked
1665+ )
16551666 _register_parameter (gm , new_key_fc2_weight_blockscale_fp8 , w2_weight_blockscale_fp8_stacked )
16561667 _register_parameter (gm , new_key_fc1_act_scale , fc1_act_scale )
16571668 _register_parameter (gm , new_key_fc2_act_scale , fc2_act_scale )
1658- _register_parameter (gm , new_key_fc1_alpha , w1_alpha_stacked )
1669+ _register_parameter (gm , new_key_fc1_alpha , fc1_alpha_stacked )
16591670 _register_parameter (gm , new_key_fc2_alpha , w2_alpha_stacked )
16601671
16611672 with graph .inserting_before (node ):
@@ -1705,50 +1716,23 @@ def _prepare_args_cutlass_format_nvfp4():
17051716 # Stack the actual tensor values (fast, like in quantize_moe.py)
17061717 w1_stacked = _stack (w1_list , dim = 0 )
17071718 w2_stacked = _stack (w2_list , dim = 0 )
1708- w3_stacked = (
1709- _stack (w3_list , dim = 0 )
1710- if w3_list
1711- else torch .empty (0 , device = w1_stacked .device , dtype = w1_stacked .dtype )
1712- )
1719+ device , dtype = (w1_stacked .device , w1_stacked .dtype )
1720+ w3_stacked = _stack (w3_list , dim = 0 , device = device , dtype = dtype )
17131721
17141722 # Scales are buffers, not parameters
17151723 w1_input_scale_stacked = _stack (w1_input_scale , dim = 0 )
17161724 w2_input_scale_stacked = _stack (w2_input_scale , dim = 0 )
1717- w3_input_scale_stacked = (
1718- _stack (w3_input_scale , dim = 0 )
1719- if w3_input_scale
1720- else torch .empty (
1721- 0 , device = w1_input_scale_stacked .device , dtype = w1_input_scale_stacked .dtype
1722- )
1723- )
1724- # assert torch.all(w1_input_scale_stacked[0] == w1_input_scale_stacked), (
1725- # "All w1 scales should have the same value."
1726- # )
1727- # assert torch.all(w2_input_scale_stacked[0] == w2_input_scale_stacked), (
1728- # "All w2 scales should have the same value."
1729- # )
1725+ w3_input_scale_stacked = _stack (w3_input_scale , dim = 0 , device = device , dtype = dtype )
17301726
17311727 w1_weight_blockscale_fp8_stacked = _stack (w1_weight_scale , dim = 0 ).to (torch .float8_e4m3fn )
17321728 w2_weight_blockscale_fp8_stacked = _stack (w2_weight_scale , dim = 0 ).to (torch .float8_e4m3fn )
1733- # w3_weight_blockscale_fp8_stacked = (
1734- # (
1735- # _stack(w3_weight_scale, dim=0)
1736- # if w3_weight_scale
1737- # else torch.empty(
1738- # 0,
1739- # device=w1_weight_blockscale_fp8_stacked.device,
1740- # dtype=w1_weight_blockscale_fp8_stacked.dtype,
1741- # )
1742- # )
1743- # .to(torch.float8_e4m3fn)
1744- # .contiguous()
1745- # )
1746-
1747- ###
1729+ w3_weight_blockscale_fp8_stacked = _stack (
1730+ w3_weight_scale , dim = 0 , device = device , dtype = dtype
1731+ ).to (torch .float8_e4m3fn )
1732+
17481733 w1_alpha_stacked = _stack (w1_alpha , dim = 0 )
17491734 w2_alpha_stacked = _stack (w2_alpha , dim = 0 )
1750- # w3_alpha_stacked = _stack(w3_alpha, dim=0)
1751- ###
1735+ w3_alpha_stacked = _stack (w3_alpha , dim = 0 , device = device , dtype = dtype )
17521736
17531737 args = _prepare_args_cutlass_format_nvfp4 ()
17541738
@@ -1770,7 +1754,6 @@ def _prepare_args_cutlass_format_nvfp4():
17701754 # will remove the parameters/buffers that are no longer referenced
17711755 gm .graph .eliminate_dead_code ()
17721756 gm .delete_all_unused_submodules ()
1773-
17741757 return fused_key_counter
17751758
17761759
0 commit comments