[MPSFloatType Type Error] Running PyTorch code on Mac MPS, runs fine on TinyVGG model but type error when running efficientnet_b0 model #698
Replies: 2 comments
-
Hey @gulnuravci , Did you setup your You may have to set this up at the start to make sure your objects all use the target For example, see here:
When does the error happen? When you run -- Note: If you're using PyTorch 2.0+, you can globally set devices, see: https://www.learnpytorch.io/pytorch_2_intro/#11-globally-set-devices For example (for import torch
# Set the device
device = "cuda" if torch.cuda.is_available() else "cpu"
# Set the device globally
torch.set_default_device(device)
# All tensors created will be on the global device by default
layer = torch.nn.Linear(20, 30)
print(f"Layer weights are on device: {layer.weight.device}")
print(f"Layer creating data on device: {layer(torch.randn(128, 20)).device}")
Also see the PyTorch documentation for more: https://pytorch.org/tutorials/recipes/recipes/changing_default_device.html You could try this with |
Beta Was this translation helpful? Give feedback.
-
Hi @mrdbourke, I did some more debugging with your suggestions and here's what I found:
|
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Hello!
I tried to summarize in the title, but basically I'm running my code on Vscode on Mac and am using Apple's MPS support for PyTorch for faster computing: https://developer.apple.com/metal/pytorch/#:~:text=PyTorch%20uses%20the%20new%20Metal,and%20run%20operations%20on%20Mac..
When I run the TinyVGG model we created in 04. PyTorch Custom Datasets with
device = "mps"
it works fine (since we correctly send things to the device), but I run into this type error:RuntimeError: Input type (MPSFloatType) and weight type (torch.FloatTensor) should be the same
when running the transfer learning code we discussed in07. PyTorch Transfer Learning
. I'm aware something isn't on the correct device, but I couldn't figure out what it is based on google search and looking at PyTorch documentation. Would appreciate any advice!Beta Was this translation helpful? Give feedback.
All reactions