77import os
88import logging
99from pydantic import BaseModel
10+ import gc
1011
1112logger = logging .getLogger (__name__ )
1213
@@ -19,38 +20,134 @@ class TextToImageInput(BaseModel):
1920class TextToImagePipelineSD3 :
2021 def __init__ (self , model_path : str | None = None ):
2122 self .model_path = model_path or os .getenv ("MODEL_PATH" )
22- self .pipeline : StableDiffusion3Pipeline = None
23- self .device : str = None
24-
23+ self .pipeline : StableDiffusion3Pipeline | None = None
24+ self .device : str | None = None
25+
2526 def start (self ):
27+ torch .set_float32_matmul_precision ("high" )
28+
29+ if hasattr (torch ._inductor , 'config' ):
30+ if hasattr (torch ._inductor .config , 'conv_1x1_as_mm' ):
31+ torch ._inductor .config .conv_1x1_as_mm = True
32+ if hasattr (torch ._inductor .config , 'coordinate_descent_tuning' ):
33+ torch ._inductor .config .coordinate_descent_tuning = True
34+ if hasattr (torch ._inductor .config , 'epilogue_fusion' ):
35+ torch ._inductor .config .epilogue_fusion = False
36+ if hasattr (torch ._inductor .config , 'coordinate_descent_check_all_directions' ):
37+ torch ._inductor .config .coordinate_descent_check_all_directions = True
38+
39+ if torch .cuda .is_available ():
40+ torch .backends .cudnn .benchmark = True
41+ torch .backends .cuda .matmul .allow_tf32 = True
42+ torch .backends .cudnn .deterministic = False
43+ torch .backends .cudnn .allow_tf32 = True
44+
45+
2646 if torch .cuda .is_available ():
2747 model_path = self .model_path or "stabilityai/stable-diffusion-3.5-large"
28- logger .info ("Loading CUDA" )
48+ logger .info (f "Loading CUDA with model: { model_path } " )
2949 self .device = "cuda"
50+
51+ torch .cuda .empty_cache ()
52+ gc .collect ()
53+
3054 self .pipeline = StableDiffusion3Pipeline .from_pretrained (
3155 model_path ,
3256 torch_dtype = torch .float16 ,
33- ).to (device = self .device )
57+ use_safetensors = True ,
58+ variant = "fp16" if "fp16" in model_path else None ,
59+ low_cpu_mem_usage = True ,
60+ )
61+
62+ self .pipeline = self .pipeline .to (device = self .device )
63+
64+ if hasattr (self .pipeline , 'transformer' ) and self .pipeline .transformer is not None :
65+ self .pipeline .transformer = self .pipeline .transformer .to (
66+ memory_format = torch .channels_last
67+ )
68+ logger .info ("Transformer optimized with channels_last format" )
69+
70+ if hasattr (self .pipeline , 'vae' ) and self .pipeline .vae is not None :
71+ self .pipeline .vae = self .pipeline .vae .to (
72+ memory_format = torch .channels_last
73+ )
74+ logger .info ("VAE optimized with channels_last format" )
75+
76+ try :
77+ self .pipeline .enable_xformers_memory_efficient_attention ()
78+ logger .info ("XFormers memory efficient attention enabled" )
79+ except Exception as e :
80+ logger .info (f"XFormers not available: { e } " )
81+
82+ # --- Se descarta torch.compile pero se mantiene el resto ---
83+ if torch .__version__ >= "2.0.0" :
84+ logger .info ("Skipping torch.compile - running without compile optimizations by design" )
85+
86+ if torch .cuda .is_available ():
87+ torch .cuda .empty_cache ()
88+
89+ logger .info ("CUDA pipeline fully optimized and ready" )
90+
3491 elif torch .backends .mps .is_available ():
3592 model_path = self .model_path or "stabilityai/stable-diffusion-3.5-medium"
36- logger .info ("Loading MPS for Mac M Series" )
93+ logger .info (f "Loading MPS for Mac M Series with model: { model_path } " )
3794 self .device = "mps"
3895 self .pipeline = StableDiffusion3Pipeline .from_pretrained (
3996 model_path ,
4097 torch_dtype = torch .bfloat16 ,
98+ use_safetensors = True ,
99+ low_cpu_mem_usage = True ,
41100 ).to (device = self .device )
101+
102+ if hasattr (self .pipeline , 'transformer' ) and self .pipeline .transformer is not None :
103+ self .pipeline .transformer = self .pipeline .transformer .to (
104+ memory_format = torch .channels_last
105+ )
106+
107+ if hasattr (self .pipeline , 'vae' ) and self .pipeline .vae is not None :
108+ self .pipeline .vae = self .pipeline .vae .to (
109+ memory_format = torch .channels_last
110+ )
111+
112+
113+ logger .info ("MPS pipeline optimized and ready" )
114+
42115 else :
43116 raise Exception ("No CUDA or MPS device available" )
117+
118+ # OPTIONAL WARMUP
119+ self ._warmup ()
120+
121+ logger .info ("Pipeline initialization completed successfully" )
122+
123+ def _warmup (self ):
124+ if self .pipeline :
125+ logger .info ("Running warmup inference..." )
126+ with torch .no_grad ():
127+ _ = self .pipeline (
128+ prompt = "warmup" ,
129+ num_inference_steps = 1 ,
130+ height = 512 ,
131+ width = 512 ,
132+ guidance_scale = 1.0 ,
133+ )
134+ torch .cuda .empty_cache () if self .device == "cuda" else None
135+ logger .info ("Warmup completed" )
44136
45137class TextToImagePipelineFlux :
46138 def __init__ (self , model_path : str | None = None , low_vram : bool = False ):
139+ """
140+ Inicialización de la clase con la ruta del modelo.
141+ Si no se proporciona, se obtiene de la variable de entorno.
142+ """
47143 self .model_path = model_path or os .getenv ("MODEL_PATH" )
48144 self .pipeline : FluxPipeline = None
49145 self .device : str = None
50146 self .low_vram = low_vram
51147
52148 def start (self ):
53149 if torch .cuda .is_available ():
150+ # Si no se definió model_path, se asigna el valor por defecto para CUDA.
54151 model_path = self .model_path or "black-forest-labs/FLUX.1-schnell"
55152 logger .info ("Loading CUDA" )
56153 self .device = "cuda"
@@ -63,6 +160,7 @@ def start(self):
63160 else :
64161 pass
65162 elif torch .backends .mps .is_available ():
163+ # Si no se definió model_path, se asigna el valor por defecto para MPS.
66164 model_path = self .model_path or "black-forest-labs/FLUX.1-schnell"
67165 logger .info ("Loading MPS for Mac M Series" )
68166 self .device = "mps"
@@ -75,12 +173,17 @@ def start(self):
75173
76174class TextToImagePipelineSD :
77175 def __init__ (self , model_path : str | None = None ):
176+ """
177+ Inicialización de la clase con la ruta del modelo.
178+ Si no se proporciona, se obtiene de la variable de entorno.
179+ """
78180 self .model_path = model_path or os .getenv ("MODEL_PATH" )
79181 self .pipeline : StableDiffusionPipeline = None
80182 self .device : str = None
81183
82184 def start (self ):
83185 if torch .cuda .is_available ():
186+ # Si no se definió model_path, se asigna el valor por defecto para CUDA.
84187 model_path = self .model_path or "sd-legacy/stable-diffusion-v1-5"
85188 logger .info ("Loading CUDA" )
86189 self .device = "cuda"
@@ -89,6 +192,7 @@ def start(self):
89192 torch_dtype = torch .float16 ,
90193 ).to (device = self .device )
91194 elif torch .backends .mps .is_available ():
195+ # Si no se definió model_path, se asigna el valor por defecto para MPS.
92196 model_path = self .model_path or "sd-legacy/stable-diffusion-v1-5"
93197 logger .info ("Loading MPS for Mac M Series" )
94198 self .device = "mps"
0 commit comments