Skip to content

Commit b052d27

Browse files
We keep the implementation simple in examples/server-async
1 parent ed617fe commit b052d27

File tree

5 files changed

+58
-370
lines changed

5 files changed

+58
-370
lines changed
Lines changed: 9 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
# Pipelines.py
2-
32
from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3 import StableDiffusion3Pipeline
43
from diffusers.pipelines.flux.pipeline_flux import FluxPipeline
54
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
65
import torch
76
import os
87
import logging
98
from pydantic import BaseModel
10-
import gc
119

1210
logger = logging.getLogger(__name__)
1311

@@ -22,155 +20,36 @@ def __init__(self, model_path: str | None = None):
2220
self.model_path = model_path or os.getenv("MODEL_PATH")
2321
self.pipeline: StableDiffusion3Pipeline | None = None
2422
self.device: str | None = None
25-
23+
2624
def start(self):
27-
torch.set_float32_matmul_precision("high")
28-
29-
if hasattr(torch._inductor, 'config'):
30-
if hasattr(torch._inductor.config, 'conv_1x1_as_mm'):
31-
torch._inductor.config.conv_1x1_as_mm = True
32-
if hasattr(torch._inductor.config, 'coordinate_descent_tuning'):
33-
torch._inductor.config.coordinate_descent_tuning = True
34-
if hasattr(torch._inductor.config, 'epilogue_fusion'):
35-
torch._inductor.config.epilogue_fusion = False
36-
if hasattr(torch._inductor.config, 'coordinate_descent_check_all_directions'):
37-
torch._inductor.config.coordinate_descent_check_all_directions = True
38-
39-
if torch.cuda.is_available():
40-
torch.backends.cudnn.benchmark = True
41-
torch.backends.cuda.matmul.allow_tf32 = True
42-
torch.backends.cudnn.deterministic = False
43-
torch.backends.cudnn.allow_tf32 = True
44-
4525
if torch.cuda.is_available():
4626
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-large"
47-
logger.info(f"Loading CUDA with model: {model_path}")
27+
logger.info("Loading CUDA")
4828
self.device = "cuda"
49-
50-
torch.cuda.empty_cache()
51-
gc.collect()
52-
5329
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
5430
model_path,
5531
torch_dtype=torch.float16,
56-
use_safetensors=True,
57-
variant="fp16" if "fp16" in model_path else None,
58-
low_cpu_mem_usage=True,
59-
)
60-
61-
self.pipeline = self.pipeline.to(device=self.device)
62-
63-
if hasattr(self.pipeline, 'enable_vae_slicing'):
64-
self.pipeline.enable_vae_slicing()
65-
logger.info("VAE slicing enabled - will reduce memory spikes during decoding")
66-
67-
if hasattr(self.pipeline, 'enable_vae_tiling'):
68-
self.pipeline.enable_vae_tiling()
69-
logger.info("VAE tiling enabled - will allow processing larger images")
70-
71-
if hasattr(self.pipeline, 'transformer') and self.pipeline.transformer is not None:
72-
self.pipeline.transformer = self.pipeline.transformer.to(
73-
memory_format=torch.channels_last
74-
)
75-
logger.info("Transformer optimized with channels_last format")
76-
77-
if hasattr(self.pipeline, 'vae') and self.pipeline.vae is not None:
78-
self.pipeline.vae = self.pipeline.vae.to(
79-
memory_format=torch.channels_last
80-
)
81-
82-
if hasattr(self.pipeline.vae, 'enable_slicing'):
83-
self.pipeline.vae.enable_slicing()
84-
logger.info("VAE slicing activated directly in the VAE")
85-
86-
if hasattr(self.pipeline.vae, 'enable_tiling'):
87-
self.pipeline.vae.enable_tiling()
88-
logger.info("VAE tiling activated directly on the VAE")
89-
90-
logger.info("VAE optimized with channels_last format")
91-
92-
try:
93-
self.pipeline.enable_xformers_memory_efficient_attention()
94-
logger.info("XFormers memory efficient attention enabled")
95-
except Exception as e:
96-
logger.info(f"XFormers not available: {e}")
97-
98-
logger.info("Skipping torch.compile - running without compile optimizations by design")
99-
100-
if torch.cuda.is_available():
101-
torch.cuda.empty_cache()
102-
103-
logger.info("CUDA pipeline fully optimized and ready")
104-
32+
).to(device=self.device)
10533
elif torch.backends.mps.is_available():
10634
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-medium"
107-
logger.info(f"Loading MPS for Mac M Series with model: {model_path}")
35+
logger.info("Loading MPS for Mac M Series")
10836
self.device = "mps"
109-
11037
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
11138
model_path,
11239
torch_dtype=torch.bfloat16,
113-
use_safetensors=True,
114-
low_cpu_mem_usage=True,
11540
).to(device=self.device)
116-
117-
if hasattr(self.pipeline, 'enable_vae_slicing'):
118-
self.pipeline.enable_vae_slicing()
119-
logger.info("VAE slicing enabled in MPS")
120-
121-
if hasattr(self.pipeline, 'transformer') and self.pipeline.transformer is not None:
122-
self.pipeline.transformer = self.pipeline.transformer.to(
123-
memory_format=torch.channels_last
124-
)
125-
126-
if hasattr(self.pipeline, 'vae') and self.pipeline.vae is not None:
127-
self.pipeline.vae = self.pipeline.vae.to(
128-
memory_format=torch.channels_last
129-
)
130-
131-
logger.info("MPS pipeline optimized and ready")
132-
13341
else:
13442
raise Exception("No CUDA or MPS device available")
135-
136-
137-
self._warmup()
138-
139-
logger.info("Pipeline initialization completed successfully")
140-
141-
def _warmup(self):
142-
if self.pipeline:
143-
logger.info("Running warmup inference...")
144-
with torch.no_grad():
145-
_ = self.pipeline(
146-
prompt="warmup",
147-
num_inference_steps=1,
148-
height=512,
149-
width=512,
150-
guidance_scale=1.0,
151-
)
152-
153-
if self.device == "cuda":
154-
torch.cuda.synchronize()
155-
torch.cuda.empty_cache()
156-
157-
gc.collect()
158-
logger.info("Warmup completed with memory cleanup")
15943

