-
Couldn't load subscription status.
- Fork 6.5k
[tests] add test for hotswapping + compilation on resolution changes #11825
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 2 commits
Commits
Show all changes
8 commits
Select commit
Hold shift + click to select a range
4da07a7
add resolution changes tests to hotswapping test suite.
sayakpaul 9f1c83f
fixes
sayakpaul 7fba82c
docs
sayakpaul 3f76c09
Merge branch 'main' into resolution-hotswap-tests
sayakpaul 2076a53
explain duck shapes
sayakpaul 579fb76
fix
sayakpaul 2dc11a2
Merge branch 'main' into resolution-hotswap-tests
sayakpaul 9afccbc
Merge branch 'main' into resolution-hotswap-tests
sayakpaul File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1350,7 +1350,6 @@ def test_model_parallelism(self): | |
| new_model = self.model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) | ||
| # Making sure part of the model will actually end up offloaded | ||
| self.assertSetEqual(set(new_model.hf_device_map.values()), {0, 1}) | ||
| print(f" new_model.hf_device_map:{new_model.hf_device_map}") | ||
|
|
||
| self.check_device_map_is_respected(new_model, new_model.hf_device_map) | ||
|
|
||
|
|
@@ -2019,6 +2018,8 @@ class LoraHotSwappingForModelTesterMixin: | |
|
|
||
| """ | ||
|
|
||
| different_shapes_for_compilation = None | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| def tearDown(self): | ||
| # It is critical that the dynamo cache is reset for each test. Otherwise, if the test re-uses the same model, | ||
| # there will be recompilation errors, as torch caches the model when run in the same process. | ||
|
|
@@ -2046,7 +2047,9 @@ def get_linear_module_name_other_than_attn(self, model): | |
| ] | ||
| return linear_names[0] | ||
|
|
||
| def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_modules1=None): | ||
| def check_model_hotswap( | ||
| self, do_compile, rank0, rank1, target_modules0, target_modules1=None, different_resolutions=None | ||
| ): | ||
| """ | ||
| Check that hotswapping works on a small unet. | ||
|
|
||
|
|
@@ -2056,6 +2059,7 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_ | |
| - hotswap the second adapter | ||
| - check that the outputs are correct | ||
| - optionally compile the model | ||
| - optionally check if recompilations happen on different shapes | ||
|
|
||
| Note: We set rank == alpha here because save_lora_adapter does not save the alpha scalings, thus the test would | ||
| fail if the values are different. Since rank != alpha does not matter for the purpose of this test, this is | ||
|
|
@@ -2110,19 +2114,30 @@ def check_model_hotswap(self, do_compile, rank0, rank1, target_modules0, target_ | |
| model.load_lora_adapter(file_name0, safe_serialization=True, adapter_name="adapter0", prefix=None) | ||
|
|
||
| if do_compile: | ||
| model = torch.compile(model, mode="reduce-overhead") | ||
| model = torch.compile(model, mode="reduce-overhead", dynamic=different_resolutions is not None) | ||
|
|
||
| with torch.inference_mode(): | ||
| output0_after = model(**inputs_dict)["sample"] | ||
| assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol) | ||
| # additionally check if dynamic compilation works. | ||
| if different_resolutions is not None: | ||
| for height, width in self.different_shapes_for_compilation: | ||
| new_inputs_dict = self.prepare_dummy_input(height=height, width=width) | ||
| _ = model(**new_inputs_dict) | ||
| else: | ||
| output0_after = model(**inputs_dict)["sample"] | ||
| assert torch.allclose(output0_before, output0_after, atol=tol, rtol=tol) | ||
|
||
|
|
||
| # hotswap the 2nd adapter | ||
| model.load_lora_adapter(file_name1, adapter_name="adapter0", hotswap=True, prefix=None) | ||
|
|
||
| # we need to call forward to potentially trigger recompilation | ||
| with torch.inference_mode(): | ||
| output1_after = model(**inputs_dict)["sample"] | ||
| assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol) | ||
| if different_resolutions is not None: | ||
| for height, width in self.different_shapes_for_compilation: | ||
| new_inputs_dict = self.prepare_dummy_input(height=height, width=width) | ||
| _ = model(**new_inputs_dict) | ||
| else: | ||
| output1_after = model(**inputs_dict)["sample"] | ||
| assert torch.allclose(output1_before, output1_after, atol=tol, rtol=tol) | ||
|
|
||
| # check error when not passing valid adapter name | ||
| name = "does-not-exist" | ||
|
|
@@ -2240,3 +2255,21 @@ def test_hotswap_second_adapter_targets_more_layers_raises(self): | |
| do_compile=True, rank0=8, rank1=8, target_modules0=target_modules0, target_modules1=target_modules1 | ||
| ) | ||
| assert any("Hotswapping adapter0 was unsuccessful" in log for log in cm.output) | ||
|
|
||
| @parameterized.expand([(11, 11), (7, 13), (13, 7)]) | ||
| @require_torch_version_greater("2.7.1") | ||
| def test_hotswapping_compile_on_different_shapes(self, rank0, rank1): | ||
| different_shapes_for_compilation = self.different_shapes_for_compilation | ||
| if different_shapes_for_compilation is None: | ||
| pytest.skip(f"Skipping as `different_shapes_for_compilation` is not set for {self.__class__.__name__}.") | ||
| torch.fx.experimental._config.use_duck_shape = False | ||
sayakpaul marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| target_modules = ["to_q", "to_k", "to_v", "to_out.0"] | ||
| with torch._dynamo.config.patch(error_on_recompile=True): | ||
| self.check_model_hotswap( | ||
| do_compile=True, | ||
| rank0=rank0, | ||
| rank1=rank1, | ||
| target_modules0=target_modules, | ||
| different_resolutions=different_shapes_for_compilation, | ||
| ) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unrelated but hopefully okay :-)