Skip to content

Commit b1b4655

Browse files
authored
Re-apply make style (#40106)
make style
1 parent a07b5e9 commit b1b4655

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

src/transformers/integrations/mxfp4.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -314,12 +314,12 @@ def should_convert_module(current_key_name, patterns):
314314
def dequantize(module, param_name, param_value, target_device, dq_param_name, **kwargs):
315315
from ..integrations.tensor_parallel import shard_and_distribute_module
316316

317-
model = kwargs.get("model", None)
318-
empty_param = kwargs.get("empty_param", None)
319-
casting_dtype = kwargs.get("casting_dtype", None)
320-
to_contiguous = kwargs.get("to_contiguous", None)
321-
rank = kwargs.get("rank", None)
322-
device_mesh = kwargs.get("device_mesh", None)
317+
model = kwargs.get("model")
318+
empty_param = kwargs.get("empty_param")
319+
casting_dtype = kwargs.get("casting_dtype")
320+
to_contiguous = kwargs.get("to_contiguous")
321+
rank = kwargs.get("rank")
322+
device_mesh = kwargs.get("device_mesh")
323323

324324
for proj in ["gate_up_proj", "down_proj"]:
325325
if proj in param_name:
@@ -357,12 +357,12 @@ def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, **kwa
357357
)
358358
from ..integrations.tensor_parallel import shard_and_distribute_module
359359

360-
model = kwargs.get("model", None)
361-
empty_param = kwargs.get("empty_param", None)
362-
casting_dtype = kwargs.get("casting_dtype", None)
363-
to_contiguous = kwargs.get("to_contiguous", None)
364-
rank = kwargs.get("rank", None)
365-
device_mesh = kwargs.get("device_mesh", None)
360+
model = kwargs.get("model")
361+
empty_param = kwargs.get("empty_param")
362+
casting_dtype = kwargs.get("casting_dtype")
363+
to_contiguous = kwargs.get("to_contiguous")
364+
rank = kwargs.get("rank")
365+
device_mesh = kwargs.get("device_mesh")
366366

367367
for proj in ["gate_up_proj", "down_proj"]:
368368
if proj in param_name:

src/transformers/quantizers/quantizer_mxfp4.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def validate_environment(self, *args, **kwargs):
101101
global triton_kernels_hub
102102
triton_kernels_hub = get_kernel("kernels-community/triton_kernels")
103103

104-
device_map = kwargs.get("device_map", None)
104+
device_map = kwargs.get("device_map")
105105
if device_map is None:
106106
logger.warning_once(
107107
"You have loaded an FP4 model on CPU and have a CUDA device available, make sure to set "
@@ -210,11 +210,11 @@ def create_quantized_param(
210210
# we take this path if already quantized but not in a compatible way
211211
# The params going here are either gate_up_proj_blocks, or down_proj_blocks, or gate_up_proj_scales, or down_proj_scales
212212
else:
213-
empty_param = kwargs.get("empty_param", None)
214-
casting_dtype = kwargs.get("casting_dtype", None)
215-
to_contiguous = kwargs.get("to_contiguous", None)
216-
rank = kwargs.get("rank", None)
217-
device_mesh = kwargs.get("device_mesh", None)
213+
empty_param = kwargs.get("empty_param")
214+
casting_dtype = kwargs.get("casting_dtype")
215+
to_contiguous = kwargs.get("to_contiguous")
216+
rank = kwargs.get("rank")
217+
device_mesh = kwargs.get("device_mesh")
218218
if ("blocks" in param_name or "scales" in param_name) and self.quantization_config.dequantize:
219219
# blocks and scales have the same length that's this works for both
220220
module, _ = get_module_from_name(model, param_name[: -len("_blocks")])

0 commit comments

Comments
 (0)