Skip to content
Draft
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 103 additions & 26 deletions pylops/medical/ct.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from pylops import LinearOperator
from pylops.utils import deps
from pylops.utils.backend import get_array_module, get_module_name, to_numpy
from pylops.utils.decorators import reshaped
from pylops.utils.typing import DTypeLike, InputDimsLike, NDArray

Expand All @@ -19,6 +20,7 @@

if astra_message is None:
import astra
import astra.experimental


class CT2D(LinearOperator):
Expand Down Expand Up @@ -50,7 +52,11 @@ class CT2D(LinearOperator):
Distance between origin and detector along the source-origin line
(only for "proj_geom_type=fanflat")
projector_type : :obj:`int`, optional
Type of projection geometry (``strip``, or ``line``, or ``linear``)
Type of projection kernel (``strip``, or ``line``, or ``linear``).
For ``cuda`` computation engine, only ``cuda`` (hardware-accelerated ``linear``)
type is supported.
engine : :obj:`str`, optional
Engine used for computation (``cpu`` or ``cuda``).
dtype : :obj:`str`, optional
Type of elements in input array.
name : :obj:`str`, optional
Expand Down Expand Up @@ -86,54 +92,125 @@ class CT2D(LinearOperator):
def __init__(
self,
dims: InputDimsLike,
det_width: int,
det_count: float,
det_width: float,
det_count: int,
thetas: NDArray,
proj_geom_type: Optional[str] = "parallel",
source_origin_dist: float = None,
origin_detector_dist: float = None,
projector_type: Optional[str] = "strip",
projector_type: Optional[str] = None,
engine="cpu",
dtype: DTypeLike = "float64",
name: str = "C",
) -> None:
if astra_message is not None:
raise NotImplementedError(astra_message)

# create volume and projection geometries
self.vol_geom = astra.create_vol_geom(dims)
if proj_geom_type == "parallel":
self.dims = dims
self.det_width = det_width
self.det_count = det_count
self.thetas = to_numpy(thetas) # ASTRA can only consume angles as a NumPy array
self.proj_geom_type = proj_geom_type
self.source_origin_dist = source_origin_dist
self.origin_detector_dist = origin_detector_dist

# make "strip" projector type default for cpu engine and only allow cuda otherwise
if engine == "cpu":
if projector_type is None:
self.projector_type = "strip"
else:
self.projector_type = projector_type
if projector_type == "cuda":
logging.warning("'cuda' projector type specified with 'cpu' engine.")
elif engine == "cuda":
if projector_type in [None, "cuda"]:
self.projector_type = "cuda"
else:
raise ValueError(
"Only 'cuda' projector type is supported for 'cuda' engine."
)

# create create volume and projection geometries as well as projector
self._init_geometries()
if engine == "cuda":
# efficient GPU data exchange only implemented for 3D data in ASTRA, so we
# emulate 2D geometry as 3D case with 1 slice
self._init_1_slice_3d_geometries()

super().__init__(
dtype=np.dtype(dtype), dims=dims, dimsd=(len(thetas), det_count), name=name
)

def _init_geometries(self):
self.vol_geom = astra.create_vol_geom(self.dims)
if self.proj_geom_type == "parallel":
self.proj_geom = astra.create_proj_geom(
proj_geom_type, det_width, det_count, thetas
"parallel", self.det_width, self.det_count, self.thetas
)
else:
elif self.proj_geom_type == "fanflat":
self.proj_geom = astra.create_proj_geom(
proj_geom_type,
det_width,
det_count,
thetas,
source_origin_dist,
origin_detector_dist,
"fanflat",
self.det_width,
self.det_count,
self.thetas,
self.source_origin_dist,
self.origin_detector_dist,
)

