-
Notifications
You must be signed in to change notification settings - Fork 14
Open
Description
There is an issue with some version of MPS when using the predict() method.
This is probably a problem with the MPS module of PyTorch or Lightning
This is a minimal code that reproduces the issue:
import deeplay as dl
import torch
device = torch.device("mps")
x = torch.zeros(59998, 1, 28, 28).to(device)
backbone = dl.models.BackboneResnet18(1, pool_output=True).build().to(device)
head = dl.MultiLayerPerceptron(512, [512, 512], 10).build().to(device)
print("Start predict backbone", flush=True)
y = backbone.predict(x)
print("Start predict head", flush=True)
y = head.predict(y)
print("End predict", flush=True)
The previous patch (module.py from line 979) doesn't seem to work any more.
if len(x) >= 1000:
if isinstance(x, torch.Tensor) and x.device.type == "mps":
device = x.device
x = x.cpu().to(device)
args = []
for arg in args:
if isinstance(arg, torch.Tensor) and arg.device.type == "mps":
arg = arg.cpu().to(device)
args.append(arg)
args = tuple(args)
# for _x in (x,) + args:
# if isinstance(_x, torch.Tensor) and _x.device.type == "mps":
# _x.to("cpu").to("mps")
Metadata
Metadata
Assignees
Labels
No labels