11import torch
22
3+ from torch .utils .dlpack import from_dlpack , to_dlpack
4+ from pylops .utils import deps
5+
6+ if deps .cupy_enabled :
7+ import cupy as cp
8+ else :
9+ cp = None
10+
311
412class _TorchOperator (torch .autograd .Function ):
513 """Wrapper class for PyLops operators into Torch functions
@@ -16,20 +24,47 @@ def forward(ctx, x, forw, adj, pylops, device):
1624 ctx .pylops = pylops
1725 ctx .device = device
1826
27+ # prepare input
1928 if ctx .pylops :
20- x = x .cpu ().detach ().numpy ()
29+ if ctx .device == 'cpu' :
30+ # bring x to cpu and numpy
31+ x = x .cpu ().detach ().numpy ()
32+ else :
33+ # pass x to cupy using DLPack
34+ x = cp .fromDlpack (to_dlpack (x ))
35+
36+ # apply forward operator
2137 y = ctx .forw (x )
38+
39+ # prepare output
2240 if ctx .pylops :
23- y = torch .from_numpy (y ).to (ctx .device )
41+ if ctx .device == 'cpu' :
42+ # move y to torch and device
43+ y = torch .from_numpy (y ).to (ctx .device )
44+ else :
45+ # move y to torch and device
46+ y = from_dlpack (y .toDlpack ())
2447 return y
2548
2649 @staticmethod
2750 def backward (ctx , y ):
51+ # prepare input
2852 if ctx .pylops :
29- y = y .cpu ().detach ().numpy ()
53+ if ctx .device == 'cpu' :
54+ y = y .cpu ().detach ().numpy ()
55+ else :
56+ # pass x to cupy using DLPack
57+ y = cp .fromDlpack (to_dlpack (y ))
58+
59+ # apply adjoint operator
3060 x = ctx .adj (y )
61+
62+ # prepare output
3163 if ctx .pylops :
32- x = torch .from_numpy (x ).to (ctx .device )
64+ if ctx .device == 'cpu' :
65+ x = torch .from_numpy (x ).to (ctx .device )
66+ else :
67+ x = from_dlpack (x .toDlpack ())
3368 return x , None , None , None , None
3469
3570
@@ -75,6 +110,7 @@ def __init__(self, Op, batch=False, pylops=False, device='cpu'):
75110 else :
76111 self .matvec = lambda x : Op .matmat (x , kfirst = True )
77112 self .rmatvec = lambda x : Op .rmatmat (x , kfirst = True )
113+ self .Top = _TorchOperator .apply
78114
79115 def apply (self , x ):
80116 """Apply forward pass to input vector
@@ -90,5 +126,5 @@ def apply(self, x):
90126 Output array resulting from the application of the operator to ``x``.
91127
92128 """
93- return _TorchOperator . apply (x , self .matvec , self .rmatvec ,
94- self .pylops , self .device )
129+ return self . Top (x , self .matvec , self .rmatvec ,
130+ self .pylops , self .device )
0 commit comments