-
Notifications
You must be signed in to change notification settings - Fork 6.5k
Module Group Offloading #10503
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Module Group Offloading #10503
Changes from 42 commits
Commits
Show all changes
50 commits
Select commit
Hold shift + click to select a range
d1737e3
update
a-r-r-o-w 2783669
fix
a-r-r-o-w 6a9a3e5
non_blocking; handle parameters and buffers
a-r-r-o-w c426a34
update
a-r-r-o-w d579037
Group offloading with cuda stream prefetching (#10516)
a-r-r-o-w 5f33621
Merge branch 'main' into groupwise-offloading
a-r-r-o-w a8eabd0
update
a-r-r-o-w deda9a3
Merge branch 'main' into groupwise-offloading
a-r-r-o-w 80ac5a7
copy model hook implementation from pab
a-r-r-o-w d2a2981
update; ~very workaround based implementation but it seems to work as…
a-r-r-o-w 01c7d22
more workarounds to make it actually work
a-r-r-o-w 22aff34
cleanup
a-r-r-o-w 42bc19b
rewrite
a-r-r-o-w 8c63bf5
update
a-r-r-o-w e09e716
make sure to sync current stream before overwriting with pinned params
a-r-r-o-w bf379c1
Merge branch 'main' into groupwise-offloading
a-r-r-o-w 0bf0baf
better check
a-r-r-o-w b850c75
update
a-r-r-o-w 6ed9c2f
remove hook implementation to not deal with merge conflict
a-r-r-o-w 13dd337
Merge branch 'main' into groupwise-offloading
a-r-r-o-w 073d4bc
re-add hook changes
a-r-r-o-w 8ba2bda
why use more memory when less memory do trick
a-r-r-o-w b2e838f
why still use slightly more memory when less memory do trick
a-r-r-o-w f30c55f
Merge branch 'main' into groupwise-offloading
a-r-r-o-w 5ea3d8a
optimise
a-r-r-o-w db2fd3b
add model tests
a-r-r-o-w a0160e1
add pipeline tests
a-r-r-o-w aaa9a53
update docs
a-r-r-o-w 17b2753
Merge branch 'main' into groupwise-offloading
a-r-r-o-w edf8103
add layernorm and groupnorm
a-r-r-o-w af62c93
Merge branch 'main' into groupwise-offloading
a-r-r-o-w f227e15
Merge branch 'main' into groupwise-offloading
a-r-r-o-w 24f9273
address review comments
a-r-r-o-w 8f10d05
improve tests; add docs
a-r-r-o-w 06b411f
improve docs
a-r-r-o-w 8bd7e3b
Merge branch 'main' into groupwise-offloading
a-r-r-o-w 904e470
Apply suggestions from code review
a-r-r-o-w 3172ed5
apply suggestions from code review
a-r-r-o-w 72aa57f
Merge branch 'main' into groupwise-offloading
a-r-r-o-w aee24bc
update tests
a-r-r-o-w db125ce
Merge branch 'main' into groupwise-offloading
a-r-r-o-w 3f20e6b
apply suggestions from review
a-r-r-o-w 840576a
enable_group_offloading -> enable_group_offload for naming consistency
a-r-r-o-w 8804d74
raise errors if multiple offloading strategies used; add relevant tests
a-r-r-o-w 954bb7d
handle .to() when group offload applied
a-r-r-o-w ba6c4a8
Merge branch 'main' into groupwise-offloading
a-r-r-o-w da88c33
refactor some repeated code
a-r-r-o-w a872e84
remove unintentional change from merge conflict
a-r-r-o-w 6be43b8
handle .cuda()
a-r-r-o-w 274b84e
Merge branch 'main' into groupwise-offloading
a-r-r-o-w File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,155 @@ | ||
| # Copyright 2024 HuggingFace Inc. | ||
| # | ||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||
| # you may not use this file except in compliance with the License. | ||
| # You may obtain a copy of the License at | ||
| # | ||
| # http://www.apache.org/licenses/LICENSE-2.0 | ||
| # | ||
| # Unless required by applicable law or agreed to in writing, software | ||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
| # See the License for the specific language governing permissions and | ||
| # limitations under the License. | ||
|
|
||
| import gc | ||
| import unittest | ||
|
|
||
| import torch | ||
|
|
||
| from diffusers.models import ModelMixin | ||
| from diffusers.utils.testing_utils import require_torch_gpu, torch_device | ||
|
|
||
|
|
||
| class DummyBlock(torch.nn.Module): | ||
| def __init__(self, in_features: int, hidden_features: int, out_features: int) -> None: | ||
| super().__init__() | ||
|
|
||
| self.proj_in = torch.nn.Linear(in_features, hidden_features) | ||
| self.activation = torch.nn.ReLU() | ||
| self.proj_out = torch.nn.Linear(hidden_features, out_features) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| x = self.proj_in(x) | ||
| x = self.activation(x) | ||
| x = self.proj_out(x) | ||
| return x | ||
|
|
||
|
|
||
| class DummyModel(ModelMixin): | ||
| def __init__(self, in_features: int, hidden_features: int, out_features: int, num_layers: int) -> None: | ||
| super().__init__() | ||
|
|
||
| self.linear_1 = torch.nn.Linear(in_features, hidden_features) | ||
| self.activation = torch.nn.ReLU() | ||
| self.blocks = torch.nn.ModuleList( | ||
| [DummyBlock(hidden_features, hidden_features, hidden_features) for _ in range(num_layers)] | ||
| ) | ||
| self.linear_2 = torch.nn.Linear(hidden_features, out_features) | ||
|
|
||
| def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
| x = self.linear_1(x) | ||
| x = self.activation(x) | ||
| for block in self.blocks: | ||
| x = block(x) | ||
| x = self.linear_2(x) | ||
| return x | ||
|
|
||
|
|
||
| @require_torch_gpu | ||
| class GroupOffloadTests(unittest.TestCase): | ||
| in_features = 64 | ||
| hidden_features = 256 | ||
| out_features = 64 | ||
| num_layers = 4 | ||
|
|
||
| def setUp(self): | ||
| with torch.no_grad(): | ||
| self.model = self.get_model() | ||
| self.input = torch.randn((4, self.in_features)).to(torch_device) | ||
|
|
||
| def tearDown(self): | ||
| super().tearDown() | ||
|
|
||
| del self.model | ||
| del self.input | ||
| gc.collect() | ||
| torch.cuda.empty_cache() | ||
| torch.cuda.reset_peak_memory_stats() | ||
|
|
||
| def get_model(self): | ||
| torch.manual_seed(0) | ||
| return DummyModel( | ||
| in_features=self.in_features, | ||
| hidden_features=self.hidden_features, | ||
| out_features=self.out_features, | ||
| num_layers=self.num_layers, | ||
| ) | ||
|
|
||
| def test_offloading_forward_pass(self): | ||
| @torch.no_grad() | ||
| def run_forward(model): | ||
| gc.collect() | ||
| torch.cuda.empty_cache() | ||
| torch.cuda.reset_peak_memory_stats() | ||
| self.assertTrue( | ||
| all( | ||
| module._diffusers_hook.get_hook("group_offloading") is not None | ||
| for module in model.modules() | ||
| if hasattr(module, "_diffusers_hook") | ||
| ) | ||
| ) | ||
| model.eval() | ||
| output = model(self.input)[0].cpu() | ||
| max_memory_allocated = torch.cuda.max_memory_allocated() | ||
| return output, max_memory_allocated | ||
|
|
||
| self.model.to(torch_device) | ||
| output_without_group_offloading, mem_baseline = run_forward(self.model) | ||
| self.model.to("cpu") | ||
|
|
||
| model = self.get_model() | ||
| model.enable_group_offloading(torch_device, offload_type="block_level", num_blocks_per_group=3) | ||
| output_with_group_offloading1, mem1 = run_forward(model) | ||
|
|
||
| model = self.get_model() | ||
| model.enable_group_offloading(torch_device, offload_type="block_level", num_blocks_per_group=1) | ||
| output_with_group_offloading2, mem2 = run_forward(model) | ||
|
|
||
| model = self.get_model() | ||
| model.enable_group_offloading( | ||
| torch_device, offload_type="block_level", num_blocks_per_group=1, use_stream=True | ||
| ) | ||
| output_with_group_offloading3, mem3 = run_forward(model) | ||
|
|
||
| model = self.get_model() | ||
| model.enable_group_offloading(torch_device, offload_type="leaf_level") | ||
| output_with_group_offloading4, mem4 = run_forward(model) | ||
|
|
||
| model = self.get_model() | ||
| model.enable_group_offloading(torch_device, offload_type="leaf_level", use_stream=True) | ||
| output_with_group_offloading5, mem5 = run_forward(model) | ||
|
|
||
| # Precision assertions - offloading should not impact the output | ||
| self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading1, atol=1e-5)) | ||
| self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading2, atol=1e-5)) | ||
| self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading3, atol=1e-5)) | ||
| self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading4, atol=1e-5)) | ||
| self.assertTrue(torch.allclose(output_without_group_offloading, output_with_group_offloading5, atol=1e-5)) | ||
|
|
||
| # Memory assertions - offloading should reduce memory usage | ||
| self.assertTrue(mem4 <= mem5 < mem2 < mem3 < mem1 < mem_baseline) | ||
|
|
||
| def test_error_raised_if_streams_used_and_no_cuda_device(self): | ||
| original_is_available = torch.cuda.is_available | ||
| torch.cuda.is_available = lambda: False | ||
| with self.assertRaises(ValueError): | ||
| self.model.enable_group_offloading( | ||
| onload_device=torch.device("cuda"), offload_type="leaf_level", use_stream=True | ||
| ) | ||
| torch.cuda.is_available = original_is_available | ||
|
|
||
| def test_error_raised_if_supports_group_offloading_false(self): | ||
| self.model._supports_group_offloading = False | ||
| with self.assertRaisesRegex(ValueError, "does not support group offloading"): | ||
| self.model.enable_group_offloading(onload_device=torch.device("cuda")) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.