Skip to content

Commit ed617fe

Browse files
Optimizations in examples/server-async
1 parent 8a238c3 commit ed617fe

File tree

3 files changed

+178
-158
lines changed

3 files changed

+178
-158
lines changed

examples/server-async/DiffusersServer/Pipelines.py

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ def start(self):
4242
torch.backends.cudnn.deterministic = False
4343
torch.backends.cudnn.allow_tf32 = True
4444

45-
4645
if torch.cuda.is_available():
4746
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-large"
4847
logger.info(f"Loading CUDA with model: {model_path}")
@@ -61,6 +60,14 @@ def start(self):
6160

6261
self.pipeline = self.pipeline.to(device=self.device)
6362

63+
if hasattr(self.pipeline, 'enable_vae_slicing'):
64+
self.pipeline.enable_vae_slicing()
65+
logger.info("VAE slicing enabled - will reduce memory spikes during decoding")
66+
67+
if hasattr(self.pipeline, 'enable_vae_tiling'):
68+
self.pipeline.enable_vae_tiling()
69+
logger.info("VAE tiling enabled - will allow processing larger images")
70+
6471
if hasattr(self.pipeline, 'transformer') and self.pipeline.transformer is not None:
6572
self.pipeline.transformer = self.pipeline.transformer.to(
6673
memory_format=torch.channels_last
@@ -71,6 +78,15 @@ def start(self):
7178
self.pipeline.vae = self.pipeline.vae.to(
7279
memory_format=torch.channels_last
7380
)
81+
82+
if hasattr(self.pipeline.vae, 'enable_slicing'):
83+
self.pipeline.vae.enable_slicing()
84+
logger.info("VAE slicing activated directly in the VAE")
85+
86+
if hasattr(self.pipeline.vae, 'enable_tiling'):
87+
self.pipeline.vae.enable_tiling()
88+
logger.info("VAE tiling activated directly on the VAE")
89+
7490
logger.info("VAE optimized with channels_last format")
7591

7692
try:
@@ -79,9 +95,7 @@ def start(self):
7995
except Exception as e:
8096
logger.info(f"XFormers not available: {e}")
8197

82-
# --- Se descarta torch.compile pero se mantiene el resto ---
83-
if torch.__version__ >= "2.0.0":
84-
logger.info("Skipping torch.compile - running without compile optimizations by design")
98+
logger.info("Skipping torch.compile - running without compile optimizations by design")
8599

86100
if torch.cuda.is_available():
87101
torch.cuda.empty_cache()
@@ -92,13 +106,18 @@ def start(self):
92106
model_path = self.model_path or "stabilityai/stable-diffusion-3.5-medium"
93107
logger.info(f"Loading MPS for Mac M Series with model: {model_path}")
94108
self.device = "mps"
109+
95110
self.pipeline = StableDiffusion3Pipeline.from_pretrained(
96111
model_path,
97112
torch_dtype=torch.bfloat16,
98113
use_safetensors=True,
99114
low_cpu_mem_usage=True,
100115
).to(device=self.device)
101-
116+
117+
if hasattr(self.pipeline, 'enable_vae_slicing'):
118+
self.pipeline.enable_vae_slicing()
119+
logger.info("VAE slicing enabled in MPS")
120+
102121
if hasattr(self.pipeline, 'transformer') and self.pipeline.transformer is not None:
103122
self.pipeline.transformer = self.pipeline.transformer.to(
104123
memory_format=torch.channels_last
@@ -108,14 +127,13 @@ def start(self):
108127
self.pipeline.vae = self.pipeline.vae.to(
109128
memory_format=torch.channels_last
110129
)
111-
112130

113131
logger.info("MPS pipeline optimized and ready")
114132

115133
else:
116134
raise Exception("No CUDA or MPS device available")
117135

118-
# OPTIONAL WARMUP
136+
119137
self._warmup()
120138

121139
logger.info("Pipeline initialization completed successfully")
@@ -131,8 +149,13 @@ def _warmup(self):
131149
width=512,
132150
guidance_scale=1.0,
133151
)
134-
torch.cuda.empty_cache() if self.device == "cuda" else None
135-
logger.info("Warmup completed")
152+
153+
if self.device == "cuda":
154+
torch.cuda.synchronize()
155+
torch.cuda.empty_cache()
156+
157+
gc.collect()
158+
logger.info("Warmup completed with memory cleanup")
136159

137160
class TextToImagePipelineFlux:
138161
def __init__(self, model_path: str | None = None, low_vram: bool = False):

0 commit comments

Comments
 (0)