@@ -7,15 +7,32 @@ class BaseAsyncScheduler:
77 def __init__ (self , scheduler : Any ):
88 self .scheduler = scheduler
99
10+ def __getattr__ (self , name : str ):
11+ if hasattr (self .scheduler , name ):
12+ return getattr (self .scheduler , name )
13+ raise AttributeError (f"'{ self .__class__ .__name__ } ' object has no attribute '{ name } '" )
14+
15+ def __setattr__ (self , name : str , value ):
16+ if name == 'scheduler' :
17+ super ().__setattr__ (name , value )
18+ else :
19+ if hasattr (self , 'scheduler' ) and hasattr (self .scheduler , name ):
20+ setattr (self .scheduler , name , value )
21+ else :
22+ super ().__setattr__ (name , value )
23+
1024 def clone_for_request (self , num_inference_steps : int , device : Union [str , torch .device , None ] = None , ** kwargs ):
1125 local = copy .deepcopy (self .scheduler )
12-
1326 local .set_timesteps (num_inference_steps = num_inference_steps , device = device , ** kwargs )
14-
1527 cloned = self .__class__ (local )
16-
1728 return cloned
1829
30+ def __repr__ (self ):
31+ return f"BaseAsyncScheduler({ repr (self .scheduler )} )"
32+
33+ def __str__ (self ):
34+ return f"BaseAsyncScheduler wrapping: { str (self .scheduler )} "
35+
1936
2037def async_retrieve_timesteps (
2138 scheduler ,
0 commit comments