-
Notifications
You must be signed in to change notification settings - Fork 4
Expand file tree
/
Copy pathhunyuan_vae_simple.py
More file actions
301 lines (232 loc) · 9.11 KB
/
hunyuan_vae_simple.py
File metadata and controls
301 lines (232 loc) · 9.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
"""
HunyuanImage-3.0 Simple VAE Manager - V2 Unified Node Support
Simple, hook-free VAE management using explicit .to() calls.
Author: Eric Hiss (GitHub: EricRollei)
License: Dual License (Non-Commercial and Commercial Use)
Copyright (c) 2025 Eric Hiss. All rights reserved.
"""
import logging
import time
from enum import Enum
from typing import Any, Optional
import torch
logger = logging.getLogger(__name__)
class VAEPlacement(Enum):
"""VAE placement strategy."""
ALWAYS_GPU = "always_gpu" # VAE stays on GPU at all times
MANAGED = "managed" # Move to GPU for decode, back to CPU after
class SimpleVAEManager:
"""
Simple VAE management without hooks.
This manager uses explicit .to() calls instead of hooks to control
VAE placement. This avoids conflicts with other ComfyUI extensions
and provides predictable behavior.
Usage:
manager = SimpleVAEManager(model, VAEPlacement.MANAGED)
manager.setup_initial_placement()
# Before decode:
manager.prepare_for_decode()
images = vae.decode(latents)
manager.cleanup_after_decode()
"""
def __init__(
self,
model: Any,
placement: VAEPlacement,
device: str = "cuda:0",
dtype: torch.dtype = torch.bfloat16
):
"""
Initialize VAE manager.
Args:
model: The Hunyuan model containing .vae attribute
placement: VAE placement strategy
device: Target GPU device
dtype: Target dtype for VAE (default bfloat16)
"""
self.model = model
self.placement = placement
self.device = torch.device(device)
self.dtype = dtype
self._vae_on_gpu = False
self._original_device = None
@property
def vae(self):
"""Access the VAE from the model."""
if not hasattr(self.model, 'vae'):
raise AttributeError("Model does not have a 'vae' attribute")
return self.model.vae
@property
def is_on_gpu(self) -> bool:
"""Check if VAE is currently on GPU."""
return self._vae_on_gpu
def setup_initial_placement(self) -> None:
"""
Set up initial VAE placement based on configuration.
Call this after model loading to establish VAE position.
"""
if not hasattr(self.model, 'vae'):
logger.warning("Model has no VAE attribute, skipping VAE setup")
return
if self.placement == VAEPlacement.ALWAYS_GPU:
self._move_vae_to_gpu()
logger.info(f"VAE placement: always_gpu ({self.device})")
else:
self._move_vae_to_cpu()
logger.info("VAE placement: managed (parked on CPU)")
def prepare_for_decode(self) -> float:
"""
Prepare VAE for decoding by moving to GPU if needed.
Returns:
Time taken to move VAE (0 if already on GPU)
"""
if self._vae_on_gpu:
logger.debug("VAE already on GPU, no movement needed")
return 0.0
start_time = time.time()
self._move_vae_to_gpu()
elapsed = time.time() - start_time
logger.info(f"VAE moved to GPU in {elapsed:.2f}s for decode")
return elapsed
def cleanup_after_decode(self) -> float:
"""
Clean up VAE after decoding.
For MANAGED placement, moves VAE back to CPU.
For ALWAYS_GPU, does nothing.
Returns:
Time taken to move VAE (0 if not moved)
"""
if self.placement == VAEPlacement.ALWAYS_GPU:
logger.debug("VAE placement is always_gpu, keeping on GPU")
return 0.0
if not self._vae_on_gpu:
logger.debug("VAE already on CPU, no movement needed")
return 0.0
start_time = time.time()
self._move_vae_to_cpu()
elapsed = time.time() - start_time
# Clear CUDA cache after moving VAE off GPU
torch.cuda.empty_cache()
logger.info(f"VAE moved back to CPU in {elapsed:.2f}s")
return elapsed
def cleanup(self) -> None:
"""
Break reference to the model to allow garbage collection.
Without this, vae_manager.model holds a strong ref to the entire
~150GB model, preventing gc even after the cache drops its ref.
"""
self.model = None
self._vae_on_gpu = False
logger.info("SimpleVAEManager cleanup: model reference released")
def _move_vae_to_gpu(self) -> None:
"""Move VAE to GPU."""
if not hasattr(self.model, 'vae'):
return
vae = self.model.vae
# Move entire VAE module
vae.to(device=self.device, dtype=self.dtype)
# Ensure all parameters are moved (some may be lazy)
for param in vae.parameters():
if param.device != self.device:
param.data = param.data.to(device=self.device, dtype=self.dtype)
# Ensure buffers are moved too
for buffer in vae.buffers():
if buffer.device != self.device and buffer.device.type != "meta":
if torch.is_floating_point(buffer):
buffer.data = buffer.data.to(device=self.device, dtype=self.dtype)
else:
buffer.data = buffer.data.to(device=self.device)
# Sync to ensure transfer is complete
if self.device.type == "cuda":
torch.cuda.synchronize()
self._vae_on_gpu = True
def _move_vae_to_cpu(self) -> None:
"""Move VAE to CPU."""
if not hasattr(self.model, 'vae'):
return
vae = self.model.vae
cpu_device = torch.device("cpu")
# Move entire VAE module
vae.to(device=cpu_device)
# Ensure all parameters are moved
for param in vae.parameters():
if param.device != cpu_device:
param.data = param.data.to(device=cpu_device)
# Ensure buffers are moved too
for buffer in vae.buffers():
if buffer.device != cpu_device and buffer.device.type != "meta":
buffer.data = buffer.data.to(device=cpu_device)
self._vae_on_gpu = False
def get_vae_memory_gb(self) -> float:
"""Get estimated VAE memory usage in GB."""
if not hasattr(self.model, 'vae'):
return 0.0
total_bytes = 0
vae = self.model.vae
for param in vae.parameters():
total_bytes += param.numel() * param.element_size()
for buffer in vae.buffers():
if buffer.device.type != "meta":
total_bytes += buffer.numel() * buffer.element_size()
return total_bytes / (1024**3)
def get_status(self) -> dict:
"""Get VAE manager status."""
return {
"placement": self.placement.value,
"is_on_gpu": self._vae_on_gpu,
"device": str(self.device),
"dtype": str(self.dtype),
"memory_gb": self.get_vae_memory_gb(),
}
def __repr__(self) -> str:
status = "GPU" if self._vae_on_gpu else "CPU"
return f"SimpleVAEManager(placement={self.placement.value}, status={status})"
def enable_vae_tiling(model: Any, tile_size: int = 256) -> bool:
"""
Enable tiled VAE decoding for memory efficiency.
This is separate from VAE placement - tiling happens during decode
regardless of where VAE is located.
Args:
model: Model with VAE
tile_size: Tile size for spatial tiling
Returns:
True if tiling was enabled
"""
if not hasattr(model, 'vae'):
logger.warning("Model has no VAE, cannot enable tiling")
return False
vae = model.vae
# Check for enable_tiling method (diffusers-style VAE)
if hasattr(vae, 'enable_tiling'):
vae.enable_tiling()
logger.info(f"VAE tiling enabled via enable_tiling()")
return True
# Check for tile_latent_input attribute
if hasattr(vae, 'tile_latent_input'):
vae.tile_latent_input = True
if hasattr(vae, 'tile_sample_min_size'):
vae.tile_sample_min_size = tile_size
logger.info(f"VAE tiling enabled via tile_latent_input (size={tile_size})")
return True
logger.debug("VAE does not support tiling")
return False
def disable_vae_tiling(model: Any) -> bool:
"""
Disable tiled VAE decoding.
Args:
model: Model with VAE
Returns:
True if tiling was disabled
"""
if not hasattr(model, 'vae'):
return False
vae = model.vae
if hasattr(vae, 'disable_tiling'):
vae.disable_tiling()
logger.info("VAE tiling disabled via disable_tiling()")
return True
if hasattr(vae, 'tile_latent_input'):
vae.tile_latent_input = False
logger.info("VAE tiling disabled via tile_latent_input")
return True
return False