- 
                Notifications
    You must be signed in to change notification settings 
- Fork 6.5k
          [tests] model-level device_map clarifications
          #11681
        
          New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 22 commits
a4dd7fd
              5e35ac5
              9b8015c
              d4a380d
              2672347
              c3431bf
              f035a0d
              43c41f4
              8ab7d17
              7f85be4
              0bd70de
              962483b
              eb913e2
              e5820d7
              407b67f
              e9ccc73
              08d429b
              3bebf25
              f359edc
              00b1a06
              86539e6
              c3cadc6
              143420e
              b9d60da
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -1083,6 +1083,42 @@ def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self): | |
| assert loaded_model | ||
| assert new_output.sample.shape == (4, 4, 16, 16) | ||
|  | ||
| @parameterized.expand( | ||
| [ | ||
| (-1, "You can't pass device_map as a negative int"), | ||
| ("foo", "When passing device_map as a string, the value needs to be a device name"), | ||
| ] | ||
| ) | ||
| def test_wrong_device_map_raises_error(self, device_map, msg_substring): | ||
| with self.assertRaises(ValueError) as err_ctx: | ||
| _ = self.model_class.from_pretrained( | ||
| "hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map | ||
| ) | ||
|  | ||
| assert msg_substring in str(err_ctx.exception) | ||
|  | ||
| @parameterized.expand([0, "cuda", torch.device("cuda"), torch.device("cuda:0")]) | ||
| @require_torch_gpu | ||
| def test_passing_non_dict_device_map_works(self, device_map): | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we test some more cases like:  Additionally, we should probably run device map tests for all models IMO (can be taken up in future PR) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree this would be fantastic, but we can probably tackle that in a separate PR, and leave the scope of this one to tests/docs/bugfixes/assertions. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 
 We have a bunch of  I can shift the current ones being added through this PR to  
 Feel free to add that in a separate PR. | ||
| _, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
| loaded_model = self.model_class.from_pretrained( | ||
| "hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map=device_map | ||
| ) | ||
| output = loaded_model(**inputs_dict) | ||
| assert output.sample.shape == (4, 4, 16, 16) | ||
|  | ||
| @parameterized.expand([("", "cuda"), ("", torch.device("cuda"))]) | ||
| @require_torch_gpu | ||
| def test_passing_dict_device_map_works(self, name, device_map): | ||
| # There are other valid dict-based `device_map` values too. It's best to refer to | ||
| # the docs for those: https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap. | ||
| _, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
| loaded_model = self.model_class.from_pretrained( | ||
| "hf-internal-testing/unet2d-sharded-dummy-subfolder", subfolder="unet", device_map={name: device_map} | ||
| ) | ||
| output = loaded_model(**inputs_dict) | ||
| assert output.sample.shape == (4, 4, 16, 16) | ||
|  | ||
| @require_peft_backend | ||
| def test_load_attn_procs_raise_warning(self): | ||
| init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() | ||
|  | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it looks like this is just documenting the scalar cases. the bit that I need docs for is the dictionary convention.
{'': device.type}as the simpest valid input is extremely hard to guess. there really needs to be an explanation of what the key of the dictionary means.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cc: @SunMarc @stevhliu
How is this documented in
transformers?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can find this here https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also can you fix the typing for the
DiffusionPipelinefrom_pretrained for device_map since for this specific function, we only allowbalancedvalue ?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done in 407b67f.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Birch-san I clarified the docs to include the case of
{"": torch.device("cuda")}and have added tests for it, too. For other possible and validdictinputs todevice_map, I would have to defer you to https://huggingface.co/docs/accelerate/en/concept_guides/big_model_inference#the-devicemap as you can notice it's hard to specify that beforehand without doing a bit of investigation.So, I would suggest loading your model with "auto"
device_map, first. And then printing(model.hf_device_map)to get a much better handle. This way, you will have a reasonable starting point which you could then use to tweak things around a bit.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's not clear to me from the accelerate docs how the key is used. the fact that
''works suggests there's some kind of pattern-matching or special-cases, which aren't documented.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Going to defer to @SunMarc for that (again).