Skip to content

Commit 7391815

Browse files
authored
Merge pull request #1 from Comfy-Org/refactor-1
Refactors to quantization code, fixes to README + LICENSE
2 parents 16604c5 + 12d61e4 commit 7391815

File tree

9 files changed

+184
-55
lines changed

9 files changed

+184
-55
lines changed

.github/workflows/build-wheels.yml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,15 @@ on:
44
push:
55
branches:
66
- main
7+
tags:
8+
- "v*"
79
pull_request:
810
workflow_dispatch:
911

12+
permissions:
13+
id-token: write
14+
contents: read
15+
1016
jobs:
1117
build_wheels:
1218
name: Build wheel for ${{ matrix.os }}
@@ -253,3 +259,24 @@ jobs:
253259
# Tests requiring CUDA will be skipped on CPU-only runners
254260
python -m pytest tests/ -v --tb=short
255261
262+
publish:
263+
name: Publish to PyPI
264+
needs: [build_wheels, test]
265+
runs-on: ubuntu-latest
266+
if: github.event_name == 'push' && startsWith(github.ref, 'refs/tags/v')
267+
environment: pypi
268+
269+
steps:
270+
- name: Download all wheel artifacts
271+
uses: actions/download-artifact@v4
272+
with:
273+
pattern: wheels-*
274+
path: dist/
275+
merge-multiple: true
276+
277+
- name: List wheels to publish
278+
run: ls -la dist/
279+
280+
- name: Publish to PyPI
281+
uses: pypa/gh-action-pypi-publish@release/v1
282+

LICENSE

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
Apache License
23
Version 2.0, January 2004
34
http://www.apache.org/licenses/
@@ -162,7 +163,7 @@
162163
other commercial damages or losses), even if such Contributor
163164
has been advised of the possibility of such damages.
164165

165-
9. Accepting Warranty or Additional Liability. While redistributing
166+
9. Accepting Warranty or Additional Support. While redistributing
166167
the Work or Derivative Works thereof, You may choose to offer,
167168
and charge a fee for, acceptance of support, warranty, indemnity,
168169
or other liability obligations and/or rights consistent with this
@@ -186,7 +187,7 @@
186187
same "printed page" as the copyright notice for easier
187188
identification within third-party archives.
188189

189-
Copyright [yyyy] [name of copyright owner]
190+
Copyright (c) 2025 Comfy Org. All rights reserved.
190191

191192
Licensed under the Apache License, Version 2.0 (the "License");
192193
you may not use this file except in compliance with the License.

README.md

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -68,18 +68,29 @@ pip install -e ".[dev]"
6868
# Skip build isolation for faster rebuilds
6969
pip install -e . --no-build-isolation -v
7070