16044
class TextToImagePipelineFlux:
16145
def __init__(self, model_path: str | None = None, low_vram: bool = False):
162-
"""
163-
Inicialización de la clase con la ruta del modelo.
164-
Si no se proporciona, se obtiene de la variable de entorno.
165-
"""
16646
self.model_path = model_path or os.getenv("MODEL_PATH")
167-
self.pipeline: FluxPipeline = None
168-
self.device: str = None
47+
self.pipeline: FluxPipeline | None = None
48+
self.device: str | None = None
16949
self.low_vram = low_vram
17050

17151
def start(self):
17252
if torch.cuda.is_available():
173-
# Si no se definió model_path, se asigna el valor por defecto para CUDA.
17453
model_path = self.model_path or "black-forest-labs/FLUX.1-schnell"
17554
logger.info("Loading CUDA")
17655
self.device = "cuda"
@@ -183,7 +62,6 @@ def start(self):
18362
else:
18463
pass
18564
elif torch.backends.mps.is_available():
186-
# Si no se definió model_path, se asigna el valor por defecto para MPS.
18765
model_path = self.model_path or "black-forest-labs/FLUX.1-schnell"
18866
logger.info("Loading MPS for Mac M Series")
18967
self.device = "mps"
@@ -196,17 +74,12 @@ def start(self):
19674

19775
class TextToImagePipelineSD:
19876
def __init__(self, model_path: str | None = None):
199-
"""
200-
Inicialización de la clase con la ruta del modelo.
201-
Si no se proporciona, se obtiene de la variable de entorno.
202-
"""
20377
self.model_path = model_path or os.getenv("MODEL_PATH")
204-
self.pipeline: StableDiffusionPipeline = None
205-
self.device: str = None
78+
self.pipeline: StableDiffusionPipeline | None = None
79+
self.device: str | None = None
20680

20781
def start(self):
20882
if torch.cuda.is_available():
209-
# Si no se definió model_path, se asigna el valor por defecto para CUDA.
21083
model_path = self.model_path or "sd-legacy/stable-diffusion-v1-5"
21184
logger.info("Loading CUDA")
21285
self.device = "cuda"
@@ -215,7 +88,6 @@ def start(self):
21588
torch_dtype=torch.float16,
21689
).to(device=self.device)
21790
elif torch.backends.mps.is_available():
218-
# Si no se definió model_path, se asigna el valor por defecto para MPS.
21991
model_path = self.model_path or "sd-legacy/stable-diffusion-v1-5"
22092
logger.info("Loading MPS for Mac M Series")
22193
self.device = "mps"
@@ -224,4 +96,4 @@ def start(self):
22496
torch_dtype=torch.float16,
22597
).to(device=self.device)
22698
else:
227-
raise Exception("No CUDA or MPS device available")
99+
raise Exception("No CUDA or MPS device available")
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
from .Pipelines import TextToImagePipelineSD3
2-
from .superpipeline import SuperPipelinesT2Img
32
from .create_server import create_inference_server_Async as DiffusersServerApp

0 commit comments

Comments
 (0)