@@ -8102,15 +8102,13 @@ def repack_mxfp4(self, new_name: str, blocks: Tensor, scales: Tensor):
8102
8102
def generate_extra_tensors (self ) -> Iterable [tuple [str , Tensor ]]:
8103
8103
blocks0 : Tensor = torch .zeros (1 )
8104
8104
blocks1 : Tensor = torch .zeros (1 )
8105
- found_mxfp4_tensors = False
8106
8105
# we assume that tensors are loaded in the correct order
8107
8106
for name , data_torch in self .get_tensors ():
8108
8107
if "mlp.experts.down_proj_blocks" in name :
8109
8108
blocks0 = data_torch
8110
8109
elif "mlp.experts.down_proj_scales" in name :
8111
8110
new_name = self .map_tensor_name (name .replace ("_scales" , ".weight" ))
8112
8111
self .repack_mxfp4 (new_name , blocks0 , data_torch )
8113
- found_mxfp4_tensors = True
8114
8112
elif "mlp.experts.gate_up_proj_blocks" in name :
8115
8113
blocks0 , blocks1 = data_torch [:, ::2 , :, :], data_torch [:, 1 ::2 , :, :]
8116
8114
elif "mlp.experts.gate_up_proj_scales" in name :
@@ -8119,9 +8117,6 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
8119
8117
new_name_up = self .map_tensor_name (name .replace ("gate_up_proj_scales" , "up_proj.weight" ))
8120
8118
self .repack_mxfp4 (new_name_gate , blocks0 , scales0 )
8121
8119
self .repack_mxfp4 (new_name_up , blocks1 , scales1 )
8122
- found_mxfp4_tensors = True
8123
- if not found_mxfp4_tensors :
8124
- raise ValueError ("No MXFP4 tensors found in the model. Please make sure you are using MXFP4 model." )
8125
8120
return []
8126
8121
8127
8122
def modify_tensors (self , data_torch : Tensor , name : str , bid : int | None ) -> Iterable [tuple [str , Tensor ]]:
@@ -8134,7 +8129,12 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
8134
8129
if "down_proj" in name :
8135
8130
if name .endswith ("_bias" ):
8136
8131
name = name .replace ("down_proj_bias" , "down_proj.bias" )
8132
+ elif "_blocks" not in name and "_scales" not in name :
8133
+ logger .warning (f"{ name } is not in MXFP4, performance may be degraded" )
8134
+ name = name .replace ("down_proj" , "down_proj.weight" )
8135
+ data_torch = data_torch .transpose (- 1 , - 2 )
8137
8136
else :
8137
+ # otherwise, it should already be repacked to ggml MXFP4 format
8138
8138
return []
8139
8139
8140
8140
# split the gate_up into gate and up
@@ -8147,7 +8147,18 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
8147
8147
(self .map_tensor_name (name_gate ), gate_proj_bias ),
8148
8148
(self .map_tensor_name (name_up ), up_proj_bias )
8149
8149
]
8150
+ elif "_blocks" not in name and "_scales" not in name :
8151
+ logger .warning (f"{ name } is not in MXFP4, performance may be degraded" )
8152
+ name_up = name .replace ("gate_up_proj" , "up_proj.weight" )
8153
+ name_gate = name .replace ("gate_up_proj" , "gate_proj.weight" )
8154
+ data_torch = data_torch .transpose (- 1 , - 2 )
8155
+ gate_proj_weight , up_proj_weight = data_torch [:, ::2 , :], data_torch [:, 1 ::2 , :]
8156
+ return [
8157
+ (self .map_tensor_name (name_gate ), gate_proj_weight ),
8158
+ (self .map_tensor_name (name_up ), up_proj_weight )
8159
+ ]
8150
8160
else :
8161
+ # otherwise, it should already be repacked to ggml MXFP4 format
8151
8162
return []
8152
8163
8153
8164
return [(self .map_tensor_name (name ), data_torch )]
0 commit comments