Skip to content

Commit 666a2f5

Browse files
fix the issue of recompile
1 parent e1184b6 commit 666a2f5

2 files changed

Lines changed: 191 additions & 45 deletions

File tree

node_openvino.py

Lines changed: 190 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,158 @@
11
import torch
22
import openvino as ov
33
from typing_extensions import override
4+
from typing import Optional
45
import openvino.frontend.pytorch.torchdynamo.execute as ov_ex
56
from comfy_api.latest import ComfyExtension, io
67
from comfy_api.torch_helpers import set_torch_compile_wrapper
78

9+
TORCH_COMPILE_KWARGS_VAE = "torch_compile_kwargs_vae"
10+
11+
12+
class VAECompileWrapper:
13+
"""
14+
VAE compiler wrapper that mirrors set_torch_compile_wrapper
15+
Dynamically swaps modules during forward instead of using setattr directly
16+
"""
17+
def __init__(self, vae):
18+
self.vae = vae
19+
self.first_stage = vae.first_stage_model
20+
self.compiled_modules = {}
21+
self.compile_kwargs = {}
22+
self.is_active = False
23+
24+
# Store original forward methods
25+
self.original_encode = None
26+
self.original_decode = None
27+
28+
def compile(self, backend: str, options: Optional[dict] = None,
29+
mode: Optional[str] = None, fullgraph=False, dynamic: Optional[bool] = None,
30+
keys: Optional[list[str]] = None):
31+
"""Compile specified VAE modules"""
32+
33+
# Clean previous compilation
34+
if self.is_active:
35+
self.remove()
36+
37+
# Determine keys to compile
38+
if keys is None:
39+
keys = []
40+
if hasattr(self.first_stage, "taesd_encoder"):
41+
keys = ["taesd_encoder", "taesd_decoder"]
42+
else:
43+
keys = ["encoder", "decoder"]
44+
45+
# Compile arguments
46+
compile_kwargs = {
47+
"backend": backend,
48+
"options": options,
49+
"mode": mode,
50+
"fullgraph": fullgraph,
51+
"dynamic": dynamic,
52+
}
53+
compile_kwargs = {k: v for k, v in compile_kwargs.items() if v is not None}
54+
55+
# Compile each module
56+
for key in keys:
57+
if not hasattr(self.first_stage, key):
58+
continue
59+
60+
try:
61+
original_module = getattr(self.first_stage, key)
62+
# ✅ Only compile module without setattr
63+
compiled_module = torch.compile(original_module, **compile_kwargs)
64+
self.compiled_modules[key] = compiled_module
65+
print(f"✅ Successfully compiled VAE.{key}")
66+
except Exception as e:
67+
print(f"❌ Failed to compile VAE.{key}: {e}")
68+
69+
if self.compiled_modules:
70+
self.compile_kwargs = compile_kwargs
71+
self._wrap_forward_methods()
72+
self.is_active = True
73+
74+
# Store into vae_options
75+
if not hasattr(self.vae, 'vae_options'):
76+
self.vae.vae_options = {}
77+
self.vae.vae_options[TORCH_COMPILE_KWARGS_VAE] = compile_kwargs
78+
79+
def _wrap_forward_methods(self):
80+
"""Wrap encode/decode to use compiled modules at runtime"""
81+
82+
# Save original methods
83+
if hasattr(self.first_stage, 'encode'):
84+
self.original_encode = self.first_stage.encode
85+
self.first_stage.encode = self._create_encode_wrapper()
86+
87+
if hasattr(self.first_stage, 'decode'):
88+
self.original_decode = self.first_stage.decode
89+
self.first_stage.decode = self._create_decode_wrapper()
90+
91+
def _create_encode_wrapper(self):
92+
"""Create encode wrapper"""
93+
def encode_wrapper(x):
94+
# Determine which encoder to use
95+
encoder_key = "taesd_encoder" if "taesd_encoder" in self.compiled_modules else "encoder"
96+
97+
if encoder_key in self.compiled_modules:
98+
# Temporarily replace encoder
99+
original_encoder = getattr(self.first_stage, encoder_key)
100+
try:
101+
# ✅ Use compiled encoder
102+
compiled_encoder = self.compiled_modules[encoder_key]
103+
return compiled_encoder(x)
104+
except Exception as e:
105+
print(f"Compiled encoder execution failed, falling back to original: {e}")
106+
return original_encoder(x)
107+
else:
108+
# Use original method
109+
return self.original_encode(x)
110+
111+
return encode_wrapper
112+
113+
def _create_decode_wrapper(self):
114+
"""Create decode wrapper"""
115+
def decode_wrapper(z):
116+
# Determine which decoder to use
117+
decoder_key = "taesd_decoder" if "taesd_decoder" in self.compiled_modules else "decoder"
118+
119+
if decoder_key in self.compiled_modules:
120+
# Temporarily replace decoder
121+
original_decoder = getattr(self.first_stage, decoder_key)
122+
try:
123+
# ✅ Use compiled decoder
124+
compiled_decoder = self.compiled_modules[decoder_key]
125+
return compiled_decoder(z)
126+
except Exception as e:
127+
print(f"Compiled decoder execution failed, falling back to original: {e}")
128+
return original_decoder(z)
129+
else:
130+
# Use original method
131+
return self.original_decode(z)
132+
133+
return decode_wrapper
134+
135+
def remove(self):
136+
"""Remove compilation wrapper"""
137+
if not self.is_active:
138+
return
139+
140+
# Restore original methods
141+
if self.original_encode is not None:
142+
self.first_stage.encode = self.original_encode
143+
if self.original_decode is not None:
144+
self.first_stage.decode = self.original_decode
145+
146+
# Clean up
147+
self.compiled_modules.clear()
148+
self.compile_kwargs.clear()
149+
self.is_active = False
150+
151+
if hasattr(self.vae, 'vae_options') and TORCH_COMPILE_KWARGS_VAE in self.vae.vae_options:
152+
del self.vae.vae_options[TORCH_COMPILE_KWARGS_VAE]
153+
154+
print("✅ VAE compilation removed")
155+
8156

