diff --git a/README-pypi.md b/README-pypi.md
index be0de197..f6173669 100644
--- a/README-pypi.md
+++ b/README-pypi.md
@@ -93,6 +93,14 @@ Here you find a series of notebooks that give you an overview of the core featur
Using PyTorch gradients to fit a Gaussian generated by a DeepTrack2 pipeline.
+- DTGS171A **[Creating Custom Scatterers](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171A_custom_scatterers.ipynb)**
+
+ Creating custom scatterers of arbitrary shapes.
+
+- DTGS171B **[Creating Custom Scatterers: Bacteria](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171B_custom_scatterers_bacteria.ipynb)**
+
+ Creating custom scatterers in the shape of bacteria.
+
# Examples
These are examples of how DeepTrack2 can be used on real datasets:
diff --git a/README.md b/README.md
index 674d41d3..a9f507c2 100644
--- a/README.md
+++ b/README.md
@@ -97,6 +97,14 @@ Here you find a series of notebooks that give you an overview of the core featur
Using PyTorch gradients to fit a Gaussian generated by a DeepTrack2 pipeline.
+- DTGS171A **[Creating Custom Scatterers](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171A_custom_scatterers.ipynb)**
+
+ Creating custom scatterers of arbitrary shapes.
+
+- DTGS171B **[Creating Custom Scatterers: Bacteria](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171B_custom_scatterers_bacteria.ipynb)**
+
+ Creating custom scatterers in the shape of bacteria.
+
# Examples
These are examples of how DeepTrack2 can be used on real datasets:
diff --git a/deeptrack/features.py b/deeptrack/features.py
index 4bdfad38..be82b558 100644
--- a/deeptrack/features.py
+++ b/deeptrack/features.py
@@ -218,11 +218,11 @@
"OneOf",
"OneOfDict",
"LoadImage",
- "SampleToMasks", # TODO ***CM*** revise this after elimination of Image
+ "SampleToMasks",
"AsType",
"ChannelFirst2d",
- "Upscale", # TODO ***CM*** revise and check PyTorch afrer elimin. Image
- "NonOverlapping", # TODO ***CM*** revise + PyTorch afrer elimin. Image
+ "Upscale",
+ "NonOverlapping",
"Store",
"Squeeze",
"Unsqueeze",
@@ -7398,7 +7398,7 @@ class SampleToMasks(Feature):
Returns
-------
- Image or np.ndarray
+ np.ndarray
The final mask image with the specified number of layers.
Raises
@@ -7460,7 +7460,7 @@ class SampleToMasks(Feature):
def __init__(
self: Feature,
- transformation_function: Callable[[Image], Image],
+ transformation_function: Callable[[np.ndarray], np.ndarray, torch.Tensor],
number_of_masks: PropertyLike[int] = 1,
output_region: PropertyLike[tuple[int, int, int, int]] = None,
merge_method: PropertyLike[str | Callable | list[str | Callable]] = "add",
@@ -7493,17 +7493,17 @@ def __init__(
def get(
self: Feature,
- image: np.ndarray | Image,
- transformation_function: Callable[[Image], Image],
+ image: np.ndarray,
+ transformation_function: Callable[list[np.ndarray] | np.ndarray | torch.Tensor],
**kwargs: Any,
- ) -> Image:
+ ) -> np.ndarray:
"""Apply the transformation function to a single image.
Parameters
----------
- image: np.ndarray | Image
+ image: np.ndarray
The input image.
- transformation_function: Callable[[Image], Image]
+ transformation_function: Callable[[np.ndarray], np.ndarray]
Function to transform the image.
**kwargs: dict[str, Any]
Additional parameters.
@@ -7515,13 +7515,13 @@ def get(
"""
- return transformation_function(image)
+ return transformation_function(image.array)
def _process_and_get(
self: Feature,
- images: list[np.ndarray] | np.ndarray | list[Image] | Image,
+ images: list[np.ndarray] | np.ndarray | list[torch.Tensor] | torch.Tensor,
**kwargs: Any,
- ) -> Image | np.ndarray:
+ ) -> np.ndarray:
"""Process a list of images and generate a multi-layer mask.
Parameters
@@ -7540,50 +7540,44 @@ def _process_and_get(
"""
# Handle list of images.
- if isinstance(images, list) and len(images) != 1:
- list_of_labels = super()._process_and_get(images, **kwargs)
- if not self._wrap_array_with_image:
- for idx, (label, image) in enumerate(zip(list_of_labels,
- images)):
- list_of_labels[idx] = \
- Image(label, copy=False).merge_properties_from(image)
- else:
- if isinstance(images, list):
- images = images[0]
- list_of_labels = []
- for prop in images.properties:
-
- if "position" in prop:
+ # if isinstance(images, list) and len(images) != 1:
+ list_of_labels = super()._process_and_get(images, **kwargs)
- inp = Image(np.array(images))
- inp.append(prop)
- out = Image(self.get(inp, **kwargs))
- out.merge_properties_from(inp)
- list_of_labels.append(out)
+ from deeptrack.scatterers import ScatteredObject
+
+ for idx, (label, image) in enumerate(zip(list_of_labels, images)):
+ list_of_labels[idx] = \
+ ScatteredObject(array=label, properties=image.properties.copy(), role=image.role)
# Create an empty output image.
output_region = kwargs["output_region"]
- output = np.zeros(
+ output = xp.zeros(
(
output_region[2] - output_region[0],
output_region[3] - output_region[1],
kwargs["number_of_masks"],
- )
+ ),
+ dtype=list_of_labels[0].array.dtype,
)
from deeptrack.optics import _get_position
# Merge masks into the output.
- for label in list_of_labels:
- position = _get_position(label)
- p0 = np.round(position - output_region[0:2])
+ for volume in list_of_labels:
+ label = volume.array
+ position = _get_position(volume)
+
+ p0 = xp.round(position - xp.asarray(output_region[0:2]))
+ p0 = p0.astype(xp.int64)
- if np.any(p0 > output.shape[0:2]) or \
- np.any(p0 + label.shape[0:2] < 0):
+
+ if xp.any(p0 > xp.asarray(output.shape[:2])) or \
+ xp.any(p0 + xp.asarray(label.shape[:2]) < 0):
continue
- crop_x = int(-np.min([p0[0], 0]))
- crop_y = int(-np.min([p0[1], 0]))
+ crop_x = (-xp.minimum(p0[0], 0)).item()
+ crop_y = (-xp.minimum(p0[1], 0)).item()
+
crop_x_end = int(
label.shape[0]
- np.max([p0[0] + label.shape[0] - output.shape[0], 0])
@@ -7635,9 +7629,10 @@ def _process_and_get(
p0[0] : p0[0] + labelarg.shape[0],
p0[1] : p0[1] + labelarg.shape[1],
label_index,
- ] = (output_slice[..., label_index] != 0) | (
+ ] = xp.logical_or(
+ output_slice[..., label_index] != 0,
labelarg[..., label_index] != 0
- )
+ )
elif merge == "mul":
output[
@@ -7657,11 +7652,6 @@ def _process_and_get(
labelarg[..., label_index],
)
- if not self._wrap_array_with_image:
- return output
- output = Image(output)
- for label in list_of_labels:
- output.merge_properties_from(label)
return output
@@ -8082,23 +8072,23 @@ def get(
# Ensure factor is a tuple of three integers.
if np.size(factor) == 1:
- factor = (factor,) * 3
+ factor = (factor, factor, 1)
elif len(factor) != 3:
raise ValueError(
"Factor must be an integer or a tuple of three integers."
)
-
+
# Create a context for upscaling and perform computation.
ctx = create_context(None, None, None, *factor)
with units.context(ctx):
image = self.feature(image)
- # Downscale the result to the original resolution.
- import skimage.measure
+ # # Downscale the result to the original resolution.
+ # import skimage.measure
- image = skimage.measure.block_reduce(
- image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean
- )
+ # image = skimage.measure.block_reduce(
+ # image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean
+ # )
return image
@@ -8356,7 +8346,7 @@ def get(
list_of_volumes = [list_of_volumes]
for _ in range(max_iters):
-
+
list_of_volumes = [
self._resample_volume_position(volume)
for volume in list_of_volumes
@@ -8411,32 +8401,41 @@ def _check_non_overlapping(
- If bounding cubes overlap, voxel-level checks are performed.
"""
+ from deeptrack.scatterers import ScatteredObject
- from skimage.morphology import isotropic_erosion, isotropic_dilation
-
- from deeptrack.augmentations import CropTight, Pad
+ from deeptrack.augmentations import CropTight, Pad # these are not compatibles with torch backend
from deeptrack.optics import _get_position
min_distance = self.min_distance()
crop = CropTight()
+
+ new_volumes = []
- if min_distance < 0:
- list_of_volumes = [
- Image(
- crop(isotropic_erosion(volume != 0, -min_distance/2)),
- copy=False,
- ).merge_properties_from(volume)
- for volume in list_of_volumes
- ]
- else:
- pad = Pad(px = [int(np.ceil(min_distance/2))]*6, keep_size=True)
- list_of_volumes = [
- Image(
- crop(isotropic_dilation(pad(volume) != 0, min_distance/2)),
- copy=False,
- ).merge_properties_from(volume)
- for volume in list_of_volumes
- ]
+ for volume in list_of_volumes:
+ arr = volume.array
+ mask = arr != 0
+
+ if min_distance < 0:
+ new_arr = isotropic_erosion(mask, -min_distance / 2, backend=self.get_backend())
+ else:
+ pad = Pad(px=[int(np.ceil(min_distance / 2))] * 6, keep_size=True)
+ new_arr = isotropic_dilation(pad(mask) != 0 , min_distance / 2, backend=self.get_backend())
+ new_arr = crop(new_arr)
+
+ if self.get_backend() == "torch":
+ new_arr = new_arr.to(dtype=arr.dtype)
+ else:
+ new_arr = new_arr.astype(arr.dtype)
+
+ new_volume = ScatteredObject(
+ array=new_arr,
+ properties=volume.properties.copy(),
+ role=volume.role,
+ )
+
+ new_volumes.append(new_volume)
+
+ list_of_volumes = new_volumes
min_distance = 1
# The position of the top left corner of each volume (index (0, 0, 0)).
@@ -8472,10 +8471,10 @@ def _check_non_overlapping(
volume_bounding_cube[i], volume_bounding_cube[j]
)
overlapping_volume_1 = self._get_overlapping_volume(
- list_of_volumes[i], volume_bounding_cube[i], overlapping_cube
+ list_of_volumes[i].array, volume_bounding_cube[i], overlapping_cube
)
overlapping_volume_2 = self._get_overlapping_volume(
- list_of_volumes[j], volume_bounding_cube[j], overlapping_cube
+ list_of_volumes[j].array, volume_bounding_cube[j], overlapping_cube
)
# If either the overlapping regions are empty, the volumes do not
@@ -8710,8 +8709,12 @@ def _check_volumes_non_overlapping(
"""
# Get the positions of the non-zero voxels of each volume.
- positions_1 = np.argwhere(volume_1)
- positions_2 = np.argwhere(volume_2)
+ if self.get_backend() == "torch":
+ positions_1 = torch.nonzero(volume_1, as_tuple=False)
+ positions_2 = torch.nonzero(volume_2, as_tuple=False)
+ else:
+ positions_1 = np.argwhere(volume_1)
+ positions_2 = np.argwhere(volume_2)
# if positions_1.size == 0 or positions_2.size == 0:
# return True # If either volume is empty, they are "non-overlapping"
@@ -8732,9 +8735,14 @@ def _check_volumes_non_overlapping(
# Check that the non-zero voxels of the volumes are at least
# min_distance apart.
- return np.all(
- cdist(positions_1, positions_2) > min_distance
- )
+ if self.get_backend() == "torch":
+ dist = torch.cdist(
+ positions_1.float(),
+ positions_2.float(),
+ )
+ return bool((dist > min_distance).all())
+ else:
+ return np.all(cdist(positions_1, positions_2) > min_distance)
def _resample_volume_position(
self: NonOverlapping,
@@ -8750,7 +8758,7 @@ def _resample_volume_position(
Parameters
----------
- volume: np.ndarray or Image
+ volume: np.ndarray
The 3D volume whose position is to be resampled. The volume must
have a `properties` attribute containing dictionaries with
`position` and `_position_sampler` keys.
@@ -8771,12 +8779,12 @@ def _resample_volume_position(
"""
- for pdict in volume.properties:
- if "position" in pdict and "_position_sampler" in pdict:
- new_position = pdict["_position_sampler"]()
- if isinstance(new_position, Quantity):
- new_position = new_position.to("pixel").magnitude
- pdict["position"] = new_position
+ pdict = volume.properties
+ if "position" in pdict and "_position_sampler" in pdict:
+ new_position = pdict["_position_sampler"]()
+ if isinstance(new_position, Quantity):
+ new_position = new_position.to("pixel").magnitude
+ pdict["position"] = new_position
return volume
@@ -9594,3 +9602,73 @@ def get(
res = res[0]
return res
+
+### Move to math?
+def isotropic_dilation(
+ mask,
+ radius: float,
+ *,
+ backend: str,
+ device=None,
+ dtype=None,
+):
+ if radius <= 0:
+ return mask
+
+ if backend == "numpy":
+ from skimage.morphology import isotropic_dilation
+ return isotropic_dilation(mask, radius)
+
+ # torch backend
+ import torch
+
+ r = int(np.ceil(radius))
+ kernel = torch.ones(
+ (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1),
+ device=device or mask.device,
+ dtype=dtype or torch.float32,
+ )
+
+ x = mask.to(dtype=kernel.dtype)[None, None]
+ y = torch.nn.functional.conv3d(
+ x,
+ kernel,
+ padding=r,
+ )
+
+ return (y[0, 0] > 0)
+
+
+def isotropic_erosion(
+ mask,
+ radius: float,
+ *,
+ backend: str,
+ device=None,
+ dtype=None,
+):
+ if radius <= 0:
+ return mask
+
+ if backend == "numpy":
+ from skimage.morphology import isotropic_erosion
+ return isotropic_erosion(mask, radius)
+
+ import torch
+
+ r = int(np.ceil(radius))
+ kernel = torch.ones(
+ (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1),
+ device=device or mask.device,
+ dtype=dtype or torch.float32,
+ )
+
+ x = mask.to(dtype=kernel.dtype)[None, None]
+ y = torch.nn.functional.conv3d(
+ x,
+ kernel,
+ padding=r,
+ )
+
+ required = kernel.numel()
+ return (y[0, 0] >= required)
diff --git a/deeptrack/holography.py b/deeptrack/holography.py
index 380969cf..141cc540 100644
--- a/deeptrack/holography.py
+++ b/deeptrack/holography.py
@@ -101,7 +101,7 @@ def get_propagation_matrix(
def get_propagation_matrix(
shape: tuple[int, int],
to_z: float,
- pixel_size: float,
+ pixel_size: float | tuple[float, float],
wavelength: float,
dx: float = 0,
dy: float = 0
@@ -118,8 +118,8 @@ def get_propagation_matrix(
The dimensions of the optical field (height, width).
to_z: float
Propagation distance along the z-axis.
- pixel_size: float
- The physical size of each pixel in the optical field.
+ pixel_size: float | tuple[float, float]
+ Physical pixel size. If scalar, isotropic pixels are assumed.
wavelength: float
The wavelength of the optical field.
dx: float, optional
@@ -140,14 +140,22 @@ def get_propagation_matrix(
"""
+ if pixel_size is None:
+ pixel_size = get_active_voxel_size()
+
+ if np.isscalar(pixel_size):
+ pixel_size = (pixel_size, pixel_size)
+
+ px, py = pixel_size
+
k = 2 * np.pi / wavelength
yr, xr, *_ = shape
x = np.arange(0, xr, 1) - xr / 2 + (xr % 2) / 2
y = np.arange(0, yr, 1) - yr / 2 + (yr % 2) / 2
- x = 2 * np.pi / pixel_size * x / xr
- y = 2 * np.pi / pixel_size * y / yr
+ x = 2 * np.pi / px * x / xr
+ y = 2 * np.pi / py * y / yr
KXk, KYk = np.meshgrid(x, y)
KXk = KXk.astype(complex)
diff --git a/deeptrack/optics.py b/deeptrack/optics.py
index 5149bdae..d0b5d66b 100644
--- a/deeptrack/optics.py
+++ b/deeptrack/optics.py
@@ -137,11 +137,13 @@ def _pad_volume(
from __future__ import annotations
from pint import Quantity
-from typing import Any
+from typing import Any, TYPE_CHECKING
import warnings
import numpy as np
-from scipy.ndimage import convolve
+from scipy.ndimage import convolve # might be removed later
+import torch
+import torch.nn.functional as F
from deeptrack.backend.units import (
ConversionTable,
@@ -152,12 +154,22 @@ def _pad_volume(
from deeptrack.math import AveragePooling
from deeptrack.features import propagate_data_to_dependencies
from deeptrack.features import DummyFeature, Feature, StructuralFeature
-from deeptrack.image import Image, pad_image_to_fft
+from deeptrack.image import pad_image_to_fft
from deeptrack.types import ArrayLike, PropertyLike
from deeptrack import image
from deeptrack import units_registry as u
+from deeptrack import TORCH_AVAILABLE, image
+from deeptrack.backend import xp
+from deeptrack.scatterers import ScatteredObject
+
+if TORCH_AVAILABLE:
+ import torch
+
+if TYPE_CHECKING:
+ import torch
+
#TODO ***??*** revise Microscope - torch, typing, docstring, unit test
class Microscope(StructuralFeature):
@@ -186,7 +198,7 @@ class Microscope(StructuralFeature):
Methods
-------
- `get(image: Image or None, **kwargs: Any) -> Image`
+ `get(image: np.ndarray or None, **kwargs: Any) -> np.ndarray`
Simulates the imaging process using the defined optical system and
returns the resulting image.
@@ -238,13 +250,13 @@ def __init__(
self._sample = self.add_feature(sample)
self._objective = self.add_feature(objective)
- self._sample.store_properties()
+ # self._sample.store_properties()
def get(
self: Microscope,
- image: Image | None,
+ image: np.ndarray | None,
**kwargs: Any,
- ) -> Image:
+ ) -> np.ndarray:
"""Generate an image of the sample using the defined optical system.
This method processes the sample through the optical system to
@@ -252,14 +264,14 @@ def get(
Parameters
----------
- image: Image | None
+ image: np.ndarray | None
The input image to be processed. If None, a new image is created.
**kwargs: Any
Additional parameters for the imaging process.
Returns
-------
- Image: Image
+ image: np.ndarray
The processed image after applying the optical system.
Examples
@@ -279,17 +291,18 @@ def get(
# Grab properties from the objective to pass to the sample
additional_sample_kwargs = self._objective.properties()
+ contrast_type = getattr(self._objective, "contrast_type", None)
+ if contrast_type is None:
+ raise RuntimeError(
+ f"{self._objective.__class__.__name__} must define `contrast_type` "
+ "(e.g. 'intensity' or 'refractive_index')."
+ )
- # Calculate required output image for the given upscale
- # This way of providing the upscale will be deprecated in the future
- # in favor of dt.Upscale().
- _upscale_given_by_optics = additional_sample_kwargs["upscale"]
- if np.array(_upscale_given_by_optics).size == 1:
- _upscale_given_by_optics = (_upscale_given_by_optics,) * 3
+ additional_sample_kwargs["contrast_type"] = contrast_type
with u.context(
create_context(
- *additional_sample_kwargs["voxel_size"], *_upscale_given_by_optics
+ *additional_sample_kwargs["voxel_size"]#, *_upscale_given_by_optics
)
):
@@ -329,26 +342,21 @@ def get(
volume_samples = [
scatterer
for scatterer in list_of_scatterers
- if not scatterer.get_property("is_field", default=False)
+ if scatterer.role == "volume"
]
# All scatterers that are defined as fields.
field_samples = [
scatterer
for scatterer in list_of_scatterers
- if scatterer.get_property("is_field", default=False)
+ if scatterer.role == "field"
]
-
+
# Merge all volumes into a single volume.
sample_volume, limits = _create_volume(
volume_samples,
**additional_sample_kwargs,
)
- sample_volume = Image(sample_volume)
-
- # Merge all properties into the volume.
- for scatterer in volume_samples + field_samples:
- sample_volume.merge_properties_from(scatterer)
# Let the objective know about the limits of the volume and all the fields.
propagate_data_to_dependencies(
@@ -359,33 +367,17 @@ def get(
imaged_sample = self._objective.resolve(sample_volume)
- # Upscale given by the optics needs to be handled separately.
- if _upscale_given_by_optics != (1, 1, 1):
- imaged_sample = AveragePooling((*_upscale_given_by_optics[:2], 1))(
- imaged_sample
- )
-
- # Merge with input
- if not image:
- if not self._wrap_array_with_image and isinstance(imaged_sample, Image):
- return imaged_sample._value
+ # Handling upscale from dt.Upscale() here to eliminate Image
+ # wrapping issues.
+ if np.any(np.array(upscale) != 1):
+ ux, uy = upscale[:2]
+ if contrast_type == "intensity":
+ print("Using sum pooling for intensity downscaling.")
+ imaged_sample = SumPoolingCM((ux, uy, 1))(imaged_sample)
else:
- return imaged_sample
-
- if not isinstance(image, list):
- image = [image]
- for i in range(len(image)):
- image[i].merge_properties_from(imaged_sample)
- return image
-
- # def _no_wrap_format_input(self, *args, **kwargs) -> list:
- # return self._image_wrapped_format_input(*args, **kwargs)
+ imaged_sample = AveragePoolingCM((ux, uy, 1))(imaged_sample)
- # def _no_wrap_process_and_get(self, *args, **feature_input) -> list:
- # return self._image_wrapped_process_and_get(*args, **feature_input)
-
- # def _no_wrap_process_output(self, *args, **feature_input):
- # return self._image_wrapped_process_output(*args, **feature_input)
+ return imaged_sample
#TODO ***??*** revise Optics - torch, typing, docstring, unit test
@@ -757,19 +749,18 @@ def _pupil(
W, H = np.meshgrid(y, x)
RHO = (W ** 2 + H ** 2).astype(complex)
- pupil_function = Image((RHO < 1) + 0.0j, copy=False)
+ pupil_function = (RHO < 1) + 0.0j
# Defocus
- z_shift = Image(
+ z_shift = (
2
* np.pi
* refractive_index_medium
/ wavelength
* voxel_size[2]
- * np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO),
- copy=False,
+ * np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO)
)
- z_shift._value[z_shift._value.imag != 0] = 0
+ z_shift[z_shift.imag != 0] = 0
try:
z_shift = np.nan_to_num(z_shift, False, 0, 0, 0)
@@ -1007,7 +998,7 @@ class Fluorescence(Optics):
Methods
-------
- `get(illuminated_volume: array_like[complex], limits: array_like[int, int], **kwargs: Any) -> Image`
+ `get(illuminated_volume: array_like[complex], limits: array_like[int, int], **kwargs: Any) -> np.ndarray`
Simulates the imaging process using a fluorescence microscope.
Examples
@@ -1023,13 +1014,14 @@ class Fluorescence(Optics):
1.4
"""
+ contrast_type = "intensity"
def get(
self: Fluorescence,
illuminated_volume: ArrayLike[complex],
limits: ArrayLike[int],
**kwargs: Any,
- ) -> Image:
+ ) -> ArrayLike[complex]:
"""Simulates the imaging process using a fluorescence microscope.
This method convolves the 3D illuminated volume with a pupil function
@@ -1048,7 +1040,7 @@ def get(
Returns
-------
- Image: Image
+ image: np.ndarray
A 2D image object representing the fluorescence projection.
Notes
@@ -1066,7 +1058,7 @@ def get(
>>> optics = dt.Fluorescence(
... NA=1.4, wavelength=0.52e-6, magnification=60,
... )
- >>> volume = dt.Image(np.ones((128, 128, 10), dtype=complex))
+ >>> volume = np.ones((128, 128, 10), dtype=complex)
>>> limits = np.array([[0, 128], [0, 128], [0, 10]])
>>> properties = optics.properties()
>>> filtered_properties = {
@@ -1118,9 +1110,7 @@ def get(
]
z_limits = limits[2, :]
- output_image = Image(
- np.zeros((*padded_volume.shape[0:2], 1)), copy=False
- )
+ output_image = np.zeros((*padded_volume.shape[0:2], 1))
index_iterator = range(padded_volume.shape[2])
@@ -1156,12 +1146,12 @@ def get(
field = np.fft.ifft2(convolved_fourier_field)
# # Discard remaining imaginary part (should be 0 up to rounding error)
field = np.real(field)
- output_image._value[:, :, 0] += field[
+ output_image[:, :, 0] += field[
: padded_volume.shape[0], : padded_volume.shape[1]
]
output_image = output_image[pad[0] : -pad[2], pad[1] : -pad[3]]
- output_image.properties = illuminated_volume.properties + pupils.properties
+ # output_image.properties = illuminated_volume.properties + pupils.properties
return output_image
@@ -1234,7 +1224,7 @@ class Brightfield(Optics):
-------
`get(illuminated_volume: array_like[complex],
limits: array_like[int, int], fields: array_like[complex],
- **kwargs: Any) -> Image`
+ **kwargs: Any) -> np.ndarray`
Simulates imaging with brightfield microscopy.
@@ -1250,6 +1240,8 @@ class Brightfield(Optics):
"""
+ contrast_type = "refractive_index"
+
__conversion_table__ = ConversionTable(
working_distance=(u.meter, u.meter),
)
@@ -1260,7 +1252,7 @@ def get(
limits: ArrayLike[int],
fields: ArrayLike[complex],
**kwargs: Any,
- ) -> Image:
+ ) -> np.ndarray:
"""Simulates imaging with brightfield microscopy.
This method propagates light through the given volume, applying
@@ -1285,7 +1277,7 @@ def get(
Returns
-------
- Image: Image
+ image: np.ndarray
Processed image after simulating the brightfield imaging process.
Examples
@@ -1300,7 +1292,7 @@ def get(
... wavelength=0.52e-6,
... magnification=60,
... )
- >>> volume = dt.Image(np.ones((128, 128, 10), dtype=complex))
+ >>> volume = np.ones((128, 128, 10), dtype=complex)
>>> limits = np.array([[0, 128], [0, 128], [0, 10]])
>>> fields = np.array([np.ones((162, 162), dtype=complex)])
>>> properties = optics.properties()
@@ -1345,7 +1337,7 @@ def get(
if output_region[3] is None
else int(output_region[3] - limits[1, 0] + pad[3])
)
-
+
padded_volume = padded_volume[
output_region[0] : output_region[2],
output_region[1] : output_region[3],
@@ -1353,9 +1345,7 @@ def get(
]
z_limits = limits[2, :]
- output_image = Image(
- np.zeros((*padded_volume.shape[0:2], 1))
- )
+ output_image = np.zeros((*padded_volume.shape[0:2], 1))
index_iterator = range(padded_volume.shape[2])
z_iterator = np.linspace(
@@ -1414,7 +1404,25 @@ def get(
light_in_focus = light_in * shifted_pupil
if len(fields) > 0:
- field = np.sum(fields, axis=0)
+ # field = np.sum(fields, axis=0)
+ field_arrays = []
+
+ for fs in fields:
+ # fs is a ScatteredField
+ arr = fs.array
+
+ # Enforce (H, W, 1) shape
+ if arr.ndim == 2:
+ arr = arr[..., None]
+
+ if arr.ndim != 3 or arr.shape[-1] != 1:
+ raise ValueError(
+ f"Expected field of shape (H, W, 1), got {arr.shape}"
+ )
+
+ field_arrays.append(arr)
+
+ field = np.sum(field_arrays, axis=0)
light_in_focus += field[..., 0]
shifted_pupil = np.fft.fftshift(pupils[-1])
light_in_focus = light_in_focus * shifted_pupil
@@ -1426,7 +1434,7 @@ def get(
: padded_volume.shape[0], : padded_volume.shape[1]
]
output_image = np.expand_dims(output_image, axis=-1)
- output_image = Image(output_image[pad[0] : -pad[2], pad[1] : -pad[3]])
+ output_image = output_image[pad[0] : -pad[2], pad[1] : -pad[3]]
if not kwargs.get("return_field", False):
output_image = np.square(np.abs(output_image))
@@ -1436,7 +1444,7 @@ def get(
# output_image = output_image * np.exp(1j * -np.pi / 4)
# output_image = output_image + 1
- output_image.properties = illuminated_volume.properties
+ # output_image.properties = illuminated_volume.properties
return output_image
@@ -1631,7 +1639,7 @@ def get(
limits: ArrayLike[int],
fields: ArrayLike[complex],
**kwargs: Any,
- ) -> Image:
+ ) -> np.ndarray:
"""Retrieve the darkfield image of the illuminated volume.
Parameters
@@ -1802,7 +1810,7 @@ def get(
#TODO ***??*** revise _get_position - torch, typing, docstring, unit test
def _get_position(
- image: Image,
+ scatterer: ScatteredObject,
mode: str = "corner",
return_z: bool = False,
) -> np.ndarray:
@@ -1826,26 +1834,23 @@ def _get_position(
num_outputs = 2 + return_z
- if mode == "corner" and image.size > 0:
+ if mode == "corner" and scatterer.array.size > 0:
import scipy.ndimage
- image = image.to_numpy()
-
- shift = scipy.ndimage.center_of_mass(np.abs(image))
+ shift = scipy.ndimage.center_of_mass(np.abs(scatterer.array))
if np.isnan(shift).any():
- shift = np.array(image.shape) / 2
+ shift = np.array(scatterer.array.shape) / 2
else:
shift = np.zeros((num_outputs))
- position = np.array(image.get_property("position", default=None))
+ position = np.array(scatterer.get_property("position", default=None))
if position is None:
return position
scale = np.array(get_active_scale())
-
if len(position) == 3:
position = position * scale + 0.5 * (scale - 1)
if return_z:
@@ -1856,7 +1861,7 @@ def _get_position(
elif len(position) == 2:
if return_z:
outp = (
- np.array([position[0], position[1], image.get_property("z", default=0)])
+ np.array([position[0], position[1], scatterer.get_property("z", default=0)])
* scale
- shift
+ 0.5 * (scale - 1)
@@ -1868,6 +1873,58 @@ def _get_position(
return position
+def _bilinear_interpolate_numpy(
+ scatterer: np.ndarray, x_off: float, y_off: float
+) -> np.ndarray:
+ """Apply bilinear subpixel interpolation in the x–y plane (NumPy)."""
+ kernel = np.array(
+ [
+ [0.0, 0.0, 0.0],
+ [0.0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off],
+ [0.0, x_off * (1 - y_off), x_off * y_off],
+ ]
+ )
+ out = np.zeros_like(scatterer)
+ for z in range(scatterer.shape[2]):
+ if np.iscomplexobj(scatterer):
+ out[:, :, z] = (
+ convolve(np.real(scatterer[:, :, z]), kernel, mode="constant")
+ + 1j
+ * convolve(np.imag(scatterer[:, :, z]), kernel, mode="constant")
+ )
+ else:
+ out[:, :, z] = convolve(scatterer[:, :, z], kernel, mode="constant")
+ return out
+
+
+def _bilinear_interpolate_torch(
+ scatterer: torch.Tensor, x_off: float, y_off: float
+) -> torch.Tensor:
+ """Apply bilinear subpixel interpolation in the x–y plane (Torch).
+
+ Uses grid_sample for autograd-friendly interpolation.
+ """
+ H, W, D = scatterer.shape
+
+ # Normalized shifts in [-1,1]
+ x_shift = 2 * x_off / (W - 1)
+ y_shift = 2 * y_off / (H - 1)
+
+ yy, xx = torch.meshgrid(
+ torch.linspace(-1, 1, H, device=scatterer.device, dtype=scatterer.dtype),
+ torch.linspace(-1, 1, W, device=scatterer.device, dtype=scatterer.dtype),
+ indexing="ij",
+ )
+ grid = torch.stack((xx + x_shift, yy + y_shift), dim=-1) # (H,W,2)
+ grid = grid.unsqueeze(0).repeat(D, 1, 1, 1) # (D,H,W,2)
+
+ inp = scatterer.permute(2, 0, 1).unsqueeze(1) # (D,1,H,W)
+
+ out = F.grid_sample(inp, grid, mode="bilinear",
+ padding_mode="zeros", align_corners=True)
+ return out.squeeze(1).permute(1, 2, 0) # (H,W,D)
+
+
#TODO ***??*** revise _create_volume - torch, typing, docstring, unit test
def _create_volume(
list_of_scatterers: list,
@@ -1903,6 +1960,12 @@ def _create_volume(
Spatial limits of the volume.
"""
+ contrast_type = kwargs.get("contrast_type", None)
+ if contrast_type is None:
+ raise RuntimeError(
+ "_create_volume requires a contrast_type "
+ "(e.g. 'intensity' or 'refractive_index')"
+ )
if not isinstance(list_of_scatterers, list):
list_of_scatterers = [list_of_scatterers]
@@ -1927,24 +1990,28 @@ def _create_volume(
# This accounts for upscale doing AveragePool instead of SumPool. This is
# a bit of a hack, but it works for now.
- fudge_factor = scale[0] * scale[1] / scale[2]
+ # fudge_factor = scale[0] * scale[1] / scale[2]
for scatterer in list_of_scatterers:
-
position = _get_position(scatterer, mode="corner", return_z=True)
- if scatterer.get_property("intensity", None) is not None:
- intensity = scatterer.get_property("intensity")
- scatterer_value = intensity * fudge_factor
- elif scatterer.get_property("refractive_index", None) is not None:
- refractive_index = scatterer.get_property("refractive_index")
- scatterer_value = (
- refractive_index - refractive_index_medium
- )
+ if contrast_type == "intensity":
+ value = scatterer.get_property("intensity", None)
+ if value is None:
+ raise ValueError("Scatterer has no intensity.")
+ scatterer_value = value
+
+ elif contrast_type == "refractive_index":
+ ri = scatterer.get_property("refractive_index", None)
+ if ri is None:
+ raise ValueError("Scatterer has no refractive_index.")
+ scatterer_value = ri - refractive_index_medium
+
else:
- scatterer_value = scatterer.get_property("value")
+ raise RuntimeError(f"Unknown contrast_type: {contrast_type}")
- scatterer = scatterer * scatterer_value
+ # Scale the array accordingly
+ scatterer.array = scatterer.array * scatterer_value
if limits is None:
limits = np.zeros((3, 2), dtype=np.int32)
@@ -1952,26 +2019,25 @@ def _create_volume(
limits[:, 1] = np.floor(position).astype(np.int32) + 1
if (
- position[0] + scatterer.shape[0] < OR[0]
+ position[0] + scatterer.array.shape[0] < OR[0]
or position[0] > OR[2]
- or position[1] + scatterer.shape[1] < OR[1]
+ or position[1] + scatterer.array.shape[1] < OR[1]
or position[1] > OR[3]
):
continue
- padded_scatterer = Image(
- np.pad(
- scatterer,
+ # Pad scatterer to avoid edge effects during interpolation
+ padded_scatterer_arr = np.pad( #Use Pad instead and make it torch-compatible?
+ scatterer.array,
[(2, 2), (2, 2), (2, 2)],
"constant",
constant_values=0,
)
- )
- padded_scatterer.merge_properties_from(scatterer)
-
- scatterer = padded_scatterer
- position = _get_position(scatterer, mode="corner", return_z=True)
- shape = np.array(scatterer.shape)
+ padded_scatterer = ScatteredObject(
+ array=padded_scatterer_arr, properties=scatterer.properties.copy(), role=scatterer.role,
+ )
+ position = _get_position(padded_scatterer, mode="corner", return_z=True)
+ shape = np.array(padded_scatterer.array.shape)
if position is None:
RuntimeWarning(
@@ -1980,36 +2046,20 @@ def _create_volume(
)
continue
- splined_scatterer = np.zeros_like(scatterer)
-
x_off = position[0] - np.floor(position[0])
y_off = position[1] - np.floor(position[1])
- kernel = np.array(
- [
- [0, 0, 0],
- [0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off],
- [0, x_off * (1 - y_off), x_off * y_off],
- ]
- )
-
- for z in range(scatterer.shape[2]):
- if splined_scatterer.dtype == complex:
- splined_scatterer[:, :, z] = (
- convolve(
- np.real(scatterer[:, :, z]), kernel, mode="constant"
- )
- + convolve(
- np.imag(scatterer[:, :, z]), kernel, mode="constant"
- )
- * 1j
- )
- else:
- splined_scatterer[:, :, z] = convolve(
- scatterer[:, :, z], kernel, mode="constant"
- )
+
+ if isinstance(padded_scatterer.array, np.ndarray): # get_backend is a method of Features and not exposed
+ splined_scatterer = _bilinear_interpolate_numpy(padded_scatterer.array, x_off, y_off)
+ elif isinstance(padded_scatterer.array, torch.Tensor):
+ splined_scatterer = _bilinear_interpolate_torch(padded_scatterer.array, x_off, y_off)
+ else:
+ raise TypeError(
+ f"Unsupported array type {type(padded_scatterer.array)}. "
+ "Expected np.ndarray or torch.Tensor."
+ )
- scatterer = splined_scatterer
position = np.floor(position)
new_limits = np.zeros(limits.shape, dtype=np.int32)
for i in range(3):
@@ -2039,6 +2089,7 @@ def _create_volume(
within_volume_position = position - limits[:, 0]
# NOTE: Maybe shouldn't be additive.
+ # give options: sum default, but also mean, max, min, or
volume[
int(within_volume_position[0]) :
int(within_volume_position[0] + shape[0]),
@@ -2048,5 +2099,72 @@ def _create_volume(
int(within_volume_position[2]) :
int(within_volume_position[2] + shape[2]),
- ] += scatterer
+ ] += splined_scatterer
return volume, limits
+
+# this should be moved to math
+class _CenteredPoolingBase:
+ def __init__(self, pool_size: tuple[int, int, int]):
+ px, py, pz = pool_size
+ if pz != 1:
+ raise ValueError("Only pz=1 supported.")
+ self.px = int(px)
+ self.py = int(py)
+
+ def _crop_center(self, array):
+ H, W = array.shape[:2]
+ px, py = self.px, self.py
+
+ crop_h = (H // px) * px
+ crop_w = (W // py) * py
+
+ off_h = (H - crop_h) // 2
+ off_w = (W - crop_w) // 2
+
+ return array[off_h:off_h+crop_h, off_w:off_w+crop_w, ...]
+
+ def _pool_numpy(self, array, func):
+ import skimage.measure
+ array = self._crop_center(array)
+ pool_shape = (self.px, self.py) + (1,) * (array.ndim - 2)
+ return skimage.measure.block_reduce(array, pool_shape, func)
+
+ def _pool_torch(self, array, sum_pool=False):
+ px, py = self.px, self.py
+ array = self._crop_center(array)
+
+ extra = array.shape[2:]
+ C = int(np.prod(extra)) if extra else 1
+ x = array.reshape(1, C, array.shape[0], array.shape[1])
+
+ pooled = torch.nn.functional.avg_pool2d(
+ x, kernel_size=(px, py), stride=(px, py)
+ )
+ if sum_pool:
+ pooled = pooled * (px * py)
+
+ return pooled.reshape(
+ (pooled.shape[2], pooled.shape[3]) + extra
+ )
+
+class AveragePoolingCM(_CenteredPoolingBase):
+ """Center-preserving average pooling (intensive quantities)."""
+
+ def __call__(self, array):
+ if isinstance(array, np.ndarray):
+ return self._pool_numpy(array, np.mean)
+ elif TORCH_AVAILABLE and isinstance(array, torch.Tensor):
+ return self._pool_torch(array, sum_pool=False)
+ else:
+ raise TypeError("Unsupported array type.")
+
+class SumPoolingCM(_CenteredPoolingBase):
+ """Center-preserving sum pooling (extensive quantities)."""
+
+ def __call__(self, array):
+ if isinstance(array, np.ndarray):
+ return self._pool_numpy(array, np.sum)
+ elif TORCH_AVAILABLE and isinstance(array, torch.Tensor):
+ return self._pool_torch(array, sum_pool=True)
+ else:
+ raise TypeError("Unsupported array type.")
\ No newline at end of file
diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py
index 04a7c5ea..b942888c 100644
--- a/deeptrack/scatterers.py
+++ b/deeptrack/scatterers.py
@@ -166,6 +166,7 @@
import numpy as np
from numpy.typing import NDArray
from pint import Quantity
+from dataclasses import dataclass, field
from deeptrack.holography import get_propagation_matrix
from deeptrack.backend.units import (
@@ -175,7 +176,7 @@
)
from deeptrack.backend import mie
from deeptrack.features import Feature, MERGE_STRATEGY_APPEND
-from deeptrack.image import pad_image_to_fft, Image
+from deeptrack.image import pad_image_to_fft
from deeptrack.types import ArrayLike
from deeptrack import units_registry as u
@@ -258,7 +259,7 @@ def __init__(
**kwargs,
) -> None:
# Ignore warning to help with comparison with arrays.
- if upsample is not 1: # noqa: F632
+ if upsample != 1: # noqa: F632
warnings.warn(
f"Setting upsample != 1 is deprecated. "
f"Please, instead use dt.Upscale(f, factor={upsample})"
@@ -296,7 +297,7 @@ def _process_and_get(
upsample_axes=None,
crop_empty=True,
**kwargs
- ) -> list[Image] | list[np.ndarray]:
+ ) -> list[np.ndarray]:
# Post processes the created object to handle upsampling,
# as well as cropping empty slices.
if not self._processed_properties:
@@ -310,7 +311,7 @@ def _process_and_get(
voxel_size = get_active_voxel_size()
# Calls parent _process_and_get.
- new_image = super()._process_and_get(
+ new_image = super(Scatterer, self)._process_and_get(
*args,
voxel_size=voxel_size,
upsample=upsample,
@@ -333,28 +334,33 @@ def _process_and_get(
new_image = new_image[:, ~np.all(new_image == 0, axis=(0, 2))]
new_image = new_image[:, :, ~np.all(new_image == 0, axis=(0, 1))]
- return [Image(new_image)]
+ # # Copy properties
+ # props = kwargs.copy()
+ return [self._wrap_output(new_image, kwargs)]
+
+ def _wrap_output(self, array, props) -> ScatteredObject:
+ # """Must be overridden in subclasses to wrap output correctly."""
+ # raise NotImplementedError
+ return ScatteredObject(
+ array=array,
+ properties=props.copy(),
+ role = self.role,
+ )
- def _no_wrap_format_input(
- self,
- *args,
- **kwargs
- ) -> list:
- return self._image_wrapped_format_input(*args, **kwargs)
+# class VolumeScatterer(Scatterer):
+# """Abstract scatterer producing ScatteredVolume outputs."""
+# def _wrap_output(self, array, props) -> ScatteredVolume:
+# return ScatteredVolume(
+# array=array,
+# properties=props.copy(),
+# )
- def _no_wrap_process_and_get(
- self,
- *args,
- **feature_input
- ) -> list:
- return self._image_wrapped_process_and_get(*args, **feature_input)
-
- def _no_wrap_process_output(
- self,
- *args,
- **feature_input
- ) -> list:
- return self._image_wrapped_process_output(*args, **feature_input)
+# class FieldScatterer(Scatterer):
+# def _wrap_output(self, array, props) -> ScatteredField:
+# return ScatteredField(
+# array=array,
+# properties=props.copy(),
+# )
#TODO ***??*** revise PointParticle - torch, typing, docstring, unit test
@@ -381,6 +387,7 @@ class PointParticle(Scatterer):
for `Brightfield` and `intensity` for `Fluorescence`).
"""
+ role = "volume"
def __init__(
self: PointParticle,
@@ -394,7 +401,7 @@ def __init__(
def get(
self: PointParticle,
- image: Image | np.ndarray,
+ image: np.ndarray,
**kwarg: Any,
) -> NDArray[Any] | torch.Tensor:
"""Evaluate and return the scatterer volume."""
@@ -440,6 +447,7 @@ class Ellipse(Scatterer):
before rotation.
"""
+ role = "volume"
__conversion_table__ = ConversionTable(
radius=(u.meter, u.meter),
@@ -545,6 +553,7 @@ class Sphere(Scatterer):
Upsamples the calculations of the pixel occupancy fraction.
"""
+ role = "volume"
__conversion_table__ = ConversionTable(
radius=(u.meter, u.meter),
@@ -559,7 +568,7 @@ def __init__(
def get(
self,
- image: Image | np.ndarray,
+ image: np.ndarray,
radius: float,
voxel_size: float,
**kwargs
@@ -620,6 +629,8 @@ class Ellipsoid(Scatterer):
"""
+ role = "volume"
+
__conversion_table__ = ConversionTable(
radius=(u.meter, u.meter),
rotation=(u.radian, u.radian),
@@ -694,7 +705,7 @@ def _process_properties(
def get(
self,
- image: Image | np.ndarray,
+ image: np.ndarray,
radius: float,
rotation: ArrayLike[float] | float,
voxel_size: float,
@@ -826,6 +837,8 @@ class MieScatterer(Scatterer):
"""
+ role = "field"
+
__conversion_table__ = ConversionTable(
radius=(u.meter, u.meter),
polarization_angle=(u.radian, u.radian),
@@ -856,6 +869,7 @@ def __init__(
illumination_angle: float=0,
amp_factor: float=1,
phase_shift_correction: bool=False,
+ # pupil: ArrayLike=[], # Daniel
**kwargs,
) -> None:
if polarization_angle is not None:
@@ -864,11 +878,11 @@ def __init__(
"Please use input_polarization instead"
)
input_polarization = polarization_angle
- kwargs.pop("is_field", None)
+ # kwargs.pop("is_field", None) # remove
kwargs.pop("crop_empty", None)
super().__init__(
- is_field=True,
+ is_field=True, # remove
crop_empty=False,
L=L,
offset_z=offset_z,
@@ -889,6 +903,7 @@ def __init__(
illumination_angle=illumination_angle,
amp_factor=amp_factor,
phase_shift_correction=phase_shift_correction,
+ # pupil=pupil, # Daniel
**kwargs,
)
@@ -1014,7 +1029,8 @@ def get_plane_in_polar_coords(
shape: int,
voxel_size: ArrayLike[float],
plane_position: float,
- illumination_angle: float
+ illumination_angle: float,
+ # k: float, # Daniel
) -> tuple[float, float, float, float]:
"""Computes the coordinates of the plane in polar form."""
@@ -1027,15 +1043,24 @@ def get_plane_in_polar_coords(
R2_squared = X ** 2 + Y ** 2
R3 = np.sqrt(R2_squared + Z ** 2) # Might be +z instead of -z.
+
+ # # DANIEL
+ # Q = np.sqrt(R2_squared)/voxel_size[0]**2*2*np.pi/shape[0]
+ # # is dimensionally ok?
+ # sin_theta=Q/(k)
+ # pupil_mask=sin_theta<1
+ # cos_theta=np.zeros(sin_theta.shape)
+ # cos_theta[pupil_mask]=np.sqrt(1-sin_theta[pupil_mask]**2)
# Fet the angles.
cos_theta = Z / R3
+
illumination_cos_theta = (
np.cos(np.arccos(cos_theta) + illumination_angle)
)
phi = np.arctan2(Y, X)
- return R3, cos_theta, illumination_cos_theta, phi
+ return R3, cos_theta, illumination_cos_theta, phi#, pupil_mask # Daniel
def get(
self,
@@ -1060,6 +1085,7 @@ def get(
illumination_angle: float,
amp_factor: float,
phase_shift_correction: bool,
+ # pupil: ArrayLike, # Daniel
**kwargs,
) -> ArrayLike[float]:
"""Abstract method to initialize the Mie scatterer"""
@@ -1067,8 +1093,9 @@ def get(
# Get size of the output.
xSize, ySize = self.get_xy_size(output_region, padding)
voxel_size = get_active_voxel_size()
+ scale = get_active_scale()
arr = pad_image_to_fft(np.zeros((xSize, ySize))).astype(complex)
- position = np.array(position) * voxel_size[: len(position)]
+ position = np.array(position) * scale[: len(position)] * voxel_size[: len(position)]
pupil_physical_size = working_distance * np.tan(collection_angle) * 2
@@ -1076,7 +1103,11 @@ def get(
ratio = offset_z / (working_distance - z)
- # Position of pbjective relative particle.
+ # Wave vector.
+ k = 2 * np.pi / wavelength * refractive_index_medium
+
+
+ # Position of objective relative particle.
relative_position = np.array(
(
position_objective[0] - position[0],
@@ -1085,12 +1116,13 @@ def get(
)
)
- # Get field evaluation plane at offset_z.
+ # Get field evaluation plane at offset_z. # , pupil_mask # Daniel
R3_field, cos_theta_field, illumination_angle_field, phi_field =\
self.get_plane_in_polar_coords(
arr.shape, voxel_size,
relative_position * ratio,
- illumination_angle
+ illumination_angle,
+ # k # Daniel
)
cos_phi_field, sin_phi_field = np.cos(phi_field), np.sin(phi_field)
@@ -1108,7 +1140,7 @@ def get(
sin_phi_field / ratio
)
- # If the beam is within the pupil.
+ # If the beam is within the pupil. Remove if Daniel
pupil_mask = (x_farfield - position_objective[0]) ** 2 + (
y_farfield - position_objective[1]
) ** 2 < (pupil_physical_size / 2) ** 2
@@ -1146,9 +1178,6 @@ def get(
* illumination_angle_field
)
- # Wave vector.
- k = 2 * np.pi / wavelength * refractive_index_medium
-
# Harmonics.
A, B = coefficients(L)
PI, TAU = mie.harmonics(illumination_angle_field, L)
@@ -1165,12 +1194,15 @@ def get(
[E[i] * B[i] * PI[i] + E[i] * A[i] * TAU[i] for i in range(0, L)]
)
+ # Daniel
+ # arr[pupil_mask] = (S2 * S2_coef + S1 * S1_coef)/amp_factor
arr[pupil_mask] = (
-1j
/ (k * R3_field)
* np.exp(1j * k * R3_field)
* (S2 * S2_coef + S1 * S1_coef)
) / amp_factor
+
# For phase shift correction (a multiplication of the field
# by exp(1j * k * z)).
@@ -1188,15 +1220,23 @@ def get(
-mask.shape[1] // 2 : mask.shape[1] // 2,
]
mask = np.exp(-0.5 * (x ** 2 + y ** 2) / ((sigma) ** 2))
-
arr = arr * mask
+ # Not sure if needed... CM
+ # if len(pupil)>0:
+ # c_pix=[arr.shape[0]//2,arr.shape[1]//2]
+
+ # arr[c_pix[0]-pupil.shape[0]//2:c_pix[0]+pupil.shape[0]//2,c_pix[1]-pupil.shape[1]//2:c_pix[1]+pupil.shape[1]//2]*=pupil
+
+ # Daniel
+ # fourier_field = -np.fft.ifft2(np.fft.fftshift(np.fft.fft2(np.fft.fftshift(arr))))
fourier_field = np.fft.fft2(arr)
propagation_matrix = get_propagation_matrix(
fourier_field.shape,
- pixel_size=voxel_size[2],
+ pixel_size=voxel_size[:2], # this needs a double check
wavelength=wavelength / refractive_index_medium,
+ # to_z=(-z), # Daniel
to_z=(-offset_z - z),
dy=(
relative_position[0] * ratio
@@ -1206,11 +1246,12 @@ def get(
dx=(
relative_position[1] * ratio
+ position[1]
- + (padding[1] - arr.shape[1] / 2) * voxel_size[1]
+ + (padding[2] - arr.shape[1] / 2) * voxel_size[1] # check if padding is top, bottom, left, right
),
)
+
fourier_field = (
- fourier_field * propagation_matrix * np.exp(-1j * k * offset_z)
+ fourier_field * propagation_matrix * np.exp(-1j * k * offset_z) # Remove last part (from exp)) if Daniel
)
if return_fft:
@@ -1275,6 +1316,8 @@ class MieSphere(MieScatterer):
"""
+ role = "field"
+
def __init__(
self,
radius: float = 1e-6,
@@ -1377,6 +1420,8 @@ class MieStratifiedSphere(MieScatterer):
"""
+ role = "field"
+
def __init__(
self,
radius: ArrayLike[float] = [1e-6],
@@ -1412,3 +1457,51 @@ def inner(
refractive_index=refractive_index,
**kwargs,
)
+
+
+@dataclass
+class ScatteredObject:
+ """Base class for scatterers (volumes and fields)."""
+
+ array: ArrayLike
+ properties: dict[str, Any] = field(default_factory=dict)
+ role: Literal["volume", "field"] = "volume"
+
+ @property
+ def ndim(self) -> int:
+ """Number of dimensions of the underlying array."""
+ return self.array.ndim
+
+ @property
+ def shape(self) -> int:
+ """Number of dimensions of the underlying array."""
+ return self.array.shape
+
+ @property
+ def pos3d(self) -> np.ndarray:
+ return np.array([*self.position, self.z], dtype=float)
+
+ @property
+ def position(self) -> np.ndarray:
+ pos = self.properties.get("position", None)
+ if pos is None:
+ return None
+ pos = np.asarray(pos, dtype=float)
+ if pos.ndim == 2 and pos.shape[0] == 1:
+ pos = pos[0]
+ return pos
+
+ def as_array(self) -> ArrayLike:
+ """Return the underlying array.
+
+ Notes
+ -----
+ The raw array is also directly available as ``scatterer.array``.
+ This method exists mainly for API compatibility and clarity.
+
+ """
+
+ return self.array
+
+ def get_property(self, key: str, default: Any = None) -> Any:
+ return getattr(self, key, self.properties.get(key, default))