Skip to content

Commit 7cfee77

Browse files
Fix BaseAsyncScheduler
1 parent a519915 commit 7cfee77

File tree

1 file changed

+20
-3
lines changed

1 file changed

+20
-3
lines changed

examples/server-async/DiffusersServer/utils/scheduler.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,32 @@ class BaseAsyncScheduler:
77
def __init__(self, scheduler: Any):
88
self.scheduler = scheduler
99

10+
def __getattr__(self, name: str):
11+
if hasattr(self.scheduler, name):
12+
return getattr(self.scheduler, name)
13+
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")
14+
15+
def __setattr__(self, name: str, value):
16+
if name == 'scheduler':
17+
super().__setattr__(name, value)
18+
else:
19+
if hasattr(self, 'scheduler') and hasattr(self.scheduler, name):
20+
setattr(self.scheduler, name, value)
21+
else:
22+
super().__setattr__(name, value)
23+
1024
def clone_for_request(self, num_inference_steps: int, device: Union[str, torch.device, None] = None, **kwargs):
1125
local = copy.deepcopy(self.scheduler)
12-
1326
local.set_timesteps(num_inference_steps=num_inference_steps, device=device, **kwargs)
14-
1527
cloned = self.__class__(local)
16-
1728
return cloned
1829

30+
def __repr__(self):
31+
return f"BaseAsyncScheduler({repr(self.scheduler)})"
32+
33+
def __str__(self):
34+
return f"BaseAsyncScheduler wrapping: {str(self.scheduler)}"
35+
1936

2037
def async_retrieve_timesteps(
2138
scheduler,

0 commit comments

Comments
 (0)