Skip to content

Commit 442361b

Browse files
committed
Modified TorchOperator for cupy integration
Use dlpack as explained in https://docs.cupy.dev/en/stable/reference/interoperability.html to allow moving arrays on gpu between pytorch and cupy.
1 parent ed991d9 commit 442361b

File tree

2 files changed

+57
-9
lines changed

2 files changed

+57
-9
lines changed

pylops_gpu/TorchOperator.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,13 @@
11
import torch
22

3+
from torch.utils.dlpack import from_dlpack, to_dlpack
4+
from pylops.utils import deps
5+
6+
if deps.cupy_enabled:
7+
import cupy as cp
8+
else:
9+
cp = None
10+
311

412
class _TorchOperator(torch.autograd.Function):
513
"""Wrapper class for PyLops operators into Torch functions
@@ -16,20 +24,47 @@ def forward(ctx, x, forw, adj, pylops, device):
1624
ctx.pylops = pylops
1725
ctx.device = device
1826

27+
# prepare input
1928
if ctx.pylops:
20-
x = x.cpu().detach().numpy()
29+
if ctx.device == 'cpu':
30+
# bring x to cpu and numpy
31+
x = x.cpu().detach().numpy()
32+
else:
33+
# pass x to cupy using DLPack
34+
x = cp.fromDlpack(to_dlpack(x))
35+
36+
# apply forward operator
2137
y = ctx.forw(x)
38+
39+
# prepare output
2240
if ctx.pylops:
23-
y = torch.from_numpy(y).to(ctx.device)
41+
if ctx.device == 'cpu':
42+
# move y to torch and device
43+
y = torch.from_numpy(y).to(ctx.device)
44+
else:
45+
# move y to torch and device
46+
y = from_dlpack(y.toDlpack())
2447
return y
2548

2649
@staticmethod
2750
def backward(ctx, y):
51+
# prepare input
2852
if ctx.pylops:
29-
y = y.cpu().detach().numpy()
53+
if ctx.device == 'cpu':
54+
y = y.cpu().detach().numpy()
55+
else:
56+
# pass x to cupy using DLPack
57+
y = cp.fromDlpack(to_dlpack(y))
58+
59+
# apply adjoint operator
3060
x = ctx.adj(y)
61+
62+
# prepare output
3163
if ctx.pylops:
32-
x = torch.from_numpy(x).to(ctx.device)
64+
if ctx.device == 'cpu':
65+
x = torch.from_numpy(x).to(ctx.device)
66+
else:
67+
x = from_dlpack(x.toDlpack())
3368
return x, None, None, None, None
3469

3570

@@ -75,6 +110,7 @@ def __init__(self, Op, batch=False, pylops=False, device='cpu'):
75110
else:
76111
self.matvec = lambda x: Op.matmat(x, kfirst=True)
77112
self.rmatvec = lambda x: Op.rmatmat(x, kfirst=True)
113+
self.Top = _TorchOperator.apply
78114

79115
def apply(self, x):
80116
"""Apply forward pass to input vector
@@ -90,5 +126,5 @@ def apply(self, x):
90126
Output array resulting from the application of the operator to ``x``.
91127
92128
"""
93-
return _TorchOperator.apply(x, self.matvec, self.rmatvec,
94-
self.pylops, self.device)
129+
return self.Top(x, self.matvec, self.rmatvec,
130+
self.pylops, self.device)

tutorials/ad.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import numpy as np
2020
import torch
2121
import matplotlib.pyplot as plt
22+
from torch.autograd import gradcheck
2223

2324
import pylops_gpu
2425
from pylops_gpu.utils.backend import device
@@ -61,15 +62,15 @@
6162
# :math:`\mathbf{v}` that we have provided to PyTorch ``backward``.
6263

6364
nx, ny = 10, 6
64-
x0 = torch.arange(nx, dtype=torch.float32, requires_grad=True)
65+
x0 = torch.arange(nx, dtype=torch.double, requires_grad=True)
6566

6667
# Forward
67-
A = torch.normal(0., 1., (ny, nx))
68+
A = torch.normal(0., 1., (ny, nx), dtype=torch.double)
6869
Aop = pylops_gpu.TorchOperator(pylops_gpu.MatrixMult(A))
6970
y = Aop.apply(torch.sin(x0))
7071

7172
# AD
72-
v = torch.ones(ny)
73+
v = torch.ones(ny, dtype=torch.double)
7374
y.backward(v, retain_graph=True)
7475
adgrad = x0.grad
7576

@@ -81,6 +82,17 @@
8182
print('AD gradient: ', adgrad)
8283
print('Analytical gradient: ', anagrad)
8384

85+
86+
###############################################################################
87+
# Similarly we can use the :func:`torch.autograd.gradcheck` directly from
88+
# PyTorch. Note that doubles must be used for this to succeed with very small
89+
# `eps` and `atol`
90+
input = (torch.arange(nx, dtype=torch.double, requires_grad=True),
91+
Aop.matvec, Aop.rmatvec, Aop.pylops, Aop.device)
92+
test = gradcheck(Aop.Top, input, eps=1e-6, atol=1e-4)
93+
print(test)
94+
95+
8496
###############################################################################
8597
# Note that while matrix-vector multiplication could have been performed using
8698
# the native PyTorch operator :func:`torch.matmul`, in this case we have shown

0 commit comments

Comments
 (0)