forked from BobJohnson24/ComfyUI-INT8-Fast
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy path__init__.py
More file actions
114 lines (94 loc) · 4.12 KB
/
__init__.py
File metadata and controls
114 lines (94 loc) · 4.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
"""
int88 - Fast INT8 Tensorwise Quantization for ComfyUI
Provides:
- Int8TensorwiseOps: Custom operations for direct int8 weight loading
- OTUNetLoaderW8A8: Load int8 quantized diffusion models
- OTCheckpointLoaderW8A8: Load int8 quantized checkpoints
Uses torch._int_mm for fast inference.
"""
import logging
import torch
# =============================================================================
# Layout Registration
# =============================================================================
def _register_layouts():
"""
Register the Int8Tensorwise layout with ComfyUI's model management.
"""
try:
from comfy.quant_ops import QUANT_ALGOS, register_layout_class, QuantizedLayout
class Int8TensorwiseLayout(QuantizedLayout):
"""Minimal layout class to satisfy ComfyUI's registry requirements."""
class Params:
def __init__(self, scale=None, orig_dtype=None, orig_shape=None, **kwargs):
self.scale = scale
self.orig_dtype = orig_dtype
self.orig_shape = orig_shape
def clone(self):
return Int8TensorwiseLayout.Params(
scale=self.scale.clone() if isinstance(self.scale, torch.Tensor) else self.scale,
orig_dtype=self.orig_dtype,
orig_shape=self.orig_shape
)
@classmethod
def state_dict_tensors(cls, qdata, params):
return {"": qdata, "weight_scale": params.scale}
@classmethod
def dequantize(cls, qdata, params):
return qdata.float() * params.scale
# Register the class
register_layout_class("Int8TensorwiseLayout", Int8TensorwiseLayout)
# Register the Algo Config
QUANT_ALGOS.setdefault(
"int8_tensorwise",
{
"storage_t": torch.int8,
# We include input_scale here so ComfyUI extracts it from checkpoints if present,
# even though our LinearW8A8 implementation explicitly ignores it.
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "Int8TensorwiseLayout",
}
)
except ImportError:
logging.warning("Int88: ComfyUI Quantization system not found (Update ComfyUI?)")
except Exception as e:
logging.error(f"Int88: Failed to register layouts: {e}")
# =============================================================================
# Module Initialization
# =============================================================================
# 1. Register Layouts
_register_layouts()
# 2. Export Custom Ops (for external use)
try:
from .int8_quant import Int8TensorwiseOps
except ImportError:
Int8TensorwiseOps = None
# 3. Node Mappings
# Wrap imports in try/except to prevent total failure if dependencies are missing
try:
from .int8_unet_loader import UNetLoaderINTW8A8
from .int8_lora import INT8LoraLoader, INT8LoraLoaderStack
from .int8_dynamic_lora import INT8DynamicLoraLoader, INT8DynamicLoraStack
NODE_CLASS_MAPPINGS = {
"OTUNetLoaderW8A8": UNetLoaderINTW8A8,
"INT8LoraLoader": INT8LoraLoader,
"INT8LoraLoaderStack": INT8LoraLoaderStack,
"INT8DynamicLoraLoader": INT8DynamicLoraLoader,
"INT8DynamicLoraStack": INT8DynamicLoraStack,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"OTUNetLoaderW8A8": "Load Diffusion Model INT8 (W8A8)",
"INT8LoraLoader": "Load LoRA INT8 (Stochastic)",
"INT8LoraLoaderStack": "INT8 LoRA Stack (Stochastic)",
"INT8DynamicLoraLoader": "Load LoRA INT8 (Dynamic)",
"INT8DynamicLoraStack": "INT8 LoRA Stack (Dynamic)",
}
except ImportError as e:
logging.error(f"Int88: Failed to import nodes: {e}")
NODE_CLASS_MAPPINGS = {}
NODE_DISPLAY_NAME_MAPPINGS = {}
__all__ = [
"NODE_CLASS_MAPPINGS",
"NODE_DISPLAY_NAME_MAPPINGS",
"Int8TensorwiseOps",
]