-
Couldn't load subscription status.
- Fork 6.5k
[tests] cache non lora pipeline outputs. #12298
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 4 commits
Commits
Show all changes
26 commits
Select commit
Hold shift + click to select a range
02fd92e
cache non lora pipeline outputs.
sayakpaul c8afd1c
up
sayakpaul 6c0c72d
up
sayakpaul 4256de9
up
sayakpaul 772c32e
up
sayakpaul 2c47a2f
Revert "up"
sayakpaul 9c24d1f
Merge branch 'main' into cache-non-lora-outputs
sayakpaul cca03df
up
sayakpaul 53ca186
Revert "up"
sayakpaul 336efbd
up
sayakpaul 69f2d5c
Merge branch 'main' into cache-non-lora-outputs
sayakpaul 9d3f707
up
sayakpaul 4923986
Merge branch 'main' into cache-non-lora-outputs
sayakpaul 34d0aa2
resolve big conflicts.
sayakpaul 1569fca
add .
sayakpaul 8f405ed
Merge branch 'main' into cache-non-lora-outputs
sayakpaul ead2e04
up
sayakpaul fbcdf8b
Merge branch 'main' into cache-non-lora-outputs
sayakpaul 8efb5c4
up
sayakpaul ca424e5
up
sayakpaul 3e0a7f9
Merge branch 'main' into cache-non-lora-outputs
sayakpaul e442df0
up
sayakpaul dcfd979
Merge branch 'main' into cache-non-lora-outputs
sayakpaul 734b500
up
sayakpaul b01bf8e
up
sayakpaul c16d892
Merge branch 'main' into cache-non-lora-outputs
sayakpaul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -129,6 +129,34 @@ class PeftLoraLoaderMixinTests: | |
| text_encoder_target_modules = ["q_proj", "k_proj", "v_proj", "out_proj"] | ||
| denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"] | ||
|
|
||
| cached_non_lora_outputs = {} | ||
|
|
||
| @pytest.fixture(scope="class", autouse=True) | ||
|
||
| def cache_non_lora_outputs(self): | ||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| """ | ||
| This fixture will be executed once per test class and will populate | ||
| the cached_non_lora_outputs dictionary. | ||
| """ | ||
| for scheduler_cls in self.scheduler_classes: | ||
sayakpaul marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| # Check if the output for this scheduler is already cached to avoid re-running | ||
| if scheduler_cls.__name__ in self.cached_non_lora_outputs: | ||
| continue | ||
|
|
||
| components, _, _ = self.get_dummy_components(scheduler_cls) | ||
| pipe = self.pipeline_class(**components) | ||
| pipe = pipe.to(torch_device) | ||
| pipe.set_progress_bar_config(disable=None) | ||
|
|
||
| # Always ensure the inputs are without the `generator`. Make sure to pass the `generator` | ||
| # explicitly. | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| self.cached_non_lora_outputs[scheduler_cls.__name__] = output_no_lora | ||
|
|
||
| # Ensures that there's no inconsistency when reusing the cache. | ||
| yield | ||
| self.cached_non_lora_outputs.clear() | ||
|
|
||
| def get_dummy_components(self, scheduler_cls=None, use_dora=False, lora_alpha=None): | ||
| if self.unet_kwargs and self.transformer_kwargs: | ||
| raise ValueError("Both `unet_kwargs` and `transformer_kwargs` cannot be specified.") | ||
|
|
@@ -320,13 +348,7 @@ def test_simple_inference(self): | |
| Tests a simple inference and makes sure it works as expected | ||
| """ | ||
| for scheduler_cls in self.scheduler_classes: | ||
| components, text_lora_config, _ = self.get_dummy_components(scheduler_cls) | ||
| pipe = self.pipeline_class(**components) | ||
| pipe = pipe.to(torch_device) | ||
| pipe.set_progress_bar_config(disable=None) | ||
|
|
||
| _, _, inputs = self.get_dummy_inputs() | ||
| output_no_lora = pipe(**inputs)[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
| self.assertTrue(output_no_lora.shape == self.output_shape) | ||
|
|
||
| def test_simple_inference_with_text_lora(self): | ||
|
|
@@ -341,7 +363,7 @@ def test_simple_inference_with_text_lora(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
| self.assertTrue(output_no_lora.shape == self.output_shape) | ||
|
|
||
| pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) | ||
|
|
@@ -424,7 +446,7 @@ def test_low_cpu_mem_usage_with_loading(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
| self.assertTrue(output_no_lora.shape == self.output_shape) | ||
|
|
||
| pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) | ||
|
|
@@ -480,7 +502,7 @@ def test_simple_inference_with_text_lora_and_scale(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
| self.assertTrue(output_no_lora.shape == self.output_shape) | ||
|
|
||
| pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) | ||
|
|
@@ -518,7 +540,7 @@ def test_simple_inference_with_text_lora_fused(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
| self.assertTrue(output_no_lora.shape == self.output_shape) | ||
|
|
||
| pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) | ||
|
|
@@ -550,7 +572,7 @@ def test_simple_inference_with_text_lora_unloaded(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
| self.assertTrue(output_no_lora.shape == self.output_shape) | ||
|
|
||
| pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) | ||
|
|
@@ -585,7 +607,7 @@ def test_simple_inference_with_text_lora_save_load(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
| self.assertTrue(output_no_lora.shape == self.output_shape) | ||
|
|
||
| pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) | ||
|
|
@@ -636,7 +658,7 @@ def test_simple_inference_with_partial_text_lora(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
| self.assertTrue(output_no_lora.shape == self.output_shape) | ||
|
|
||
| pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) | ||
|
|
@@ -687,7 +709,7 @@ def test_simple_inference_save_pretrained_with_text_lora(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
| self.assertTrue(output_no_lora.shape == self.output_shape) | ||
|
|
||
| pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config=None) | ||
|
|
@@ -730,7 +752,7 @@ def test_simple_inference_with_text_denoiser_lora_save_load(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
| self.assertTrue(output_no_lora.shape == self.output_shape) | ||
|
|
||
| pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) | ||
|
|
@@ -771,7 +793,7 @@ def test_simple_inference_with_text_denoiser_lora_and_scale(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
| self.assertTrue(output_no_lora.shape == self.output_shape) | ||
|
|
||
| pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) | ||
|
|
@@ -815,7 +837,7 @@ def test_simple_inference_with_text_lora_denoiser_fused(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
| self.assertTrue(output_no_lora.shape == self.output_shape) | ||
|
|
||
| pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) | ||
|
|
@@ -853,7 +875,7 @@ def test_simple_inference_with_text_denoiser_lora_unloaded(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
| self.assertTrue(output_no_lora.shape == self.output_shape) | ||
|
|
||
| pipe, denoiser = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) | ||
|
|
@@ -932,7 +954,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
|
|
||
| if "text_encoder" in self.pipeline_class._lora_loadable_modules: | ||
| pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") | ||
|
|
@@ -1061,7 +1083,7 @@ def test_simple_inference_with_text_denoiser_block_scale(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
|
|
||
| pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") | ||
| self.assertTrue(check_if_lora_correctly_set(pipe.text_encoder), "Lora not correctly set in text encoder") | ||
|
|
@@ -1118,7 +1140,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_block_lora(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
|
|
||
| if "text_encoder" in self.pipeline_class._lora_loadable_modules: | ||
| pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") | ||
|
|
@@ -1281,7 +1303,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_delete_adapter(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
|
|
||
| if "text_encoder" in self.pipeline_class._lora_loadable_modules: | ||
| pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") | ||
|
|
@@ -1375,7 +1397,7 @@ def test_simple_inference_with_text_denoiser_multi_adapter_weighted(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
|
|
||
| if "text_encoder" in self.pipeline_class._lora_loadable_modules: | ||
| pipe.text_encoder.add_adapter(text_lora_config, "adapter-1") | ||
|
|
@@ -1619,7 +1641,7 @@ def test_simple_inference_with_text_lora_denoiser_fused_multi( | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
| self.assertTrue(output_no_lora.shape == self.output_shape) | ||
|
|
||
| if "text_encoder" in self.pipeline_class._lora_loadable_modules: | ||
|
|
@@ -1700,7 +1722,7 @@ def test_lora_scale_kwargs_match_fusion(self, expected_atol: float = 1e-3, expec | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
| self.assertTrue(output_no_lora.shape == self.output_shape) | ||
|
|
||
| if "text_encoder" in self.pipeline_class._lora_loadable_modules: | ||
|
|
@@ -1755,7 +1777,7 @@ def test_simple_inference_with_dora(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_dora_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_dora_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
| self.assertTrue(output_no_dora_lora.shape == self.output_shape) | ||
|
|
||
| pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) | ||
|
|
@@ -1887,7 +1909,7 @@ def test_logs_info_when_no_lora_keys_found(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
|
|
||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
| original_out = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| original_out = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
|
|
||
| no_op_state_dict = {"lora_foo": torch.tensor(2.0), "lora_bar": torch.tensor(3.0)} | ||
| logger = logging.get_logger("diffusers.loaders.peft") | ||
|
|
@@ -1933,7 +1955,7 @@ def test_set_adapters_match_attention_kwargs(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
| self.assertTrue(output_no_lora.shape == self.output_shape) | ||
|
|
||
| pipe, _ = self.add_adapters_to_pipeline(pipe, text_lora_config, denoiser_lora_config) | ||
|
|
@@ -2287,7 +2309,7 @@ def test_lora_adapter_metadata_save_load_inference(self, lora_alpha): | |
| pipe = self.pipeline_class(**components).to(torch_device) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
| self.assertTrue(output_no_lora.shape == self.output_shape) | ||
|
|
||
| pipe, _ = self.add_adapters_to_pipeline( | ||
|
|
@@ -2337,7 +2359,7 @@ def test_inference_load_delete_load_adapters(self): | |
| pipe.set_progress_bar_config(disable=None) | ||
| _, _, inputs = self.get_dummy_inputs(with_generator=False) | ||
|
|
||
| output_no_lora = pipe(**inputs, generator=torch.manual_seed(0))[0] | ||
| output_no_lora = self.cached_non_lora_outputs[scheduler_cls.__name__] | ||
|
|
||
| if "text_encoder" in self.pipeline_class._lora_loadable_modules: | ||
| pipe.text_encoder.add_adapter(text_lora_config) | ||
|
|
||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.