Skip to content

Commit fa926e7

Browse files
committed
update
1 parent a8c5801 commit fa926e7

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

tests/lora/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,8 @@ def setUp(self):
138138
def get_base_pipe_outs(self):
139139
cached_base_pipe_outs = getattr(type(self), "cached_base_pipe_outs", {})
140140
all_scheduler_names = [scheduler_cls.__name__ for scheduler_cls in self.scheduler_classes]
141-
if cached_base_pipe_outs is not None and all(k in cached_base_pipe_outs for k in all_scheduler_names):
141+
# Check if all required schedulers are already cached
142+
if cached_base_pipe_outs and all(k in cached_base_pipe_outs for k in all_scheduler_names):
142143
return
143144

144145
cached_base_pipe_outs = cached_base_pipe_outs or {}
@@ -163,7 +164,11 @@ def get_base_pipeline_output(self, scheduler_cls):
163164
"""
164165
Returns the cached base pipeline output for the given scheduler.
165166
Properly handles accessing the class-level cache.
167+
Ensures cache is populated if it hasn't been already.
166168
"""
169+
# Ensure cache is populated
170+
self.get_base_pipe_outs()
171+
167172
cached_base_pipe_outs = getattr(type(self), "cached_base_pipe_outs", {})
168173
return cached_base_pipe_outs[scheduler_cls.__name__]
169174

0 commit comments

Comments
 (0)