Skip to content

Commit a5dbca4

Browse files
committed
fix: tweak cg reconstructor.
1 parent 8f66003 commit a5dbca4

File tree

2 files changed

+19
-8
lines changed

2 files changed

+19
-8
lines changed

src/snake/toolkit/reconstructors/__init__.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
from .base import BaseReconstructor
44
from .pysap import ZeroFilledReconstructor, SequentialReconstructor
5+
from .cg import ConjugateGradientReconstructor
56

6-
7-
__all__ = ["BaseReconstructor", "ZeroFilledReconstructor", "SequentialReconstructor"]
7+
__all__ = [
8+
"BaseReconstructor",
9+
"ZeroFilledReconstructor",
10+
"SequentialReconstructor",
11+
"ConjugateGradientReconstructor",
12+
]

src/snake/toolkit/reconstructors/cg.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,22 +26,26 @@ class ConjugateGradientReconstructor(ZeroFilledReconstructor):
2626

2727
__reconstructor_name__ = "cg"
2828

29-
max_iter: int
30-
tol: float
31-
density_compensation: str | False | None = False
29+
max_iter: int = 50
30+
tol: float = 1e-4
31+
density_compensation: str | bool | None = False
3232
nufft_backend: str = "cufinufft"
3333

34+
def __str__(self) -> str:
35+
"""Return a string representation of the object."""
36+
return f"{self.__reconstructor_name__}_{self.max_iter}_{self.tol:.0e}"
37+
3438
def _reconstruct_nufft(self, data_loader: NonCartesianFrameDataLoader) -> NDArray:
3539
"""Reconstruct the data using the NUFFT operator."""
3640
from mrinufft.extras.gradient import cg
3741

3842
nufft_operator = init_nufft(
3943
data_loader,
4044
density_compensation=self.density_compensation,
41-
backend=self.nufft_backend,
45+
nufft_backend=self.nufft_backend,
4246
)
4347
final_images = np.empty(
44-
(data_loader.n_frames, *data_loader.shape), dtype=np.float32
48+
(data_loader.n_frames, *data_loader.shape), dtype=np.complex64
4549
)
4650

4751
for i in tqdm(range(data_loader.n_frames)):
@@ -52,7 +56,9 @@ def _reconstruct_nufft(self, data_loader: NonCartesianFrameDataLoader) -> NDArra
5256
)[0, :, :2]
5357
data = np.reshape(data, (data.shape[0], data_loader.n_shots, -1))
5458
for j in range(data.shape[1]):
55-
final_images[i, :, :, j] = cg(nufft_operator, data[:, j])
59+
final_images[i, :, :, j] = cg(
60+
nufft_operator, data[:, j],num_iter=self.max_iter, tol=self.tol
61+
)
5662
else:
5763
final_images[i] = cg(
5864
nufft_operator, data, num_iter=self.max_iter, tol=self.tol

0 commit comments

Comments
 (0)