Skip to content

Commit 840f0e4

Browse files
Fix server-async
1 parent 0beab1c commit 840f0e4

File tree

5 files changed

+288
-41
lines changed

5 files changed

+288
-41
lines changed

examples/server-async/DiffusersServer/Pipelines.py

Lines changed: 110 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import os
88
import logging
99
from pydantic import BaseModel
10+
import gc
1011

1112
logger = logging.getLogger(__name__)
1213

@@ -19,38 +20,134 @@ class TextToImageInput(BaseModel):
1920
class TextToImagePipelineSD3:
2021
def __init__(self, model_path: str | None = None):
2122
self.model_path = model_path or os.getenv("MODEL_PATH")
22-
self.pipeline: StableDiffusion3Pipeline = None
23-
self.device: str = None
24-
23+
self.pipeline: StableDiffusion3Pipeline | None = None
24+
self.device: str | None = None
25+
2526
def start(self):
27+
torch.set_float32_matmul_precision("high")
28+
29+
if hasattr(torch._inductor, 'config'):
30+
if hasattr(torch._inductor.config, 'conv_1x1_as_mm'):
31+
torch._inductor.config.conv_1x1_as_mm = True
32+
if hasattr(torch._inductor.config, 'coordinate_descent_tuning'):
33+
torch._inductor.config.coordinate_descent_tuning = True
34+
if hasattr(torch._inductor.config, 'epilogue_fusion'):
35+
torch._inductor.config.epilogue_fusion = False
36+
if hasattr(torch._inductor.config, 'coordinate_descent_check_all_directions'):
37+
torch._inductor.config.coordinate_descent_check_all_directions = True
38+
39+
if torch.cuda.is_available():
40+
torch.backends.cudnn.benchmark = True
41+
torch.backends.cuda.matmul.allow_tf32 = True
42+
torch.backends.cudnn.deterministic = False
43+
torch.backends.cudnn.allow_tf32 = True
44+
45+
2646
if torch.cuda.is_available():
2747
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-large"
28-
logger.info("Loading CUDA")
48+
logger.info(f"Loading CUDA with model: {model_path}")
2949
self.device = "cuda"
50+
51+
torch.cuda.empty_cache()
52+
gc.collect()
53+
3054
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
3155
model_path,
3256
torch_dtype=torch.float16,
33-
).to(device=self.device)
57+
use_safetensors=True,
58+
variant="fp16" if "fp16" in model_path else None,
59+
low_cpu_mem_usage=True,
60+
)
61+
62+
self.pipeline = self.pipeline.to(device=self.device)
63+
64+
if hasattr(self.pipeline, 'transformer') and self.pipeline.transformer is not None:
65+
self.pipeline.transformer = self.pipeline.transformer.to(
66+
memory_format=torch.channels_last
67+
)
68+
logger.info("Transformer optimized with channels_last format")
69+
70+
if hasattr(self.pipeline, 'vae') and self.pipeline.vae is not None:
71+
self.pipeline.vae = self.pipeline.vae.to(
72+
memory_format=torch.channels_last
73+
)
74+
logger.info("VAE optimized with channels_last format")
75+
76+
try:
77+
self.pipeline.enable_xformers_memory_efficient_attention()
78+
logger.info("XFormers memory efficient attention enabled")
79+
except Exception as e:
80+
logger.info(f"XFormers not available: {e}")
81+
82+
# --- Se descarta torch.compile pero se mantiene el resto ---
83+
if torch.__version__ >= "2.0.0":
84+
logger.info("Skipping torch.compile - running without compile optimizations by design")
85+
86+
if torch.cuda.is_available():
87+
torch.cuda.empty_cache()
88+
89+
logger.info("CUDA pipeline fully optimized and ready")
90+
3491
elif torch.backends.mps.is_available():
3592
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-medium"
36-
logger.info("Loading MPS for Mac M Series")
93+
logger.info(f"Loading MPS for Mac M Series with model: {model_path}")
3794
self.device = "mps"
3895
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
3996
model_path,
4097
torch_dtype=torch.bfloat16,
98+
use_safetensors=True,
99+
low_cpu_mem_usage=True,
41100
).to(device=self.device)
101+
102+
if hasattr(self.pipeline, 'transformer') and self.pipeline.transformer is not None:
103+
self.pipeline.transformer = self.pipeline.transformer.to(
104+
memory_format=torch.channels_last
105+
)
106+
107+
if hasattr(self.pipeline, 'vae') and self.pipeline.vae is not None:
108+
self.pipeline.vae = self.pipeline.vae.to(
109+
memory_format=torch.channels_last
110+
)
111+
112+
113+
logger.info("MPS pipeline optimized and ready")
114+
42115
else:
43116
raise Exception("No CUDA or MPS device available")
117+
118+
# OPTIONAL WARMUP
119+
self._warmup()
120+
121+
logger.info("Pipeline initialization completed successfully")
122+
123+
def _warmup(self):
124+
if self.pipeline:
125+
logger.info("Running warmup inference...")
126+
with torch.no_grad():
127+
_ = self.pipeline(
128+
prompt="warmup",
129+
num_inference_steps=1,
130+
height=512,
131+
width=512,
132+
guidance_scale=1.0,
133+
)
134+
torch.cuda.empty_cache() if self.device == "cuda" else None
135+
logger.info("Warmup completed")
44136

