Skip to content

Commit e6e0c4a

Browse files
committed
fix shapes
1 parent 0eb3989 commit e6e0c4a

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

src/diffusers/quantizers/finegrained_fp8/finegrained_fp8_quantizer.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -106,13 +106,13 @@ def create_quantized_param(
106106

107107
if rows % block_size_m != 0 or cols % block_size_n != 0:
108108
raise ValueError(
109-
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})"
109+
f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n} for {param_name})"
110110
)
111111
param_value_orig_shape = param_value.shape
112112

113113
param_value = param_value.reshape(
114-
-1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n
115-
).permute(0, 1, 3, 2, 4)
114+
rows // block_size_m, block_size_m, cols // block_size_n, block_size_n
115+
).permute(0, 2, 1, 3)
116116

117117
# Calculate scaling factor for each block
118118
max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
@@ -123,12 +123,12 @@ def create_quantized_param(
123123
# Quantize the weights
124124
quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
125125

126-
quantized_param = quantized_param.permute(0, 1, 3, 2, 4)
126+
quantized_param = quantized_param.permute(0, 2, 1, 3)
127127
# Reshape back to matrix shape
128128
quantized_param = quantized_param.reshape(param_value_orig_shape)
129129

130130
# Reshape scale to match the number of blocks
131-
scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()
131+
scale = scale.reshape(scale_orig_shape).reciprocal()
132132

133133
# Load into the model
134134
module._parameters[tensor_name] = quantized_param.to(target_device)

src/diffusers/quantizers/finegrained_fp8/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def w8a8_block_fp8_matmul_triton(
175175
assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
176176
M = A.numel() // A.shape[-1]
177177

178-
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
178+
assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2, f"B {B.shape} and Bs {Bs.shape}"
179179
N, K = B.shape
180180
assert triton.cdiv(N, block_n) == Bs.shape[0]
181181
assert triton.cdiv(K, block_k) == Bs.shape[1]

0 commit comments

Comments
 (0)