forked from BobJohnson24/ComfyUI-INT8-Fast
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathint8_unet_loader.py
More file actions
86 lines (72 loc) · 3.7 KB
/
int8_unet_loader.py
File metadata and controls
86 lines (72 loc) · 3.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import os
import torch
import folder_paths
import comfy.sd
import comfy.utils
from .int8_quant import Int8TensorwiseOps
class UNetLoaderINTW8A8:
"""
Load INT8 tensorwise quantized diffusion models.
Uses Int8TensorwiseOps for direct int8 loading.
Inference uses fast torch._int_mm for blazing speed. (insert rocket emoji, fire emoji to taste)
"""
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"unet_name": (folder_paths.get_filename_list("diffusion_models"),),
"weight_dtype": (["default", "fp8_e4m3fn", "fp16", "bf16"],),
"model_type": (["flux2", "z-image", "chroma", "wan", "ltx2", "qwen"],),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"
CATEGORY = "loaders"
DESCRIPTION = "Load INT8 tensorwise quantized models with fast torch._int_mm inference."
def load_unet(self, unet_name, weight_dtype, model_type):
unet_path = folder_paths.get_full_path("diffusion_models", unet_name)
# Use Int8TensorwiseOps for proper direct int8 loading
model_options = {"custom_operations": Int8TensorwiseOps}
# We need to peek at the model type to set exclusions for Flux
# ComfyUI loads metadata before the full model
from comfy.sd import load_diffusion_model
# Reset exclusions and pre-quantization flag (in case this is the second load)
Int8TensorwiseOps.excluded_names = []
Int8TensorwiseOps._is_prequantized = None
# Check explicit model_type for exclusions
if model_type == "flux2":
Int8TensorwiseOps.excluded_names = [
'img_in', 'time_in', 'guidance_in', 'txt_in', 'final_layer',
'double_stream_modulation_img', 'double_stream_modulation_txt',
'single_stream_modulation',
]
elif model_type == "z-image":
Int8TensorwiseOps.excluded_names = [
'cap_embedder', 't_embedder', 'x_embedder', 'cap_pad_token', 'context_refiner',
'final_layer', 'noise_refiner', 'adaLN',
'x_pad_token',
]
elif model_type == "chroma":
Int8TensorwiseOps.excluded_names = [
'distilled_guidance_layer', 'final_layer', 'img_in', 'txt_in', 'nerf_image_embedder',
'nerf_blocks', 'nerf_final_layer_conv', '__x0__', 'nerf_final_layer_conv',
]
elif model_type == "qwen":
Int8TensorwiseOps.excluded_names = [
'time_text_embed', 'img_in', 'norm_out', 'proj_out', 'txt_in'
]
elif model_type == "wan":
Int8TensorwiseOps.excluded_names = [
'patch_embedding', 'text_embedding', 'time_embedding', 'time_projection', 'head',
'img_emb',
]
elif model_type == "ltx2":
Int8TensorwiseOps.excluded_names = [
'adaln_single', 'audio_adaln_single', 'audio_caption_projection', 'audio_patchify_proj', 'audio_proj_out',
'audio_scale_shift_table', 'av_ca_a2v_gate_adaln_single', 'av_ca_audio_scale_shift_adaln_single', 'av_ca_v2a_gate_adaln_single',
'av_ca_video_scale_shift_adaln_single', 'caption_projection', 'patchify_proj', 'proj_out', 'scale_shift_table',
]
#print(f"Applying model-specific exclusions to Int8TensorwiseOps: {Int8TensorwiseOps.excluded_names}")
# Load model directly - Int8TensorwiseOps handles int8 weights natively
model = load_diffusion_model(unet_path, model_options=model_options)
return (model,)