71-
# Install without CUDA backend
72-
pip install . --no-cuda
7371
```
7472

7573
#### Available Build Options
7674

77-
| Option | Description | Default |
78-
|--------|-------------|---------|
79-
| `--no-cuda` | Build without CUDA backend | Enabled (build with CUDA) |
80-
| `--cuda-archs=...` | CUDA architectures to build for | Windows: `80;89;120f`<br>Linux: `80;89;90a;100a;120f` |
81-
| `--debug-build` | Build in debug mode with symbols | Disabled (Release) |
82-
| `--lineinfo` | Enable NVCC line information for profiling | Disabled |
75+
These options require using `setup.py` directly (not `pip install`):
76+
77+
| Option | Command | Description | Default |
78+
|--------|---------|-------------|---------|
79+
| `--no-cuda` | `python setup.py bdist_wheel --no-cuda` | Build without CUDA backend | Enabled (build with CUDA) |
80+
| `--cuda-archs=...` | `python setup.py build_ext --cuda-archs="80;89"` | CUDA architectures to build for | Windows: `80;89;120f`<br>Linux: `80;89;90a;100f;120f` |
81+
| `--debug-build` | `python setup.py build_ext --debug-build` | Build in debug mode with symbols | Disabled (Release) |
82+
| `--lineinfo` | `python setup.py build_ext --lineinfo` | Enable NVCC line info for profiling | Disabled |
83+
84+
```bash
85+
# Build without CUDA
86+
python setup.py bdist_wheel --no-cuda
87+
88+
# Build with custom CUDA architectures
89+
python setup.py build_ext --cuda-archs="80;89" bdist_wheel
90+
91+
# Debug build with line info for profiling
92+
python setup.py build_ext --debug-build --lineinfo bdist_wheel
93+
```
8394

8495

8596

comfy_kitchen/tensor/base.py

Lines changed: 59 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Base classes for quantized tensors with typed layout parameters."""
22
from __future__ import annotations
33

4+
import contextlib
45
import dataclasses
56
import logging
67
from abc import ABC, abstractmethod
@@ -90,6 +91,11 @@ def dequantize(cls, qdata: torch.Tensor, params: Any) -> torch.Tensor:
9091
def get_plain_tensors(cls, qtensor: QuantizedTensor) -> tuple[torch.Tensor, ...]:
9192
raise NotImplementedError
9293

94+
@classmethod
95+
@abstractmethod
96+
def state_dict_tensors(cls, qdata: torch.Tensor, params: Any) -> dict[str, torch.Tensor]:
97+
raise NotImplementedError
98+
9399
@classmethod
94100
def supports_fast_matmul(cls) -> bool:
95101
"""Check if fast quantized matmul is supported on current hardware."""
@@ -263,6 +269,10 @@ def dequantize(self) -> torch.Tensor:
263269
return full[slices]
264270
return full
265271

272+
def state_dict(self, prefix: str = "") -> dict[str, torch.Tensor]:
273+
tensors = self._layout_cls.state_dict_tensors(self._qdata, self._params)
274+
return {f"{prefix}{suffix}": tensor for suffix, tensor in tensors.items()}
275+
266276
# ==================== Flatten/Unflatten Protocol ====================
267277

268278
def __tensor_flatten__(self):
@@ -344,6 +354,23 @@ def dequantize_args(args):
344354

345355
# ==================== Dispatch Handlers ====================
346356

357+
def _parse_to_args(args, kwargs):
358+
"""Extract device and dtype from .to() arguments."""
359+
device = kwargs.get("device")
360+
dtype = kwargs.get("dtype")
361+
for arg in args[1:]:
362+
if isinstance(arg, torch.device):
363+
device = arg
364+
elif isinstance(arg, torch.dtype):
365+
dtype = arg
366+
elif isinstance(arg, str):
367+
with contextlib.suppress(Exception):
368+
device = torch.device(arg)
369+
if isinstance(device, str):
370+
device = torch.device(device)
371+
return device, dtype
372+
373+
347374
def _handle_detach(qt, args, kwargs):
348375
return qt._copy_with(qdata=qt._qdata.detach())
349376

@@ -352,29 +379,32 @@ def _handle_clone(qt, args, kwargs):
352379
return qt._copy_with(qdata=qt._qdata.clone())
353380

354381

355-
def _handle_to(qt, args, kwargs):
356-
"""Unified handler for device/dtype changes."""
357-
target_device = kwargs.get("device")
358-
target_dtype = kwargs.get("dtype")
359-
360-
if isinstance(target_device, str):
361-
target_device = torch.device(target_device)
382+
def _handle_to(qt, args, kwargs, force_copy=False):
383+
target_device, target_dtype = _parse_to_args(args, kwargs)
362384

363385
needs_device = target_device is not None and target_device != qt._qdata.device
364386
needs_dtype = target_dtype is not None and target_dtype != qt._params.orig_dtype
365387

366-
if not needs_device and not needs_dtype:
388+
if not needs_device and not needs_dtype and not force_copy:
367389
return qt
368390

369-
new_qdata = qt._qdata.to(device=target_device) if needs_device else qt._qdata
370-
new_params = qt._params.clone()
371391
if needs_device:
372-
new_params = new_params.to_device(target_device)
392+
new_qdata = qt._qdata.to(device=target_device)
393+
new_params = qt._params.to_device(target_device)
394+
else:
395+
new_qdata = qt._qdata.clone() if force_copy else qt._qdata
396+
new_params = qt._params.clone()
397+
373398
if needs_dtype:
374399
new_params.orig_dtype = target_dtype
400+
375401
return qt._copy_with(qdata=new_qdata, params=new_params, clone_params=False)
376402

377403

404+
def _handle_to_copy(qt, args, kwargs):
405+
return _handle_to(qt, args, kwargs, force_copy=True)
406+
407+
378408
def _handle_contiguous(qt, args, kwargs):
379409
if qt._qdata.is_contiguous():
380410
return qt
@@ -386,53 +416,46 @@ def _handle_is_contiguous(qt, args, kwargs):
386416

387417

388418
def _handle_copy_(qt, args, kwargs):
389-
"""Handle in-place copy between QuantizedTensors.
390-
391-
Raises:
392-
TypeError: If src is not a QuantizedTensor or layouts don't match.
393-
"""
394419
dst, src = args[0], args[1]
395-
non_blocking = kwargs.get("non_blocking", False)
396-
if len(args) >= 3:
397-
non_blocking = True
398420
if not isinstance(src, QuantizedTensor):
399-
raise TypeError(
400-
f"Cannot copy {type(src).__name__} to QuantizedTensor. "
401-
"Use QuantizedTensor.from_float() to create a new quantized tensor."
402-
)
421+
raise TypeError(f"Cannot copy {type(src).__name__} to QuantizedTensor")
403422
if dst._layout_cls != src._layout_cls:
404-
raise TypeError(
405-
f"Cannot copy between different layouts: "
406-
f"{dst._layout_cls.__name__} vs {src._layout_cls.__name__}"
407-
)
423+
raise TypeError(f"Layout mismatch: {dst._layout_cls.__name__} vs {src._layout_cls.__name__}")
424+
425+
dst_orig_dtype = dst._params.orig_dtype
426+
non_blocking = kwargs.get("non_blocking", len(args) >= 3)
427+
408428
dst._qdata.copy_(src._qdata, non_blocking=non_blocking)
409429
dst._params.copy_from(src._params, non_blocking=non_blocking)
430+
dst._params.orig_dtype = dst_orig_dtype
410431
return dst
411432

412433

413434
def _handle_empty_like(qt, args, kwargs):
414-
new_qdata = torch.empty_like(qt._qdata, **kwargs)
435+
target_dtype = kwargs.pop("dtype", None)
436+
target_device = kwargs.get("device")
437+
438+
new_qdata = torch.empty_like(qt._qdata, device=target_device)
415439
new_params = qt._params.clone()
416-
if "device" in kwargs:
417-
new_params = new_params.to_device(kwargs["device"])
418-
return qt._copy_with(qdata=new_qdata, params=new_params, clone_params=False)
419440

441+
if target_device is not None:
442+
new_params = new_params.to_device(target_device)
443+
if target_dtype is not None:
444+
new_params.orig_dtype = target_dtype
420445

421-
def _handle_has_compatible_shallow_copy_type(qt, args, kwargs):
422-
"""QuantizedTensors support shallow copy compatibility."""
423-
return True
446+
return qt._copy_with(qdata=new_qdata, params=new_params, clone_params=False)
424447

425448

426449
_DISPATCH_TABLE = {
427450
torch.ops.aten.detach.default: _handle_detach,
428451
torch.ops.aten.clone.default: _handle_clone,
429-
torch.ops.aten._to_copy.default: _handle_to,
452+
torch.ops.aten._to_copy.default: _handle_to_copy,
430453
torch.ops.aten.to.dtype_layout: _handle_to,
431454
torch.ops.aten.contiguous.default: _handle_contiguous,
432455
torch.ops.aten.is_contiguous.default: _handle_is_contiguous,
433456
torch.ops.aten.copy_.default: _handle_copy_,
434457
torch.ops.aten.empty_like.default: _handle_empty_like,
435-
torch.ops.aten._has_compatible_shallow_copy_type.default: _handle_has_compatible_shallow_copy_type,
458+
torch.ops.aten._has_compatible_shallow_copy_type.default: lambda qt, args, kwargs: True,
436459
}
437460

438461
# Layout-specific dispatch table: {torch_op: {layout_cls: handler}}

comfy_kitchen/tensor/fp8.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,13 +44,14 @@ class Params(BaseLayoutParams):
4444
def quantize(
4545
cls,
4646
tensor: torch.Tensor,
47-
scale: torch.Tensor | float | None = None,
47+
scale: torch.Tensor | float | str | None = None,
4848
dtype: torch.dtype = torch.float8_e4m3fn,
49+
**kwargs,
4950
) -> tuple[torch.Tensor, Params]:
5051
orig_dtype = tensor.dtype
5152
orig_shape = tuple(tensor.shape)
5253

53-
if scale is None:
54+
if scale is None or scale == "recalculate":
5455
scale = torch.amax(tensor.abs()) / torch.finfo(dtype).max
5556

5657
if not isinstance(scale, torch.Tensor):
@@ -69,6 +70,14 @@ def dequantize(cls, qdata: torch.Tensor, params: Params) -> torch.Tensor:
6970
def get_plain_tensors(cls, qtensor: QuantizedTensor) -> tuple[torch.Tensor, torch.Tensor]:
7071
return qtensor._qdata, qtensor._params.scale
7172

73+
@classmethod
74+
def state_dict_tensors(cls, qdata: torch.Tensor, params: Params) -> dict[str, torch.Tensor]:
75+
"""Return key suffix → tensor mapping for serialization."""
76+
return {
77+
"": qdata,
78+
"_scale": params.scale,
79+
}
80+
7281

7382
# ==================== Helper Utilities ====================
7483

comfy_kitchen/tensor/nvfp4.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,16 @@ def _tensor_fields(self) -> list[str]:
4848
def quantize(
4949
cls,
5050
tensor: torch.Tensor,
51-
scale: torch.Tensor | float | None = None,
51+
scale: torch.Tensor | float | str | None = None,
52+
**kwargs,
5253
) -> tuple[torch.Tensor, Params]:
5354
if tensor.dim() != 2:
5455
raise ValueError(f"NVFP4 requires 2D tensor, got {tensor.dim()}D")
5556

5657
orig_dtype = tensor.dtype
5758
orig_shape = tuple(tensor.shape)
5859

59-
if scale is None:
60+
if scale is None or scale == "recalculate":
6061
scale = torch.amax(tensor.abs()) / (F8_E4M3_MAX * F4_E2M1_MAX)
6162

6263
if not isinstance(scale, torch.Tensor):
@@ -86,6 +87,15 @@ def get_plain_tensors(
8687
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
8788
return qtensor._qdata, qtensor._params.scale, qtensor._params.block_scale
8889

90+
@classmethod
91+
def state_dict_tensors(cls, qdata: torch.Tensor, params: Params) -> dict[str, torch.Tensor]:
92+
"""Return key suffix → tensor mapping for serialization."""
93+
return {
94+
"": qdata,
95+
"_scale": params.block_scale,
96+
"_scale_2": params.scale,
97+
}
98+
8999
@classmethod
90100
def get_padded_shape(cls, orig_shape: tuple[int, ...]) -> tuple[int, ...]:
91101
if len(orig_shape) != 2:

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@ classifiers = [
2323
]
2424
dependencies = [
2525
"torch>=2.5.0",
26-
"cuda-core[cu13]>=0.3.2,<1.0",
2726
"nvidia-cublas>=13.0.0",
2827
]
2928

setup.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,8 +289,12 @@ def get_cmdclass(has_extensions):
289289
class CUDABdistWheel(bdist_wheel):
290290
def finalize_options(self):
291291
super().finalize_options()
292-
# Add CUDA version as local version identifier (e.g., 0.1.0+cu128)
293292
if not BUILD_NO_CUDA:
293+
# Set stable ABI tag: cp312-abi3 instead of cp312-cp312
294+
# This indicates the extension uses Python's Limited API
295+
self.py_limited_api = "cp312"
296+
297+
# Add CUDA version as local version identifier (e.g., 0.1.0+cu128)
294298
cuda_version = get_cuda_version()
295299
if cuda_version and self.distribution.metadata.version:
296300
cuda_tag = f"cu{cuda_version[0]}{cuda_version[1]}"
@@ -343,9 +347,9 @@ def get_packages():
343347

344348
setup_kwargs.update({
345349
"packages": get_packages(),
346-
"name": "comfy-kitchen-no-cuda",
350+
"name": "comfy-kitchen",
347351
"version": version,
348-
"description": f"{description} (CPU-only, no CUDA)",
352+
"description": f"{description} (CPU-only)",
349353
"include_package_data": False,
350354
"install_requires": [
351355
"torch>=2.5.0",

0 commit comments

Comments
 (0)