Skip to content

Commit 958291a

Browse files
committed
refactor: use the base operator safe_squeeze.
1 parent 17318d5 commit 958291a

File tree

5 files changed

+13
-65
lines changed

5 files changed

+13
-65
lines changed

src/mrinufft/operators/base.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,19 @@ def check_shape(self, *, image=None, ksp=None):
192192
if image is None and ksp is None:
193193
raise ValueError("Nothing to check, provides image or ksp arguments")
194194

195+
def _safe_squeeze(self, arr):
196+
"""Squeeze the first two dimensions of shape of the operator."""
197+
if self.squeeze_dims:
198+
try:
199+
arr = arr.squeeze(axis=1)
200+
except ValueError:
201+
pass
202+
try:
203+
arr = arr.squeeze(axis=0)
204+
except ValueError:
205+
pass
206+
return arr
207+
195208
@abstractmethod
196209
def op(self, data):
197210
"""Compute operator transform.
@@ -795,19 +808,6 @@ def _grad_calibless(self, image_data, obs_data):
795808
grad /= self.norm_factor
796809
return grad.reshape(B, C, *XYZ)
797810

798-
def _safe_squeeze(self, arr):
799-
"""Squeeze the first two dimensions of shape of the operator."""
800-
if self.squeeze_dims:
801-
try:
802-
arr = arr.squeeze(axis=1)
803-
except ValueError:
804-
pass
805-
try:
806-
arr = arr.squeeze(axis=0)
807-
except ValueError:
808-
pass
809-
return arr
810-
811811

812812
def power_method(
813813
max_iter: int,

src/mrinufft/operators/interfaces/cufinufft.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -749,19 +749,6 @@ def _dc_calibless_device(self, image_data, obs_data):
749749
grad /= self.norm_factor
750750
return grad
751751

752-
def _safe_squeeze(self, arr):
753-
"""Squeeze the first two dimensions of shape of the operator."""
754-
if self.squeeze_dims:
755-
try:
756-
arr = arr.squeeze(axis=1)
757-
except ValueError:
758-
pass
759-
try:
760-
arr = arr.squeeze(axis=0)
761-
except ValueError:
762-
pass
763-
return arr
764-
765752
@property
766753
def eps(self):
767754
"""Return the underlying precision parameter."""

src/mrinufft/operators/interfaces/gpunufft.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -656,19 +656,6 @@ def get_lipschitz_cst(self, max_iter=10, tolerance=1e-5, **kwargs):
656656
max_iter=max_iter, tolerance=tolerance
657657
)
658658

659-
def _safe_squeeze(self, arr):
660-
"""Squeeze the first two dimensions of shape of the operator."""
661-
if self.squeeze_dims:
662-
try:
663-
arr = arr.squeeze(axis=1)
664-
except ValueError:
665-
pass
666-
try:
667-
arr = arr.squeeze(axis=0)
668-
except ValueError:
669-
pass
670-
return arr
671-
672659
@with_numpy_cupy
673660
def data_consistency(self, image_data, obs_data):
674661
"""Compute the data consistency estimation directly on gpu.

src/mrinufft/operators/interfaces/torchkbnufft.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -166,19 +166,6 @@ def adj_op(self, coeffs, out=None):
166166
img /= self.norm_factor
167167
return self._safe_squeeze(img)
168168

169-
def _safe_squeeze(self, arr):
170-
"""Squeeze the first two dimensions of shape of the operator."""
171-
if self.squeeze_dims:
172-
try:
173-
arr = arr.squeeze(axis=1)
174-
except ValueError:
175-
pass
176-
try:
177-
arr = arr.squeeze(axis=0)
178-
except ValueError:
179-
pass
180-
return arr
181-
182169
@with_torch
183170
def data_consistency(self, data, obs_data):
184171
"""Compute the data consistency.

src/mrinufft/operators/stacked.py

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -283,19 +283,6 @@ def _adj_op_calibless(self, coeffs, img):
283283
img = self._ifftz(imgz)
284284
return img
285285

286-
def _safe_squeeze(self, arr):
287-
"""Squeeze the first two dimensions of shape of the operator."""
288-
if self.squeeze_dims:
289-
try:
290-
arr = arr.squeeze(axis=1)
291-
except ValueError:
292-
pass
293-
try:
294-
arr = arr.squeeze(axis=0)
295-
except ValueError:
296-
pass
297-
return arr
298-
299286
def get_lipschitz_cst(self, max_iter=10):
300287
"""Return the Lipschitz constant of the operator.
301288

0 commit comments

Comments
 (0)