-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathint8_dynamic_lora.py
More file actions
103 lines (86 loc) · 3.59 KB
/
int8_dynamic_lora.py
File metadata and controls
103 lines (86 loc) · 3.59 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
import torch
import folder_paths
import comfy.utils
import comfy.lora
import logging
from torch import nn
class INT8DynamicLoraLoader:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL",),
"lora_name": (folder_paths.get_filename_list("loras"),),
"strength": ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01}),
}
}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_lora"
CATEGORY = "loaders"
def load_lora(self, model, lora_name, strength):
if strength == 0:
return (model,)
lora_path = folder_paths.get_full_path("loras", lora_name)
lora = comfy.utils.load_torch_file(lora_path, safe_load=True)
model_patcher = model.clone()
# 1. Get Patch Map
key_map = {}
if model_patcher.model.model_type.name != "ModelType.CLIP":
key_map = comfy.lora.model_lora_keys_unet(model_patcher.model, key_map)
patch_dict = comfy.lora.load_lora(lora, key_map, log_missing=True)
# 2. Register Global Hook (if not exists)
from .int8_quant import DynamicLoRAHook
DynamicLoRAHook.register(model_patcher.model.diffusion_model)
# 3. Add to Dynamic LoRA list in transformer_options
# This ensures ComfyUI's cloning handles everything and it's non-sticky
if "transformer_options" not in model_patcher.model_options:
model_patcher.model_options["transformer_options"] = {}
opts = model_patcher.model_options["transformer_options"]
if "dynamic_loras" not in opts:
opts["dynamic_loras"] = []
else:
# Shallow copy the list to avoid modifying the parent patcher's list
opts["dynamic_loras"] = opts["dynamic_loras"].copy()
opts["dynamic_loras"].append({
"name": lora_name,
"strength": strength,
"patches": patch_dict
})
return (model_patcher,)
class INT8DynamicLoraStack:
"""
Apply multiple LoRAs in one node for efficiency.
"""
@classmethod
def INPUT_TYPES(s):
inputs = {
"required": {"model": ("MODEL",)},
"optional": {},
}
lora_list = ["None"] + folder_paths.get_filename_list("loras")
for i in range(1, 11):
inputs["optional"][f"lora_{i}"] = (lora_list,)
inputs["optional"][f"strength_{i}"] = ("FLOAT", {"default": 1.0, "min": -20.0, "max": 20.0, "step": 0.01})
return inputs
RETURN_TYPES = ("MODEL",)
FUNCTION = "apply_stack"
CATEGORY = "loaders"
def apply_stack(self, model, **kwargs):
loader = INT8DynamicLoraLoader()
current_model = model
for i in range(1, 11):
lora_name = kwargs.get(f"lora_{i}")
strength = kwargs.get(f"strength_{i}", 0)
if lora_name and lora_name != "None" and strength != 0:
# We can optimize this by NOT cloning and re-hooking 10 times,
# but for simplicity/reliability, we'll use the loader.
(current_model,) = loader.load_lora(current_model, lora_name, strength)
return (current_model,)
NODE_CLASS_MAPPINGS = {
"INT8DynamicLoraLoader": INT8DynamicLoraLoader,
"INT8DynamicLoraStack": INT8DynamicLoraStack,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"INT8DynamicLoraLoader": "Load LoRA INT8 (Dynamic)",
"INT8DynamicLoraStack": "INT8 LoRA Stack (Dynamic)",
}