9157
class TorchCompileDiffusionOpenVINO(io.ComfyNode):
10158
@classmethod
@@ -16,10 +164,7 @@ def define_schema(cls) -> io.Schema:
16164
category="OpenVINO",
17165
inputs=[
18166
io.Model.Input("model"),
19-
io.Combo.Input(
20-
"device",
21-
options=available_devices,
22-
),
167+
io.Combo.Input("device", options=available_devices),
23168
],
24169
outputs=[io.Model.Output()],
25170
is_experimental=True,
@@ -31,9 +176,12 @@ def execute(cls, model, device) -> io.NodeOutput:
31176
ov_ex.compiled_cache.clear()
32177
ov_ex.req_cache.clear()
33178
ov_ex.partitioned_modules.clear()
179+
34180
m = model.clone()
35181
set_torch_compile_wrapper(
36-
model=m, backend="openvino", options={"device": device}
182+
model=m,
183+
backend="openvino",
184+
options={"device": device}
37185
)
38186
return io.NodeOutput(m)
39187

@@ -48,64 +196,62 @@ def define_schema(cls) -> io.Schema:
48196
category="OpenVINO",
49197
inputs=[
50198
io.Vae.Input("vae"),
51-
io.Combo.Input(
52-
"device",
53-
options=available_devices,
54-
),
55-
io.Boolean.Input(
56-
"compile_encoder",
57-
default=True,
58-
),
59-
io.Boolean.Input(
60-
"compile_decoder",
61-
default=True,
62-
),
199+
io.Combo.Input("device", options=available_devices),
200+
io.Boolean.Input("compile_encoder", default=True),
201+
io.Boolean.Input("compile_decoder", default=True),
202+
io.Boolean.Input("remove_compile", default=False,
203+
tooltip="Remove VAE compilation"),
63204
],
64205
outputs=[io.Vae.Output()],
65206
is_experimental=True,
66207
)
67208

68209
@classmethod
69-
def execute(cls, vae, device, compile_encoder, compile_decoder) -> io.NodeOutput:
210+
def execute(cls, vae, device, compile_encoder, compile_decoder, remove_compile) -> io.NodeOutput:
70211
torch._dynamo.reset()
71212
ov_ex.compiled_cache.clear()
72213
ov_ex.req_cache.clear()
73214
ov_ex.partitioned_modules.clear()
215+
216+
# Get or create wrapper
217+
if not hasattr(vae, '_compile_wrapper'):
218+
vae._compile_wrapper = VAECompileWrapper(vae)
219+
220+
wrapper = vae._compile_wrapper
221+
222+
# Remove compilation if requested
223+
if remove_compile:
224+
wrapper.remove()
225+
return io.NodeOutput(vae)
226+
227+
# Otherwise compile as requested
228+
keys = []
229+
first_stage = vae.first_stage_model
230+
has_taesd = hasattr(first_stage, "taesd_encoder")
231+
74232
if compile_encoder:
75-
encoder_name = "encoder"
76-
if hasattr(vae.first_stage_model, "taesd_encoder"):
77-
encoder_name = "taesd_encoder"
78-
79-
setattr(
80-
vae.first_stage_model,
81-
encoder_name,
82-
torch.compile(
83-
getattr(vae.first_stage_model, encoder_name),
84-
backend="openvino",
85-
options={"device": device},
86-
),
87-
)
233+
keys.append("taesd_encoder" if has_taesd else "encoder")
234+
88235
if compile_decoder:
89-
decoder_name = "decoder"
90-
if hasattr(vae.first_stage_model, "taesd_decoder"):
91-
decoder_name = "taesd_decoder"
92-
93-
setattr(
94-
vae.first_stage_model,
95-
decoder_name,
96-
torch.compile(
97-
getattr(vae.first_stage_model, decoder_name),
98-
backend="openvino",
99-
options={"device": device},
100-
),
236+
keys.append("taesd_decoder" if has_taesd else "decoder")
237+
238+
if keys:
239+
wrapper.compile(
240+
backend="openvino",
241+
options={"device": device},
242+
keys=keys,
101243
)
244+
102245
return io.NodeOutput(vae)
103246

104247

105248
class OpenVINOTorchCompileExtension(ComfyExtension):
106249
@override
107250
async def get_node_list(self) -> list[type[io.ComfyNode]]:
108-
return [TorchCompileDiffusionOpenVINO, TorchCompileVAEOpenVINO]
251+
return [
252+
TorchCompileDiffusionOpenVINO,
253+
TorchCompileVAEOpenVINO,
254+
]
109255

110256

111257
async def comfy_entrypoint() -> OpenVINOTorchCompileExtension:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
[project]
22
name = "comfyui-openvino"
33
description = "OpenVINO node is designed for optimizing the performance of model inference in ComfyUI by leveraging Intel OpenVINO toolkits. It can support running model on Intel CPU, GPU and NPU device."
4-
version = "1.1.1"
4+
version = "1.1.2"
55
license = {file = "LICENSE"}
66

77
[project.urls]

0 commit comments

Comments
 (0)