45137
class TextToImagePipelineFlux:
46138
def __init__(self, model_path: str | None = None, low_vram: bool = False):
139+
"""
140+
Inicialización de la clase con la ruta del modelo.
141+
Si no se proporciona, se obtiene de la variable de entorno.
142+
"""
47143
self.model_path = model_path or os.getenv("MODEL_PATH")
48144
self.pipeline: FluxPipeline = None
49145
self.device: str = None
50146
self.low_vram = low_vram
51147

52148
def start(self):
53149
if torch.cuda.is_available():
150+
# Si no se definió model_path, se asigna el valor por defecto para CUDA.
54151
model_path = self.model_path or "black-forest-labs/FLUX.1-schnell"
55152
logger.info("Loading CUDA")
56153
self.device = "cuda"
@@ -63,6 +160,7 @@ def start(self):
63160
else:
64161
pass
65162
elif torch.backends.mps.is_available():
163+
# Si no se definió model_path, se asigna el valor por defecto para MPS.
66164
model_path = self.model_path or "black-forest-labs/FLUX.1-schnell"
67165
logger.info("Loading MPS for Mac M Series")
68166
self.device = "mps"
@@ -75,12 +173,17 @@ def start(self):
75173

76174
class TextToImagePipelineSD:
77175
def __init__(self, model_path: str | None = None):
176+
"""
177+
Inicialización de la clase con la ruta del modelo.
178+
Si no se proporciona, se obtiene de la variable de entorno.
179+
"""
78180
self.model_path = model_path or os.getenv("MODEL_PATH")
79181
self.pipeline: StableDiffusionPipeline = None
80182
self.device: str = None
81183

82184
def start(self):
83185
if torch.cuda.is_available():
186+
# Si no se definió model_path, se asigna el valor por defecto para CUDA.
84187
model_path = self.model_path or "sd-legacy/stable-diffusion-v1-5"
85188
logger.info("Loading CUDA")
86189
self.device = "cuda"
@@ -89,6 +192,7 @@ def start(self):
89192
torch_dtype=torch.float16,
90193
).to(device=self.device)
91194
elif torch.backends.mps.is_available():
195+
# Si no se definió model_path, se asigna el valor por defecto para MPS.
92196
model_path = self.model_path or "sd-legacy/stable-diffusion-v1-5"
93197
logger.info("Loading MPS for Mac M Series")
94198
self.device = "mps"

0 commit comments

Comments
 (0)