Skip to content

Commit 40f12d2

Browse files
committed
update
1 parent 1e08566 commit 40f12d2

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

tests/lora/utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -137,11 +137,12 @@ def _cache_base_pipeline_output(self):
137137
# Get or create the cache on the class (not instance)
138138
if not hasattr(type(self), "cached_base_pipe_outs"):
139139
setattr(type(self), "cached_base_pipe_outs", {})
140-
140+
141141
cached_base_pipe_outs = type(self).cached_base_pipe_outs
142-
142+
143143
all_scheduler_names = [scheduler_cls.__name__ for scheduler_cls in self.scheduler_classes]
144144
if cached_base_pipe_outs and all(k in cached_base_pipe_outs for k in all_scheduler_names):
145+
__import__("ipdb").set_trace()
145146
return
146147

147148
for scheduler_cls in self.scheduler_classes:
@@ -158,12 +159,15 @@ def _cache_base_pipeline_output(self):
158159
_, _, inputs = self.get_dummy_inputs(with_generator=False)
159160
output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0]
160161
cached_base_pipe_outs[scheduler_cls.__name__] = output_no_lora
161-
162+
162163
# Update the class attribute
163164
setattr(type(self), "cached_base_pipe_outs", cached_base_pipe_outs)
164165

165166
def get_base_pipeline_output(self, scheduler_cls):
166-
self._cache_base_pipeline_output()
167+
"""
168+
Returns the cached base pipeline output for the given scheduler.
169+
Cache is populated during setUp, so this just retrieves the value.
170+
"""
167171
return type(self).cached_base_pipe_outs[scheduler_cls.__name__]
168172

169173
def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None):

0 commit comments

Comments
 (0)