11import numpy as np
22
33from pylops import LinearOperator
4+ from pylops .utils ._internal import _value_or_list_like_to_tuple
45
56
67class Pad (LinearOperator ):
@@ -60,14 +61,19 @@ class Pad(LinearOperator):
6061 def __init__ (self , dims , pad , dtype = "float64" ):
6162 if np .any (np .array (pad ) < 0 ):
6263 raise ValueError ("Padding must be positive or zero" )
63- self .dims = dims
64+ self .reshape = False if isinstance (dims , int ) else True
65+ self .dims = _value_or_list_like_to_tuple (dims )
6466 self .pad = pad
65- self .reshape = False if isinstance (self .dims , int ) else True
6667 if self .reshape :
67- self .dimsd = [dim + p [0 ] + p [1 ] for dim , p in zip (dims , pad )]
68+ dimsd = [
69+ dim + before + after
70+ for dim , (before , after ) in zip (self .dims , self .pad )
71+ ]
6872 else :
69- self .dimsd = dims + pad [0 ] + pad [1 ]
70- self .shape = (np .prod (np .array (self .dimsd )), np .prod (np .array (self .dims )))
73+ dimsd = [self .dims [0 ] + pad [0 ] + pad [1 ]]
74+ self .dimsd = tuple (dimsd )
75+
76+ self .shape = (np .prod (self .dimsd ), np .prod (self .dims ))
7177 self .dtype = np .dtype (dtype )
7278 self .explicit = False
7379
@@ -82,8 +88,8 @@ def _matvec(self, x):
8288 def _rmatvec (self , x ):
8389 if self .reshape :
8490 y = x .reshape (self .dimsd )
85- for ax , pad in enumerate (self .pad ):
86- y = np .take (y , np .arange (pad [ 0 ], pad [ 0 ] + self .dims [ax ]), axis = ax )
91+ for ax , ( before , _ ) in enumerate (self .pad ):
92+ y = np .take (y , np .arange (before , before + self .dims [ax ]), axis = ax )
8793 else :
88- y = x [self .pad [0 ] : self .pad [0 ] + self .dims ]
94+ y = x [self .pad [0 ] : self .pad [0 ] + self .dims [ 0 ] ]
8995 return y .ravel ()
0 commit comments