2121
2222if  TYPE_CHECKING :
2323    from  ..models .attention_processor  import  AttentionProcessor 
24+     from  ..pipelines .modular_pipeline  import  BlockState 
2425
2526
2627logger  =  get_logger (__name__ )  # pylint: disable=invalid-name 
@@ -30,14 +31,15 @@ class BaseGuidance:
3031    r"""Base class providing the skeleton for implementing guidance techniques.""" 
3132
3233    _input_predictions  =  None 
34+     _identifier_key  =  "__guidance_identifier__" 
3335
3436    def  __init__ (self , start : float  =  0.0 , stop : float  =  1.0 ):
3537        self ._start  =  start 
3638        self ._stop  =  stop 
3739        self ._step : int  =  None 
3840        self ._num_inference_steps : int  =  None 
3941        self ._timestep : torch .LongTensor  =  None 
40-         self ._preds : Dict [str , torch . Tensor ]  =  {} 
42+         self ._input_fields : Dict [str , Union [ str ,  Tuple [ str ,  str ]]]  =  None 
4143        self ._num_outputs_prepared : int  =  0 
4244        self ._enabled  =  True 
4345
@@ -65,28 +67,64 @@ def set_state(self, step: int, num_inference_steps: int, timestep: torch.LongTen
6567        self ._step  =  step 
6668        self ._num_inference_steps  =  num_inference_steps 
6769        self ._timestep  =  timestep 
68-         self ._preds  =  {}
6970        self ._num_outputs_prepared  =  0 
7071
72+     def  set_input_fields (self , ** kwargs : Dict [str , Union [str , Tuple [str , str ]]]) ->  None :
73+         """ 
74+         Set the input fields for the guidance technique. The input fields are used to specify the names of the 
75+         returned attributes containing the prepared data after `prepare_inputs` is called. The prepared data is 
76+         obtained from the values of the provided keyword arguments to this method. 
77+ 
78+         Args: 
79+             **kwargs (`Dict[str, Union[str, Tuple[str, str]]]`): 
80+                 A dictionary where the keys are the names of the fields that will be used to store the data once 
81+                 it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, 
82+                 which is used to look up the required data provided for preparation. 
83+ 
84+                 If a string is provided, it will be used as the conditional data (or unconditional if used with 
85+                 a guidance method that requires it). If a tuple of length 2 is provided, the first element must 
86+                 be the conditional data identifier and the second element must be the unconditional data identifier 
87+                 or None. 
88+ 
89+                 Example: 
90+                  
91+                 ``` 
92+                 data = {"prompt_embeds": <some tensor>, "negative_prompt_embeds": <some tensor>, "latents": <some tensor>} 
93+ 
94+                 BaseGuidance.set_input_fields( 
95+                     latents="latents", 
96+                     prompt_embeds=("prompt_embeds", "negative_prompt_embeds"), 
97+                 ) 
98+                 ``` 
99+         """ 
100+         for  key , value  in  kwargs .items ():
101+             is_string  =  isinstance (value , str )
102+             is_tuple_of_str_with_len_2  =  isinstance (value , tuple ) and  len (value ) ==  2  and  all (isinstance (v , str ) for  v  in  value )
103+             if  not  (is_string  or  is_tuple_of_str_with_len_2 ):
104+                 raise  ValueError (
105+                     f"Expected `set_input_fields` to be called with a string or a tuple of string with length 2, but got { type (value )}   for key { key }  ." 
106+                 )
107+         self ._input_fields  =  kwargs 
108+     
71109    def  prepare_models (self , denoiser : torch .nn .Module ) ->  None :
72110        """ 
73111        Prepares the models for the guidance technique on a given batch of data. This method should be overridden in 
74112        subclasses to implement specific model preparation logic. 
75113        """ 
76114        pass 
77115
78-     def  prepare_inputs (self , denoiser :  torch . nn . Module ,  * args :  Union [ Tuple [ torch . Tensor ],  List [ torch . Tensor ]] ) ->  Tuple [ List [torch . Tensor ], ... ]:
116+     def  prepare_inputs (self , data :  "BlockState" ) ->  List ["BlockState" ]:
79117        raise  NotImplementedError ("BaseGuidance::prepare_inputs must be implemented in subclasses." )
80118
81-     def  prepare_outputs (self , denoiser : torch .nn .Module , pred : torch .Tensor ) ->  None :
82-         raise  NotImplementedError ("BaseGuidance::prepare_outputs must be implemented in subclasses." )
83- 
84-     def  __call__ (self , ** kwargs ) ->  Any :
85-         if  len (kwargs ) !=  self .num_conditions :
119+     def  __call__ (self , data : List ["BlockState" ]) ->  Any :
120+         if  not  all (hasattr (d , "noise_pred" ) for  d  in  data ):
121+             raise  ValueError ("Expected all data to have `noise_pred` attribute." )
122+         if  len (data ) !=  self .num_conditions :
86123            raise  ValueError (
87-                 f"Expected { self .num_conditions }   arguments , but got { len (kwargs )}  . Please provide  the correct number of arguments ." 
124+                 f"Expected { self .num_conditions }   data items , but got { len (data )}  . Please check  the input data ." 
88125            )
89-         return  self .forward (** kwargs )
126+         forward_inputs  =  {getattr (d , self ._identifier_key ): d .noise_pred  for  d  in  data }
127+         return  self .forward (** forward_inputs )
90128
91129    def  forward (self , * args , ** kwargs ) ->  Any :
92130        raise  NotImplementedError ("BaseGuidance::forward must be implemented in subclasses." )
@@ -102,10 +140,48 @@ def is_unconditional(self) -> bool:
102140    @property  
103141    def  num_conditions (self ) ->  int :
104142        raise  NotImplementedError ("BaseGuidance::num_conditions must be implemented in subclasses." )
105- 
106-     @property  
107-     def  outputs (self ) ->  Dict [str , torch .Tensor ]:
108-         return  self ._preds , {}
143+     
144+     @classmethod  
145+     def  _prepare_batch (cls , input_fields : Dict [str , Union [str , Tuple [str , str ]]], data : "BlockState" , tuple_index : int , identifier : str ) ->  "BlockState" :
146+         """ 
147+         Prepares a batch of data for the guidance technique. This method is used in the `prepare_inputs` method of 
148+         the `BaseGuidance` class. It prepares the batch based on the provided tuple index. 
149+ 
150+         Args: 
151+             input_fields (`Dict[str, Union[str, Tuple[str, str]]]`): 
152+                 A dictionary where the keys are the names of the fields that will be used to store the data once 
153+                 it is prepared with `prepare_inputs`. The values can be either a string or a tuple of length 2, 
154+                 which is used to look up the required data provided for preparation. 
155+                 If a string is provided, it will be used as the conditional data (or unconditional if used with 
156+                 a guidance method that requires it). If a tuple of length 2 is provided, the first element must 
157+                 be the conditional data identifier and the second element must be the unconditional data identifier 
158+                 or None. 
159+             data (`BlockState`): 
160+                 The input data to be prepared. 
161+             tuple_index (`int`): 
162+                 The index to use when accessing input fields that are tuples. 
163+          
164+         Returns: 
165+             `BlockState`: The prepared batch of data. 
166+         """ 
167+         from  ..pipelines .modular_pipeline  import  BlockState 
168+ 
169+         if  input_fields  is  None :
170+             raise  ValueError ("Input fields have not been set. Please call `set_input_fields` before preparing inputs." )
171+         data_batch  =  {}
172+         for  key , value  in  input_fields .items ():
173+             try :
174+                 if  isinstance (value , str ):
175+                     data_batch [key ] =  getattr (data , value )
176+                 elif  isinstance (value , tuple ):
177+                     data_batch [key ] =  getattr (data , value [tuple_index ])
178+                 else :
179+                     # We've already checked that value is a string or a tuple of strings with length 2 
180+                     pass 
181+             except  AttributeError :
182+                 raise  ValueError (f"Expected `data` to have attribute(s) { value }  , but it does not. Please check the input data." )
183+         data_batch [cls ._identifier_key ] =  identifier 
184+         return  BlockState (** data_batch )
109185
110186
111187def  rescale_noise_cfg (noise_cfg , noise_pred_text , guidance_rescale = 0.0 ):
0 commit comments