@@ -42,7 +42,6 @@ def start(self):
4242 torch .backends .cudnn .deterministic = False
4343 torch .backends .cudnn .allow_tf32 = True
4444
45-
4645 if torch .cuda .is_available ():
4746 model_path = self .model_path or "stabilityai/stable-diffusion-3.5-large"
4847 logger .info (f"Loading CUDA with model: { model_path } " )
@@ -61,6 +60,14 @@ def start(self):
6160
6261 self .pipeline = self .pipeline .to (device = self .device )
6362
63+ if hasattr (self .pipeline , 'enable_vae_slicing' ):
64+ self .pipeline .enable_vae_slicing ()
65+ logger .info ("VAE slicing enabled - will reduce memory spikes during decoding" )
66+
67+ if hasattr (self .pipeline , 'enable_vae_tiling' ):
68+ self .pipeline .enable_vae_tiling ()
69+ logger .info ("VAE tiling enabled - will allow processing larger images" )
70+
6471 if hasattr (self .pipeline , 'transformer' ) and self .pipeline .transformer is not None :
6572 self .pipeline .transformer = self .pipeline .transformer .to (
6673 memory_format = torch .channels_last
@@ -71,6 +78,15 @@ def start(self):
7178 self .pipeline .vae = self .pipeline .vae .to (
7279 memory_format = torch .channels_last
7380 )
81+
82+ if hasattr (self .pipeline .vae , 'enable_slicing' ):
83+ self .pipeline .vae .enable_slicing ()
84+ logger .info ("VAE slicing activated directly in the VAE" )
85+
86+ if hasattr (self .pipeline .vae , 'enable_tiling' ):
87+ self .pipeline .vae .enable_tiling ()
88+ logger .info ("VAE tiling activated directly on the VAE" )
89+
7490 logger .info ("VAE optimized with channels_last format" )
7591
7692 try :
@@ -79,9 +95,7 @@ def start(self):
7995 except Exception as e :
8096 logger .info (f"XFormers not available: { e } " )
8197
82- # --- Se descarta torch.compile pero se mantiene el resto ---
83- if torch .__version__ >= "2.0.0" :
84- logger .info ("Skipping torch.compile - running without compile optimizations by design" )
98+ logger .info ("Skipping torch.compile - running without compile optimizations by design" )
8599
86100 if torch .cuda .is_available ():
87101 torch .cuda .empty_cache ()
@@ -92,13 +106,18 @@ def start(self):
92106 model_path = self .model_path or "stabilityai/stable-diffusion-3.5-medium"
93107 logger .info (f"Loading MPS for Mac M Series with model: { model_path } " )
94108 self .device = "mps"
109+
95110 self .pipeline = StableDiffusion3Pipeline .from_pretrained (
96111 model_path ,
97112 torch_dtype = torch .bfloat16 ,
98113 use_safetensors = True ,
99114 low_cpu_mem_usage = True ,
100115 ).to (device = self .device )
101-
116+
117+ if hasattr (self .pipeline , 'enable_vae_slicing' ):
118+ self .pipeline .enable_vae_slicing ()
119+ logger .info ("VAE slicing enabled in MPS" )
120+
102121 if hasattr (self .pipeline , 'transformer' ) and self .pipeline .transformer is not None :
103122 self .pipeline .transformer = self .pipeline .transformer .to (
104123 memory_format = torch .channels_last
@@ -108,14 +127,13 @@ def start(self):
108127 self .pipeline .vae = self .pipeline .vae .to (
109128 memory_format = torch .channels_last
110129 )
111-
112130
113131 logger .info ("MPS pipeline optimized and ready" )
114132
115133 else :
116134 raise Exception ("No CUDA or MPS device available" )
117135
118- # OPTIONAL WARMUP
136+
119137 self ._warmup ()
120138
121139 logger .info ("Pipeline initialization completed successfully" )
@@ -131,8 +149,13 @@ def _warmup(self):
131149 width = 512 ,
132150 guidance_scale = 1.0 ,
133151 )
134- torch .cuda .empty_cache () if self .device == "cuda" else None
135- logger .info ("Warmup completed" )
152+
153+ if self .device == "cuda" :
154+ torch .cuda .synchronize ()
155+ torch .cuda .empty_cache ()
156+
157+ gc .collect ()
158+ logger .info ("Warmup completed with memory cleanup" )
136159
137160class TextToImagePipelineFlux :
138161 def __init__ (self , model_path : str | None = None , low_vram : bool = False ):
0 commit comments