Skip to content

Commit a519915

Browse files
Update examples/server-async/utils/*
1 parent 06bb136 commit a519915

File tree

4 files changed

+44
-19
lines changed

4 files changed

+44
-19
lines changed

examples/server-async/DiffusersServer/serverasync.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from pydantic import BaseModel
66
from .Pipelines import TextToImagePipelineSD3, TextToImagePipelineFlux, TextToImagePipelineSD, logger
77
import logging
8-
from ..utils import RequestScopedPipeline
8+
from .utils import RequestScopedPipeline
99
from diffusers import *
1010
import random
1111
import uuid
File renamed without changes.

examples/server-async/utils/requestscopedpipeline.py renamed to examples/server-async/DiffusersServer/utils/requestscopedpipeline.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33
import threading
44
import torch
55
from diffusers.utils import logging
6+
from .scheduler import BaseAsyncScheduler, async_retrieve_timesteps
7+
68

79
logger = logging.get_logger(__name__)
810

@@ -27,14 +29,19 @@ def __init__(
2729
mutable_attrs: Optional[Iterable[str]] = None,
2830
auto_detect_mutables: bool = True,
2931
tensor_numel_threshold: int = 1_000_000,
30-
tokenizer_lock: Optional[threading.Lock] = None
32+
tokenizer_lock: Optional[threading.Lock] = None,
33+
wrap_scheduler: bool = True
3134
):
3235
self._base = pipeline
3336
self.unet = getattr(pipeline, "unet", None)
3437
self.vae = getattr(pipeline, "vae", None)
3538
self.text_encoder = getattr(pipeline, "text_encoder", None)
3639
self.components = getattr(pipeline, "components", None)
3740

41+
if wrap_scheduler and hasattr(pipeline, 'scheduler') and pipeline.scheduler is not None:
42+
if not isinstance(pipeline.scheduler, BaseAsyncScheduler):
43+
pipeline.scheduler = BaseAsyncScheduler(pipeline.scheduler)
44+
3845
self._mutable_attrs = list(mutable_attrs) if mutable_attrs is not None else list(self.DEFAULT_MUTABLE_ATTRS)
3946
self._tokenizer_lock = tokenizer_lock if tokenizer_lock is not None else threading.Lock()
4047

@@ -48,17 +55,24 @@ def _make_local_scheduler(self, num_inference_steps: int, device: Optional[str]
4855
if base_sched is None:
4956
return None
5057

51-
if hasattr(base_sched, "clone_for_request"):
52-
try:
53-
return base_sched.clone_for_request(num_inference_steps=num_inference_steps, device=device, **clone_kwargs)
54-
except Exception as e:
55-
logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
58+
if not isinstance(base_sched, BaseAsyncScheduler):
59+
wrapped_scheduler = BaseAsyncScheduler(base_sched)
60+
else:
61+
wrapped_scheduler = base_sched
5662

5763
try:
58-
return copy.deepcopy(base_sched)
64+
return wrapped_scheduler.clone_for_request(
65+
num_inference_steps=num_inference_steps,
66+
device=device,
67+
**clone_kwargs
68+
)
5969
except Exception as e:
60-
logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
61-
return base_sched
70+
logger.debug(f"clone_for_request failed: {e}; falling back to deepcopy()")
71+
try:
72+
return copy.deepcopy(wrapped_scheduler)
73+
except Exception as e:
74+
logger.warning(f"Deepcopy of scheduler failed: {e}. Returning original scheduler (*risky*).")
75+
return wrapped_scheduler
6276

6377
def _autodetect_mutables(self, max_attrs: int = 40):
6478
if not self._auto_detect_mutables:
@@ -197,7 +211,16 @@ def generate(self, *args, num_inference_steps: int = 50, device: Optional[str] =
197211

198212
if local_scheduler is not None:
199213
try:
200-
setattr(local_pipe, "scheduler", local_scheduler)
214+
timesteps, num_steps, configured_scheduler = async_retrieve_timesteps(
215+
local_scheduler.scheduler,
216+
num_inference_steps=num_inference_steps,
217+
device=device,
218+
return_scheduler=True,
219+
**{k: v for k, v in kwargs.items() if k in ['timesteps', 'sigmas']}
220+
)
221+
222+
final_scheduler = BaseAsyncScheduler(configured_scheduler)
223+
setattr(local_pipe, "scheduler", final_scheduler)
201224
except Exception:
202225
logger.warning("Could not set scheduler on local pipe; proceeding without replacing scheduler.")
203226

examples/server-async/utils/scheduler.py renamed to examples/server-async/DiffusersServer/utils/scheduler.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,16 @@
55

66
class BaseAsyncScheduler:
77
def __init__(self, scheduler: Any):
8-
pass
9-
10-
def clone_for_request(self, num_inference_steps: int, device: Union[str, torch.device] = None):
11-
# I leave it as an example of what the Scheduler should do to implement it later
12-
"""local = copy.deepcopy(self)
13-
local.set_timesteps(num_inference_steps=num_inference_steps, device=device)
14-
return local"""
15-
pass
8+
self.scheduler = scheduler
9+
10+
def clone_for_request(self, num_inference_steps: int, device: Union[str, torch.device, None] = None, **kwargs):
11+
local = copy.deepcopy(self.scheduler)
12+
13+
local.set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs)
14+
15+
cloned = self.__class__(local)
16+
17+
return cloned
1618

1719

1820
def async_retrieve_timesteps(

0 commit comments

Comments
 (0)