diff --git a/README-pypi.md b/README-pypi.md index be0de197..f6173669 100644 --- a/README-pypi.md +++ b/README-pypi.md @@ -93,6 +93,14 @@ Here you find a series of notebooks that give you an overview of the core featur Using PyTorch gradients to fit a Gaussian generated by a DeepTrack2 pipeline. +- DTGS171A **[Creating Custom Scatterers](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171A_custom_scatterers.ipynb)** + + Creating custom scatterers of arbitrary shapes. + +- DTGS171B **[Creating Custom Scatterers: Bacteria](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171B_custom_scatterers_bacteria.ipynb)** + + Creating custom scatterers in the shape of bacteria. + # Examples These are examples of how DeepTrack2 can be used on real datasets: diff --git a/README.md b/README.md index 674d41d3..a9f507c2 100644 --- a/README.md +++ b/README.md @@ -97,6 +97,14 @@ Here you find a series of notebooks that give you an overview of the core featur Using PyTorch gradients to fit a Gaussian generated by a DeepTrack2 pipeline. +- DTGS171A **[Creating Custom Scatterers](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171A_custom_scatterers.ipynb)** + + Creating custom scatterers of arbitrary shapes. + +- DTGS171B **[Creating Custom Scatterers: Bacteria](https://github.com/DeepTrackAI/DeepTrack2/blob/develop/tutorials/1-getting-started/DTGS171B_custom_scatterers_bacteria.ipynb)** + + Creating custom scatterers in the shape of bacteria. + # Examples These are examples of how DeepTrack2 can be used on real datasets: diff --git a/deeptrack/features.py b/deeptrack/features.py index 4bdfad38..be82b558 100644 --- a/deeptrack/features.py +++ b/deeptrack/features.py @@ -218,11 +218,11 @@ "OneOf", "OneOfDict", "LoadImage", - "SampleToMasks", # TODO ***CM*** revise this after elimination of Image + "SampleToMasks", "AsType", "ChannelFirst2d", - "Upscale", # TODO ***CM*** revise and check PyTorch afrer elimin. Image - "NonOverlapping", # TODO ***CM*** revise + PyTorch afrer elimin. Image + "Upscale", + "NonOverlapping", "Store", "Squeeze", "Unsqueeze", @@ -7398,7 +7398,7 @@ class SampleToMasks(Feature): Returns ------- - Image or np.ndarray + np.ndarray The final mask image with the specified number of layers. Raises @@ -7460,7 +7460,7 @@ class SampleToMasks(Feature): def __init__( self: Feature, - transformation_function: Callable[[Image], Image], + transformation_function: Callable[[np.ndarray], np.ndarray, torch.Tensor], number_of_masks: PropertyLike[int] = 1, output_region: PropertyLike[tuple[int, int, int, int]] = None, merge_method: PropertyLike[str | Callable | list[str | Callable]] = "add", @@ -7493,17 +7493,17 @@ def __init__( def get( self: Feature, - image: np.ndarray | Image, - transformation_function: Callable[[Image], Image], + image: np.ndarray, + transformation_function: Callable[list[np.ndarray] | np.ndarray | torch.Tensor], **kwargs: Any, - ) -> Image: + ) -> np.ndarray: """Apply the transformation function to a single image. Parameters ---------- - image: np.ndarray | Image + image: np.ndarray The input image. - transformation_function: Callable[[Image], Image] + transformation_function: Callable[[np.ndarray], np.ndarray] Function to transform the image. **kwargs: dict[str, Any] Additional parameters. @@ -7515,13 +7515,13 @@ def get( """ - return transformation_function(image) + return transformation_function(image.array) def _process_and_get( self: Feature, - images: list[np.ndarray] | np.ndarray | list[Image] | Image, + images: list[np.ndarray] | np.ndarray | list[torch.Tensor] | torch.Tensor, **kwargs: Any, - ) -> Image | np.ndarray: + ) -> np.ndarray: """Process a list of images and generate a multi-layer mask. Parameters @@ -7540,50 +7540,44 @@ def _process_and_get( """ # Handle list of images. - if isinstance(images, list) and len(images) != 1: - list_of_labels = super()._process_and_get(images, **kwargs) - if not self._wrap_array_with_image: - for idx, (label, image) in enumerate(zip(list_of_labels, - images)): - list_of_labels[idx] = \ - Image(label, copy=False).merge_properties_from(image) - else: - if isinstance(images, list): - images = images[0] - list_of_labels = [] - for prop in images.properties: - - if "position" in prop: + # if isinstance(images, list) and len(images) != 1: + list_of_labels = super()._process_and_get(images, **kwargs) - inp = Image(np.array(images)) - inp.append(prop) - out = Image(self.get(inp, **kwargs)) - out.merge_properties_from(inp) - list_of_labels.append(out) + from deeptrack.scatterers import ScatteredObject + + for idx, (label, image) in enumerate(zip(list_of_labels, images)): + list_of_labels[idx] = \ + ScatteredObject(array=label, properties=image.properties.copy(), role=image.role) # Create an empty output image. output_region = kwargs["output_region"] - output = np.zeros( + output = xp.zeros( ( output_region[2] - output_region[0], output_region[3] - output_region[1], kwargs["number_of_masks"], - ) + ), + dtype=list_of_labels[0].array.dtype, ) from deeptrack.optics import _get_position # Merge masks into the output. - for label in list_of_labels: - position = _get_position(label) - p0 = np.round(position - output_region[0:2]) + for volume in list_of_labels: + label = volume.array + position = _get_position(volume) + + p0 = xp.round(position - xp.asarray(output_region[0:2])) + p0 = p0.astype(xp.int64) - if np.any(p0 > output.shape[0:2]) or \ - np.any(p0 + label.shape[0:2] < 0): + + if xp.any(p0 > xp.asarray(output.shape[:2])) or \ + xp.any(p0 + xp.asarray(label.shape[:2]) < 0): continue - crop_x = int(-np.min([p0[0], 0])) - crop_y = int(-np.min([p0[1], 0])) + crop_x = (-xp.minimum(p0[0], 0)).item() + crop_y = (-xp.minimum(p0[1], 0)).item() + crop_x_end = int( label.shape[0] - np.max([p0[0] + label.shape[0] - output.shape[0], 0]) @@ -7635,9 +7629,10 @@ def _process_and_get( p0[0] : p0[0] + labelarg.shape[0], p0[1] : p0[1] + labelarg.shape[1], label_index, - ] = (output_slice[..., label_index] != 0) | ( + ] = xp.logical_or( + output_slice[..., label_index] != 0, labelarg[..., label_index] != 0 - ) + ) elif merge == "mul": output[ @@ -7657,11 +7652,6 @@ def _process_and_get( labelarg[..., label_index], ) - if not self._wrap_array_with_image: - return output - output = Image(output) - for label in list_of_labels: - output.merge_properties_from(label) return output @@ -8082,23 +8072,23 @@ def get( # Ensure factor is a tuple of three integers. if np.size(factor) == 1: - factor = (factor,) * 3 + factor = (factor, factor, 1) elif len(factor) != 3: raise ValueError( "Factor must be an integer or a tuple of three integers." ) - + # Create a context for upscaling and perform computation. ctx = create_context(None, None, None, *factor) with units.context(ctx): image = self.feature(image) - # Downscale the result to the original resolution. - import skimage.measure + # # Downscale the result to the original resolution. + # import skimage.measure - image = skimage.measure.block_reduce( - image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean - ) + # image = skimage.measure.block_reduce( + # image, (factor[0], factor[1]) + (1,) * (image.ndim - 2), np.mean + # ) return image @@ -8356,7 +8346,7 @@ def get( list_of_volumes = [list_of_volumes] for _ in range(max_iters): - + list_of_volumes = [ self._resample_volume_position(volume) for volume in list_of_volumes @@ -8411,32 +8401,41 @@ def _check_non_overlapping( - If bounding cubes overlap, voxel-level checks are performed. """ + from deeptrack.scatterers import ScatteredObject - from skimage.morphology import isotropic_erosion, isotropic_dilation - - from deeptrack.augmentations import CropTight, Pad + from deeptrack.augmentations import CropTight, Pad # these are not compatibles with torch backend from deeptrack.optics import _get_position min_distance = self.min_distance() crop = CropTight() + + new_volumes = [] - if min_distance < 0: - list_of_volumes = [ - Image( - crop(isotropic_erosion(volume != 0, -min_distance/2)), - copy=False, - ).merge_properties_from(volume) - for volume in list_of_volumes - ] - else: - pad = Pad(px = [int(np.ceil(min_distance/2))]*6, keep_size=True) - list_of_volumes = [ - Image( - crop(isotropic_dilation(pad(volume) != 0, min_distance/2)), - copy=False, - ).merge_properties_from(volume) - for volume in list_of_volumes - ] + for volume in list_of_volumes: + arr = volume.array + mask = arr != 0 + + if min_distance < 0: + new_arr = isotropic_erosion(mask, -min_distance / 2, backend=self.get_backend()) + else: + pad = Pad(px=[int(np.ceil(min_distance / 2))] * 6, keep_size=True) + new_arr = isotropic_dilation(pad(mask) != 0 , min_distance / 2, backend=self.get_backend()) + new_arr = crop(new_arr) + + if self.get_backend() == "torch": + new_arr = new_arr.to(dtype=arr.dtype) + else: + new_arr = new_arr.astype(arr.dtype) + + new_volume = ScatteredObject( + array=new_arr, + properties=volume.properties.copy(), + role=volume.role, + ) + + new_volumes.append(new_volume) + + list_of_volumes = new_volumes min_distance = 1 # The position of the top left corner of each volume (index (0, 0, 0)). @@ -8472,10 +8471,10 @@ def _check_non_overlapping( volume_bounding_cube[i], volume_bounding_cube[j] ) overlapping_volume_1 = self._get_overlapping_volume( - list_of_volumes[i], volume_bounding_cube[i], overlapping_cube + list_of_volumes[i].array, volume_bounding_cube[i], overlapping_cube ) overlapping_volume_2 = self._get_overlapping_volume( - list_of_volumes[j], volume_bounding_cube[j], overlapping_cube + list_of_volumes[j].array, volume_bounding_cube[j], overlapping_cube ) # If either the overlapping regions are empty, the volumes do not @@ -8710,8 +8709,12 @@ def _check_volumes_non_overlapping( """ # Get the positions of the non-zero voxels of each volume. - positions_1 = np.argwhere(volume_1) - positions_2 = np.argwhere(volume_2) + if self.get_backend() == "torch": + positions_1 = torch.nonzero(volume_1, as_tuple=False) + positions_2 = torch.nonzero(volume_2, as_tuple=False) + else: + positions_1 = np.argwhere(volume_1) + positions_2 = np.argwhere(volume_2) # if positions_1.size == 0 or positions_2.size == 0: # return True # If either volume is empty, they are "non-overlapping" @@ -8732,9 +8735,14 @@ def _check_volumes_non_overlapping( # Check that the non-zero voxels of the volumes are at least # min_distance apart. - return np.all( - cdist(positions_1, positions_2) > min_distance - ) + if self.get_backend() == "torch": + dist = torch.cdist( + positions_1.float(), + positions_2.float(), + ) + return bool((dist > min_distance).all()) + else: + return np.all(cdist(positions_1, positions_2) > min_distance) def _resample_volume_position( self: NonOverlapping, @@ -8750,7 +8758,7 @@ def _resample_volume_position( Parameters ---------- - volume: np.ndarray or Image + volume: np.ndarray The 3D volume whose position is to be resampled. The volume must have a `properties` attribute containing dictionaries with `position` and `_position_sampler` keys. @@ -8771,12 +8779,12 @@ def _resample_volume_position( """ - for pdict in volume.properties: - if "position" in pdict and "_position_sampler" in pdict: - new_position = pdict["_position_sampler"]() - if isinstance(new_position, Quantity): - new_position = new_position.to("pixel").magnitude - pdict["position"] = new_position + pdict = volume.properties + if "position" in pdict and "_position_sampler" in pdict: + new_position = pdict["_position_sampler"]() + if isinstance(new_position, Quantity): + new_position = new_position.to("pixel").magnitude + pdict["position"] = new_position return volume @@ -9594,3 +9602,73 @@ def get( res = res[0] return res + +### Move to math? +def isotropic_dilation( + mask, + radius: float, + *, + backend: str, + device=None, + dtype=None, +): + if radius <= 0: + return mask + + if backend == "numpy": + from skimage.morphology import isotropic_dilation + return isotropic_dilation(mask, radius) + + # torch backend + import torch + + r = int(np.ceil(radius)) + kernel = torch.ones( + (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1), + device=device or mask.device, + dtype=dtype or torch.float32, + ) + + x = mask.to(dtype=kernel.dtype)[None, None] + y = torch.nn.functional.conv3d( + x, + kernel, + padding=r, + ) + + return (y[0, 0] > 0) + + +def isotropic_erosion( + mask, + radius: float, + *, + backend: str, + device=None, + dtype=None, +): + if radius <= 0: + return mask + + if backend == "numpy": + from skimage.morphology import isotropic_erosion + return isotropic_erosion(mask, radius) + + import torch + + r = int(np.ceil(radius)) + kernel = torch.ones( + (1, 1, 2 * r + 1, 2 * r + 1, 2 * r + 1), + device=device or mask.device, + dtype=dtype or torch.float32, + ) + + x = mask.to(dtype=kernel.dtype)[None, None] + y = torch.nn.functional.conv3d( + x, + kernel, + padding=r, + ) + + required = kernel.numel() + return (y[0, 0] >= required) diff --git a/deeptrack/holography.py b/deeptrack/holography.py index 380969cf..141cc540 100644 --- a/deeptrack/holography.py +++ b/deeptrack/holography.py @@ -101,7 +101,7 @@ def get_propagation_matrix( def get_propagation_matrix( shape: tuple[int, int], to_z: float, - pixel_size: float, + pixel_size: float | tuple[float, float], wavelength: float, dx: float = 0, dy: float = 0 @@ -118,8 +118,8 @@ def get_propagation_matrix( The dimensions of the optical field (height, width). to_z: float Propagation distance along the z-axis. - pixel_size: float - The physical size of each pixel in the optical field. + pixel_size: float | tuple[float, float] + Physical pixel size. If scalar, isotropic pixels are assumed. wavelength: float The wavelength of the optical field. dx: float, optional @@ -140,14 +140,22 @@ def get_propagation_matrix( """ + if pixel_size is None: + pixel_size = get_active_voxel_size() + + if np.isscalar(pixel_size): + pixel_size = (pixel_size, pixel_size) + + px, py = pixel_size + k = 2 * np.pi / wavelength yr, xr, *_ = shape x = np.arange(0, xr, 1) - xr / 2 + (xr % 2) / 2 y = np.arange(0, yr, 1) - yr / 2 + (yr % 2) / 2 - x = 2 * np.pi / pixel_size * x / xr - y = 2 * np.pi / pixel_size * y / yr + x = 2 * np.pi / px * x / xr + y = 2 * np.pi / py * y / yr KXk, KYk = np.meshgrid(x, y) KXk = KXk.astype(complex) diff --git a/deeptrack/optics.py b/deeptrack/optics.py index 5149bdae..d0b5d66b 100644 --- a/deeptrack/optics.py +++ b/deeptrack/optics.py @@ -137,11 +137,13 @@ def _pad_volume( from __future__ import annotations from pint import Quantity -from typing import Any +from typing import Any, TYPE_CHECKING import warnings import numpy as np -from scipy.ndimage import convolve +from scipy.ndimage import convolve # might be removed later +import torch +import torch.nn.functional as F from deeptrack.backend.units import ( ConversionTable, @@ -152,12 +154,22 @@ def _pad_volume( from deeptrack.math import AveragePooling from deeptrack.features import propagate_data_to_dependencies from deeptrack.features import DummyFeature, Feature, StructuralFeature -from deeptrack.image import Image, pad_image_to_fft +from deeptrack.image import pad_image_to_fft from deeptrack.types import ArrayLike, PropertyLike from deeptrack import image from deeptrack import units_registry as u +from deeptrack import TORCH_AVAILABLE, image +from deeptrack.backend import xp +from deeptrack.scatterers import ScatteredObject + +if TORCH_AVAILABLE: + import torch + +if TYPE_CHECKING: + import torch + #TODO ***??*** revise Microscope - torch, typing, docstring, unit test class Microscope(StructuralFeature): @@ -186,7 +198,7 @@ class Microscope(StructuralFeature): Methods ------- - `get(image: Image or None, **kwargs: Any) -> Image` + `get(image: np.ndarray or None, **kwargs: Any) -> np.ndarray` Simulates the imaging process using the defined optical system and returns the resulting image. @@ -238,13 +250,13 @@ def __init__( self._sample = self.add_feature(sample) self._objective = self.add_feature(objective) - self._sample.store_properties() + # self._sample.store_properties() def get( self: Microscope, - image: Image | None, + image: np.ndarray | None, **kwargs: Any, - ) -> Image: + ) -> np.ndarray: """Generate an image of the sample using the defined optical system. This method processes the sample through the optical system to @@ -252,14 +264,14 @@ def get( Parameters ---------- - image: Image | None + image: np.ndarray | None The input image to be processed. If None, a new image is created. **kwargs: Any Additional parameters for the imaging process. Returns ------- - Image: Image + image: np.ndarray The processed image after applying the optical system. Examples @@ -279,17 +291,18 @@ def get( # Grab properties from the objective to pass to the sample additional_sample_kwargs = self._objective.properties() + contrast_type = getattr(self._objective, "contrast_type", None) + if contrast_type is None: + raise RuntimeError( + f"{self._objective.__class__.__name__} must define `contrast_type` " + "(e.g. 'intensity' or 'refractive_index')." + ) - # Calculate required output image for the given upscale - # This way of providing the upscale will be deprecated in the future - # in favor of dt.Upscale(). - _upscale_given_by_optics = additional_sample_kwargs["upscale"] - if np.array(_upscale_given_by_optics).size == 1: - _upscale_given_by_optics = (_upscale_given_by_optics,) * 3 + additional_sample_kwargs["contrast_type"] = contrast_type with u.context( create_context( - *additional_sample_kwargs["voxel_size"], *_upscale_given_by_optics + *additional_sample_kwargs["voxel_size"]#, *_upscale_given_by_optics ) ): @@ -329,26 +342,21 @@ def get( volume_samples = [ scatterer for scatterer in list_of_scatterers - if not scatterer.get_property("is_field", default=False) + if scatterer.role == "volume" ] # All scatterers that are defined as fields. field_samples = [ scatterer for scatterer in list_of_scatterers - if scatterer.get_property("is_field", default=False) + if scatterer.role == "field" ] - + # Merge all volumes into a single volume. sample_volume, limits = _create_volume( volume_samples, **additional_sample_kwargs, ) - sample_volume = Image(sample_volume) - - # Merge all properties into the volume. - for scatterer in volume_samples + field_samples: - sample_volume.merge_properties_from(scatterer) # Let the objective know about the limits of the volume and all the fields. propagate_data_to_dependencies( @@ -359,33 +367,17 @@ def get( imaged_sample = self._objective.resolve(sample_volume) - # Upscale given by the optics needs to be handled separately. - if _upscale_given_by_optics != (1, 1, 1): - imaged_sample = AveragePooling((*_upscale_given_by_optics[:2], 1))( - imaged_sample - ) - - # Merge with input - if not image: - if not self._wrap_array_with_image and isinstance(imaged_sample, Image): - return imaged_sample._value + # Handling upscale from dt.Upscale() here to eliminate Image + # wrapping issues. + if np.any(np.array(upscale) != 1): + ux, uy = upscale[:2] + if contrast_type == "intensity": + print("Using sum pooling for intensity downscaling.") + imaged_sample = SumPoolingCM((ux, uy, 1))(imaged_sample) else: - return imaged_sample - - if not isinstance(image, list): - image = [image] - for i in range(len(image)): - image[i].merge_properties_from(imaged_sample) - return image - - # def _no_wrap_format_input(self, *args, **kwargs) -> list: - # return self._image_wrapped_format_input(*args, **kwargs) + imaged_sample = AveragePoolingCM((ux, uy, 1))(imaged_sample) - # def _no_wrap_process_and_get(self, *args, **feature_input) -> list: - # return self._image_wrapped_process_and_get(*args, **feature_input) - - # def _no_wrap_process_output(self, *args, **feature_input): - # return self._image_wrapped_process_output(*args, **feature_input) + return imaged_sample #TODO ***??*** revise Optics - torch, typing, docstring, unit test @@ -757,19 +749,18 @@ def _pupil( W, H = np.meshgrid(y, x) RHO = (W ** 2 + H ** 2).astype(complex) - pupil_function = Image((RHO < 1) + 0.0j, copy=False) + pupil_function = (RHO < 1) + 0.0j # Defocus - z_shift = Image( + z_shift = ( 2 * np.pi * refractive_index_medium / wavelength * voxel_size[2] - * np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO), - copy=False, + * np.sqrt(1 - (NA / refractive_index_medium) ** 2 * RHO) ) - z_shift._value[z_shift._value.imag != 0] = 0 + z_shift[z_shift.imag != 0] = 0 try: z_shift = np.nan_to_num(z_shift, False, 0, 0, 0) @@ -1007,7 +998,7 @@ class Fluorescence(Optics): Methods ------- - `get(illuminated_volume: array_like[complex], limits: array_like[int, int], **kwargs: Any) -> Image` + `get(illuminated_volume: array_like[complex], limits: array_like[int, int], **kwargs: Any) -> np.ndarray` Simulates the imaging process using a fluorescence microscope. Examples @@ -1023,13 +1014,14 @@ class Fluorescence(Optics): 1.4 """ + contrast_type = "intensity" def get( self: Fluorescence, illuminated_volume: ArrayLike[complex], limits: ArrayLike[int], **kwargs: Any, - ) -> Image: + ) -> ArrayLike[complex]: """Simulates the imaging process using a fluorescence microscope. This method convolves the 3D illuminated volume with a pupil function @@ -1048,7 +1040,7 @@ def get( Returns ------- - Image: Image + image: np.ndarray A 2D image object representing the fluorescence projection. Notes @@ -1066,7 +1058,7 @@ def get( >>> optics = dt.Fluorescence( ... NA=1.4, wavelength=0.52e-6, magnification=60, ... ) - >>> volume = dt.Image(np.ones((128, 128, 10), dtype=complex)) + >>> volume = np.ones((128, 128, 10), dtype=complex) >>> limits = np.array([[0, 128], [0, 128], [0, 10]]) >>> properties = optics.properties() >>> filtered_properties = { @@ -1118,9 +1110,7 @@ def get( ] z_limits = limits[2, :] - output_image = Image( - np.zeros((*padded_volume.shape[0:2], 1)), copy=False - ) + output_image = np.zeros((*padded_volume.shape[0:2], 1)) index_iterator = range(padded_volume.shape[2]) @@ -1156,12 +1146,12 @@ def get( field = np.fft.ifft2(convolved_fourier_field) # # Discard remaining imaginary part (should be 0 up to rounding error) field = np.real(field) - output_image._value[:, :, 0] += field[ + output_image[:, :, 0] += field[ : padded_volume.shape[0], : padded_volume.shape[1] ] output_image = output_image[pad[0] : -pad[2], pad[1] : -pad[3]] - output_image.properties = illuminated_volume.properties + pupils.properties + # output_image.properties = illuminated_volume.properties + pupils.properties return output_image @@ -1234,7 +1224,7 @@ class Brightfield(Optics): ------- `get(illuminated_volume: array_like[complex], limits: array_like[int, int], fields: array_like[complex], - **kwargs: Any) -> Image` + **kwargs: Any) -> np.ndarray` Simulates imaging with brightfield microscopy. @@ -1250,6 +1240,8 @@ class Brightfield(Optics): """ + contrast_type = "refractive_index" + __conversion_table__ = ConversionTable( working_distance=(u.meter, u.meter), ) @@ -1260,7 +1252,7 @@ def get( limits: ArrayLike[int], fields: ArrayLike[complex], **kwargs: Any, - ) -> Image: + ) -> np.ndarray: """Simulates imaging with brightfield microscopy. This method propagates light through the given volume, applying @@ -1285,7 +1277,7 @@ def get( Returns ------- - Image: Image + image: np.ndarray Processed image after simulating the brightfield imaging process. Examples @@ -1300,7 +1292,7 @@ def get( ... wavelength=0.52e-6, ... magnification=60, ... ) - >>> volume = dt.Image(np.ones((128, 128, 10), dtype=complex)) + >>> volume = np.ones((128, 128, 10), dtype=complex) >>> limits = np.array([[0, 128], [0, 128], [0, 10]]) >>> fields = np.array([np.ones((162, 162), dtype=complex)]) >>> properties = optics.properties() @@ -1345,7 +1337,7 @@ def get( if output_region[3] is None else int(output_region[3] - limits[1, 0] + pad[3]) ) - + padded_volume = padded_volume[ output_region[0] : output_region[2], output_region[1] : output_region[3], @@ -1353,9 +1345,7 @@ def get( ] z_limits = limits[2, :] - output_image = Image( - np.zeros((*padded_volume.shape[0:2], 1)) - ) + output_image = np.zeros((*padded_volume.shape[0:2], 1)) index_iterator = range(padded_volume.shape[2]) z_iterator = np.linspace( @@ -1414,7 +1404,25 @@ def get( light_in_focus = light_in * shifted_pupil if len(fields) > 0: - field = np.sum(fields, axis=0) + # field = np.sum(fields, axis=0) + field_arrays = [] + + for fs in fields: + # fs is a ScatteredField + arr = fs.array + + # Enforce (H, W, 1) shape + if arr.ndim == 2: + arr = arr[..., None] + + if arr.ndim != 3 or arr.shape[-1] != 1: + raise ValueError( + f"Expected field of shape (H, W, 1), got {arr.shape}" + ) + + field_arrays.append(arr) + + field = np.sum(field_arrays, axis=0) light_in_focus += field[..., 0] shifted_pupil = np.fft.fftshift(pupils[-1]) light_in_focus = light_in_focus * shifted_pupil @@ -1426,7 +1434,7 @@ def get( : padded_volume.shape[0], : padded_volume.shape[1] ] output_image = np.expand_dims(output_image, axis=-1) - output_image = Image(output_image[pad[0] : -pad[2], pad[1] : -pad[3]]) + output_image = output_image[pad[0] : -pad[2], pad[1] : -pad[3]] if not kwargs.get("return_field", False): output_image = np.square(np.abs(output_image)) @@ -1436,7 +1444,7 @@ def get( # output_image = output_image * np.exp(1j * -np.pi / 4) # output_image = output_image + 1 - output_image.properties = illuminated_volume.properties + # output_image.properties = illuminated_volume.properties return output_image @@ -1631,7 +1639,7 @@ def get( limits: ArrayLike[int], fields: ArrayLike[complex], **kwargs: Any, - ) -> Image: + ) -> np.ndarray: """Retrieve the darkfield image of the illuminated volume. Parameters @@ -1802,7 +1810,7 @@ def get( #TODO ***??*** revise _get_position - torch, typing, docstring, unit test def _get_position( - image: Image, + scatterer: ScatteredObject, mode: str = "corner", return_z: bool = False, ) -> np.ndarray: @@ -1826,26 +1834,23 @@ def _get_position( num_outputs = 2 + return_z - if mode == "corner" and image.size > 0: + if mode == "corner" and scatterer.array.size > 0: import scipy.ndimage - image = image.to_numpy() - - shift = scipy.ndimage.center_of_mass(np.abs(image)) + shift = scipy.ndimage.center_of_mass(np.abs(scatterer.array)) if np.isnan(shift).any(): - shift = np.array(image.shape) / 2 + shift = np.array(scatterer.array.shape) / 2 else: shift = np.zeros((num_outputs)) - position = np.array(image.get_property("position", default=None)) + position = np.array(scatterer.get_property("position", default=None)) if position is None: return position scale = np.array(get_active_scale()) - if len(position) == 3: position = position * scale + 0.5 * (scale - 1) if return_z: @@ -1856,7 +1861,7 @@ def _get_position( elif len(position) == 2: if return_z: outp = ( - np.array([position[0], position[1], image.get_property("z", default=0)]) + np.array([position[0], position[1], scatterer.get_property("z", default=0)]) * scale - shift + 0.5 * (scale - 1) @@ -1868,6 +1873,58 @@ def _get_position( return position +def _bilinear_interpolate_numpy( + scatterer: np.ndarray, x_off: float, y_off: float +) -> np.ndarray: + """Apply bilinear subpixel interpolation in the x–y plane (NumPy).""" + kernel = np.array( + [ + [0.0, 0.0, 0.0], + [0.0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off], + [0.0, x_off * (1 - y_off), x_off * y_off], + ] + ) + out = np.zeros_like(scatterer) + for z in range(scatterer.shape[2]): + if np.iscomplexobj(scatterer): + out[:, :, z] = ( + convolve(np.real(scatterer[:, :, z]), kernel, mode="constant") + + 1j + * convolve(np.imag(scatterer[:, :, z]), kernel, mode="constant") + ) + else: + out[:, :, z] = convolve(scatterer[:, :, z], kernel, mode="constant") + return out + + +def _bilinear_interpolate_torch( + scatterer: torch.Tensor, x_off: float, y_off: float +) -> torch.Tensor: + """Apply bilinear subpixel interpolation in the x–y plane (Torch). + + Uses grid_sample for autograd-friendly interpolation. + """ + H, W, D = scatterer.shape + + # Normalized shifts in [-1,1] + x_shift = 2 * x_off / (W - 1) + y_shift = 2 * y_off / (H - 1) + + yy, xx = torch.meshgrid( + torch.linspace(-1, 1, H, device=scatterer.device, dtype=scatterer.dtype), + torch.linspace(-1, 1, W, device=scatterer.device, dtype=scatterer.dtype), + indexing="ij", + ) + grid = torch.stack((xx + x_shift, yy + y_shift), dim=-1) # (H,W,2) + grid = grid.unsqueeze(0).repeat(D, 1, 1, 1) # (D,H,W,2) + + inp = scatterer.permute(2, 0, 1).unsqueeze(1) # (D,1,H,W) + + out = F.grid_sample(inp, grid, mode="bilinear", + padding_mode="zeros", align_corners=True) + return out.squeeze(1).permute(1, 2, 0) # (H,W,D) + + #TODO ***??*** revise _create_volume - torch, typing, docstring, unit test def _create_volume( list_of_scatterers: list, @@ -1903,6 +1960,12 @@ def _create_volume( Spatial limits of the volume. """ + contrast_type = kwargs.get("contrast_type", None) + if contrast_type is None: + raise RuntimeError( + "_create_volume requires a contrast_type " + "(e.g. 'intensity' or 'refractive_index')" + ) if not isinstance(list_of_scatterers, list): list_of_scatterers = [list_of_scatterers] @@ -1927,24 +1990,28 @@ def _create_volume( # This accounts for upscale doing AveragePool instead of SumPool. This is # a bit of a hack, but it works for now. - fudge_factor = scale[0] * scale[1] / scale[2] + # fudge_factor = scale[0] * scale[1] / scale[2] for scatterer in list_of_scatterers: - position = _get_position(scatterer, mode="corner", return_z=True) - if scatterer.get_property("intensity", None) is not None: - intensity = scatterer.get_property("intensity") - scatterer_value = intensity * fudge_factor - elif scatterer.get_property("refractive_index", None) is not None: - refractive_index = scatterer.get_property("refractive_index") - scatterer_value = ( - refractive_index - refractive_index_medium - ) + if contrast_type == "intensity": + value = scatterer.get_property("intensity", None) + if value is None: + raise ValueError("Scatterer has no intensity.") + scatterer_value = value + + elif contrast_type == "refractive_index": + ri = scatterer.get_property("refractive_index", None) + if ri is None: + raise ValueError("Scatterer has no refractive_index.") + scatterer_value = ri - refractive_index_medium + else: - scatterer_value = scatterer.get_property("value") + raise RuntimeError(f"Unknown contrast_type: {contrast_type}") - scatterer = scatterer * scatterer_value + # Scale the array accordingly + scatterer.array = scatterer.array * scatterer_value if limits is None: limits = np.zeros((3, 2), dtype=np.int32) @@ -1952,26 +2019,25 @@ def _create_volume( limits[:, 1] = np.floor(position).astype(np.int32) + 1 if ( - position[0] + scatterer.shape[0] < OR[0] + position[0] + scatterer.array.shape[0] < OR[0] or position[0] > OR[2] - or position[1] + scatterer.shape[1] < OR[1] + or position[1] + scatterer.array.shape[1] < OR[1] or position[1] > OR[3] ): continue - padded_scatterer = Image( - np.pad( - scatterer, + # Pad scatterer to avoid edge effects during interpolation + padded_scatterer_arr = np.pad( #Use Pad instead and make it torch-compatible? + scatterer.array, [(2, 2), (2, 2), (2, 2)], "constant", constant_values=0, ) - ) - padded_scatterer.merge_properties_from(scatterer) - - scatterer = padded_scatterer - position = _get_position(scatterer, mode="corner", return_z=True) - shape = np.array(scatterer.shape) + padded_scatterer = ScatteredObject( + array=padded_scatterer_arr, properties=scatterer.properties.copy(), role=scatterer.role, + ) + position = _get_position(padded_scatterer, mode="corner", return_z=True) + shape = np.array(padded_scatterer.array.shape) if position is None: RuntimeWarning( @@ -1980,36 +2046,20 @@ def _create_volume( ) continue - splined_scatterer = np.zeros_like(scatterer) - x_off = position[0] - np.floor(position[0]) y_off = position[1] - np.floor(position[1]) - kernel = np.array( - [ - [0, 0, 0], - [0, (1 - x_off) * (1 - y_off), (1 - x_off) * y_off], - [0, x_off * (1 - y_off), x_off * y_off], - ] - ) - - for z in range(scatterer.shape[2]): - if splined_scatterer.dtype == complex: - splined_scatterer[:, :, z] = ( - convolve( - np.real(scatterer[:, :, z]), kernel, mode="constant" - ) - + convolve( - np.imag(scatterer[:, :, z]), kernel, mode="constant" - ) - * 1j - ) - else: - splined_scatterer[:, :, z] = convolve( - scatterer[:, :, z], kernel, mode="constant" - ) + + if isinstance(padded_scatterer.array, np.ndarray): # get_backend is a method of Features and not exposed + splined_scatterer = _bilinear_interpolate_numpy(padded_scatterer.array, x_off, y_off) + elif isinstance(padded_scatterer.array, torch.Tensor): + splined_scatterer = _bilinear_interpolate_torch(padded_scatterer.array, x_off, y_off) + else: + raise TypeError( + f"Unsupported array type {type(padded_scatterer.array)}. " + "Expected np.ndarray or torch.Tensor." + ) - scatterer = splined_scatterer position = np.floor(position) new_limits = np.zeros(limits.shape, dtype=np.int32) for i in range(3): @@ -2039,6 +2089,7 @@ def _create_volume( within_volume_position = position - limits[:, 0] # NOTE: Maybe shouldn't be additive. + # give options: sum default, but also mean, max, min, or volume[ int(within_volume_position[0]) : int(within_volume_position[0] + shape[0]), @@ -2048,5 +2099,72 @@ def _create_volume( int(within_volume_position[2]) : int(within_volume_position[2] + shape[2]), - ] += scatterer + ] += splined_scatterer return volume, limits + +# this should be moved to math +class _CenteredPoolingBase: + def __init__(self, pool_size: tuple[int, int, int]): + px, py, pz = pool_size + if pz != 1: + raise ValueError("Only pz=1 supported.") + self.px = int(px) + self.py = int(py) + + def _crop_center(self, array): + H, W = array.shape[:2] + px, py = self.px, self.py + + crop_h = (H // px) * px + crop_w = (W // py) * py + + off_h = (H - crop_h) // 2 + off_w = (W - crop_w) // 2 + + return array[off_h:off_h+crop_h, off_w:off_w+crop_w, ...] + + def _pool_numpy(self, array, func): + import skimage.measure + array = self._crop_center(array) + pool_shape = (self.px, self.py) + (1,) * (array.ndim - 2) + return skimage.measure.block_reduce(array, pool_shape, func) + + def _pool_torch(self, array, sum_pool=False): + px, py = self.px, self.py + array = self._crop_center(array) + + extra = array.shape[2:] + C = int(np.prod(extra)) if extra else 1 + x = array.reshape(1, C, array.shape[0], array.shape[1]) + + pooled = torch.nn.functional.avg_pool2d( + x, kernel_size=(px, py), stride=(px, py) + ) + if sum_pool: + pooled = pooled * (px * py) + + return pooled.reshape( + (pooled.shape[2], pooled.shape[3]) + extra + ) + +class AveragePoolingCM(_CenteredPoolingBase): + """Center-preserving average pooling (intensive quantities).""" + + def __call__(self, array): + if isinstance(array, np.ndarray): + return self._pool_numpy(array, np.mean) + elif TORCH_AVAILABLE and isinstance(array, torch.Tensor): + return self._pool_torch(array, sum_pool=False) + else: + raise TypeError("Unsupported array type.") + +class SumPoolingCM(_CenteredPoolingBase): + """Center-preserving sum pooling (extensive quantities).""" + + def __call__(self, array): + if isinstance(array, np.ndarray): + return self._pool_numpy(array, np.sum) + elif TORCH_AVAILABLE and isinstance(array, torch.Tensor): + return self._pool_torch(array, sum_pool=True) + else: + raise TypeError("Unsupported array type.") \ No newline at end of file diff --git a/deeptrack/scatterers.py b/deeptrack/scatterers.py index 04a7c5ea..b942888c 100644 --- a/deeptrack/scatterers.py +++ b/deeptrack/scatterers.py @@ -166,6 +166,7 @@ import numpy as np from numpy.typing import NDArray from pint import Quantity +from dataclasses import dataclass, field from deeptrack.holography import get_propagation_matrix from deeptrack.backend.units import ( @@ -175,7 +176,7 @@ ) from deeptrack.backend import mie from deeptrack.features import Feature, MERGE_STRATEGY_APPEND -from deeptrack.image import pad_image_to_fft, Image +from deeptrack.image import pad_image_to_fft from deeptrack.types import ArrayLike from deeptrack import units_registry as u @@ -258,7 +259,7 @@ def __init__( **kwargs, ) -> None: # Ignore warning to help with comparison with arrays. - if upsample is not 1: # noqa: F632 + if upsample != 1: # noqa: F632 warnings.warn( f"Setting upsample != 1 is deprecated. " f"Please, instead use dt.Upscale(f, factor={upsample})" @@ -296,7 +297,7 @@ def _process_and_get( upsample_axes=None, crop_empty=True, **kwargs - ) -> list[Image] | list[np.ndarray]: + ) -> list[np.ndarray]: # Post processes the created object to handle upsampling, # as well as cropping empty slices. if not self._processed_properties: @@ -310,7 +311,7 @@ def _process_and_get( voxel_size = get_active_voxel_size() # Calls parent _process_and_get. - new_image = super()._process_and_get( + new_image = super(Scatterer, self)._process_and_get( *args, voxel_size=voxel_size, upsample=upsample, @@ -333,28 +334,33 @@ def _process_and_get( new_image = new_image[:, ~np.all(new_image == 0, axis=(0, 2))] new_image = new_image[:, :, ~np.all(new_image == 0, axis=(0, 1))] - return [Image(new_image)] + # # Copy properties + # props = kwargs.copy() + return [self._wrap_output(new_image, kwargs)] + + def _wrap_output(self, array, props) -> ScatteredObject: + # """Must be overridden in subclasses to wrap output correctly.""" + # raise NotImplementedError + return ScatteredObject( + array=array, + properties=props.copy(), + role = self.role, + ) - def _no_wrap_format_input( - self, - *args, - **kwargs - ) -> list: - return self._image_wrapped_format_input(*args, **kwargs) +# class VolumeScatterer(Scatterer): +# """Abstract scatterer producing ScatteredVolume outputs.""" +# def _wrap_output(self, array, props) -> ScatteredVolume: +# return ScatteredVolume( +# array=array, +# properties=props.copy(), +# ) - def _no_wrap_process_and_get( - self, - *args, - **feature_input - ) -> list: - return self._image_wrapped_process_and_get(*args, **feature_input) - - def _no_wrap_process_output( - self, - *args, - **feature_input - ) -> list: - return self._image_wrapped_process_output(*args, **feature_input) +# class FieldScatterer(Scatterer): +# def _wrap_output(self, array, props) -> ScatteredField: +# return ScatteredField( +# array=array, +# properties=props.copy(), +# ) #TODO ***??*** revise PointParticle - torch, typing, docstring, unit test @@ -381,6 +387,7 @@ class PointParticle(Scatterer): for `Brightfield` and `intensity` for `Fluorescence`). """ + role = "volume" def __init__( self: PointParticle, @@ -394,7 +401,7 @@ def __init__( def get( self: PointParticle, - image: Image | np.ndarray, + image: np.ndarray, **kwarg: Any, ) -> NDArray[Any] | torch.Tensor: """Evaluate and return the scatterer volume.""" @@ -440,6 +447,7 @@ class Ellipse(Scatterer): before rotation. """ + role = "volume" __conversion_table__ = ConversionTable( radius=(u.meter, u.meter), @@ -545,6 +553,7 @@ class Sphere(Scatterer): Upsamples the calculations of the pixel occupancy fraction. """ + role = "volume" __conversion_table__ = ConversionTable( radius=(u.meter, u.meter), @@ -559,7 +568,7 @@ def __init__( def get( self, - image: Image | np.ndarray, + image: np.ndarray, radius: float, voxel_size: float, **kwargs @@ -620,6 +629,8 @@ class Ellipsoid(Scatterer): """ + role = "volume" + __conversion_table__ = ConversionTable( radius=(u.meter, u.meter), rotation=(u.radian, u.radian), @@ -694,7 +705,7 @@ def _process_properties( def get( self, - image: Image | np.ndarray, + image: np.ndarray, radius: float, rotation: ArrayLike[float] | float, voxel_size: float, @@ -826,6 +837,8 @@ class MieScatterer(Scatterer): """ + role = "field" + __conversion_table__ = ConversionTable( radius=(u.meter, u.meter), polarization_angle=(u.radian, u.radian), @@ -856,6 +869,7 @@ def __init__( illumination_angle: float=0, amp_factor: float=1, phase_shift_correction: bool=False, + # pupil: ArrayLike=[], # Daniel **kwargs, ) -> None: if polarization_angle is not None: @@ -864,11 +878,11 @@ def __init__( "Please use input_polarization instead" ) input_polarization = polarization_angle - kwargs.pop("is_field", None) + # kwargs.pop("is_field", None) # remove kwargs.pop("crop_empty", None) super().__init__( - is_field=True, + is_field=True, # remove crop_empty=False, L=L, offset_z=offset_z, @@ -889,6 +903,7 @@ def __init__( illumination_angle=illumination_angle, amp_factor=amp_factor, phase_shift_correction=phase_shift_correction, + # pupil=pupil, # Daniel **kwargs, ) @@ -1014,7 +1029,8 @@ def get_plane_in_polar_coords( shape: int, voxel_size: ArrayLike[float], plane_position: float, - illumination_angle: float + illumination_angle: float, + # k: float, # Daniel ) -> tuple[float, float, float, float]: """Computes the coordinates of the plane in polar form.""" @@ -1027,15 +1043,24 @@ def get_plane_in_polar_coords( R2_squared = X ** 2 + Y ** 2 R3 = np.sqrt(R2_squared + Z ** 2) # Might be +z instead of -z. + + # # DANIEL + # Q = np.sqrt(R2_squared)/voxel_size[0]**2*2*np.pi/shape[0] + # # is dimensionally ok? + # sin_theta=Q/(k) + # pupil_mask=sin_theta<1 + # cos_theta=np.zeros(sin_theta.shape) + # cos_theta[pupil_mask]=np.sqrt(1-sin_theta[pupil_mask]**2) # Fet the angles. cos_theta = Z / R3 + illumination_cos_theta = ( np.cos(np.arccos(cos_theta) + illumination_angle) ) phi = np.arctan2(Y, X) - return R3, cos_theta, illumination_cos_theta, phi + return R3, cos_theta, illumination_cos_theta, phi#, pupil_mask # Daniel def get( self, @@ -1060,6 +1085,7 @@ def get( illumination_angle: float, amp_factor: float, phase_shift_correction: bool, + # pupil: ArrayLike, # Daniel **kwargs, ) -> ArrayLike[float]: """Abstract method to initialize the Mie scatterer""" @@ -1067,8 +1093,9 @@ def get( # Get size of the output. xSize, ySize = self.get_xy_size(output_region, padding) voxel_size = get_active_voxel_size() + scale = get_active_scale() arr = pad_image_to_fft(np.zeros((xSize, ySize))).astype(complex) - position = np.array(position) * voxel_size[: len(position)] + position = np.array(position) * scale[: len(position)] * voxel_size[: len(position)] pupil_physical_size = working_distance * np.tan(collection_angle) * 2 @@ -1076,7 +1103,11 @@ def get( ratio = offset_z / (working_distance - z) - # Position of pbjective relative particle. + # Wave vector. + k = 2 * np.pi / wavelength * refractive_index_medium + + + # Position of objective relative particle. relative_position = np.array( ( position_objective[0] - position[0], @@ -1085,12 +1116,13 @@ def get( ) ) - # Get field evaluation plane at offset_z. + # Get field evaluation plane at offset_z. # , pupil_mask # Daniel R3_field, cos_theta_field, illumination_angle_field, phi_field =\ self.get_plane_in_polar_coords( arr.shape, voxel_size, relative_position * ratio, - illumination_angle + illumination_angle, + # k # Daniel ) cos_phi_field, sin_phi_field = np.cos(phi_field), np.sin(phi_field) @@ -1108,7 +1140,7 @@ def get( sin_phi_field / ratio ) - # If the beam is within the pupil. + # If the beam is within the pupil. Remove if Daniel pupil_mask = (x_farfield - position_objective[0]) ** 2 + ( y_farfield - position_objective[1] ) ** 2 < (pupil_physical_size / 2) ** 2 @@ -1146,9 +1178,6 @@ def get( * illumination_angle_field ) - # Wave vector. - k = 2 * np.pi / wavelength * refractive_index_medium - # Harmonics. A, B = coefficients(L) PI, TAU = mie.harmonics(illumination_angle_field, L) @@ -1165,12 +1194,15 @@ def get( [E[i] * B[i] * PI[i] + E[i] * A[i] * TAU[i] for i in range(0, L)] ) + # Daniel + # arr[pupil_mask] = (S2 * S2_coef + S1 * S1_coef)/amp_factor arr[pupil_mask] = ( -1j / (k * R3_field) * np.exp(1j * k * R3_field) * (S2 * S2_coef + S1 * S1_coef) ) / amp_factor + # For phase shift correction (a multiplication of the field # by exp(1j * k * z)). @@ -1188,15 +1220,23 @@ def get( -mask.shape[1] // 2 : mask.shape[1] // 2, ] mask = np.exp(-0.5 * (x ** 2 + y ** 2) / ((sigma) ** 2)) - arr = arr * mask + # Not sure if needed... CM + # if len(pupil)>0: + # c_pix=[arr.shape[0]//2,arr.shape[1]//2] + + # arr[c_pix[0]-pupil.shape[0]//2:c_pix[0]+pupil.shape[0]//2,c_pix[1]-pupil.shape[1]//2:c_pix[1]+pupil.shape[1]//2]*=pupil + + # Daniel + # fourier_field = -np.fft.ifft2(np.fft.fftshift(np.fft.fft2(np.fft.fftshift(arr)))) fourier_field = np.fft.fft2(arr) propagation_matrix = get_propagation_matrix( fourier_field.shape, - pixel_size=voxel_size[2], + pixel_size=voxel_size[:2], # this needs a double check wavelength=wavelength / refractive_index_medium, + # to_z=(-z), # Daniel to_z=(-offset_z - z), dy=( relative_position[0] * ratio @@ -1206,11 +1246,12 @@ def get( dx=( relative_position[1] * ratio + position[1] - + (padding[1] - arr.shape[1] / 2) * voxel_size[1] + + (padding[2] - arr.shape[1] / 2) * voxel_size[1] # check if padding is top, bottom, left, right ), ) + fourier_field = ( - fourier_field * propagation_matrix * np.exp(-1j * k * offset_z) + fourier_field * propagation_matrix * np.exp(-1j * k * offset_z) # Remove last part (from exp)) if Daniel ) if return_fft: @@ -1275,6 +1316,8 @@ class MieSphere(MieScatterer): """ + role = "field" + def __init__( self, radius: float = 1e-6, @@ -1377,6 +1420,8 @@ class MieStratifiedSphere(MieScatterer): """ + role = "field" + def __init__( self, radius: ArrayLike[float] = [1e-6], @@ -1412,3 +1457,51 @@ def inner( refractive_index=refractive_index, **kwargs, ) + + +@dataclass +class ScatteredObject: + """Base class for scatterers (volumes and fields).""" + + array: ArrayLike + properties: dict[str, Any] = field(default_factory=dict) + role: Literal["volume", "field"] = "volume" + + @property + def ndim(self) -> int: + """Number of dimensions of the underlying array.""" + return self.array.ndim + + @property + def shape(self) -> int: + """Number of dimensions of the underlying array.""" + return self.array.shape + + @property + def pos3d(self) -> np.ndarray: + return np.array([*self.position, self.z], dtype=float) + + @property + def position(self) -> np.ndarray: + pos = self.properties.get("position", None) + if pos is None: + return None + pos = np.asarray(pos, dtype=float) + if pos.ndim == 2 and pos.shape[0] == 1: + pos = pos[0] + return pos + + def as_array(self) -> ArrayLike: + """Return the underlying array. + + Notes + ----- + The raw array is also directly available as ``scatterer.array``. + This method exists mainly for API compatibility and clarity. + + """ + + return self.array + + def get_property(self, key: str, default: Any = None) -> Any: + return getattr(self, key, self.properties.get(key, default))