11# Pipelines.py
2-
32from diffusers .pipelines .stable_diffusion_3 .pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
43from diffusers .pipelines .flux .pipeline_flux import FluxPipeline
54from diffusers .pipelines .stable_diffusion .pipeline_stable_diffusion import StableDiffusionPipeline
65import torch
76import os
87import logging
98from pydantic import BaseModel
10- import gc
119
1210logger = logging .getLogger (__name__ )
1311
@@ -22,155 +20,36 @@ def __init__(self, model_path: str | None = None):
2220 self .model_path = model_path or os .getenv ("MODEL_PATH" )
2321 self .pipeline : StableDiffusion3Pipeline | None = None
2422 self .device : str | None = None
25-
23+
2624 def start (self ):
27- torch .set_float32_matmul_precision ("high" )
28-
29- if hasattr (torch ._inductor , 'config' ):
30- if hasattr (torch ._inductor .config , 'conv_1x1_as_mm' ):
31- torch ._inductor .config .conv_1x1_as_mm = True
32- if hasattr (torch ._inductor .config , 'coordinate_descent_tuning' ):
33- torch ._inductor .config .coordinate_descent_tuning = True
34- if hasattr (torch ._inductor .config , 'epilogue_fusion' ):
35- torch ._inductor .config .epilogue_fusion = False
36- if hasattr (torch ._inductor .config , 'coordinate_descent_check_all_directions' ):
37- torch ._inductor .config .coordinate_descent_check_all_directions = True
38-
39- if torch .cuda .is_available ():
40- torch .backends .cudnn .benchmark = True
41- torch .backends .cuda .matmul .allow_tf32 = True
42- torch .backends .cudnn .deterministic = False
43- torch .backends .cudnn .allow_tf32 = True
44-
4525 if torch .cuda .is_available ():
4626 model_path = self .model_path or "stabilityai/stable-diffusion-3.5-large"
47- logger .info (f "Loading CUDA with model: { model_path } " )
27+ logger .info ("Loading CUDA" )
4828 self .device = "cuda"
49-
50- torch .cuda .empty_cache ()
51- gc .collect ()
52-
5329 self .pipeline = StableDiffusion3Pipeline .from_pretrained (
5430 model_path ,
5531 torch_dtype = torch .float16 ,
56- use_safetensors = True ,
57- variant = "fp16" if "fp16" in model_path else None ,
58- low_cpu_mem_usage = True ,
59- )
60-
61- self .pipeline = self .pipeline .to (device = self .device )
62-
63- if hasattr (self .pipeline , 'enable_vae_slicing' ):
64- self .pipeline .enable_vae_slicing ()
65- logger .info ("VAE slicing enabled - will reduce memory spikes during decoding" )
66-
67- if hasattr (self .pipeline , 'enable_vae_tiling' ):
68- self .pipeline .enable_vae_tiling ()
69- logger .info ("VAE tiling enabled - will allow processing larger images" )
70-
71- if hasattr (self .pipeline , 'transformer' ) and self .pipeline .transformer is not None :
72- self .pipeline .transformer = self .pipeline .transformer .to (
73- memory_format = torch .channels_last
74- )
75- logger .info ("Transformer optimized with channels_last format" )
76-
77- if hasattr (self .pipeline , 'vae' ) and self .pipeline .vae is not None :
78- self .pipeline .vae = self .pipeline .vae .to (
79- memory_format = torch .channels_last
80- )
81-
82- if hasattr (self .pipeline .vae , 'enable_slicing' ):
83- self .pipeline .vae .enable_slicing ()
84- logger .info ("VAE slicing activated directly in the VAE" )
85-
86- if hasattr (self .pipeline .vae , 'enable_tiling' ):
87- self .pipeline .vae .enable_tiling ()
88- logger .info ("VAE tiling activated directly on the VAE" )
89-
90- logger .info ("VAE optimized with channels_last format" )
91-
92- try :
93- self .pipeline .enable_xformers_memory_efficient_attention ()
94- logger .info ("XFormers memory efficient attention enabled" )
95- except Exception as e :
96- logger .info (f"XFormers not available: { e } " )
97-
98- logger .info ("Skipping torch.compile - running without compile optimizations by design" )
99-
100- if torch .cuda .is_available ():
101- torch .cuda .empty_cache ()
102-
103- logger .info ("CUDA pipeline fully optimized and ready" )
104-
32+ ).to (device = self .device )
10533 elif torch .backends .mps .is_available ():
10634 model_path = self .model_path or "stabilityai/stable-diffusion-3.5-medium"
107- logger .info (f "Loading MPS for Mac M Series with model: { model_path } " )
35+ logger .info ("Loading MPS for Mac M Series" )
10836 self .device = "mps"
109-
11037 self .pipeline = StableDiffusion3Pipeline .from_pretrained (
11138 model_path ,
11239 torch_dtype = torch .bfloat16 ,
113- use_safetensors = True ,
114- low_cpu_mem_usage = True ,
11540 ).to (device = self .device )
116-
117- if hasattr (self .pipeline , 'enable_vae_slicing' ):
118- self .pipeline .enable_vae_slicing ()
119- logger .info ("VAE slicing enabled in MPS" )
120-
121- if hasattr (self .pipeline , 'transformer' ) and self .pipeline .transformer is not None :
122- self .pipeline .transformer = self .pipeline .transformer .to (
123- memory_format = torch .channels_last
124- )
125-
126- if hasattr (self .pipeline , 'vae' ) and self .pipeline .vae is not None :
127- self .pipeline .vae = self .pipeline .vae .to (
128- memory_format = torch .channels_last
129- )
130-
131- logger .info ("MPS pipeline optimized and ready" )
132-
13341 else :
13442 raise Exception ("No CUDA or MPS device available" )
135-
136-
137- self ._warmup ()
138-
139- logger .info ("Pipeline initialization completed successfully" )
140-
141- def _warmup (self ):
142- if self .pipeline :
143- logger .info ("Running warmup inference..." )
144- with torch .no_grad ():
145- _ = self .pipeline (
146- prompt = "warmup" ,
147- num_inference_steps = 1 ,
148- height = 512 ,
149- width = 512 ,
150- guidance_scale = 1.0 ,
151- )
152-
153- if self .device == "cuda" :
154- torch .cuda .synchronize ()
155- torch .cuda .empty_cache ()
156-
157- gc .collect ()
158- logger .info ("Warmup completed with memory cleanup" )
15943
16044class TextToImagePipelineFlux :
16145 def __init__ (self , model_path : str | None = None , low_vram : bool = False ):
162- """
163- Inicialización de la clase con la ruta del modelo.
164- Si no se proporciona, se obtiene de la variable de entorno.
165- """
16646 self .model_path = model_path or os .getenv ("MODEL_PATH" )
167- self .pipeline : FluxPipeline = None
168- self .device : str = None
47+ self .pipeline : FluxPipeline | None = None
48+ self .device : str | None = None
16949 self .low_vram = low_vram
17050
17151 def start (self ):
17252 if torch .cuda .is_available ():
173- # Si no se definió model_path, se asigna el valor por defecto para CUDA.
17453 model_path = self .model_path or "black-forest-labs/FLUX.1-schnell"
17554 logger .info ("Loading CUDA" )
17655 self .device = "cuda"
@@ -183,7 +62,6 @@ def start(self):
18362 else :
18463 pass
18564 elif torch .backends .mps .is_available ():
186- # Si no se definió model_path, se asigna el valor por defecto para MPS.
18765 model_path = self .model_path or "black-forest-labs/FLUX.1-schnell"
18866 logger .info ("Loading MPS for Mac M Series" )
18967 self .device = "mps"
@@ -196,17 +74,12 @@ def start(self):
19674
19775class TextToImagePipelineSD :
19876 def __init__ (self , model_path : str | None = None ):
199- """
200- Inicialización de la clase con la ruta del modelo.
201- Si no se proporciona, se obtiene de la variable de entorno.
202- """
20377 self .model_path = model_path or os .getenv ("MODEL_PATH" )
204- self .pipeline : StableDiffusionPipeline = None
205- self .device : str = None
78+ self .pipeline : StableDiffusionPipeline | None = None
79+ self .device : str | None = None
20680
20781 def start (self ):
20882 if torch .cuda .is_available ():
209- # Si no se definió model_path, se asigna el valor por defecto para CUDA.
21083 model_path = self .model_path or "sd-legacy/stable-diffusion-v1-5"
21184 logger .info ("Loading CUDA" )
21285 self .device = "cuda"
@@ -215,7 +88,6 @@ def start(self):
21588 torch_dtype = torch .float16 ,
21689 ).to (device = self .device )
21790 elif torch .backends .mps .is_available ():
218- # Si no se definió model_path, se asigna el valor por defecto para MPS.
21991 model_path = self .model_path or "sd-legacy/stable-diffusion-v1-5"
22092 logger .info ("Loading MPS for Mac M Series" )
22193 self .device = "mps"
@@ -224,4 +96,4 @@ def start(self):
22496 torch_dtype = torch .float16 ,
22597 ).to (device = self .device )
22698 else :
227- raise Exception ("No CUDA or MPS device available" )
99+ raise Exception ("No CUDA or MPS device available" )
0 commit comments