-
Notifications
You must be signed in to change notification settings - Fork 6.5k
[tests] model-level device_map clarifications
#11681
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
a4dd7fd
5e35ac5
9b8015c
d4a380d
2672347
c3431bf
f035a0d
43c41f4
8ab7d17
7f85be4
0bd70de
962483b
eb913e2
e5820d7
407b67f
e9ccc73
08d429b
3bebf25
f359edc
00b1a06
86539e6
c3cadc6
143420e
b9d60da
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -46,6 +46,7 @@ | |
| require_peft_backend, | ||
| require_torch_accelerator, | ||
| require_torch_accelerator_with_fp16, | ||
| require_torch_gpu, | ||
| skip_mps, | ||
| slow, | ||
| torch_all_close, | ||
|
|
@@ -1083,6 +1084,42 @@ def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self): | |
| assert loaded_model | ||
| assert new_output.sample.shape == (4, 4, 16, 16) | ||
|
|
||
| @parameterized.expand( | ||
| [ | ||
| (-1, "You can't pass device_map as a negative int"), | ||
| ("foo", "When passing device_map as a string, the value needs to be a device name"), | ||
| ] | ||
| ) | ||
| def test_wrong_device_map_raises_error(self, device_map, msg_substring): | ||
| with self.assertRaises(ValueError) as err_ctx: | ||
| _ = self.model_class.from_pretrained( | ||
| "hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map | ||
| ) | ||
|
|
||
| assert msg_substring in str(err_ctx.exception) | ||
|
|
||
| @parameterized.expand([0, "cuda", torch.device("cuda"), torch.device("cuda:0")]) | ||
| @require_torch_gpu | ||
| def test_passing_non_dict_device_map_works(self, device_map): | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we test some more cases like: Additionally, we should probably run device map tests for all models IMO (can be taken up in future PR) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree this would be fantastic, but we can probably tackle that in a separate PR, and leave the scope of this one to tests/docs/bugfixes/assertions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
We have a bunch of I can shift the current ones being added through this PR to
Feel free to add that in a separate PR. |
||
| _, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
| loaded_model = self.model_class.from_pretrained( | ||
| "hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map | ||
| ) | ||
| output = loaded_model(**inputs_dict) | ||
| assert output.sample.shape == (4, 4, 16, 16) | ||
|
|
||
| @parameterized.expand([("", "cuda"), ("", torch.device("cuda"))]) | ||
| @require_torch_gpu | ||
| def test_passing_dict_device_map_works(self, name, device_map): | ||
| # There are other valid dict-based `device_map` values too. It's best to refer to | ||
| # the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap. | ||
| _, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
| loaded_model = self.model_class.from_pretrained( | ||
| "hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map={name: device_map} | ||
| ) | ||
| output = loaded_model(**inputs_dict) | ||
| assert output.sample.shape == (4, 4, 16, 16) | ||
|
|
||
| @require_peft_backend | ||
| def test_load_attn_procs_raise_warning(self): | ||
| init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like this is just documenting the scalar cases. the bit that I need docs for is the dictionary convention.
{'': device.type}as the simpest valid input is extremely hard to guess. there really needs to be an explanation of what the key of the dictionary means.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cc: @SunMarc @stevhliu
How is this documented in
transformers?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can find this here https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also can you fix the typing for the
DiffusionPipelinefrom_pretrained for device_map since for this specific function, we only allowbalancedvalue ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in 407b67f.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Birch-san I clarified the docs to include the case of
{"": torch.device("cuda")}and have added tests for it, too. For other possible and validdictinputs todevice_map, I would have to defer you to https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap as you can notice it's hard to specify that beforehand without doing a bit of investigation.So, I would suggest loading your model with "auto"
device_map, first. And then printing(model.hf_device_map)to get a much better handle. This way, you will have a reasonable starting point which you could then use to tweak things around a bit.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not clear to me from the accelerate docs how the key is used. the fact that
''works suggests there's some kind of pattern-matching or special-cases, which aren't documented.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Going to defer to @SunMarc for that (again).