@@ -314,12 +314,12 @@ def should_convert_module(current_key_name, patterns):
314
314
def dequantize (module , param_name , param_value , target_device , dq_param_name , ** kwargs ):
315
315
from ..integrations .tensor_parallel import shard_and_distribute_module
316
316
317
- model = kwargs .get ("model" , None )
318
- empty_param = kwargs .get ("empty_param" , None )
319
- casting_dtype = kwargs .get ("casting_dtype" , None )
320
- to_contiguous = kwargs .get ("to_contiguous" , None )
321
- rank = kwargs .get ("rank" , None )
322
- device_mesh = kwargs .get ("device_mesh" , None )
317
+ model = kwargs .get ("model" )
318
+ empty_param = kwargs .get ("empty_param" )
319
+ casting_dtype = kwargs .get ("casting_dtype" )
320
+ to_contiguous = kwargs .get ("to_contiguous" )
321
+ rank = kwargs .get ("rank" )
322
+ device_mesh = kwargs .get ("device_mesh" )
323
323
324
324
for proj in ["gate_up_proj" , "down_proj" ]:
325
325
if proj in param_name :
@@ -357,12 +357,12 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, **kwa
357
357
)
358
358
from ..integrations .tensor_parallel import shard_and_distribute_module
359
359
360
- model = kwargs .get ("model" , None )
361
- empty_param = kwargs .get ("empty_param" , None )
362
- casting_dtype = kwargs .get ("casting_dtype" , None )
363
- to_contiguous = kwargs .get ("to_contiguous" , None )
364
- rank = kwargs .get ("rank" , None )
365
- device_mesh = kwargs .get ("device_mesh" , None )
360
+ model = kwargs .get ("model" )
361
+ empty_param = kwargs .get ("empty_param" )
362
+ casting_dtype = kwargs .get ("casting_dtype" )
363
+ to_contiguous = kwargs .get ("to_contiguous" )
364
+ rank = kwargs .get ("rank" )
365
+ device_mesh = kwargs .get ("device_mesh" )
366
366
367
367
for proj in ["gate_up_proj" , "down_proj" ]:
368
368
if proj in param_name :
0 commit comments