# create projector
self.proj_id = astra.create_projector(
projector_type, self.proj_geom, self.vol_geom
self.projector_id = astra.create_projector(
self.projector_type, self.proj_geom, self.vol_geom
)
super().__init__(
dtype=np.dtype(dtype), dims=dims, dimsd=(len(thetas), det_count), name=name

def _init_1_slice_3d_geometries(self):
self._3d_vol_geom = astra.create_vol_geom(*self.dims, 1)
if self.proj_geom_type == "parallel":
self._3d_proj_geom = astra.create_proj_geom(
"parallel3d", 1.0, self.det_width, 1, self.det_count, self.thetas
)
elif self.proj_geom_type == "fanflat":
self._3d_proj_geom = astra.create_proj_geom(
"cone",
1.0,
self.det_width,
1,
self.det_count,
self.thetas,
self.source_origin_dist,
self.origin_detector_dist,
)
self._3d_projector_id = astra.create_projector(
"cuda3d", self._3d_proj_geom, self._3d_vol_geom
)

@reshaped
def _matvec(self, x):
y_id, y = astra.create_sino(x, self.proj_id)
astra.data2d.delete(y_id)
ncp = get_array_module(x)
backend = get_module_name(ncp)
if backend == "numpy":
y_id, y = astra.create_sino(x, self.projector_id)
astra.data2d.delete(y_id)
else:
# Ensure x and y are 1-slice 3D arrays
x = ncp.expand_dims(x, axis=0)
y = ncp.empty_like(x, shape=astra.geom_size(self._3d_proj_geom))
astra.experimental.direct_FP3D(self._3d_projector_id, x, y)
return y

@reshaped
def _rmatvec(self, x):
y_id, y = astra.create_backprojection(x, self.proj_id)
astra.data2d.delete(y_id)
ncp = get_array_module(x)
backend = get_module_name(ncp)
if backend == "numpy":
y_id, y = astra.create_backprojection(x, self.projector_id)
astra.data2d.delete(y_id)
else:
# Ensure x and y are 1-slice 3D arrays
x = ncp.expand_dims(x, axis=0)
y = ncp.empty_like(x, shape=astra.geom_size(self._3d_vol_geom))
astra.experimental.direct_BP3D(self._3d_projector_id, y, x)
return y

def __del__(self):
astra.projector.delete(self.proj_id)
if hasattr(self, "projector_id"):
astra.projector.delete(self.projector_id)
if hasattr(self, "_3d_projector_id"):
astra.projector.delete(self._3d_projector_id)
88 changes: 54 additions & 34 deletions pytests/test_ct.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,64 @@
import os
import platform

import numpy as np
if int(os.environ.get("TEST_CUPY_PYLOPS", 0)):
import cupy as np

backend = "cupy"
else:
import numpy as np

backend = "numpy"
import pytest

from pylops.medical import CT2D
from pylops.utils import dottest

par1 = {
"ny": 51,
"nx": 30,
"ntheta": 20,
"proj_geom_type": "parallel",
"projector_type": "strip",
"dtype": "float64",
} # parallel, strip

par2 = {
"ny": 51,
"nx": 30,
"ntheta": 20,
"proj_geom_type": "parallel",
"projector_type": "line",
"dtype": "float64",
} # parallel, line

par3 = {
"ny": 51,
"nx": 30,
"ntheta": 20,
"proj_geom_type": "fanflat",
"projector_type": "strip",
"dtype": "float64",
} # fanflat, strip


@pytest.mark.skipif(
int(os.environ.get("TEST_CUPY_PYLOPS", 0)) == 1, reason="Not CuPy enabled"
)
par = {
"parallel_strip": {
"ny": 51,
"nx": 30,
"ntheta": 20,
"proj_geom_type": "parallel",
"projector_type": "strip",
"dtype": "float64",
},
"parallel_line": {
"ny": 51,
"nx": 30,
"ntheta": 20,
"proj_geom_type": "parallel",
"projector_type": "line",
"dtype": "float64",
},
"fanflat_strip": {
"ny": 51,
"nx": 30,
"ntheta": 20,
"proj_geom_type": "fanflat",
"source_origin_dist": 100,
"origin_detector_dist": 0,
"projector_type": "strip_fanflat",
"dtype": "float64",
},
"cuda": {
"ny": 51,
"nx": 30,
"ntheta": 20,
"proj_geom_type": "parallel",
"projector_type": "cuda",
"dtype": "float64",
},
}


@pytest.mark.skipif(platform.system() == "Darwin", reason="Not OSX enabled")
@pytest.mark.parametrize("par", [(par1), (par2)])
@pytest.mark.parametrize("par", par.values(), ids=par.keys())
def test_CT2D(par):
"""Dot-test for CT2D operator"""
if backend == "cupy" or par["projector_type"] == "cuda":
pytest.skip("CUDA tests are failing because of severely mismatched adjoint.")

theta = np.linspace(0.0, np.pi, par["ntheta"], endpoint=False)

Cop = CT2D(
Expand All @@ -50,7 +67,10 @@ def test_CT2D(par):
par["ny"],
theta,
proj_geom_type=par["proj_geom_type"],
projector_type=par["projector_type"],
projector_type=par["projector_type"] if backend == "numpy" else "cuda",
source_origin_dist=par.get("source_origin_dist", None),
origin_detector_dist=par.get("origin_detector_dist", None),
engine="cpu" if backend == "numpy" else "cuda",
)
assert dottest(
Cop,
Expand Down
1 change: 1 addition & 0 deletions requirements-dev-gpu.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ cupy-cuda12x
torch
numba
sympy
astra-toolbox
matplotlib
ipython
pytest
Expand Down
Loading