1313# limitations under the License.
1414
1515import math
16- from typing import List , Optional , Union , TYPE_CHECKING
16+ from typing import List , Optional , Union , TYPE_CHECKING , Dict , Tuple
1717
1818import torch
1919
@@ -156,7 +156,11 @@ def cleanup_models(self, denoiser: torch.nn.Module) -> None:
156156 for hook_name in self ._skip_layer_hook_names :
157157 registry .remove_hook (hook_name , recurse = True )
158158
159- def prepare_inputs (self , data : "BlockState" ) -> List ["BlockState" ]:
159+ def prepare_inputs (self , data : "BlockState" , input_fields : Optional [Dict [str , Union [str , Tuple [str , str ]]]] = None ) -> List ["BlockState" ]:
160+
161+ if input_fields is None :
162+ input_fields = self ._input_fields
163+
160164 if self .num_conditions == 1 :
161165 tuple_indices = [0 ]
162166 input_predictions = ["pred_cond" ]
@@ -168,7 +172,7 @@ def prepare_inputs(self, data: "BlockState") -> List["BlockState"]:
168172 input_predictions = ["pred_cond" , "pred_uncond" , "pred_cond_skip" ]
169173 data_batches = []
170174 for i in range (self .num_conditions ):
171- data_batch = self ._prepare_batch (self . _input_fields , data , tuple_indices [i ], input_predictions [i ])
175+ data_batch = self ._prepare_batch (input_fields , data , tuple_indices [i ], input_predictions [i ])
172176 data_batches .append (data_batch )
173177 return data_batches
174178
0 commit comments