11import torch
22import openvino as ov
33from typing_extensions import override
4+ from typing import Optional
45import openvino .frontend .pytorch .torchdynamo .execute as ov_ex
56from comfy_api .latest import ComfyExtension , io
67from comfy_api .torch_helpers import set_torch_compile_wrapper
78
9+ TORCH_COMPILE_KWARGS_VAE = "torch_compile_kwargs_vae"
10+
11+
12+ class VAECompileWrapper :
13+ """
14+ VAE compiler wrapper that mirrors set_torch_compile_wrapper
15+ Dynamically swaps modules during forward instead of using setattr directly
16+ """
17+ def __init__ (self , vae ):
18+ self .vae = vae
19+ self .first_stage = vae .first_stage_model
20+ self .compiled_modules = {}
21+ self .compile_kwargs = {}
22+ self .is_active = False
23+
24+ # Store original forward methods
25+ self .original_encode = None
26+ self .original_decode = None
27+
28+ def compile (self , backend : str , options : Optional [dict ] = None ,
29+ mode : Optional [str ] = None , fullgraph = False , dynamic : Optional [bool ] = None ,
30+ keys : Optional [list [str ]] = None ):
31+ """Compile specified VAE modules"""
32+
33+ # Clean previous compilation
34+ if self .is_active :
35+ self .remove ()
36+
37+ # Determine keys to compile
38+ if keys is None :
39+ keys = []
40+ if hasattr (self .first_stage , "taesd_encoder" ):
41+ keys = ["taesd_encoder" , "taesd_decoder" ]
42+ else :
43+ keys = ["encoder" , "decoder" ]
44+
45+ # Compile arguments
46+ compile_kwargs = {
47+ "backend" : backend ,
48+ "options" : options ,
49+ "mode" : mode ,
50+ "fullgraph" : fullgraph ,
51+ "dynamic" : dynamic ,
52+ }
53+ compile_kwargs = {k : v for k , v in compile_kwargs .items () if v is not None }
54+
55+ # Compile each module
56+ for key in keys :
57+ if not hasattr (self .first_stage , key ):
58+ continue
59+
60+ try :
61+ original_module = getattr (self .first_stage , key )
62+ # ✅ Only compile module without setattr
63+ compiled_module = torch .compile (original_module , ** compile_kwargs )
64+ self .compiled_modules [key ] = compiled_module
65+ print (f"✅ Successfully compiled VAE.{ key } " )
66+ except Exception as e :
67+ print (f"❌ Failed to compile VAE.{ key } : { e } " )
68+
69+ if self .compiled_modules :
70+ self .compile_kwargs = compile_kwargs
71+ self ._wrap_forward_methods ()
72+ self .is_active = True
73+
74+ # Store into vae_options
75+ if not hasattr (self .vae , 'vae_options' ):
76+ self .vae .vae_options = {}
77+ self .vae .vae_options [TORCH_COMPILE_KWARGS_VAE ] = compile_kwargs
78+
79+ def _wrap_forward_methods (self ):
80+ """Wrap encode/decode to use compiled modules at runtime"""
81+
82+ # Save original methods
83+ if hasattr (self .first_stage , 'encode' ):
84+ self .original_encode = self .first_stage .encode
85+ self .first_stage .encode = self ._create_encode_wrapper ()
86+
87+ if hasattr (self .first_stage , 'decode' ):
88+ self .original_decode = self .first_stage .decode
89+ self .first_stage .decode = self ._create_decode_wrapper ()
90+
91+ def _create_encode_wrapper (self ):
92+ """Create encode wrapper"""
93+ def encode_wrapper (x ):
94+ # Determine which encoder to use
95+ encoder_key = "taesd_encoder" if "taesd_encoder" in self .compiled_modules else "encoder"
96+
97+ if encoder_key in self .compiled_modules :
98+ # Temporarily replace encoder
99+ original_encoder = getattr (self .first_stage , encoder_key )
100+ try :
101+ # ✅ Use compiled encoder
102+ compiled_encoder = self .compiled_modules [encoder_key ]
103+ return compiled_encoder (x )
104+ except Exception as e :
105+ print (f"Compiled encoder execution failed, falling back to original: { e } " )
106+ return original_encoder (x )
107+ else :
108+ # Use original method
109+ return self .original_encode (x )
110+
111+ return encode_wrapper
112+
113+ def _create_decode_wrapper (self ):
114+ """Create decode wrapper"""
115+ def decode_wrapper (z ):
116+ # Determine which decoder to use
117+ decoder_key = "taesd_decoder" if "taesd_decoder" in self .compiled_modules else "decoder"
118+
119+ if decoder_key in self .compiled_modules :
120+ # Temporarily replace decoder
121+ original_decoder = getattr (self .first_stage , decoder_key )
122+ try :
123+ # ✅ Use compiled decoder
124+ compiled_decoder = self .compiled_modules [decoder_key ]
125+ return compiled_decoder (z )
126+ except Exception as e :
127+ print (f"Compiled decoder execution failed, falling back to original: { e } " )
128+ return original_decoder (z )
129+ else :
130+ # Use original method
131+ return self .original_decode (z )
132+
133+ return decode_wrapper
134+
135+ def remove (self ):
136+ """Remove compilation wrapper"""
137+ if not self .is_active :
138+ return
139+
140+ # Restore original methods
141+ if self .original_encode is not None :
142+ self .first_stage .encode = self .original_encode
143+ if self .original_decode is not None :
144+ self .first_stage .decode = self .original_decode
145+
146+ # Clean up
147+ self .compiled_modules .clear ()
148+ self .compile_kwargs .clear ()
149+ self .is_active = False
150+
151+ if hasattr (self .vae , 'vae_options' ) and TORCH_COMPILE_KWARGS_VAE in self .vae .vae_options :
152+ del self .vae .vae_options [TORCH_COMPILE_KWARGS_VAE ]
153+
154+ print ("✅ VAE compilation removed" )
155+
8156
9157class TorchCompileDiffusionOpenVINO (io .ComfyNode ):
10158 @classmethod
@@ -16,10 +164,7 @@ def define_schema(cls) -> io.Schema:
16164 category = "OpenVINO" ,
17165 inputs = [
18166 io .Model .Input ("model" ),
19- io .Combo .Input (
20- "device" ,
21- options = available_devices ,
22- ),
167+ io .Combo .Input ("device" , options = available_devices ),
23168 ],
24169 outputs = [io .Model .Output ()],
25170 is_experimental = True ,
@@ -31,9 +176,12 @@ def execute(cls, model, device) -> io.NodeOutput:
31176 ov_ex .compiled_cache .clear ()
32177 ov_ex .req_cache .clear ()
33178 ov_ex .partitioned_modules .clear ()
179+
34180 m = model .clone ()
35181 set_torch_compile_wrapper (
36- model = m , backend = "openvino" , options = {"device" : device }
182+ model = m ,
183+ backend = "openvino" ,
184+ options = {"device" : device }
37185 )
38186 return io .NodeOutput (m )
39187
@@ -48,64 +196,62 @@ def define_schema(cls) -> io.Schema:
48196 category = "OpenVINO" ,
49197 inputs = [
50198 io .Vae .Input ("vae" ),
51- io .Combo .Input (
52- "device" ,
53- options = available_devices ,
54- ),
55- io .Boolean .Input (
56- "compile_encoder" ,
57- default = True ,
58- ),
59- io .Boolean .Input (
60- "compile_decoder" ,
61- default = True ,
62- ),
199+ io .Combo .Input ("device" , options = available_devices ),
200+ io .Boolean .Input ("compile_encoder" , default = True ),
201+ io .Boolean .Input ("compile_decoder" , default = True ),
202+ io .Boolean .Input ("remove_compile" , default = False ,
203+ tooltip = "Remove VAE compilation" ),
63204 ],
64205 outputs = [io .Vae .Output ()],
65206 is_experimental = True ,
66207 )
67208
68209 @classmethod
69- def execute (cls , vae , device , compile_encoder , compile_decoder ) -> io .NodeOutput :
210+ def execute (cls , vae , device , compile_encoder , compile_decoder , remove_compile ) -> io .NodeOutput :
70211 torch ._dynamo .reset ()
71212 ov_ex .compiled_cache .clear ()
72213 ov_ex .req_cache .clear ()
73214 ov_ex .partitioned_modules .clear ()
215+
216+ # Get or create wrapper
217+ if not hasattr (vae , '_compile_wrapper' ):
218+ vae ._compile_wrapper = VAECompileWrapper (vae )
219+
220+ wrapper = vae ._compile_wrapper
221+
222+ # Remove compilation if requested
223+ if remove_compile :
224+ wrapper .remove ()
225+ return io .NodeOutput (vae )
226+
227+ # Otherwise compile as requested
228+ keys = []
229+ first_stage = vae .first_stage_model
230+ has_taesd = hasattr (first_stage , "taesd_encoder" )
231+
74232 if compile_encoder :
75- encoder_name = "encoder"
76- if hasattr (vae .first_stage_model , "taesd_encoder" ):
77- encoder_name = "taesd_encoder"
78-
79- setattr (
80- vae .first_stage_model ,
81- encoder_name ,
82- torch .compile (
83- getattr (vae .first_stage_model , encoder_name ),
84- backend = "openvino" ,
85- options = {"device" : device },
86- ),
87- )
233+ keys .append ("taesd_encoder" if has_taesd else "encoder" )
234+
88235 if compile_decoder :
89- decoder_name = "decoder"
90- if hasattr (vae .first_stage_model , "taesd_decoder" ):
91- decoder_name = "taesd_decoder"
92-
93- setattr (
94- vae .first_stage_model ,
95- decoder_name ,
96- torch .compile (
97- getattr (vae .first_stage_model , decoder_name ),
98- backend = "openvino" ,
99- options = {"device" : device },
100- ),
236+ keys .append ("taesd_decoder" if has_taesd else "decoder" )
237+
238+ if keys :
239+ wrapper .compile (
240+ backend = "openvino" ,
241+ options = {"device" : device },
242+ keys = keys ,
101243 )
244+
102245 return io .NodeOutput (vae )
103246
104247
105248class OpenVINOTorchCompileExtension (ComfyExtension ):
106249 @override
107250 async def get_node_list (self ) -> list [type [io .ComfyNode ]]:
108- return [TorchCompileDiffusionOpenVINO , TorchCompileVAEOpenVINO ]
251+ return [
252+ TorchCompileDiffusionOpenVINO ,
253+ TorchCompileVAEOpenVINO ,
254+ ]
109255
110256
111257async def comfy_entrypoint () -> OpenVINOTorchCompileExtension :
0 commit comments