@@ -137,11 +137,12 @@ def _cache_base_pipeline_output(self):
137137 # Get or create the cache on the class (not instance)
138138 if not hasattr (type (self ), "cached_base_pipe_outs" ):
139139 setattr (type (self ), "cached_base_pipe_outs" , {})
140-
140+
141141 cached_base_pipe_outs = type (self ).cached_base_pipe_outs
142-
142+
143143 all_scheduler_names = [scheduler_cls .__name__ for scheduler_cls in self .scheduler_classes ]
144144 if cached_base_pipe_outs and all (k in cached_base_pipe_outs for k in all_scheduler_names ):
145+ __import__ ("ipdb" ).set_trace ()
145146 return
146147
147148 for scheduler_cls in self .scheduler_classes :
@@ -158,12 +159,15 @@ def _cache_base_pipeline_output(self):
158159 _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
159160 output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
160161 cached_base_pipe_outs [scheduler_cls .__name__ ] = output_no_lora
161-
162+
162163 # Update the class attribute
163164 setattr (type (self ), "cached_base_pipe_outs" , cached_base_pipe_outs )
164165
165166 def get_base_pipeline_output (self , scheduler_cls ):
166- self ._cache_base_pipeline_output ()
167+ """
168+ Returns the cached base pipeline output for the given scheduler.
169+ Cache is populated during setUp, so this just retrieves the value.
170+ """
167171 return type (self ).cached_base_pipe_outs [scheduler_cls .__name__ ]
168172
169173 def get_dummy_components (self , scheduler_cls = None , use_dora = False , lora_alpha = None ):
0 commit comments