Skip to content

Issue with predict() and MPS #126

@giovannivolpe

Description

@giovannivolpe

There is an issue with some version of MPS when using the predict() method.
This is probably a problem with the MPS module of PyTorch or Lightning


This is a minimal code that reproduces the issue:

import deeplay as dl
import torch

device = torch.device("mps")

x = torch.zeros(59998, 1, 28, 28).to(device)
backbone = dl.models.BackboneResnet18(1, pool_output=True).build().to(device)
head = dl.MultiLayerPerceptron(512, [512, 512], 10).build().to(device)

print("Start predict backbone", flush=True)
y = backbone.predict(x)
print("Start predict head", flush=True)
y = head.predict(y)
print("End predict", flush=True)


The previous patch (module.py from line 979) doesn't seem to work any more.

    if len(x) >= 1000:
        if isinstance(x, torch.Tensor) and x.device.type == "mps":
            device = x.device
            x = x.cpu().to(device)

        args = []
        for arg in args:
            if isinstance(arg, torch.Tensor) and arg.device.type == "mps":
                arg = arg.cpu().to(device)
            args.append(arg)
        args = tuple(args)
        # for _x in (x,) + args:
        #     if isinstance(_x, torch.Tensor) and _x.device.type == "mps":
        #         _x.to("cpu").to("mps")

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions