Skip to content

Commit 831c32d

Browse files
committed
Lint
Signed-off-by: Jingyu Xin <[email protected]>
1 parent 071f167 commit 831c32d

File tree

4 files changed

+33
-20
lines changed

4 files changed

+33
-20
lines changed

examples/diffusers/quantization/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
"*weight_quantizer": {"num_bits": (4, 3), "axis": None},
3737
"*input_quantizer": {"num_bits": (4, 3), "axis": None},
3838
"*output_quantizer": {"enable": False},
39-
"*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3),"block_sizes": {-2: 32}},
39+
"*[qkv]_bmm_quantizer": {"type": "dynamic", "num_bits": (4, 3), "block_sizes": {-2: 32}},
4040
"*softmax_quantizer": {
4141
"num_bits": (4, 3),
4242
"axis": None,

modelopt/torch/quantization/export_onnx.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,7 @@ def _fp8_quantize(
234234
)
235235
return q_op
236236

237+
237238
def _fp8_block_quantize(
238239
g: torch.onnx._internal.jit_utils.GraphContext,
239240
inputs: torch.Value,
@@ -289,13 +290,14 @@ def _fp8_dequantize(
289290
out = g.op("Cast", out, to_i=onnx_dtype_map[otype]) # type: ignore[index]
290291
return out
291292

293+
292294
def _fp8_block_dequantize(
293295
g: torch.onnx._internal.jit_utils.GraphContext,
294296
inputs: torch.Value,
295297
scales: torch.Value,
296298
trt_high_precision_dtype: str,
297299
otype: str | None = None,
298-
block_sizes: list = [1,1,128,1]
300+
block_sizes: list = [1, 1, 128, 1],
299301
):
300302
"""Helper Function for Dequantization."""
301303
output_shape = sym_help._get_tensor_sizes(inputs)
@@ -339,8 +341,7 @@ def export_fp8(
339341
)
340342
return _fp8_block_dequantize(
341343
g, q_tensor, scales_output, trt_high_precision_dtype, otype, block_sizes
342-
)
343-
344+
)
344345

345346

346347
def scaled_dot_product_attention(
@@ -498,7 +499,9 @@ def export_fp8_mha(
498499
v_input_dtype = value.type().scalarType()
499500
if {q_input_dtype, k_input_dtype, v_input_dtype} != {high_precision_flag}:
500501
raise ValueError("The quantized MHA must have 16-bit inputs.")
501-
query_scaled = export_fp8(g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape)
502+
query_scaled = export_fp8(
503+
g, query_scaled, q_quantized_scale, high_precision_flag, q_block_shape
504+
)
502505
query_scaled = g.op("Cast", query_scaled, to_i=onnx_dtype_map["Float"])
503506
key_transposed_scaled = export_fp8(
504507
g, key_transposed_scaled, k_quantized_scale, high_precision_flag, k_block_shape
@@ -531,7 +534,7 @@ def export_fp8_mha(
531534

532535
if not disable_fp8_mha:
533536
# Softmax's output scale is hard coded to 1.0
534-
# We cannot do block quant for the softmax's output
537+
# We cannot do block quant for the softmax's output
535538
attn_weight = export_fp8(g, attn_weight, 1.0, high_precision_flag, None)
536539
attn_weight = g.op("Cast", attn_weight, to_i=onnx_dtype_map["Float"])
537540

modelopt/torch/quantization/nn/modules/tensor_quantizer.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -650,27 +650,27 @@ def _real_quantize(self, inputs):
650650

651651
def _get_block_sizes_list(self, shape):
652652
"""Convert block_sizes dict to list format based on tensor shape.
653-
653+
654654
Args:
655655
shape: The tensor shape to use for conversion (can be tuple or torch.Size)
656-
656+
657657
Returns:
658658
List of block sizes for each dimension, or None if block_sizes is None
659-
659+
660660
Example:
661661
block_sizes = {-2: 32} with shape [2, 24, 4608, 128] -> [1, 1, 32, 1]
662662
"""
663663
if self.block_sizes is None:
664664
return None
665-
665+
666666
block_sizes_list = []
667667
for dim in range(len(shape)):
668668
# Check both positive and negative dimension indices
669669
dim_negative = dim - len(shape)
670670
block_size = self.block_sizes.get(dim, None) or self.block_sizes.get(dim_negative, None)
671671
block_sizes_list.append(block_size if block_size is not None else 1)
672672
return block_sizes_list
673-
673+
674674
def _fake_quantize(self, inputs):
675675
"""Fake quantization."""
676676
amax = None
@@ -956,7 +956,7 @@ def forward(self, inputs):
956956
and self.block_sizes.get("type", None) != "dynamic"
957957
and self._fake_quant
958958
):
959-
# Reshape is required if the logic isn’t handled in the simulation kernel
959+
# Reshape is required if the logic isnt handled in the simulation kernel
960960
self._setup_for_blockquant(inputs)
961961
setattr(self, "_original_input_shape", inputs.shape)
962962
inputs = self._process_for_blockquant(inputs)

modelopt/torch/quantization/plugins/diffusers.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -114,18 +114,26 @@ def _quantized_sdpa(self, *args, **kwargs):
114114
key = self.k_bmm_quantizer(key)
115115
value = self.v_bmm_quantizer(value)
116116

117-
if not self.q_bmm_quantizer._dynamic and not self.k_bmm_quantizer._dynamic and not self.v_bmm_quantizer._dynamic:
117+
if (
118+
not self.q_bmm_quantizer._dynamic
119+
and not self.k_bmm_quantizer._dynamic
120+
and not self.v_bmm_quantizer._dynamic
121+
):
118122
q_quantized_scale = self.q_bmm_quantizer._get_amax(query)
119123
k_quantized_scale = self.k_bmm_quantizer._get_amax(key)
120124
v_quantized_scale = self.v_bmm_quantizer._get_amax(value)
121125
else:
122-
assert self.q_bmm_quantizer._dynamic and self.k_bmm_quantizer._dynamic and self.v_bmm_quantizer._dynamic, "QKV QDQS must be in the same type"
126+
assert (
127+
self.q_bmm_quantizer._dynamic
128+
and self.k_bmm_quantizer._dynamic
129+
and self.v_bmm_quantizer._dynamic
130+
), "QKV QDQS must be in the same type"
123131
q_quantized_scale, k_quantized_scale, v_quantized_scale = None, None, None
124-
132+
125133
# Get block sizes lists for each quantizer if needed
126-
q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape)
127-
k_block_sizes = self.k_bmm_quantizer._get_block_sizes_list(key.shape)
128-
v_block_sizes = self.v_bmm_quantizer._get_block_sizes_list(value.shape)
134+
q_block_sizes = self.q_bmm_quantizer._get_block_sizes_list(query.shape) # type: ignore[union-attr]
135+
k_block_sizes = self.k_bmm_quantizer._get_block_sizes_list(key.shape) # type: ignore[union-attr]
136+
v_block_sizes = self.v_bmm_quantizer._get_block_sizes_list(value.shape) # type: ignore[union-attr]
129137

130138
# We don't need to calibrate the output of softmax
131139
return self.bmm2_output_quantizer(
@@ -142,7 +150,7 @@ def _quantized_sdpa(self, *args, **kwargs):
142150
else "Half",
143151
self._disable_fp8_mha if hasattr(self, "_disable_fp8_mha") else True,
144152
q_block_sizes,
145-
k_block_sizes,
153+
k_block_sizes,
146154
v_block_sizes,
147155
)
148156
)
@@ -218,7 +226,9 @@ def forward(
218226
)
219227

220228
@staticmethod
221-
@symbolic_helper.parse_args("v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b", "is", "is", "is")
229+
@symbolic_helper.parse_args(
230+
"v", "v", "v", "v", "f", "b", "v", "t", "t", "t", "s", "b", "is", "is", "is"
231+
)
222232
def symbolic(
223233
g: jit_utils.GraphContext,
224234
query: torch._C.Value,

0 commit comments

Comments
 (0)