Skip to content

Commit 1969e36

Browse files
committed
resolve cublas path resolution, fix fp4 q precision for triton and eager backnds, performance boost to triton fp4
1 parent adeefec commit 1969e36

File tree

5 files changed

+82
-85
lines changed

5 files changed

+82
-85
lines changed

comfy_kitchen/backends/cuda/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def find_lib_dir(start_dir, lib_pattern):
4343
return root
4444
return None
4545

46-
nvidia_cu13_path = os.path.dirname(nvidia.cu13.__path__[0])
46+
nvidia_cu13_path = nvidia.cu13.__path__[0]
4747

4848
if sys.platform == "win32":
4949
lib_dir = find_lib_dir(nvidia_cu13_path, "cublasLt64")

comfy_kitchen/backends/eager/quantization.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from comfy_kitchen.float_utils import (
1212
F4_E2M1_MAX,
13-
F8_E4M3_EPS,
1413
F8_E4M3_MAX,
1514
F8_E5M2_MAX,
1615
_f32_to_floatx_unpacked,
@@ -82,15 +81,19 @@ def quantize_nvfp4(
8281

8382
x = x.reshape(orig_shape[0], -1, block_size)
8483
max_abs = torch.amax(torch.abs(x), dim=-1)
85-
block_scale = max_abs / F4_E2M1_MAX
86-
block_scale_fp32 = block_scale.to(torch.float32)
87-
scaled_block_scales = block_scale_fp32 / per_tensor_scale
88-
scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, min=F8_E4M3_EPS, max=F8_E4M3_MAX)
84+
block_scale = max_abs.to(torch.float32) / F4_E2M1_MAX
85+
scaled_block_scales = block_scale / per_tensor_scale
86+
scaled_block_scales_fp8 = torch.clamp(scaled_block_scales, max=F8_E4M3_MAX)
8987
scaled_block_scales_fp32 = _float8_round(scaled_block_scales_fp8)
90-
# We "temporarily" dequant the scaled_block_scales_fp32 to get the per_tensor_scale
91-
# To apply to data
9288
total_scale = per_tensor_scale * scaled_block_scales_fp32
93-
data_scaled = x / total_scale.unsqueeze(-1)
89+
90+
# Handle zero blocks (from padding): avoid 0/0 NaN
91+
zero_scale_mask = (total_scale == 0)
92+
total_scale_safe = torch.where(zero_scale_mask, torch.ones_like(total_scale), total_scale)
93+
94+
data_scaled = x.float() / total_scale_safe.unsqueeze(-1)
95+
data_scaled = torch.where(zero_scale_mask.unsqueeze(-1), torch.zeros_like(data_scaled), data_scaled)
96+
9497
out_scales = scaled_block_scales_fp8
9598

9699
data_scaled = torch.clamp(data_scaled, -F4_E2M1_MAX, F4_E2M1_MAX)

comfy_kitchen/backends/triton/quantization.py

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -262,8 +262,6 @@ def quantize_nvfp4_kernel_tl(
262262

263263
# Scale block scale to FP8
264264
scaled_block_scale = block_scale / per_tensor_scale
265-
# Clamp to [F8_E4M3_EPS, F8_E4M3_MAX] = [0.125, 448.0]
266-
scaled_block_scale = tl.maximum(scaled_block_scale, 0.125)
267265
scaled_block_scale = tl.minimum(scaled_block_scale, 448.0)
268266

269267
# Round to FP8 precision
@@ -280,44 +278,41 @@ def quantize_nvfp4_kernel_tl(
280278
# Calculate total scale for data quantization
281279
scaled_block_scale_fp32 = scaled_block_scale_fp8.to(tl.float32)
282280
total_scale = per_tensor_scale * scaled_block_scale_fp32
283-
total_scale = tl.where(total_scale < 1e-10, 1.0, total_scale)
281+
zero_scale_mask = total_scale < 1e-10
282+
total_scale = tl.where(zero_scale_mask, 1.0, total_scale)
284283

285-
# Scale and clamp data
284+
# Scale data (satfinite modifier in PTX will handle clamping)
286285
data_scaled = x / total_scale
287-
data_scaled = tl.maximum(data_scaled, -6.0) # -F4_E2M1_MAX
288-
data_scaled = tl.minimum(data_scaled, 6.0) # F4_E2M1_MAX
289-
290-
# Quantize to FP4 values and pack - optimized version
291-
# Convert all values to FP4 representation
292-
sign_all = tl.where(data_scaled < 0, 1, 0)
293-
abs_all = tl.abs(data_scaled)
294-
295-
# Map all to FP4 bit pattern (E2M1)
296-
q_all = tl.where(abs_all <= 0.25, 0,
297-
tl.where(abs_all < 0.75, 1,
298-
tl.where(abs_all <= 1.25, 2,
299-
tl.where(abs_all < 1.75, 3,
300-
tl.where(abs_all <= 2.5, 4,
301-
tl.where(abs_all < 3.5, 5,
302-
tl.where(abs_all <= 5.0, 6, 7)))))))
303-
304-
# Add sign bits to all values
305-
fp4_all = (sign_all.to(tl.int32) << 3) | q_all.to(tl.int32)
306-
307-
# Pack consecutive pairs of FP4 values
308-
# fp4_all has 16 elements: [v0, v1, v2, v3, ..., v15]
286+
data_scaled = tl.where(zero_scale_mask, 0.0, data_scaled)
287+
309288
# We want to pack: (v0,v1), (v2,v3), ..., (v14,v15)
310289
pair_idx = tl.arange(0, block_size // 2)
311290
even_idx = pair_idx * 2
312291
odd_idx = pair_idx * 2 + 1
313292

314293
# Extract even and odd elements using one-hot selection
315294
indices = tl.arange(0, block_size)
316-
fp4_even = tl.sum(tl.where(indices == even_idx[:, None], fp4_all, 0), axis=1)
317-
fp4_odd = tl.sum(tl.where(indices == odd_idx[:, None], fp4_all, 0), axis=1)
318-
319-
# Pack two 4-bit values into one uint8
320-
packed_bytes = ((fp4_even << 4) | fp4_odd).to(tl.uint8)
295+
f32_even = tl.sum(tl.where(indices == even_idx[:, None], data_scaled, 0), axis=1)
296+
f32_odd = tl.sum(tl.where(indices == odd_idx[:, None], data_scaled, 0), axis=1)
297+
298+
packed_bytes_u16 = tl.inline_asm_elementwise(
299+
asm="""
300+
{
301+
.reg .b8 fp4_byte;
302+
.reg .b16 result;
303+
cvt.rn.satfinite.e2m1x2.f32 fp4_byte, $1, $2;
304+
mov.b16 result, {fp4_byte, 0};
305+
mov.u16 $0, result;
306+
}
307+
""",
308+
constraints="=h,f,f",
309+
args=[f32_even, f32_odd],
310+
dtype=tl.uint16,
311+
is_pure=True,
312+
pack=1,
313+
)
314+
# Extract the low byte
315+
packed_bytes = (packed_bytes_u16 & 0xFF).to(tl.uint8)
321316

322317
# Store packed bytes
323318
out_offs = pid_m * (n // 2) + pid_n * (block_size // 2) + pair_idx

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"
99

1010
[project]
1111
name = "comfy-kitchen"
12-
version = "0.1.4"
12+
version = "0.1.5"
1313
description = "Fast Kernel Library for ComfyUI with multiple compute backends"
1414
readme = "README.md"
1515
requires-python = ">=3.10"

tests/test_qdq.py

Lines changed: 43 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,7 @@
33

44
import comfy_kitchen as ck
55
from comfy_kitchen.float_utils import (
6-
F4_E2M1_EPS,
76
F4_E2M1_MAX,
8-
F8_E4M3_EPS,
97
F8_E4M3_MAX,
108
fp4_x2_to_f32,
119
)
@@ -146,67 +144,48 @@ def capable_backends(self, device):
146144

147145
@pytest.mark.parametrize("m,k", [
148146
(1024, 2048),
147+
(512, 1024),
149148
(129, 128), # Edge case: odd rows requiring padding
150149
(33, 65), # Edge case: both dimensions odd
151150
])
152151
def test_quantize_nvfp4_all_backends(self, capable_backends, device, seed, m, k):
153-
"""Test NVFP4 quantization across all capable backends."""
154-
for backend_name in capable_backends:
155-
inputs = ConstraintAwareTestInputs("quantize_nvfp4", backend_name, device)
156-
x = inputs.tensor("x", shape=(m, k), dtype=torch.bfloat16)
157-
x = x * 4 # Scale up for better test coverage
158-
159-
scale = torch.max(torch.abs(x)) / (F8_E4M3_MAX * F4_E2M1_MAX)
160-
scale = scale.to(torch.float32)
161-
162-
needs_padding = (m % 16 != 0) or (k % 16 != 0)
163-
164-
with ck.use_backend(backend_name):
165-
qx, sx = ck.quantize_nvfp4(x, scale, pad_16x=needs_padding)
166-
167-
assert qx.dtype == torch.uint8
168-
assert sx.dtype == torch.float8_e4m3fn
169-
170-
@pytest.mark.parametrize("m,k", [(512, 1024)])
171-
def test_quantize_nvfp4_cross_backend_consistency(
172-
self, capable_backends, device, seed, m, k
173-
):
174-
"""Test that all backends produce consistent NVFP4 results."""
175-
if len(capable_backends) < 2:
176-
pytest.skip("Need at least 2 backends for cross-validation")
152+
"""Test NVFP4 quantization across all capable backends with accuracy testing."""
153+
if "eager" not in capable_backends:
154+
pytest.skip("Need eager backend as reference")
177155

156+
# Create test input
178157
x = torch.randn(m, k, device=device, dtype=torch.bfloat16) * 4
179158
scale = torch.max(torch.abs(x)) / (F8_E4M3_MAX * F4_E2M1_MAX)
180159
scale = scale.to(torch.float32)
160+
needs_padding = (m % 16 != 0) or (k % 16 != 0)
161+
162+
with ck.use_backend("eager"):
163+
ref_qx, ref_sx = ck.quantize_nvfp4(x, scale, pad_16x=needs_padding)
181164

182-
results = {}
183165
for backend_name in capable_backends:
184166
with ck.use_backend(backend_name):
185-
qx, sx = ck.quantize_nvfp4(x, scale)
186-
results[backend_name] = (qx, sx)
167+
qx, sx = ck.quantize_nvfp4(x, scale, pad_16x=needs_padding)
187168

188-
# Compare all against first
189-
ref_backend = capable_backends[0]
190-
ref_qx, ref_sx = results[ref_backend]
169+
# Check basic properties
170+
assert qx.dtype == torch.uint8
171+
assert sx.dtype == torch.float8_e4m3fn
191172

192-
for backend_name, (qx, sx) in results.items():
193-
if backend_name != ref_backend:
194173
assert_values_close(
195174
sx.to(torch.float32),
196175
ref_sx.to(torch.float32),
197-
rtol=F8_E4M3_EPS,
198-
atol=F8_E4M3_EPS,
199-
name=f"scales ({backend_name} vs {ref_backend})"
176+
rtol=1e-5,
177+
atol=1e-3,
178+
name=f"scales ({backend_name} vs eager)"
200179
)
201180

202181
qx_f32 = fp4_x2_to_f32(qx)
203182
ref_qx_f32 = fp4_x2_to_f32(ref_qx)
204183
assert_values_close(
205184
qx_f32,
206185
ref_qx_f32,
207-
rtol=F4_E2M1_EPS,
208-
atol=F4_E2M1_EPS,
209-
name=f"quantized ({backend_name} vs {ref_backend})"
186+
rtol=1e-2,
187+
atol=2.0,
188+
name=f"quantized data ({backend_name} vs eager)"
210189
)
211190

212191
def test_quantize_nvfp4_cpu_fallback(self, seed):
@@ -235,28 +214,48 @@ def capable_backends(self, device):
235214
pytest.skip(f"No backend supports dequantize_nvfp4 on {device}")
236215
return backends
237216

238-
@pytest.mark.parametrize("m,k", [(1024, 2048), (512, 4096)])
217+
@pytest.mark.parametrize("m,k", [
218+
(1024, 2048),
219+
(512, 4096),
220+
(129, 128), # Edge case with padding
221+
])
239222
@pytest.mark.parametrize("output_dtype", [torch.float16, torch.bfloat16])
240223
def test_dequantize_nvfp4_all_backends(
241224
self, capable_backends, device, seed, m, k, output_dtype
242225
):
243-
"""Test NVFP4 dequantization across all capable backends."""
226+
"""Test NVFP4 dequantization across all capable backends with accuracy testing."""
227+
if "eager" not in capable_backends:
228+
pytest.skip("Need eager backend as reference")
229+
244230
x = torch.randn(m, k, device=device, dtype=torch.bfloat16) * 4
245231
scale = torch.max(torch.abs(x)) / (F8_E4M3_MAX * F4_E2M1_MAX)
246232
scale = scale.to(torch.float32)
233+
needs_padding = (m % 16 != 0) or (k % 16 != 0)
247234

248235
# Quantize with eager
249236
with ck.use_backend("eager"):
250-
qx, sx = ck.quantize_nvfp4(x, scale)
237+
qx, sx = ck.quantize_nvfp4(x, scale, pad_16x=needs_padding)
238+
ref_result = ck.dequantize_nvfp4(qx, scale, sx, output_type=output_dtype)
239+
# Unpad if needed
240+
ref_result = ref_result[:m, :k]
251241

252242
for backend_name in capable_backends:
253243
with ck.use_backend(backend_name):
254244
result = ck.dequantize_nvfp4(qx, scale, sx, output_type=output_dtype)
245+
result = result[:m, :k] # Unpad if needed
255246

256-
assert result.shape == x.shape
247+
assert result.shape == (m, k)
257248
assert result.dtype == output_dtype
258249
assert result.device == x.device
259250

251+
assert_values_close(
252+
result,
253+
ref_result,
254+
rtol=1e-3,
255+
atol=1e-2,
256+
name=f"dequantized output ({backend_name} vs eager)"
257+
)
258+
260259

261260
class TestScaledMMNVFP4:
262261
"""NVFP4 matrix multiplication tests."""

0 commit comments

Comments
 (0)