diff --git a/examples/forward_heart/conf/heart.yaml b/examples/forward_heart/conf/heart.yaml new file mode 100644 index 0000000000..f4cf2c11bc --- /dev/null +++ b/examples/forward_heart/conf/heart.yaml @@ -0,0 +1,84 @@ +hydra: + run: + # dynamic output directory according to running time and override name + dir: outputs_heart/${now:%Y-%m-%d}/${now:%H-%M-%S}/${hydra.job.override_dirname} + job: + name: ${mode} # name of logfile + chdir: false # keep current working direcotry unchaned + config: + override_dirname: + exclude_keys: + - TRAIN.checkpoint_path + - TRAIN.pretrained_model_path + - EVAL.pretrained_model_path + - mode + - output_dir + - log_freq + sweep: + # output directory for multirun + dir: ${hydra.run.dir} + subdir: ./ + +EVAL_CSV_PATH: label.csv +DATA_CSV_PATH: data_train.csv +# general settings +mode: train # running mode: train/eval +seed: 2024 +output_dir: ${hydra:run.dir} +log_freq: 200 + +# set geometry +GEOM_PATH: ./stl/ES_python_mesh.stl +BASE_PATH: ./stl/base.stl +ENDO_PATH: ./stl/endo.stl +EPI_PATH: ./stl/epi.stl + +# set working condition +E: 9 # kPa +nu: 0.45 +P: 1.064 # kPa P_ENDO + +# model settings +MODEL: + input_keys: ["x","y","z"] + output_keys: ["u","v","w"] + num_layers: 10 + hidden_size: 20 + activation: "silu" + weight_norm: true + +# training settings +TRAIN: + epochs: 200 + iters_per_epoch: 1000 + lr_scheduler: + epochs: ${TRAIN.epochs} + iters_per_epoch: ${TRAIN.iters_per_epoch} + learning_rate: 1.0e-3 + gamma: 0.95 + decay_steps: 3000 + by_epoch: false + batch_size: + bc_base: 256 + bc_endo: 2048 + bc_epi: 32 + interior: 8000 + weight: + bc_base: {"u": 0.2, "v": 0.2, "w": 0.2} + bc_endo: {"traction_x": 0.1, "traction_y": 0.1, "traction_z": 0.1} + # bc_endo: {"traction": 1.0} + bc_epi: {"traction": 0.2} + interior: {"hooke_x": 0.2, "hooke_y": 0.2, "hooke_z": 0.2} + save_freq: 20 + eval_freq: 20 + eval_during_train: true + eval_with_no_grad: true + pretrained_model_path: null + checkpoint_path: null + +# evaluation settings +EVAL: + pretrained_model_path: null + eval_with_no_grad: true + batch_size: 1000 + num_vis: 100000 diff --git a/examples/forward_heart/equation_forward.py b/examples/forward_heart/equation_forward.py new file mode 100644 index 0000000000..3d23d5c33b --- /dev/null +++ b/examples/forward_heart/equation_forward.py @@ -0,0 +1,166 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import Optional +from typing import Tuple +from typing import Union + +import sympy as sp + +from ppsci.equation.pde import base + + +class Hooke(base.PDE): + r"""equations for umbrella opening force. + Use either (E, nu) or (lambda_, mu) to define the material properties. + + $$ + \begin{cases} + \end{cases} + $$ + + Args: + + + Examples: + >>> pde = Hooke( + ... C_t=1, k_f=None, rho=1, + ... ) + """ + + def __init__( + self, + E: Union[float, str], + nu: Union[float, str], + P: Union[float, str], + dim: int = 3, + time: bool = False, + detach_keys: Optional[Tuple[str, ...]] = None, + ): + super().__init__() + self.detach_keys = detach_keys + self.dim = dim + self.time = time + + t, x, y, z = self.create_symbols("t x y z") + normal_x, normal_y, normal_z = self.create_symbols("normal_x normal_y normal_z") + invars = (x, y) + if time: + invars = (t,) + invars + if self.dim == 3: + invars += (z,) + + u = self.create_function("u", invars) + v = self.create_function("v", invars) + w = self.create_function("w", invars) if dim == 3 else sp.Number(0) + + if isinstance(nu, str): + nu = self.create_function(nu, invars) + if isinstance(E, str): + E = self.create_function(E, invars) + if isinstance(P, str): + P = self.create_function(P, invars) + + self.E = E + self.nu = nu + self.P = P + + # compute sigma + sigma_xx = u.diff(x) + sigma_yy = v.diff(y) + sigma_zz = w.diff(z) if dim == 3 else sp.Number(0) + sigma_xy = 0.5 * (u.diff(y) + v.diff(x)) + sigma_xz = 0.5 * (u.diff(z) + w.diff(x)) if dim == 3 else sp.Number(0) + sigma_yz = 0.5 * (v.diff(z) + w.diff(y)) if dim == 3 else sp.Number(0) + + # compute t + G = E / (2 * (1 + nu)) + e = sigma_xx + sigma_yy + sigma_zz + t_xx = 2 * G * (sigma_xx + nu / (1 - 2 * nu) * e) + t_yy = 2 * G * (sigma_yy + nu / (1 - 2 * nu) * e) + t_zz = 2 * G * (sigma_zz + nu / (1 - 2 * nu) * e) + t_xy = 2 * sigma_xy * G + t_xz = 2 * sigma_xz * G + t_yz = 2 * sigma_yz * G + + # compute hooke (gradt = 0) + hooke_x = t_xx.diff(x) + t_xy.diff(y) + t_xz.diff(z) + hooke_y = t_xy.diff(x) + t_yy.diff(y) + t_yz.diff(z) + hooke_z = t_xz.diff(x) + t_yz.diff(y) + t_zz.diff(z) + + # compute traction splitly (t * n + P_endo * n = 0) + traction_x = t_xx * normal_x + t_xy * normal_y + t_xz * normal_z + P * normal_x + traction_y = t_xy * normal_x + t_yy * normal_y + t_yz * normal_z + P * normal_y + traction_z = t_xz * normal_x + t_yz * normal_y + t_zz * normal_z + P * normal_z + # traction_x = t_xx + t_xy * normal_y / normal_x + t_xz * normal_z / normal_x + # traction_y = t_xy * normal_x / normal_y + t_yy + t_yz * normal_z / normal_y + # traction_z = t_xz * normal_x / normal_z + t_yz * normal_y / normal_z + t_zz + + # compute traction (t * n) + traction_x_ = t_xx * normal_x + t_xy * normal_y + t_xz * normal_z + traction_y_ = t_xy * normal_x + t_yy * normal_y + t_yz * normal_z + traction_z_ = t_xz * normal_x + t_yz * normal_y + t_zz * normal_z + + traction = ( + traction_x_ * normal_x + traction_y_ * normal_y + traction_z_ * normal_z + ) + + # add hooke equations + self.add_equation("hooke_x", hooke_x) + self.add_equation("hooke_y", hooke_y) + if self.dim == 3: + self.add_equation("hooke_z", hooke_z) + + # add traction equations + self.add_equation("traction_x", traction_x) + self.add_equation("traction_y", traction_y) + if self.dim == 3: + self.add_equation("traction_z", traction_z) + + # add combined traction equations + self.add_equation("traction", traction) + + +""" + def Aij(self, A, i, j): + up = np.hstack((A[:i, :j], A[:i, j + 1:])) # 横向连接上方片段 + lo = np.hstack((A[i + 1:, :j], A[i + 1:, j + 1:])) # 横向连接下方片段 + M = np.vstack((up, lo)) # 纵向连接 + return ((-1) ** (i + j)) * self.det(M) # 代数余子式 + + def adjointMatrix(self, A): + n,_ = A.shape # 获取阶数n + list_i = [] + for i in range(n): # 每一行 + list_j = [] + for j in range(n): # 每一列 + list_j.append(self.Aij(A, i, j)) # 伴随阵元素 + list_i.append(list_j) + Am = np.mat(list_i) + return Am.T + + def det(self, A): # 1~3阶矩阵行列式 + n, _ = A.shape # 获取阶数n + if n == 1: + A_det = A[0, 0] + elif n == 2: + A_det = A[0, 0]*A[1, 1] - A[1, 0]*A[0, 1] + elif n == 3: + A_det = A[0, 0]*A[1, 1]*A[2, 2] + A[0, 1]*A[1, 2]*A[2, 0] + A[0, 2]*A[1, 0]*A[2, 1] - A[0, 2]*A[1, 1]*A[2, 0] - A[0, 0]*A[1, 2]*A[2, 1] - A[0, 1]*A[1, 0]*A[2, 2] + else: + raise ValueError(f'ERROR: shape of matrix({n}) is too big when computing det') + return A_det +""" diff --git a/examples/forward_heart/forward.py b/examples/forward_heart/forward.py new file mode 100644 index 0000000000..9dae0146c7 --- /dev/null +++ b/examples/forward_heart/forward.py @@ -0,0 +1,313 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import equation as eq_func +import hydra +from omegaconf import DictConfig + +import ppsci +from ppsci.utils import reader + + +def train(cfg: DictConfig): + # set model + model = ppsci.arch.MLP(**cfg.MODEL) + + # set optimizer + lr_scheduler = ppsci.optimizer.lr_scheduler.ExponentialDecay( + **cfg.TRAIN.lr_scheduler + )() + optimizer = ppsci.optimizer.Adam(lr_scheduler)(model) + + # set equation + equation = {"Hooke": eq_func.Hooke(E=cfg.E, nu=cfg.nu, P=cfg.P, dim=3)} + + # set geometry + heart = ppsci.geometry.Mesh(cfg.GEOM_PATH) + base = ppsci.geometry.Mesh(cfg.BASE_PATH) + endo = ppsci.geometry.Mesh(cfg.ENDO_PATH) + epi = ppsci.geometry.Mesh(cfg.EPI_PATH) + geom = {"geo": heart, "base": base, "endo": endo, "epi": epi} + # set bounds + BOUNDS_X, BOUNDS_Y, BOUNDS_Z = heart.bounds + + # set dataloader config + train_dataloader_cfg = { + "dataset": "NamedArrayDataset", + "iters_per_epoch": cfg.TRAIN.iters_per_epoch, + "sampler": { + "name": "BatchSampler", + "drop_last": True, + "shuffle": True, + }, + "num_workers": 1, + } + + # set constraint + bc_base = ppsci.constraint.BoundaryConstraint( + {"u": lambda d: d["u"], "v": lambda d: d["v"], "w": lambda d: d["w"]}, + {"u": 0, "v": 0, "w": 0}, + geom["base"], + {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.bc_base}, + ppsci.loss.MSELoss("mean"), + weight_dict=cfg.TRAIN.weight.bc_base, + name="BC_BASE", + ) + bc_endo = ppsci.constraint.BoundaryConstraint( + equation["Hooke"].equations, + # {"traction_x": 0, "traction_y": 0, "traction_z": 0}, + # {"traction_x": -cfg.P, "traction_y": -cfg.P, "traction_z": -cfg.P}, + {"traction": -cfg.P}, + geom["endo"], + # test, + {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.bc_endo}, + ppsci.loss.MSELoss("mean"), + weight_dict=cfg.TRAIN.weight.bc_endo, + name="BC_ENDO", + ) + bc_epi = ppsci.constraint.BoundaryConstraint( + equation["Hooke"].equations, + {"traction": 0}, + # {"traction_x": -cfg.P, "traction_y": -cfg.P, "traction_z": -cfg.P}, + geom["epi"], + # test, + {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.bc_epi}, + ppsci.loss.MSELoss("mean"), + weight_dict=cfg.TRAIN.weight.bc_epi, + name="BC_EPI", + ) + interior = ppsci.constraint.InteriorConstraint( + equation["Hooke"].equations, + {"hooke_x": 0, "hooke_y": 0, "hooke_z": 0}, + geom["geo"], + {**train_dataloader_cfg, "batch_size": cfg.TRAIN.batch_size.interior}, + ppsci.loss.MSELoss("mean"), + criteria=lambda x, y, z: ( + (BOUNDS_X[0] < x) + & (x < BOUNDS_X[1]) + & (BOUNDS_Y[0] < y) + & (y < BOUNDS_Y[1]) + & (BOUNDS_Z[0] < z) + & (z < BOUNDS_Z[1]) + ), + weight_dict=cfg.TRAIN.weight.interior, + name="INTERIOR", + ) + data = ppsci.constraint.SupervisedConstraint( + { + "dataset": { + "name": "IterableCSVDataset", + "file_path": cfg.DATA_CSV_PATH, + "input_keys": ("x", "y", "z"), + "label_keys": ("u", "v", "w"), + }, + }, + ppsci.loss.MSELoss("sum"), + name="DATA", + ) + + # wrap constraints together + constraint = { + bc_base.name: bc_base, + bc_endo.name: bc_endo, + bc_epi.name: bc_epi, + interior.name: interior, + data.name: data, + } + + # set validator + eval_data_dict = reader.load_csv_file( + cfg.EVAL_CSV_PATH, + ("x", "y", "z", "u", "v", "w"), + { + "x": "x", + "y": "y", + "z": "z", + "u": "u", + "v": "v", + "w": "w", + }, + ) + + input_dict = { + "x": eval_data_dict["x"], + "y": eval_data_dict["y"], + "z": eval_data_dict["z"], + } + + label_dict = { + "u": eval_data_dict["u"], + "v": eval_data_dict["v"], + "w": eval_data_dict["w"], + } + eval_dataloader_cfg = { + "dataset": { + "name": "NamedArrayDataset", + "input": input_dict, + "label": label_dict, + }, + "sampler": {"name": "BatchSampler"}, + "num_workers": 1, + } + sup_validator = ppsci.validate.SupervisedValidator( + {**eval_dataloader_cfg, "batch_size": cfg.EVAL.batch_size}, + ppsci.loss.MSELoss("mean"), + { + "u": lambda out: out["u"], + "v": lambda out: out["v"], + "w": lambda out: out["w"], + }, + metric={"L2Rel": ppsci.metric.L2Rel()}, + name="ref_u_v_w", + ) + validator = {sup_validator.name: sup_validator} + + # set visualizer(optional) + visualizer = { + "visualize_u_v_w": ppsci.visualize.VisualizerVtu( + input_dict, + { + "u": lambda out: out["u"], + "v": lambda out: out["v"], + "w": lambda out: out["w"], + }, + batch_size=cfg.EVAL.batch_size, + prefix="result_u_v_w", + ), + } + + # initialize adam solver + solver = ppsci.solver.Solver( + model, + constraint, + cfg.output_dir, + optimizer, + lr_scheduler, + cfg.TRAIN.epochs, + cfg.TRAIN.iters_per_epoch, + save_freq=cfg.TRAIN.save_freq, + log_freq=cfg.log_freq, + eval_freq=cfg.TRAIN.eval_freq, + eval_during_train=cfg.TRAIN.eval_during_train, + eval_with_no_grad=cfg.EVAL.eval_with_no_grad, + seed=cfg.seed, + equation=equation, + geom=geom, + validator=validator, + visualizer=visualizer, + checkpoint_path=cfg.TRAIN.checkpoint_path, + pretrained_model_path=cfg.TRAIN.pretrained_model_path, + # loss_aggregator=loss_aggregator, + ) + # train + solver.train() + # eval + solver.eval() + # visualize prediction after finished training + solver.visualize() + # plot loss + solver.plot_loss_history(by_epoch=True) + + +def evaluate(cfg: DictConfig): + # set models + model = ppsci.arch.MLP(**cfg.MODEL) + + # set validator + eval_data_dict = reader.load_csv_file( + cfg.EVAL_CSV_PATH, + ("x", "y", "z", "u", "v", "w"), + { + "x": "x", + "y": "y", + "z": "z", + "u": "u", + "v": "v", + "w": "w", + }, + ) + + input_dict = { + "x": eval_data_dict["x"], + "y": eval_data_dict["y"], + "z": eval_data_dict["z"], + } + + label_dict = { + "u": eval_data_dict["u"], + "v": eval_data_dict["v"], + "w": eval_data_dict["w"], + } + eval_dataloader_cfg = { + "dataset": { + "name": "NamedArrayDataset", + "input": input_dict, + "label": label_dict, + }, + "sampler": {"name": "BatchSampler"}, + "num_workers": 1, + } + sup_validator = ppsci.validate.SupervisedValidator( + {**eval_dataloader_cfg, "batch_size": cfg.EVAL.batch_size}, + ppsci.loss.MSELoss("mean"), + { + "u": lambda out: out["u"], + "v": lambda out: out["v"], + "w": lambda out: out["w"], + }, + metric={"L2Rel": ppsci.metric.L2Rel()}, + name="ref_u_v_w", + ) + validator = {sup_validator.name: sup_validator} + + # set visualizer + visualizer = { + "visualize_u_v_w": ppsci.visualize.VisualizerVtu( + input_dict, + { + "u": lambda out: out["u"], + "v": lambda out: out["v"], + "w": lambda out: out["w"], + }, + batch_size=cfg.EVAL.batch_size, + prefix="result_u_v_w", + ), + } + + # load pretrained model + solver = ppsci.solver.Solver( + model=model, + output_dir=cfg.output_dir, + validator=validator, + visualizer=visualizer, + pretrained_model_path=cfg.EVAL.pretrained_model_path, + ) + # evaluate + solver.eval() + # visualize prediction + solver.visualize() + + +@hydra.main(version_base=None, config_path="./conf", config_name="heart.yaml") +def main(cfg: DictConfig): + if cfg.mode == "train": + train(cfg) + elif cfg.mode == "eval": + evaluate(cfg) + else: + raise ValueError(f"cfg.mode should in ['train', 'eval'], but got '{cfg.mode}'") + + +if __name__ == "__main__": + main() diff --git a/jointContribution/HighResolution/README.md b/jointContribution/HighResolution/README.md new file mode 100644 index 0000000000..b10659c45c --- /dev/null +++ b/jointContribution/HighResolution/README.md @@ -0,0 +1,22 @@ + +## Dataset + +Download demo dataset: + +``` sh +cd PaddleScience/jointContribution/HighResolution +# linux +wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/HighResolution/patient001.zip +wget -nc https://paddle-org.bj.bcebos.com/paddlescience/datasets/HighResolution/Hammersmith_myo2.zip +# windows +# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/HighResolution/patient001.zip +# curl https://paddle-org.bj.bcebos.com/paddlescience/datasets/HighResolution/Hammersmith_myo2.zip + +# unzip +unzip patient001.zip -d data +unzip Hammersmith_myo2.zip +``` + +## Run + +python main_ACDC.py diff --git a/jointContribution/HighResolution/deepali/core/__init__.py b/jointContribution/HighResolution/deepali/core/__init__.py new file mode 100644 index 0000000000..749042d40d --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/__init__.py @@ -0,0 +1,115 @@ +"""Common types and functions that operate on tensors representing different types of data. + +Besides defining common types and auxiliary functions extending the standard library, +this core library in particular defines functions which operate on objects of type ``paddle.Tensor``. +This functional API defines a set of reusable state-less functions similar to ``paddle.nn.functional``. +Object-oriented APIs use this functional API to realize their functionality. In particular, the +``forward()`` method of ``paddle.nn.Layer`` subclasses, such as for example data transformations +(cf. :mod:`.data.transforms`) and neural network components (cf. :mod:`.modules` and :mod:`.networks`) +are implemented using these functional building blocks. + +The following import statement can be used to access the functional API: + +.. code-block:: + + import deepali.core.functional as U + +""" + +from .config import DataclassConfig +from .config import join_kwargs_in_sequence +from .cube import Cube +from .enum import PaddingMode +from .enum import Sampling +from .grid import ALIGN_CORNERS +from .grid import Axes +from .grid import Grid +from .grid import grid_points_transform +from .grid import grid_transform_points +from .grid import grid_transform_vectors +from .grid import grid_vectors_transform +from .path import abspath +from .path import abspath_template +from .path import delete +from .path import make_parent_dir +from .path import make_temp_file +from .path import temp_dir +from .path import temp_file +from .path import unlink_or_mkdir +from .random import multinomial +from .types import RE_OUTPUT_KEY_INDEX +from .types import Array +from .types import Batch +from .types import Dataclass +from .types import Device +from .types import DType +from .types import Name +from .types import PathStr +from .types import Sample +from .types import Scalar +from .types import ScalarOrTuple +from .types import ScalarOrTuple1d +from .types import ScalarOrTuple2d +from .types import ScalarOrTuple3d +from .types import Shape +from .types import Size +from .types import TensorCollection +from .types import get_tensor +from .types import is_bool_dtype +from .types import is_float_dtype +from .types import is_int_dtype +from .types import is_namedtuple +from .types import is_path_str +from .types import is_uint_dtype +from .types import tensor_collection_entry + +__version__ = "0.3.2" +"""Version string of installed deepali core libraries.""" +__all__ = ( + "ALIGN_CORNERS", + "RE_OUTPUT_KEY_INDEX", + "Array", + "Axes", + "Batch", + "Cube", + "Dataclass", + "DataclassConfig", + "Device", + "DType", + "Grid", + "Name", + "PaddingMode", + "PathStr", + "Sample", + "Sampling", + "Scalar", + "ScalarOrTuple", + "ScalarOrTuple1d", + "ScalarOrTuple2d", + "ScalarOrTuple3d", + "Size", + "Shape", + "TensorCollection", + "abspath", + "abspath_template", + "delete", + "get_tensor", + "grid_points_transform", + "grid_vectors_transform", + "grid_transform_points", + "grid_transform_vectors", + "join_kwargs_in_sequence", + "make_parent_dir", + "make_temp_file", + "multinomial", + "is_bool_dtype", + "is_float_dtype", + "is_int_dtype", + "is_uint_dtype", + "is_namedtuple", + "is_path_str", + "temp_dir", + "temp_file", + "tensor_collection_entry", + "unlink_or_mkdir", +) diff --git a/jointContribution/HighResolution/deepali/core/_kornia.py b/jointContribution/HighResolution/deepali/core/_kornia.py new file mode 100644 index 0000000000..e2028fdce7 --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/_kornia.py @@ -0,0 +1,497 @@ +import paddle + +from ..utils import paddle_aux + +"""Conversion functions between different representations of 3D rotations.""" +__all__ = ( + "angle_axis_to_rotation_matrix", + "angle_axis_to_quaternion", + "normalize_quaternion", + "rotation_matrix_to_angle_axis", + "rotation_matrix_to_quaternion", + "quaternion_to_angle_axis", + "quaternion_to_rotation_matrix", + "quaternion_log_to_exp", + "quaternion_exp_to_log", +) + + +def angle_axis_to_rotation_matrix(angle_axis: paddle.Tensor) -> paddle.Tensor: + """Convert 3d vector of axis-angle rotation to 3x3 rotation matrix + + Args: + angle_axis (paddle.Tensor): tensor of 3d vector of axis-angle rotations. + + Returns: + paddle.Tensor: tensor of 3x3 rotation matrices. + + Shape: + - Input: :math:`(N, 3)` + - Output: :math:`(N, 3, 3)` + + Example: + >>> input = paddle.rand(1, 3) # Nx3 + >>> output = angle_axis_to_rotation_matrix(input) # Nx3x3 + """ + if not isinstance(angle_axis, paddle.Tensor): + raise TypeError( + "Input type is not a paddle.Tensor. Got {}".format(type(angle_axis)) + ) + if not tuple(angle_axis.shape)[-1] == 3: + raise ValueError( + "Input size must be a (*, 3) tensor. Got {}".format(tuple(angle_axis.shape)) + ) + + def _compute_rotation_matrix(angle_axis, theta2, eps=1e-06): + k_one = 1.0 + theta = paddle.sqrt(x=theta2) + wxyz = angle_axis / (theta + eps) + wx, wy, wz = paddle.chunk(x=wxyz, chunks=3, axis=1) + cos_theta = paddle.cos(x=theta) + sin_theta = paddle.sin(x=theta) + r00 = cos_theta + wx * wx * (k_one - cos_theta) + r10 = wz * sin_theta + wx * wy * (k_one - cos_theta) + r20 = -wy * sin_theta + wx * wz * (k_one - cos_theta) + r01 = wx * wy * (k_one - cos_theta) - wz * sin_theta + r11 = cos_theta + wy * wy * (k_one - cos_theta) + r21 = wx * sin_theta + wy * wz * (k_one - cos_theta) + r02 = wy * sin_theta + wx * wz * (k_one - cos_theta) + r12 = -wx * sin_theta + wy * wz * (k_one - cos_theta) + r22 = cos_theta + wz * wz * (k_one - cos_theta) + rotation_matrix = paddle.concat( + x=[r00, r01, r02, r10, r11, r12, r20, r21, r22], axis=1 + ) + return rotation_matrix.view(-1, 3, 3) + + def _compute_rotation_matrix_taylor(angle_axis): + rx, ry, rz = paddle.chunk(x=angle_axis, chunks=3, axis=1) + k_one = paddle.ones_like(x=rx) + rotation_matrix = paddle.concat( + x=[k_one, -rz, ry, rz, k_one, -rx, -ry, rx, k_one], axis=1 + ) + return rotation_matrix.view(-1, 3, 3) + + _angle_axis = paddle.unsqueeze(x=angle_axis, axis=1) + x = _angle_axis + perm_9 = list(range(x.ndim)) + perm_9[1] = 2 + perm_9[2] = 1 + theta2 = paddle.matmul(x=_angle_axis, y=x.transpose(perm=perm_9)) + theta2 = paddle.squeeze(x=theta2, axis=1) + rotation_matrix_normal = _compute_rotation_matrix(angle_axis, theta2) + rotation_matrix_taylor = _compute_rotation_matrix_taylor(angle_axis) + eps = 1e-06 + mask = (theta2 > eps).view(-1, 1, 1).to(theta2.place) + mask_pos = mask.astype(dtype=theta2.dtype) + mask_neg = (~mask).astype(dtype=theta2.dtype) + batch_size = tuple(angle_axis.shape)[0] + rotation_matrix = ( + paddle.eye(num_rows=3).to(angle_axis.place).astype(dtype=angle_axis.dtype) + ) + rotation_matrix = rotation_matrix.view(1, 3, 3).repeat(batch_size, 1, 1) + rotation_matrix[(...), :3, :3] = ( + mask_pos * rotation_matrix_normal + mask_neg * rotation_matrix_taylor + ) + return rotation_matrix + + +def rotation_matrix_to_angle_axis(rotation_matrix: paddle.Tensor) -> paddle.Tensor: + """Convert 3x3 rotation matrix to Rodrigues vector. + + Args: + rotation_matrix (paddle.Tensor): rotation matrix. + + Returns: + paddle.Tensor: Rodrigues vector transformation. + + Shape: + - Input: :math:`(N, 3, 3)` + - Output: :math:`(N, 3)` + + Example: + >>> input = paddle.rand(2, 3, 3) # Nx3x3 + >>> output = rotation_matrix_to_angle_axis(input) # Nx3 + """ + if not isinstance(rotation_matrix, paddle.Tensor): + raise TypeError( + f"Input type is not a paddle.Tensor. Got {type(rotation_matrix)}" + ) + if not tuple(rotation_matrix.shape)[-2:] == (3, 3): + raise ValueError( + f"Input size must be a (*, 3, 3) tensor. Got {tuple(rotation_matrix.shape)}" + ) + quaternion: paddle.Tensor = rotation_matrix_to_quaternion(rotation_matrix) + return quaternion_to_angle_axis(quaternion) + + +def rotation_matrix_to_quaternion( + rotation_matrix: paddle.Tensor, eps: float = 1e-08 +) -> paddle.Tensor: + """Convert 3x3 rotation matrix to 4d quaternion vector. + + The quaternion vector has components in (w, x, y, z) or (x, y, z, w) format. + + .. note:: + The (x, y, z, w) order is going to be deprecated in favor of efficiency. + + Args: + rotation_matrix (paddle.Tensor): the rotation matrix to convert. + eps (float): small value to avoid zero division. Default: 1e-8. + order (QuaternionCoeffOrder): quaternion coefficient order. Default: 'xyzw'. + Note: 'xyzw' will be deprecated in favor of 'wxyz'. + + Return: + paddle.Tensor: the rotation in quaternion. + + Shape: + - Input: :math:`(*, 3, 3)` + - Output: :math:`(*, 4)` + + Example: + >>> input = paddle.rand(4, 3, 3) # Nx3x3 + >>> output = rotation_matrix_to_quaternion(input, eps=paddle.finfo(input.dtype).eps, + ... order=QuaternionCoeffOrder.WXYZ) # Nx4 + """ + if not isinstance(rotation_matrix, paddle.Tensor): + raise TypeError( + f"Input type is not a paddle.Tensor. Got {type(rotation_matrix)}" + ) + if not tuple(rotation_matrix.shape)[-2:] == (3, 3): + raise ValueError( + f"Input size must be a (*, 3, 3) tensor. Got {tuple(rotation_matrix.shape)}" + ) + + def safe_zero_division( + numerator: paddle.Tensor, denominator: paddle.Tensor + ) -> paddle.Tensor: + eps: float = paddle.finfo(paddle_aux._STR_2_PADDLE_DTYPE(numerator.dtype)).tiny + return numerator / paddle.clip(x=denominator, min=eps) + + rotation_matrix_vec: paddle.Tensor = rotation_matrix.view( + *tuple(rotation_matrix.shape)[:-2], 9 + ) + m00, m01, m02, m10, m11, m12, m20, m21, m22 = paddle.chunk( + x=rotation_matrix_vec, chunks=9, axis=-1 + ) + trace: paddle.Tensor = m00 + m11 + m22 + + def trace_positive_cond(): + sq = paddle.sqrt(x=trace + 1.0) * 2.0 + qw = 0.25 * sq + qx = safe_zero_division(m21 - m12, sq) + qy = safe_zero_division(m02 - m20, sq) + qz = safe_zero_division(m10 - m01, sq) + return paddle.concat(x=(qw, qx, qy, qz), axis=-1) + + def cond_1(): + sq = paddle.sqrt(x=1.0 + m00 - m11 - m22 + eps) * 2.0 + qw = safe_zero_division(m21 - m12, sq) + qx = 0.25 * sq + qy = safe_zero_division(m01 + m10, sq) + qz = safe_zero_division(m02 + m20, sq) + return paddle.concat(x=(qw, qx, qy, qz), axis=-1) + + def cond_2(): + sq = paddle.sqrt(x=1.0 + m11 - m00 - m22 + eps) * 2.0 + qw = safe_zero_division(m02 - m20, sq) + qx = safe_zero_division(m01 + m10, sq) + qy = 0.25 * sq + qz = safe_zero_division(m12 + m21, sq) + return paddle.concat(x=(qw, qx, qy, qz), axis=-1) + + def cond_3(): + sq = paddle.sqrt(x=1.0 + m22 - m00 - m11 + eps) * 2.0 + qw = safe_zero_division(m10 - m01, sq) + qx = safe_zero_division(m02 + m20, sq) + qy = safe_zero_division(m12 + m21, sq) + qz = 0.25 * sq + return paddle.concat(x=(qw, qx, qy, qz), axis=-1) + + where_2 = paddle.where(condition=m11 > m22, x=cond_2(), y=cond_3()) + where_1 = paddle.where(condition=(m00 > m11) & (m00 > m22), x=cond_1(), y=where_2) + quaternion: paddle.Tensor = paddle.where( + condition=trace > 0.0, x=trace_positive_cond(), y=where_1 + ) + return quaternion + + +def normalize_quaternion( + quaternion: paddle.Tensor, eps: float = 1e-12 +) -> paddle.Tensor: + """Normalizes a quaternion. + + The quaternion should be in (x, y, z, w) format. + + Args: + quaternion (paddle.Tensor): a tensor containing a quaternion to be + normalized. The tensor can be of shape :math:`(*, 4)`. + eps (Optional[bool]): small value to avoid division by zero. + Default: 1e-12. + + Return: + paddle.Tensor: the normalized quaternion of shape :math:`(*, 4)`. + + Example: + >>> quaternion = paddle.tensor((1., 0., 1., 0.)) + >>> normalize_quaternion(quaternion) + tensor([0.7071, 0.0000, 0.7071, 0.0000]) + """ + if not isinstance(quaternion, paddle.Tensor): + raise TypeError( + "Input type is not a paddle.Tensor. Got {}".format(type(quaternion)) + ) + if not tuple(quaternion.shape)[-1] == 4: + raise ValueError( + "Input must be a tensor of shape (*, 4). Got {}".format( + tuple(quaternion.shape) + ) + ) + return paddle.nn.functional.normalize(x=quaternion, p=2.0, axis=-1, epsilon=eps) + + +def quaternion_to_rotation_matrix(quaternion: paddle.Tensor) -> paddle.Tensor: + """Converts a quaternion to a rotation matrix. + + The quaternion should be in (x, y, z, w) or (w, x, y, z) format. + + Args: + quaternion (paddle.Tensor): a tensor containing a quaternion to be + converted. The tensor can be of shape :math:`(*, 4)`. + order (QuaternionCoeffOrder): quaternion coefficient order. Default: 'xyzw'. + Note: 'xyzw' will be deprecated in favor of 'wxyz'. + + Return: + paddle.Tensor: the rotation matrix of shape :math:`(*, 3, 3)`. + + Example: + >>> quaternion = paddle.tensor((0., 0., 0., 1.)) + >>> quaternion_to_rotation_matrix(quaternion, order=QuaternionCoeffOrder.WXYZ) + tensor([[-1., 0., 0.], + [ 0., -1., 0.], + [ 0., 0., 1.]]) + """ + if not isinstance(quaternion, paddle.Tensor): + raise TypeError(f"Input type is not a paddle.Tensor. Got {type(quaternion)}") + if not tuple(quaternion.shape)[-1] == 4: + raise ValueError( + f"Input must be a tensor of shape (*, 4). Got {tuple(quaternion.shape)}" + ) + quaternion_norm: paddle.Tensor = normalize_quaternion(quaternion) + w, x, y, z = paddle.chunk(x=quaternion_norm, chunks=4, axis=-1) + tx: paddle.Tensor = 2.0 * x + ty: paddle.Tensor = 2.0 * y + tz: paddle.Tensor = 2.0 * z + twx: paddle.Tensor = tx * w + twy: paddle.Tensor = ty * w + twz: paddle.Tensor = tz * w + txx: paddle.Tensor = tx * x + txy: paddle.Tensor = ty * x + txz: paddle.Tensor = tz * x + tyy: paddle.Tensor = ty * y + tyz: paddle.Tensor = tz * y + tzz: paddle.Tensor = tz * z + one: paddle.Tensor = paddle.to_tensor(data=1.0) + matrix: paddle.Tensor = paddle.stack( + x=( + one - (tyy + tzz), + txy - twz, + txz + twy, + txy + twz, + one - (txx + tzz), + tyz - twx, + txz - twy, + tyz + twx, + one - (txx + tyy), + ), + axis=-1, + ).view(-1, 3, 3) + if len(tuple(quaternion.shape)) == 1: + matrix = paddle.squeeze(x=matrix, axis=0) + return matrix + + +def quaternion_to_angle_axis(quaternion: paddle.Tensor) -> paddle.Tensor: + """Convert quaternion vector to angle axis of rotation. + + The quaternion should be in (x, y, z, w) or (w, x, y, z) format. + + Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h + + Args: + quaternion (paddle.Tensor): tensor with quaternions. + order (QuaternionCoeffOrder): quaternion coefficient order. Default: 'xyzw'. + Note: 'xyzw' will be deprecated in favor of 'wxyz'. + + Return: + paddle.Tensor: tensor with angle axis of rotation. + + Shape: + - Input: :math:`(*, 4)` where `*` means, any number of dimensions + - Output: :math:`(*, 3)` + + Example: + >>> quaternion = paddle.rand(2, 4) # Nx4 + >>> angle_axis = quaternion_to_angle_axis(quaternion) # Nx3 + """ + if not paddle.is_tensor(x=quaternion): + raise TypeError(f"Input type is not a paddle.Tensor. Got {type(quaternion)}") + if not tuple(quaternion.shape)[-1] == 4: + raise ValueError( + f"Input must be a tensor of shape Nx4 or 4. Got {tuple(quaternion.shape)}" + ) + q1: paddle.Tensor = paddle.to_tensor(data=[]) + q2: paddle.Tensor = paddle.to_tensor(data=[]) + q3: paddle.Tensor = paddle.to_tensor(data=[]) + cos_theta: paddle.Tensor = paddle.to_tensor(data=[]) + cos_theta = quaternion[..., 0] + q1 = quaternion[..., 1] + q2 = quaternion[..., 2] + q3 = quaternion[..., 3] + sin_squared_theta: paddle.Tensor = q1 * q1 + q2 * q2 + q3 * q3 + sin_theta: paddle.Tensor = paddle.sqrt(x=sin_squared_theta) + two_theta: paddle.Tensor = 2.0 * paddle.where( + condition=cos_theta < 0.0, + x=paddle.atan2(x=-sin_theta, y=-cos_theta), + y=paddle.atan2(x=sin_theta, y=cos_theta), + ) + k_pos: paddle.Tensor = two_theta / sin_theta + k_neg: paddle.Tensor = 2.0 * paddle.ones_like(x=sin_theta) + k: paddle.Tensor = paddle.where(condition=sin_squared_theta > 0.0, x=k_pos, y=k_neg) + angle_axis: paddle.Tensor = paddle.zeros_like(x=quaternion)[(...), :3] + angle_axis[..., 0] += q1 * k + angle_axis[..., 1] += q2 * k + angle_axis[..., 2] += q3 * k + return angle_axis + + +def quaternion_log_to_exp( + quaternion: paddle.Tensor, eps: float = 1e-08 +) -> paddle.Tensor: + """Applies exponential map to log quaternion. + + The quaternion should be in (x, y, z, w) or (w, x, y, z) format. + + Args: + quaternion (paddle.Tensor): a tensor containing a quaternion to be + converted. The tensor can be of shape :math:`(*, 3)`. + order (QuaternionCoeffOrder): quaternion coefficient order. Default: 'xyzw'. + Note: 'xyzw' will be deprecated in favor of 'wxyz'. + + Return: + paddle.Tensor: the quaternion exponential map of shape :math:`(*, 4)`. + + Example: + >>> quaternion = paddle.tensor((0., 0., 0.)) + >>> quaternion_log_to_exp(quaternion, eps=paddle.finfo(quaternion.dtype).eps, + ... order=QuaternionCoeffOrder.WXYZ) + tensor([1., 0., 0., 0.]) + """ + if not isinstance(quaternion, paddle.Tensor): + raise TypeError(f"Input type is not a paddle.Tensor. Got {type(quaternion)}") + if not tuple(quaternion.shape)[-1] == 3: + raise ValueError( + f"Input must be a tensor of shape (*, 3). Got {tuple(quaternion.shape)}" + ) + norm_q: paddle.Tensor = paddle.linalg.norm( + x=quaternion, p=2, axis=-1, keepdim=True + ).clip(min=eps) + quaternion_vector: paddle.Tensor = quaternion * paddle.sin(x=norm_q) / norm_q + quaternion_scalar: paddle.Tensor = paddle.cos(x=norm_q) + quaternion_exp: paddle.Tensor = paddle.to_tensor(data=[]) + quaternion_exp = paddle.concat(x=(quaternion_scalar, quaternion_vector), axis=-1) + return quaternion_exp + + +def quaternion_exp_to_log( + quaternion: paddle.Tensor, eps: float = 1e-08 +) -> paddle.Tensor: + """Applies the log map to a quaternion. + + The quaternion should be in (x, y, z, w) format. + + Args: + quaternion (paddle.Tensor): a tensor containing a quaternion to be + converted. The tensor can be of shape :math:`(*, 4)`. + eps (float): A small number for clamping. + order (QuaternionCoeffOrder): quaternion coefficient order. Default: 'xyzw'. + Note: 'xyzw' will be deprecated in favor of 'wxyz'. + + Return: + paddle.Tensor: the quaternion log map of shape :math:`(*, 3)`. + + Example: + >>> quaternion = paddle.tensor((1., 0., 0., 0.)) + >>> quaternion_exp_to_log(quaternion, eps=paddle.finfo(quaternion.dtype).eps, + ... order=QuaternionCoeffOrder.WXYZ) + tensor([0., 0., 0.]) + """ + if not isinstance(quaternion, paddle.Tensor): + raise TypeError(f"Input type is not a paddle.Tensor. Got {type(quaternion)}") + if not tuple(quaternion.shape)[-1] == 4: + raise ValueError( + f"Input must be a tensor of shape (*, 4). Got {tuple(quaternion.shape)}" + ) + quaternion_vector: paddle.Tensor = paddle.to_tensor(data=[]) + quaternion_scalar: paddle.Tensor = paddle.to_tensor(data=[]) + quaternion_scalar = quaternion[(...), 0:1] + quaternion_vector = quaternion[(...), 1:4] + norm_q: paddle.Tensor = paddle.linalg.norm( + x=quaternion_vector, p=2, axis=-1, keepdim=True + ).clip(min=eps) + quaternion_log: paddle.Tensor = ( + quaternion_vector + * paddle.acos(x=paddle.clip(x=quaternion_scalar, min=-1.0, max=1.0)) + / norm_q + ) + return quaternion_log + + +def angle_axis_to_quaternion(angle_axis: paddle.Tensor) -> paddle.Tensor: + """Convert an angle axis to a quaternion. + + The quaternion vector has components in (x, y, z, w) or (w, x, y, z) format. + + Adapted from ceres C++ library: ceres-solver/include/ceres/rotation.h + + Args: + angle_axis (paddle.Tensor): tensor with angle axis. + order (QuaternionCoeffOrder): quaternion coefficient order. Default: 'xyzw'. + Note: 'xyzw' will be deprecated in favor of 'wxyz'. + + Return: + paddle.Tensor: tensor with quaternion. + + Shape: + - Input: :math:`(*, 3)` where `*` means, any number of dimensions + - Output: :math:`(*, 4)` + + Example: + >>> angle_axis = paddle.rand(2, 3) # Nx3 + >>> quaternion = angle_axis_to_quaternion(angle_axis, order=QuaternionCoeffOrder.WXYZ) # Nx4 + """ + if not paddle.is_tensor(x=angle_axis): + raise TypeError(f"Input type is not a paddle.Tensor. Got {type(angle_axis)}") + if not tuple(angle_axis.shape)[-1] == 3: + raise ValueError( + f"Input must be a tensor of shape Nx3 or 3. Got {tuple(angle_axis.shape)}" + ) + a0: paddle.Tensor = angle_axis[(...), 0:1] + a1: paddle.Tensor = angle_axis[(...), 1:2] + a2: paddle.Tensor = angle_axis[(...), 2:3] + theta_squared: paddle.Tensor = a0 * a0 + a1 * a1 + a2 * a2 + theta: paddle.Tensor = paddle.sqrt(x=theta_squared) + half_theta: paddle.Tensor = theta * 0.5 + mask: paddle.Tensor = theta_squared > 0.0 + ones: paddle.Tensor = paddle.ones_like(x=half_theta) + k_neg: paddle.Tensor = 0.5 * ones + k_pos: paddle.Tensor = paddle.sin(x=half_theta) / theta + k: paddle.Tensor = paddle.where(condition=mask, x=k_pos, y=k_neg) + w: paddle.Tensor = paddle.where(condition=mask, x=paddle.cos(x=half_theta), y=ones) + quaternion: paddle.Tensor = paddle.zeros( + shape=(*tuple(angle_axis.shape)[:-1], 4), dtype=angle_axis.dtype + ) + quaternion[(...), 1:2] = a0 * k + quaternion[(...), 2:3] = a1 * k + quaternion[(...), 3:4] = a2 * k + quaternion[(...), 0:1] = w + return quaternion diff --git a/jointContribution/HighResolution/deepali/core/affine.py b/jointContribution/HighResolution/deepali/core/affine.py new file mode 100644 index 0000000000..651c4190ab --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/affine.py @@ -0,0 +1,519 @@ +import re +from typing import Optional +from typing import Union + +import paddle + +from .linalg import homogeneous_transform +from .tensor import as_float_tensor +from .tensor import atleast_1d +from .tensor import cat_scalars +from .types import Array +from .types import Device +from .types import Scalar +from .types import Shape + +__all__ = ( + "affine_rotation_matrix", + "apply_transform", + "euler_rotation_matrix", + "euler_rotation_angles", + "euler_rotation_order", + "identity_transform", + "rotation_matrix", + "scaling_transform", + "shear_matrix", + "translation", + "transform_points", + "transform_vectors", +) + + +def apply_transform( + transform: paddle.Tensor, points: paddle.Tensor, vectors: bool = False +) -> paddle.Tensor: + """Alias for :func:`.homogeneous_transform`.""" + return homogeneous_transform(transform, points, vectors=vectors) + + +def affine_rotation_matrix(matrix: paddle.Tensor) -> paddle.Tensor: + """Get orthonormal rotation matrix from (homogeneous) affine transformation. + + This function assumes the following order of elementary transformations: + 1) Scaling, 2) Shearing, 3) Rotation, and 4) Translation. + + See also FullAffineTransform, AffineTransform, RigidTransform, etc. + + Args: + matrix: Affine transformation as tensor of shape (..., 3, 3) or (..., 3, 4). + + Returns: + Orthonormal rotation matrices with determinant 1 as tensor of shape (..., 3, 3). + + """ + if not isinstance(matrix, paddle.Tensor): + raise TypeError("affine_rotation_matrix() 'matrix' must be paddle.Tensor") + if ( + matrix.ndim < 2 + or tuple(matrix.shape)[-2] != 3 + or tuple(matrix.shape)[-1] not in (3, 4) + ): + raise ValueError( + "affine_rotation_matrix() 'matrix' must have shape (..., 3, 3|4)" + ) + matrix = matrix[(...), :3].clone() + sx: paddle.Tensor = paddle.linalg.norm(x=matrix[..., 0], p=2, axis=-1) + matrix[..., 0] = matrix[..., 0].div(sx.unsqueeze(axis=-1)) + tansxy = matrix[..., 0].mul(matrix[..., 1]).sum(axis=-1) + matrix[..., 1] = matrix[..., 1].sub(matrix[..., 0].mul(tansxy.unsqueeze(axis=-1))) + sy: paddle.Tensor = paddle.linalg.norm(x=matrix[..., 1], p=2, axis=-1) + matrix[..., 1] = matrix[..., 1].div(sy.unsqueeze(axis=-1)) + tansxy = tansxy.div(sy) + tansxz = matrix[..., 0].mul(matrix[..., 2]).sum(axis=-1) + matrix[..., 2] = matrix[..., 2].sub(matrix[..., 0].mul(tansxz.unsqueeze(axis=-1))) + tansyz = matrix[..., 1].mul(matrix[..., 2]).sum(axis=-1) + matrix[..., 2] = matrix[..., 2].sub(matrix[..., 1].mul(tansyz.unsqueeze(axis=-1))) + sz: paddle.Tensor = paddle.linalg.norm(x=matrix[..., 2], p=2, axis=-1) + matrix[..., 2] = matrix[..., 2].div(sz.unsqueeze(axis=-1)) + tansxz = tansxz.div(sz) + tansyz = tansyz.div(sz) + mask = ( + matrix[..., 0] + .mul(matrix[..., 1].cross(y=matrix[..., 2], axis=-1)) + .sum(axis=-1) + .greater_equal(y=paddle.to_tensor(0)) + ) + mask = mask.unsqueeze(axis=-1).unsqueeze(axis=-1).expand_as(y=matrix) + matrix = matrix.where(condition=mask, y=-matrix) + return matrix + + +def identity_transform( + shape: Union[int, Shape], + *args, + homogeneous: bool = False, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Create homogeneous coordinate transformation matrix of identity mapping. + + Args: + shape: Shape of non-homogeneous point coordinates tensor ``(..., D)``, where the size of the + last dimension corresponds to the number of spatial dimensions. + homogeneous: Whether to return homogeneous transformation matrices. + dtype: Data type of output matrix. If ``None``, use default dtype. + device: Device on which to create matrix. If ``None``, use default device. + + Returns: + If ``homogeneous=Falae``, a tensor of affine matrices of shape ``(..., D, D)`` is returned, and + a tensor of homogeneous coordinate transformation matrices of shape ``(..., D, D + 1)``, otherwise. + + """ + shape_ = [int(n) for n in cat_scalars(shape, *args, device=device)] + D = shape_[-1] + J = tuple(range(D)) + matrix = paddle.zeros(shape=[*shape_, D + 1 if homogeneous else D], dtype=dtype) + matrix[..., J, J] = 1 + return matrix + + +def rotation_matrix(*args, **kwargs) -> paddle.Tensor: + """Alias for :func:`.euler_rotation_matrix`.""" + return euler_rotation_matrix(*args, **kwargs) + + +def euler_rotation_matrix( + angles: Union[Scalar, Array], + order: Optional[str] = None, + homogeneous: bool = False, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Euler rotation matrices. + + Args: + angles: Scalar rotation angle for a 2D rotation, or the three angles (alpha, beta, gamma) + for an extrinsic rotation in the specified `order`. The argument can be a tensor of + shape ``(..., D)``, with angles in the last dimension. All angles must be given in + radians, and the first angle corresponds to the left-most rotation which is applied last. + order: Order in which to compose elemental rotations. For example in 3D, "zxz" (or "ZXZ") + means that the first rotation occurs about z by angle gamma, the second about x by + angle beta, and the third rotation about z again by angle alpha. In 2D, this argument + is ignored and a single rotation about z (plane normal) is applied. Alternatively, + strings of the form "Rz o Rx o Rz" can be given as argument. When the first and last + extrinsic rotation is about the same axis, the rotation is called proper, and the + angles are referred to as proper Euler angles. When each elemental rotation in X, Y, + and Z occurs exactly once, the angles are referred to as Tait-Bryan angles. + homogeneous: Whether to return homogeneous transformation matrices. + dtype: Data type of rotation matrix. If ``None``, use ``angles.dtype`` if it is + a floating point type, and ``paddle.float`` otherwise. + device: Device on which to create rotation matrix. If ``None``, use ``angles.device``. + + Returns: + paddle.Tensor of square rotation matrices of shape ``(..., D, D)`` if ``homogeneous=False``, or tensor + of homogeneous coordinate transformation matrices of shape ``(..., D, D + 1)`` otherwise. + If ``angles`` is a scalar or vector, a single rotation matrix is returned. + + See also: + - https://mathworld.wolfram.com/EulerAngles.html + - https://en.wikipedia.org/wiki/Euler_angles#Definition_by_extrinsic_rotations + + """ + if not isinstance(order, (str, type(None))): + raise TypeError("euler_rotation_matrix() 'order' must be None or str") + angles_ = atleast_1d(angles, dtype=dtype, device=device) + angles_ = as_float_tensor(angles_) + c = paddle.cos(x=angles_) + s = paddle.sin(x=angles_) + N = tuple(angles_.shape)[-1] + D = 2 if N == 1 else N + order = euler_rotation_order(order, ndim=D) + matrix = paddle.empty( + shape=tuple(angles_.shape)[:-1] + (D, D + 1 if homogeneous else D), + dtype=angles_.dtype, + ) + if homogeneous: + matrix[..., D] = 0 + if D == 2: + matrix[..., 0, 0] = c[..., 0] + matrix[..., 0, 1] = -s[..., 0] + matrix[..., 1, 0] = s[..., 0] + matrix[..., 1, 1] = c[..., 0] + elif D == 3: + if order == "XYZ": + matrix[..., 0, 0] = c[..., 1] * c[..., 2] + matrix[..., 0, 1] = -c[..., 1] * s[..., 2] + matrix[..., 0, 2] = s[..., 1] + matrix[..., 1, 0] = ( + c[..., 0] * s[..., 2] + c[..., 2] * s[..., 0] * s[..., 1] + ) + matrix[..., 1, 1] = ( + c[..., 0] * c[..., 2] - s[..., 0] * s[..., 1] * s[..., 2] + ) + matrix[..., 1, 2] = -c[..., 1] * s[..., 0] + matrix[..., 2, 0] = ( + s[..., 0] * s[..., 2] - c[..., 0] * c[..., 2] * s[..., 1] + ) + matrix[..., 2, 1] = ( + c[..., 2] * s[..., 0] + c[..., 0] * s[..., 1] * s[..., 2] + ) + matrix[..., 2, 2] = c[..., 0] * c[..., 1] + elif order == "ZYX": + matrix[..., 0, 0] = c[..., 0] * c[..., 1] + matrix[..., 0, 1] = ( + c[..., 0] * s[..., 1] * s[..., 2] - c[..., 2] * s[..., 0] + ) + matrix[..., 0, 2] = ( + s[..., 0] * s[..., 2] + c[..., 0] * c[..., 2] * s[..., 1] + ) + matrix[..., 1, 0] = c[..., 1] * s[..., 0] + matrix[..., 1, 1] = ( + c[..., 0] * c[..., 2] + s[..., 0] * s[..., 1] * s[..., 2] + ) + matrix[..., 1, 2] = ( + c[..., 2] * s[..., 0] * s[..., 1] - c[..., 0] * s[..., 2] + ) + matrix[..., 2, 0] = -s[..., 1] + matrix[..., 2, 1] = c[..., 1] * s[..., 2] + matrix[..., 2, 2] = c[..., 1] * c[..., 2] + elif order == "ZXY": + matrix[..., 0, 0] = ( + c[..., 0] * c[..., 2] - s[..., 0] * s[..., 1] * s[..., 2] + ) + matrix[..., 0, 1] = -c[..., 1] * s[..., 0] + matrix[..., 0, 2] = ( + c[..., 0] * s[..., 2] + c[..., 2] * s[..., 0] * s[..., 1] + ) + matrix[..., 1, 0] = ( + c[..., 2] * s[..., 0] + c[..., 0] * s[..., 1] * s[..., 2] + ) + matrix[..., 1, 1] = c[..., 0] * c[..., 1] + matrix[..., 1, 2] = ( + s[..., 0] * s[..., 2] - c[..., 0] * c[..., 2] * s[..., 1] + ) + matrix[..., 2, 0] = -c[..., 1] * s[..., 2] + matrix[..., 2, 1] = s[..., 1] + matrix[..., 2, 2] = c[..., 1] * c[..., 2] + elif order == "XZX": + matrix[..., 0, 0] = c[..., 1] + matrix[..., 0, 1] = -s[..., 1] * c[..., 2] + matrix[..., 0, 2] = s[..., 1] * s[..., 2] + matrix[..., 1, 0] = c[..., 0] * s[..., 1] + matrix[..., 1, 1] = ( + -s[..., 0] * s[..., 2] + c[..., 0] * c[..., 1] * c[..., 2] + ) + matrix[..., 1, 2] = ( + -s[..., 0] * c[..., 2] - c[..., 0] * c[..., 1] * s[..., 2] + ) + matrix[..., 2, 0] = s[..., 0] * s[..., 1] + matrix[..., 2, 1] = ( + c[..., 0] * s[..., 2] + s[..., 0] * c[..., 1] * c[..., 2] + ) + matrix[..., 2, 2] = ( + c[..., 0] * c[..., 2] - s[..., 0] * c[..., 1] * s[..., 2] + ) + elif order == "ZXZ": + matrix[..., 0, 0] = ( + c[..., 0] * c[..., 2] - s[..., 0] * c[..., 1] * s[..., 2] + ) + matrix[..., 0, 1] = ( + -c[..., 0] * s[..., 2] - s[..., 0] * c[..., 1] * c[..., 2] + ) + matrix[..., 0, 2] = s[..., 0] * s[..., 1] + matrix[..., 1, 0] = ( + s[..., 0] * c[..., 2] + c[..., 0] * c[..., 1] * s[..., 2] + ) + matrix[..., 1, 1] = ( + -s[..., 0] * s[..., 2] + c[..., 0] * c[..., 1] * c[..., 2] + ) + matrix[..., 1, 2] = -c[..., 0] * s[..., 1] + matrix[..., 2, 0] = s[..., 1] * s[..., 2] + matrix[..., 2, 1] = s[..., 1] * c[..., 2] + matrix[..., 2, 2] = c[..., 1] + else: + matrix[..., 0, 0] = 1 + matrix[..., 1, 1] = 1 + matrix[..., 2, 2] = 1 + for i, char in enumerate(order): + rot = paddle.empty(shape=tuple(matrix.shape), dtype=matrix.dtype) + if char == "X": + rot[..., 0, 0] = 1 + rot[..., 0, 1] = 0 + rot[..., 0, 2] = 0 + rot[..., 1, 0] = 0 + rot[..., 1, 1] = c[..., i] + rot[..., 1, 2] = -s[..., i] + rot[..., 2, 0] = 0 + rot[..., 2, 1] = s[..., i] + rot[..., 2, 2] = c[..., i] + elif char == "Y": + rot[..., 0, 0] = c[..., i] + rot[..., 0, 1] = 0 + rot[..., 0, 2] = s[..., i] + rot[..., 1, 0] = 0 + rot[..., 1, 1] = 1 + rot[..., 1, 2] = 0 + rot[..., 2, 0] = -s[..., i] + rot[..., 2, 1] = 0 + rot[..., 2, 2] = c[..., i] + elif char == "Z": + rot[..., 0, 0] = c[..., i] + rot[..., 0, 1] = -s[..., i] + rot[..., 0, 2] = 0 + rot[..., 1, 0] = s[..., i] + rot[..., 1, 1] = c[..., i] + rot[..., 1, 2] = 0 + rot[..., 2, 0] = 0 + rot[..., 2, 1] = 0 + rot[..., 2, 2] = 1 + matrix = rot if i == 0 else paddle.bmm(x=matrix, y=rot) + else: + raise ValueError( + f"Expected 'angles' to be scalar or tensor with last dimension size 3, got {N}" + ) + return matrix + + +def euler_rotation_angles( + matrix: paddle.Tensor, order: Optional[str] = None +) -> paddle.Tensor: + """Compute Euler angles from rotation matrix. + + TODO: Write test for this function and check for wich quadrant a rotation is in. + See also https://github.com/BioMedIA/MIRTK/blob/77d3f391b49b0cee9e80da774fb074995fdf415f/Modules/Numerics/src/Matrix3x3.cc#L1217. + + Args: + order: Order in which elemental rotations are composed. For example in 3D, "zxz" means + that the first rotation occurs about z, the second about x, and the third rotation + about z again. In 2D, this argument is ignored. Alternatively, strings of the form + "Rz o Rx o Rz" can be given. + + Returns: + Rotation angles in radians, where the first angle corresponds to the + leftmost elemental rotation which is applied last. + + """ + if matrix.ndim < 2: + raise ValueError( + "euler_rotation_angles() 'matrix' must be at least 2-dimensional" + ) + D = tuple(matrix.shape)[-2] + if D not in (2, 3) or tuple(matrix.shape)[-1] not in (D, D + 1): + raise ValueError( + "euler_rotation_angles() 'matrix' must have shape (N, D, D) or (N, D, D + 1) where D=2 or D=3" + ) + order = euler_rotation_order(order, ndim=D) + det = matrix[(...), :D, :D].detach().reshape(-1, D, D).det() + if not det.abs().allclose(y=paddle.to_tensor(data=1.0).to(det)).item(): + raise ValueError( + "euler_rotation_angles() 'matrix' must be rotation matrix, i.e., matrix.det().abs() = 1" + ) + if D == 2: + angles = paddle.acos(x=matrix[..., 0, 0]) + else: + angles = paddle.empty(shape=tuple(matrix.shape)[:-2] + (D,), dtype=matrix.dtype) + if order == "XZX": + angles[..., 0] = paddle.atan2(x=matrix[..., 0, 2], y=-matrix[..., 0, 1]) + angles[..., 1] = paddle.acos(x=matrix[..., 0, 0]) + angles[..., 2] = paddle.atan2(x=matrix[..., 2, 0], y=matrix[..., 1, 0]) + elif order == "ZXZ": + angles[..., 0] = paddle.atan2(x=matrix[..., 2, 0], y=matrix[..., 2, 1]) + angles[..., 1] = paddle.acos(x=matrix[..., 2, 2]) + angles[..., 2] = paddle.atan2(x=matrix[..., 0, 2], y=-matrix[..., 1, 2]) + else: + raise NotImplementedError(f"euler_rotation_angles() order={order!r}") + return angles + + +def euler_rotation_order(arg: Optional[str] = None, ndim: int = 3) -> str: + """Standardize rotation order argument.""" + if arg is not None and not isinstance(arg, str): + raise TypeError("euler_rotation_order() 'arg' must be str or None") + if not isinstance(ndim, int): + raise TypeError("euler_rotation_order() 'ndim' must be int") + if ndim == 2: + return "Z" + if ndim != 3: + raise NotImplementedError(f"euler_rotation_order() ndim={ndim}") + order = "ZXZ" if arg is None else arg + if re.match("^(R[xyz]|[XYZ])( o (R[xyz]|[XYZ]))*$", order): + order = re.subn("R([xyz])", "\\1", order).replace(" o ", "") + order = order.upper() + if not re.match("^[XYZ][XYZ][XYZ]$", order): + raise ValueError(f"euler_rotation_order() invalid argument '{arg}'") + return order + + +def scaling_transform( + scales: Union[Scalar, Array], + homogeneous: bool = False, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Scaling matrices. + + Args: + scales: paddle.Tensor of anisotropic scaling factors of shape ``(..., D)``. + homogeneous: Whether to return homogeneous transformation matrices. + dtype: Data type of output matrix. If ``None``, use ``scales.dtype`` if it is + a floating point type, and ``paddle.float`` otherwise. + device: Device on which to create matrix. If ``None``, use ``scales.device``. + + Returns: + paddle.Tensor of square scaling matrices of shape ``(..., D, D)`` if ``homogeneous=False``, or tensor of + homogeneous coordinate transformation matrices of shape ``(..., D, D + 1)`` otherwise. + + """ + scales_ = atleast_1d(scales, dtype=dtype, device=device) + scales_ = as_float_tensor(scales_) + D = tuple(scales_.shape)[-1] + J = tuple(range(D)) + matrix = paddle.zeros( + shape=tuple(scales_.shape)[:-1] + (D, D + 1 if homogeneous else D), + dtype=scales_.dtype, + ) + # matrix[..., J, J] = scales_ + for j in J: + matrix[..., j, j] = scales_[..., j] + return matrix + + +def shear_matrix( + angles: Union[Scalar, Array], + homogeneous: bool = False, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Shear matrices. + + Args: + angles: paddle.Tensor of anisotropic shear angles in radians of shape ``(..., N)``, + where ``N=(D * (D - 1)) / 2`` with ``D`` denoting the number of spatial dimensions. + homogeneous: Whether to return homogeneous transformation matrices. + dtype: Data type of output matrix. If ``None``, use ``angles.dtype`` if it is + a floating point type, and ``paddle.float`` otherwise. + device: Device on which to create matrix. If ``None``, use ``angles.device``. + + Returns: + paddle.Tensor of square scaling matrices of shape ``(..., D, D)`` if ``homogeneous=False``, or tensor of + homogeneous coordinate transformation matrices of shape ``(..., D, D + 1)`` otherwise. + + """ + angles_ = atleast_1d(angles, dtype=dtype, device=device) + angles_ = as_float_tensor(angles_) + N = tuple(angles_.shape)[-1] + if N == 1: + D = 2 + elif N == 3: + D = 3 + elif N == 6: + D = 4 + else: + raise ValueError( + "shear_matrix() 'angles' must have last dimension size 1 (2D), 3 (3D), or 6 (4D)" + ) + J = tuple(range(D)) + K = paddle.triu_indices(row=D, col=D, offset=1) + matrix = paddle.zeros( + shape=tuple(angles_.shape)[:-1] + (D, D + 1 if homogeneous else D), + dtype=angles_.dtype, + ) + # matrix[..., J, J] = 1 + # matrix[..., K[0], K[1]] = paddle.tan(x=angles_) + for j in J: + matrix[..., j, j] = 1 + tan_values = paddle.tan(x=angles_) + for i in range(len(K[0])): + matrix[..., K[0][i], K[1][i]] = tan_values[..., i] + return matrix + + +def translation( + offset: Union[Scalar, Array], + homogeneous: bool = False, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Translation offsets / matrices. + + Args: + offset: Translation vectors of shape ``(..., D)`` or ``(..., D, 1)``. + homogeneous: Whether to return homogeneous transformation matrices. + dtype: Data type of output matrix. If ``None``, use ``offset.dtype`` if it is + a floating point type, and ``paddle.float`` otherwise. + device: Device on which to create matrix. If ``None``, use ``offset.device``. + + Returns: + Homogeneous coordinate transformation matrices of shape ``(..., D, 1)`` if ``homogeneous=False``, + or shape ``(..., D, D + 1)`` otherwise. + + """ + offset_ = atleast_1d(offset, dtype=dtype, device=device) + offset_ = as_float_tensor(offset_) + if homogeneous: + if offset_.ndim > 1: + offset_ = offset_.squeeze(axis=-1) + D = tuple(offset_.shape)[-1] + J = tuple(range(D)) + matrix = paddle.zeros( + shape=tuple(offset_.shape)[:-1] + (D, D + 1), dtype=offset_.dtype + ) + matrix[..., J, J] = 1 + matrix[..., D] = offset_ + elif offset_.ndim < 2 or tuple(offset_.shape)[-1] != 1: + matrix = offset_.unsqueeze(axis=-1) + else: + matrix = offset_ + return matrix + + +def transform_points(transforms: paddle.Tensor, points: paddle.Tensor) -> paddle.Tensor: + """Transform points by given homogeneous transformation.""" + return apply_transform(transforms, points, vectors=False) + + +def transform_vectors( + transforms: paddle.Tensor, vectors: paddle.Tensor +) -> paddle.Tensor: + """Transform vectors by given homogeneous transformation.""" + return apply_transform(transforms, vectors, vectors=True) diff --git a/jointContribution/HighResolution/deepali/core/bspline.py b/jointContribution/HighResolution/deepali/core/bspline.py new file mode 100644 index 0000000000..61a1574968 --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/bspline.py @@ -0,0 +1,676 @@ +from itertools import combinations_with_replacement +from itertools import permutations +from itertools import product +from typing import Callable +from typing import Dict +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union +from typing import overload + +import paddle + +from .enum import PaddingMode +from .enum import SpatialDim +from .enum import SpatialDimArg +from .grid import Grid +from .image import conv +from .image import conv1d +from .itertools import is_even_permutation +from .kernels import cubic_bspline1d +from .tensor import move_dim +from .types import ScalarOrTuple + + +@overload +def cubic_bspline_control_point_grid_size(size: int, stride: int) -> int: + ... + + +@overload +def cubic_bspline_control_point_grid_size( + size: Sequence[int], stride: int +) -> Tuple[int, ...]: + ... + + +@overload +def cubic_bspline_control_point_grid_size( + size: int, stride: Sequence[int] +) -> Tuple[int, ...]: + ... + + +@overload +def cubic_bspline_control_point_grid_size( + size: Sequence[int], stride: Sequence[int] +) -> Tuple[int, ...]: + ... + + +def cubic_bspline_control_point_grid_size( + size: ScalarOrTuple[int], stride: ScalarOrTuple[int] +) -> ScalarOrTuple[int]: + """Calculate required number of cubic B-spline coefficients for given output size.""" + device = str("cpu").replace("cuda", "gpu") + m: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=size, dtype="int32", place=device) + ) + s: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=stride, dtype="int32", place=device) + ) + if m.ndim != 1: + raise ValueError( + "cubic_bspline_control_point_grid_size() 'size' must be scalar or sequence" + ) + if m.less_equal(y=paddle.to_tensor(0, dtype=m.dtype)).astype("bool").any(): + raise ValueError( + "cubic_bspline_control_point_grid_size() 'size' must be positive" + ) + if s.less_equal(y=paddle.to_tensor(0, dtype=s.dtype)).astype("bool").any(): + raise ValueError( + "cubic_bspline_control_point_grid_size() 'stride' must be positive" + ) + ndim = tuple(m.shape)[0] + if ndim == 1 and tuple(s.shape)[0] > 1: + ndim = tuple(s.shape)[0] + for arg, name in zip([m, s], ["size", "stride"]): + if arg.ndim != 1 or arg.shape[0] not in (1, ndim): + raise ValueError( + f"cubic_bspline_control_point_grid_size() {name!r} must be scalar or sequence of length {ndim}" + ) + m = m.expand(shape=ndim) + s = s.expand(shape=ndim) + n = m.div(s, rounding_mode="floor").add_(y=paddle.to_tensor(3)) + n = (m % s == 0).where(x=n, y=n.add(1)) + if isinstance(size, int) and isinstance(stride, int): + return n[0].item() + return tuple(n.tolist()) + + +def cubic_bspline_control_point_grid(grid: Grid, stride: ScalarOrTuple[int]) -> Grid: + """Get control point grid for given image grid and control point stride.""" + size = cubic_bspline_control_point_grid_size(tuple(grid.shape), stride) + s: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=stride, dtype="int32", place=grid.place) + ) + s = s.expand(shape=grid.ndim) + return Grid( + size=size, + origin=grid.index_to_world(-s), + spacing=grid.spacing(), + direction=grid.direction(), + device=grid.place, + align_corners=True, + ) + + +@overload +def bspline_interpolation_weights( + degree: int, + stride: int, + dtype: Optional[paddle.dtype] = None, + device: Optional[str] = None, +) -> paddle.Tensor: + ... + + +@overload +def bspline_interpolation_weights( + degree: int, + stride: Sequence[int], + dtype: Optional[paddle.dtype] = None, + device: Optional[str] = None, +) -> Tuple[paddle.Tensor, ...]: + ... + + +def bspline_interpolation_weights( + degree: int, + stride: ScalarOrTuple[int], + dtype: Optional[paddle.dtype] = None, + device: Optional[str] = None, +) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]: + """Compute B-spline interpolation weights.""" + if degree == 3: + return cubic_bspline_interpolation_weights(stride, dtype=dtype, device=device) + kernels = {} + return_single_tensor = False + if isinstance(stride, int): + stride = [stride] + return_single_tensor = True + for s in stride: + if s in kernels: + continue + kernel = paddle.empty(shape=(s, degree + 1), dtype=dtype) + offset = paddle.arange(start=0, end=1, step=1 / s, dtype=kernel.dtype) + if degree % 2 == 0: + offset = offset.subtract_(y=paddle.to_tensor(offset.round())) + if degree == 2: + kernel[:, (1)] = 0.75 - offset.square() + kernel[:, (2)] = ( + offset.sub(kernel[:, (1)]) + .add_(y=paddle.to_tensor(1)) + .multiply_(y=paddle.to_tensor(0.5)) + ) + kernel[:, (0)] = -kernel[:, ([1, 2])].sum(axis=1).sub(1) + elif degree == 4: + a = offset.square() + t = a.mul(1 / 6) + t0 = t.sub(11 / 24).mul(offset) + t1 = ( + t.sub(0.25) + .multiply_(y=paddle.to_tensor(-a)) + .add_(y=paddle.to_tensor(19 / 96)) + ) + kernel[:, (0)] = ( + paddle.to_tensor(data=0.5, dtype=dtype, place=device) + .sub(offset) + .square() + ) + kernel[:, (0)] = kernel[:, (0)].mul(kernel[:, (0)].mul(1 / 24)) + kernel[:, (1)] = t1.add(t0) + kernel[:, (3)] = t1.sub(t0) + kernel[:, (4)] = ( + offset.mul(0.5) + .add_(y=paddle.to_tensor(kernel[:, (0)])) + .add_(y=paddle.to_tensor(t0)) + ) + kernel[:, (2)] = -kernel[:, ([0, 1, 3, 4])].sum(axis=1).sub(1) + elif degree == 5: + a = offset.square() + kernel[:, (5)] = offset.mul(a.square()).multiply_( + y=paddle.to_tensor(1 / 120) + ) + a = a.subtract_(y=paddle.to_tensor(offset)) + b = a.square() + offset = offset.subtract_(y=paddle.to_tensor(0.5)) + t = a.sub(3).multiply_(y=paddle.to_tensor(a)) + kernel[:, (0)] = ( + a.add(b) + .add_(y=paddle.to_tensor(1 / 5)) + .multiply_(y=paddle.to_tensor(1 / 24)) + .subtract_(y=paddle.to_tensor(kernel[:, (5)])) + ) + t0 = ( + a.sub(5) + .multiply_(y=paddle.to_tensor(a)) + .add_(y=paddle.to_tensor(46 / 5)) + .multiply_(y=paddle.to_tensor(1 / 24)) + ) + t1 = ( + t.add(4) + .multiply_(y=paddle.to_tensor(offset)) + .multiply_(y=paddle.to_tensor(-1 / 12)) + ) + kernel[:, (2)] = t0.add(t1) + kernel[:, (3)] = t0.sub(t1) + t0 = t.sub(9 / 5).multiply_(y=paddle.to_tensor(1.0 / 16.0)) + t1 = ( + b.sub(a) + .subtract_(y=paddle.to_tensor(5)) + .multiply_(y=paddle.to_tensor(offset)) + .multiply_(y=paddle.to_tensor(1.0 / 24.0)) + ) + kernel[:, (1)] = t0.add(t1) + kernel[:, (4)] = t0.sub(t1) + else: + raise NotImplementedError(f"B-spline interpolation for degree={degree}") + kernels[s] = kernel + kernels = tuple(kernels[s] for s in stride) + if return_single_tensor: + assert len(kernels) == 1 + return kernels[0] + return kernels + + +@overload +def cubic_bspline_interpolation_weights( + stride: int, + derivative: int = 0, + dtype: Optional[paddle.dtype] = None, + device: Optional[str] = None, +) -> paddle.Tensor: + ... + + +@overload +def cubic_bspline_interpolation_weights( + stride: int, + derivative: Sequence[int], + dtype: Optional[paddle.dtype] = None, + device: Optional[str] = None, +) -> Tuple[paddle.Tensor, ...]: + ... + + +@overload +def cubic_bspline_interpolation_weights( + stride: Sequence[int], + derivative: ScalarOrTuple[int] = 0, + dtype: Optional[paddle.dtype] = None, + device: Optional[str] = None, +) -> Tuple[paddle.Tensor, ...]: + ... + + +def cubic_bspline_interpolation_weights( + stride: ScalarOrTuple[int], + derivative: ScalarOrTuple[int] = 0, + dtype: Optional[paddle.dtype] = None, + device: Optional[str] = None, +) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]: + """Compute cubic B-spline interpolation weights.""" + kernels = {} + return_single_tensor = isinstance(stride, int) and isinstance(derivative, int) + if isinstance(stride, int): + stride = [stride] * (len(derivative) if isinstance(derivative, Sequence) else 1) + if isinstance(derivative, int): + derivative = [derivative] * len(stride) + elif not isinstance(derivative, Sequence): + raise TypeError( + "cubic_bspline_interpolation_weights() 'derivative' must be int or Sequence[int]" + ) + elif len(derivative) != len(stride): + raise ValueError( + "cubic_bspline_interpolation_weights() length of 'derivative' sequence does not match 'stride'" + ) + for s, d in zip(stride, derivative): + if (s, d) in kernels: + continue + kernel = paddle.empty(shape=(s, 4), dtype=dtype) + offset = paddle.arange(start=0, end=1, step=1 / s, dtype=kernel.dtype) + if d == 0: + kernel[:, (3)] = offset.pow(y=3).multiply_(y=paddle.to_tensor(1 / 6)) + kernel[:, (0)] = ( + offset.mul(offset.sub(1)) + .multiply_(y=paddle.to_tensor(0.5)) + .add_(y=paddle.to_tensor(1 / 6)) + .subtract_(y=paddle.to_tensor(kernel[:, (3)])) + ) + kernel[:, (2)] = offset.add(kernel[:, (0)]).subtract_( + y=paddle.to_tensor(kernel[:, (3)].mul(2)) + ) + kernel[:, (1)] = -kernel[:, ([0, 2, 3])].sum(axis=1).sub(1) + elif d == 1: + kernel[:, (3)] = offset.pow(y=2).multiply_(y=paddle.to_tensor(0.5)) + kernel[:, (0)] = offset.sub(kernel[:, (3)]).subtract_( + y=paddle.to_tensor(0.5) + ) + kernel[:, (2)] = ( + kernel[:, (0)].sub(kernel[:, (3)].mul(2)).add_(y=paddle.to_tensor(1)) + ) + kernel[:, (1)] = -kernel[:, ([0, 2, 3])].sum(axis=1) + elif d == 2: + kernel[:, (3)] = offset + kernel[:, (0)] = -offset.sub(1) + kernel[:, (2)] = -offset.mul(3).subtract_(y=paddle.to_tensor(1)) + kernel[:, (1)] = offset.mul(3).subtract_(y=paddle.to_tensor(2)) + elif d == 3: + kernel[:, (3)] = 1 + kernel[:, (0)] = -1 + kernel[:, (2)] = -3 + kernel[:, (1)] = 3 + else: + kernel.fill_(value=0) + kernels[s, d] = kernel + kernels = tuple(kernels[s, d] for s, d in zip(stride, derivative)) + if return_single_tensor: + assert len(kernels) == 1 + return kernels[0] + return kernels + + +def evaluate_cubic_bspline( + data: paddle.Tensor, + stride: ScalarOrTuple[int], + size: Optional[list] = None, + shape: Optional[list] = None, + kernel: Optional[Union[paddle.Tensor, Sequence[paddle.Tensor]]] = None, + derivative: ScalarOrTuple[int] = 0, + transpose: bool = False, +) -> paddle.Tensor: + """Evaluate cubic B-spline function. + + Args: + data: Cubic B-spline interpolation coefficients as tensor of shape ``(N, C, ..., X)``. + stride: Number of output grid points between control points plus one. This is the stride of the + transposed convolution used to upsample the control point displacements to the output size. + If a sequence of values is given, these must be the strides for the different spatial + dimensions in the order ``(sx, ...)``. + size: Spatial size of output tensor in the order ``(nx, ...)``. + shape: Spatial size of output tensor in the order ``(..., nx)``. + kernel: Precomputed cubic B-spline interpolation kernel. When multiple 1D kernels are given, + these must be in the order ``(kx, ...)``. + transpose: Whether to use separable transposed convolution as implemented in AIRLab. + When ``False``, a more efficient implementation using multi-channel convolution followed + by a reshuffling of the output is performed. This more efficient and also more accurate + implementation is adapted from the C++ code of MIRTK (``mirtk::BSplineInterpolateImageFunction``). + + Returns: + Cubic B-spline function values as tensor of shape ``(N, C, ..., X')``, where ``X' = sx * X`` + when neither output ``size`` nor ``shape`` is specified. Otherwise, the output tensor is cropped + to the requested spatial output size. + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("evaluate_cubic_bspline() 'data' must be paddle.Tensor") + if not paddle.is_floating_point(x=data): + raise TypeError( + "evaluate_cubic_bspline() 'data' must have floating point dtype" + ) + if data.ndim < 3: + raise ValueError( + "evaluate_cubic_bspline() 'data' must have shape (N, C, ..., X)" + ) + if size is not None: + if shape is not None: + raise ValueError( + "evaluate_cubic_bspline() 'size' and 'shape' are mutually exclusive" + ) + shape = tuple(reversed(size)) + D = data.ndim - 2 + N = tuple(data.shape)[0] + C = tuple(data.shape)[1] + if isinstance(stride, int): + stride = [stride] * D + if transpose: + if kernel is None: + if derivative != 0: + raise NotImplementedError( + "evaluate_cubic_bspline() 'derivative' must be 0 when kernel=None and transpose=True" + ) + kernels = {} + for s in stride: + if s not in kernels: + kernels[s] = cubic_bspline1d(s) + kernel = [kernels[s] for s in stride] + stride = tuple(reversed(stride)) + if isinstance(kernel, Sequence): + kernel = tuple(reversed(kernel)) + output = conv( + data, + kernel=kernel, + stride=stride, + padding=PaddingMode.ZEROS, + transpose=True, + ) + if shape is not None: + output = output[ + (slice(0, N), slice(0, C)) + + tuple(slice(s, s + n) for s, n in zip(stride, shape)) + ] + else: + if kernel is None: + kernel = cubic_bspline_interpolation_weights( + stride=stride, + derivative=derivative, + dtype=data.dtype, + device=data.place, + ) + elif isinstance(kernel, paddle.Tensor): + kernel = [kernel] * D + elif not isinstance(kernel, Sequence): + raise TypeError( + "evaluate_cubic_bspline() 'kernel' must be paddle.Tensor or Sequence of tensors" + ) + output = data + dims = tuple(SpatialDim(dim).tensor_dim(data.ndim) for dim in range(D)) + conv_fn: Callable[..., paddle.Tensor] = [ + paddle.nn.functional.conv1d, + paddle.nn.functional.conv2d, + paddle.nn.functional.conv3d, + ][D - 1] + for dim, w in zip(dims, kernel): + weight = w.reshape( + (tuple(w.shape)[0], 1, tuple(w.shape)[1]) + (1,) * (D - 1) + ) + weight = weight.tile(repeat_times=(C,) + (1,) * (weight.ndim - 1)) + output = move_dim(output, dim, 2) + output = conv_fn(output, weight, groups=C) + output = output.reshape((N, C, tuple(w.shape)[0]) + tuple(output.shape)[2:]) + x = output + perm_10 = list(range(x.ndim)) + perm_10[2] = 3 + perm_10[3] = 2 + output = x.transpose(perm=perm_10).flatten(start_axis=2, stop_axis=3) + output = move_dim(output, 2, dim) + if shape is not None: + output = output[ + (slice(0, N), slice(0, C)) + tuple(slice(0, n) for n in shape) + ] + return output + + +def cubic_bspline_jacobian_det( + data: paddle.Tensor, stride: ScalarOrTuple[int] +) -> paddle.Tensor: + """Evaluate Jacobian determinant of cubic B-spline free-form deformation.""" + if not isinstance(data, paddle.Tensor): + raise TypeError("cubic_bspline_jacobian_det() 'data' must be paddle.Tensor") + if not paddle.is_floating_point(x=data): + raise TypeError( + "cubic_bspline_jacobian_det() 'data' must have floating point dtype" + ) + if data.ndim < 3: + raise ValueError( + "cubic_bspline_jacobian_det() 'data' must have shape (N, C, ..., X)" + ) + D = data.ndim - 2 + C = tuple(data.shape)[1] + if C != D: + raise ValueError( + f"cubic_bspline_jacobian_det() 'data' mismatch between number of channels ({C}) and spatial dimensions ({D})" + ) + jac: Optional[paddle.Tensor] = None + for perm in permutations(range(D)): + term: Optional[paddle.Tensor] = None + for i, j in zip(range(D), perm): + derivative = [(1 if dim == j else 0) for dim in range(D)] + start_18 = data.shape[1] + i if i < 0 else i + du = evaluate_cubic_bspline( + paddle.slice(data, [1], [start_18], [start_18 + 1]), + stride=stride, + derivative=derivative, + ) + if i == j: + du = du.add_(y=paddle.to_tensor(1)) + term = du if term is None else term.multiply_(y=paddle.to_tensor(du)) + assert term is not None + if jac is None: + jac = term + elif is_even_permutation(perm): + jac = jac.add_(y=paddle.to_tensor(term)) + else: + jac = jac.subtract_(y=paddle.to_tensor(term)) + assert jac is not None + return jac + + +def cubic_bspline_jacobian_dict( + data: paddle.Tensor, + stride: ScalarOrTuple[int], + size: Optional[list] = None, + shape: Optional[list] = None, + add_identity: bool = False, +) -> Dict[Tuple[int, int], paddle.Tensor]: + """Evaluate Jacobian of cubic B-spline free-form deformation. + + Args: + data: Cubic B-spline interpolation coefficients as tensor of shape ``(N, D, ..., X)``, + where ``D`` is the number of spatial dimensions. + stride: Number of output grid points between control points plus one. If a sequence of + values is given, these must be the strides for the different spatial dimensions in + the order ``(sx, ...)``. + size: Spatial size of output tensor in the order ``(nx, ...)``. + shape: Spatial size of output tensor in the order ``(..., nx)``. + add_identity: Whether to calculate derivatives of :math:`u(x)` (False) or the free-form + deformation given by :math:`x + u(x)` (True), where :math:`u` is the cubic B-spline + function, by adding the identity matrix to the Jacobian of :math:`u`. + + Returns: + Dictionary of spatial derivatives with keys corresponding to (row, col) indices. + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("cubic_bspline_jacobian_dict() 'data' must be paddle.Tensor") + if not paddle.is_floating_point(x=data): + raise TypeError( + "cubic_bspline_jacobian_dict() 'data' must have floating point dtype" + ) + if data.ndim < 3: + raise ValueError( + "cubic_bspline_jacobian_dict() 'data' must have shape (N, C, ..., X)" + ) + if size is not None: + if shape is not None: + raise ValueError( + "cubic_bspline_jacobian_dict() 'size' and 'shape' are mutually exclusive" + ) + shape = tuple(reversed(size)) + D = data.ndim - 2 + C = tuple(data.shape)[1] + if C != D: + raise ValueError( + f"cubic_bspline_jacobian_dict() 'data' mismatch between number of channels ({C}) and spatial dimensions ({D})" + ) + jac = {} + for i, j in combinations_with_replacement(range(D), 2): + derivative = [(1 if dim == j else 0) for dim in range(D)] + start_19 = data.shape[1] + i if i < 0 else i + coeff = paddle.slice(data, [1], [start_19], [start_19 + 1]) + deriv = evaluate_cubic_bspline( + coeff, shape=shape, stride=stride, derivative=derivative + ) + if add_identity and i == j: + deriv = deriv.add_(y=paddle.to_tensor(1)) + jac[i, j] = deriv + return { + (i, j): jac[(i, j) if i < j else (j, i)] for i, j in product(range(D), repeat=2) + } + + +def cubic_bspline_jacobian_matrix( + data: paddle.Tensor, + stride: ScalarOrTuple[int], + size: Optional[list] = None, + shape: Optional[list] = None, + add_identity: bool = False, +) -> paddle.Tensor: + """Evaluate Jacobian of cubic B-spline free-form deformation. + + Args: + data: Cubic B-spline interpolation coefficients as tensor of shape ``(N, D, ..., X)``, + where ``D`` is the number of spatial dimensions. + stride: Number of output grid points between control points plus one. If a sequence of + values is given, these must be the strides for the different spatial dimensions in + the order ``(sx, ...)``. + size: Spatial size of output tensor in the order ``(nx, ...)``. + shape: Spatial size of output tensor in the order ``(..., nx)``. + add_identity: Whether to calculate derivatives of :math:`u(x)` (False) or the free-form + deformation given by :math:`x + u(x)` (True), where :math:`u` is the cubic B-spline + function, by adding the identity matrix to the Jacobian of :math:`u`. + + Returns: + Full Jacobian matrices as tensors of shape ``(N, ..., X, D, D)``. + + """ + N = tuple(data.shape)[0] + D = data.ndim - 2 + jac = cubic_bspline_jacobian_dict( + data, stride=stride, shape=shape, size=size, add_identity=add_identity + ) + mat = paddle.concat(x=[jac[i, j] for i, j in product(range(D), repeat=2)], axis=1) + mat = move_dim(mat, 1, -1) + mat = mat.reshape((N,) + tuple(jac[0, 0].shape)[2:] + (D, D)) + return mat + + +def cubic_bspline_jacobian_triu( + data: paddle.Tensor, + stride: ScalarOrTuple[int], + size: Optional[list] = None, + shape: Optional[list] = None, + add_identity: bool = False, +) -> paddle.Tensor: + """Evaluate Jacobian of cubic B-spline free-form deformation. + + Args: + data: Cubic B-spline interpolation coefficients as tensor of shape ``(N, D, ..., X)``, + where ``D`` is the number of spatial dimensions. + stride: Number of output grid points between control points plus one. If a sequence of + values is given, these must be the strides for the different spatial dimensions in + the order ``(sx, ...)``. + size: Spatial size of output tensor in the order ``(nx, ...)``. + shape: Spatial size of output tensor in the order ``(..., nx)``. + add_identity: Whether to calculate derivatives of :math:`u(x)` (False) or the free-form + deformation given by :math:`x + u(x)` (True), where :math:`u` is the cubic B-spline + function, by adding the identity matrix to the Jacobian of :math:`u`. + + Returns: + Flattened upper triangular Jacobian matrices as tensors of shape ``(N, D * (D + 1) / 2, ..., X)``. + + """ + D = data.ndim - 2 + jac = cubic_bspline_jacobian_dict( + data, stride=stride, shape=shape, size=size, add_identity=add_identity + ) + return paddle.concat( + x=[jac[i, j] for i, j in combinations_with_replacement(range(D), 2)], axis=1 + ) + + +def subdivide_cubic_bspline( + data: paddle.Tensor, + dims: Optional[Union[SpatialDimArg, Sequence[SpatialDimArg]]] = None, +) -> paddle.Tensor: + """Compute cubic B-spline coefficients for subdivided control point grid. + + Args: + data: Input control point coefficients as tensor of shape ``(N, C, ..., X)``. + dims: Spatial dimensions along which to subdivide. + + Returns: + Coefficients of subdivided cubic B-spline function. + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("subdivide_cubic_bspline() 'data' must be paddle.Tensor") + if not paddle.is_floating_point(x=data): + raise TypeError( + "subdivide_cubic_bspline() 'data' must have floating point dtype" + ) + if data.ndim < 4: + raise ValueError( + "subdivide_cubic_bspline() 'data' must have shape (N, C, ..., X)" + ) + if dims is None: + dims = tuple(range(data.ndim - 2)) + elif isinstance(dims, (int, str)): + dims = [dims] + elif not isinstance(dims, Sequence): + raise TypeError( + "subdivide_cubic_bspline() 'dims' must be int, str, or Sequence thereof" + ) + dims = sorted(SpatialDim.from_arg(dim).tensor_dim(data.ndim) for dim in dims) + output = data + kernel_1 = paddle.to_tensor( + data=[0.125, 0.75, 0.125], dtype=data.dtype, place=data.place + ) + kernel_2 = paddle.to_tensor(data=[0.5, 0.5], dtype=data.dtype, place=data.place) + for dim in dims: + shape = ( + tuple(output.shape)[:dim] + + (2 * tuple(output.shape)[dim] - 1,) + + tuple(output.shape)[dim + 1 :] + ) + temp = paddle.empty(shape=shape, dtype=data.dtype) + indices = [slice(0, n) for n in shape] + indices[dim] = slice(0, shape[dim], 2) + temp[tuple(indices)] = conv1d(output, kernel_1, dim=dim, padding=1) + indices = [slice(0, n) for n in shape] + indices[dim] = slice(1, shape[dim], 2) + temp[tuple(indices)] = conv1d(output, kernel_2, dim=dim, padding=0) + output = temp + return output diff --git a/jointContribution/HighResolution/deepali/core/config.py b/jointContribution/HighResolution/deepali/core/config.py new file mode 100644 index 0000000000..db39271ef9 --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/config.py @@ -0,0 +1,197 @@ +"""Auxiliary functions and classes for dealing with configuration files.""" +from __future__ import annotations + +import json +from dataclasses import asdict +from dataclasses import fields +from pathlib import Path +from typing import Any +from typing import Dict +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Type +from typing import TypeVar + +import dacite +import yaml + +from .path import abspath_template +from .types import PathStr +from .types import is_path_str_field + +T = TypeVar("T", bound="DataclassConfig") + + +class DataclassConfig(object): + """Base class of configuration data classes.""" + + @staticmethod + def section() -> str: + """Common key prefix of configuration entries in configuration file.""" + return "" + + @classmethod + def _from_dict( + cls: Type[T], arg: Mapping[str, Any], parent: Optional[Path] = None + ) -> T: + """Create configuration from dictionary.""" + config = dacite.from_dict(cls, arg) + return config + + @classmethod + def from_dict( + cls: Type[T], arg: Mapping[str, Any], parent: Optional[Path] = None + ) -> T: + """Create configuration from dictionary.""" + config = cls._from_dict(arg, parent=parent) + config._finalize(parent) + return config + + @classmethod + def from_path(cls: Type[T], path: PathStr, section: Optional[str] = None) -> T: + """Load configuration from file.""" + path = Path(path).absolute() + text = path.read_text() + if path.suffix == ".json": + config = json.loads(text) + elif path.suffix in (".yml", ".yaml"): + config = yaml.load(text, Loader=yaml.SafeLoader) + else: + raise ValueError( + f"{cls.__name__}.from_path() 'path' has unsupported suffix {path.suffix}" + ) + if config is None: + config = {} + if section is None: + section = cls.section() + if section: + for key in section.split("."): + config = config.get(key, {}) + return cls.from_dict(config, parent=path.parent) + + @classmethod + def read(cls: Type[T], path: PathStr, section: Optional[str] = None) -> T: + """Load configuration from file.""" + return cls.from_path(path, section=section) + + def write(self, path: PathStr) -> None: + """Write configuration to file.""" + path = Path(path).absolute() + config = self.asdict() + if path.suffix == ".json": + text = json.dumps(config) + elif path.suffix in (".yml", ".yaml"): + text = yaml.safe_dump(config) + else: + raise ValueError( + f"{type(self).__name__}.write() 'path' has unsupported suffix {path.suffix}" + ) + path.parent.mkdir(parents=True, exist_ok=True) + path.write_text(text) + + def asdict(self) -> Dict[str, Any]: + """Get configuration dictionary.""" + return asdict(self) + + def _finalize(self: T, parent: Optional[Path] = None) -> None: + """Finalize parameters after loading these from input file.""" + for field in fields(self): + value = getattr(self, field.name) + if value is None: + continue + if isinstance(value, DataclassConfig): + value._finalize(parent) + elif is_path_str_field(field): + value = abspath_template(value, parent=parent) + setattr(self, field.name, value) + + def _join_kwargs_in_sequence(self, attr: str): + """Merge kwarg dictionaries in 'norm' or 'acti' sequence. + + In case of a sequence instead of a single string or dictionary, ``ConvLayer`` expects + ``norm`` and ``acti`` parameters to be a sequence of length 2, where the first entry + is the normalization or activiation layer name, respectively, and the second entry is + a dictionary of keyword arguments for the respective layer. In a YAML file, it may + be convenient, however, to specify these arguments on a single line as follows: + + .. code-block:: yaml + + norm: [batch, momentum: 0.1, eps: 0.001] + + This represents a sequence, where the first item is a string and the following items + are dictionaries with a single key. This ``__post_init__`` function merges these separate + dictionaries into a single dictionary in order to support above YAML syntax. + + The same functionality may be useful for other configuration entries, not only + ``normalization()``, ``activiation()``, or ``pooling()`` related parameters. + + Alternatively, one could use a single dictionary in the YAML configuration: + + .. code-block:: yaml + + norm: {name: batch, momentum: 0.1, eps: 0.001} + + Args: + attr: Name of dataclass attribute to modify in place. + + """ + arg = getattr(self, attr) + arg = join_kwargs_in_sequence(arg) + setattr(self, attr, arg) + + +def join_kwargs_in_sequence(arg): + """Merge kwarg dictionaries in 'norm' or 'acti' sequence. + + In case of a sequence instead of a single string or dictionary, ``ConvLayer`` expects + ``norm`` and ``acti`` parameters to be a sequence of length 2, where the first entry + is the normalization or activiation layer name, respectively, and the second entry is + a dictionary of keyword arguments for the respective layer. In a YAML file, it may + be convenient, however, to specify these arguments on a single line as follows: + + .. code-block:: yaml + + norm: [batch, momentum: 0.1, eps: 0.001] + + This represents a sequence, where the first item is a string and the following items + are dictionaries with a single key. This ``__post_init__`` function merges these separate + dictionaries into a single dictionary in order to support above YAML syntax. + + The same functionality may be useful for other configuration entries, not only + ``normalization()``, ``activiation()``, or ``pooling()`` related parameters. + + Alternatively, one could use a single dictionary in the YAML configuration: + + .. code-block:: yaml + + norm: {name: batch, momentum: 0.1, eps: 0.001} + + Args: + attr: Name of dataclass attribute to modify in place. + + """ + if not isinstance(arg, str) and isinstance(arg, Sequence): + if not all( + isinstance(item, (str, dict) if i == 0 else dict) + for i, item in enumerate(arg) + ): + raise TypeError( + "join_kwargs_in_sequence() 'arg' must be str, dict, or sequence of dicts with the first item being either a str or dict" + ) + if len(arg) == 1 and isinstance(arg[0], dict): + arg = arg[0] + elif len(arg) > 1: + start = 0 if isinstance(arg[0], dict) else 1 + kwargs = dict(arg[start]) + for i in range(start + 1, len(arg)): + cur = arg[i] + assert isinstance(cur, dict) + for name, value in cur.items(): + if name in kwargs: + raise ValueError( + f"join_kwargs_in_sequence() 'arg' has duplicate kwarg {name}" + ) + kwargs[name] = value + arg = kwargs if start == 0 else (arg[0], kwargs) + return arg diff --git a/jointContribution/HighResolution/deepali/core/cube.py b/jointContribution/HighResolution/deepali/core/cube.py new file mode 100644 index 0000000000..ec4190a1a3 --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/cube.py @@ -0,0 +1,637 @@ +from __future__ import annotations + +from copy import copy as shallow_copy +from typing import Any +from typing import Optional +from typing import Sequence +from typing import Union +from typing import overload + +import numpy as np +import paddle + +from .grid import ALIGN_CORNERS +from .grid import Axes +from .grid import Grid +from .linalg import hmm +from .linalg import homogeneous_matrix +from .linalg import homogeneous_transform +from .tensor import as_tensor +from .tensor import cat_scalars +from .types import Array +from .types import Device +from .types import DType +from .types import Shape +from .types import Size + + +class Cube(object): + """Bounding box oriented in world space which defines a normalized domain. + + Coordinates of points within this domain can be either with respect to the world coordinate + system or the cube defined by the bounding box where coordinate axes are parallel to the + cube edges and have a uniform side length of 2. The latter are the normalized coordinates + used by ``paddle.nn.functional.grid_sample()``, in particular. In terms of the coordinate + transformations, a :class:`.Cube` is thus equivalent to a :class:`.Grid` with three points + along each dimension and ``align_corners=True``. + + A regular sampling :class:`.Grid`, on the other hand, subsamples the world space within the bounds + defined by the cube into a number of equally sized cells or equally spaced points, respectivey. + How the grid points relate to the faces of the cube depends on :meth:`.Grid.align_corners`. + + """ + + __slots__ = "_center", "_direction", "_extent" + + def __init__( + self, + extent: Optional[Union[Array, float]], + center: Optional[Union[Array, float]] = None, + origin: Optional[Union[Array, float]] = None, + direction: Optional[Array] = None, + device: Optional[Device] = None, + ): + """Initialize cube attributes. + + Args: + extent: Extent ``(extent_x, ...)`` of the cube in world units. + center: Cube center point ``(x, ...)`` in world space. + origin: World coordinates ``(x, ...)`` of lower left corner. + direction: Direction cosines defining orientation of cube in world space. + The direction cosines are vectors that point along the cube edges. + Each column of the matrix indicates the direction cosines of the unit vector + that is parallel to the cube edge corresponding to that dimension. + device: Device on which to store attributes. Uses ``"cpu"`` if ``None``. + + """ + extent = as_tensor(extent, device=device or "cpu") + if not extent.is_floating_point(): + extent = extent.astype(dtype="float32") + self._extent = extent + if direction is None: + direction = paddle.eye(num_rows=self.ndim, dtype=self.dtype) + self.direction_(direction) + if origin is None: + self.center_(0 if center is None else center) + elif center is None: + self.origin_(origin) + else: + self.center_(center) + if not paddle.allclose(x=origin, y=self.origin()).item(): + raise ValueError("Cube() 'center' and 'origin' are inconsistent") + + def numpy(self) -> np.ndarray: + """Get cube attributes as 1-dimensional NumPy array.""" + return np.concatenate( + [ + self._extent.numpy(), + self._center.numpy(), + self._direction.flatten().numpy(), + ], + axis=0, + ) + + @classmethod + def from_numpy( + cls, attrs: Union[Sequence[float], np.ndarray], origin: bool = False + ) -> Cube: + """Create Cube from 1-dimensional NumPy array.""" + if isinstance(attrs, np.ndarray): + seq = attrs.astype(float).tolist() + else: + seq = attrs + return cls.from_seq(seq, origin=origin) + + @classmethod + def from_seq(cls, attrs: Sequence[float], origin: bool = False) -> Cube: + """Create Cube from sequence of attribute values. + + Args: + attrs: Array of length (D + 2) * D, where ``D=2`` or ``D=3`` is the number + of spatial cube dimensions and array items are given as + ``(sx, ..., cx, ..., d11, ..., d21, ....)``, where ``(sx, ...)`` is the + cube extent, ``(cx, ...)`` the cube center coordinates, and ``(d11, ...)`` + are the cube direction cosines. The argument can be a Python list or tuple, + NumPy array, or tensor. + origin: Whether ``(cx, ...)`` specifies Cube origin rather than center. + + Returns: + Cube instance. + + """ + if len(attrs) == 8: + d = 2 + elif len(attrs) == 15: + d = 3 + else: + raise ValueError( + f"{cls.__name__}.from_seq() expected array of length 8 (D=2) or 15 (D=3)" + ) + kwargs = dict(extent=attrs[0:d], direction=attrs[2 * d :]) + if origin: + kwargs["origin"] = attrs[d : 2 * d] + else: + kwargs["center"] = attrs[d : 2 * d] + return Cube(**kwargs) + + @classmethod + def from_grid(cls, grid: Grid, align_corners: Optional[bool] = None) -> Cube: + """Get cube with respect to which normalized grid coordinates are defined.""" + if align_corners is not None: + grid = grid.align_corners(align_corners) + return cls( + extent=grid.cube_extent(), + center=grid.center(), + direction=grid.direction(), + device=grid.place, + ) + + def grid( + self, + size: Optional[Union[int, Size, Array]] = None, + shape: Optional[Union[int, Shape, Array]] = None, + spacing: Optional[Union[Array, float]] = None, + align_corners: bool = ALIGN_CORNERS, + ) -> Grid: + """Create regular sampling grid which covers the world space bounded by the cube.""" + if size is None and shape is None: + if spacing is None: + raise ValueError( + "Cube.grid() requires either the desired grid 'size'/'shape' or point 'spacing'" + ) + size = self.extent().div(spacing).round() + size = tuple(size.astype("int32").tolist()) + if align_corners: + size = tuple(n + 1 for n in size) + spacing = None + else: + if isinstance(size, int): + size = (size,) * self.ndim + if isinstance(shape, int): + shape = (shape,) * self.ndim + size = Grid(size=size, shape=shape).size() + ncells = paddle.to_tensor(data=size) + if align_corners: + ncells = ncells.subtract_(y=paddle.to_tensor(1)) + ncells = ncells.to(dtype=self.dtype, device=self.device) + grid = Grid( + size=size, + spacing=self.extent().div(ncells), + center=self.center(), + direction=self.direction(), + align_corners=align_corners, + device=self.device, + ) + with paddle.no_grad(): + if not paddle.allclose(x=grid.cube_extent(), y=self.extent()).item(): + raise ValueError( + "Cube.grid() 'size'/'shape' times 'spacing' does not match cube extent" + ) + return grid + + def dim(self) -> int: + """Number of cube dimensions.""" + return len(self._extent) + + @property + def ndim(self) -> int: + """Number of cube dimensions.""" + return len(self._extent) + + @property + def dtype(self) -> DType: + """Get data type of cube attribute tensors.""" + return self._extent.dtype + + @property + def device(self) -> Device: + """Get device on which cube attribute tensors are stored.""" + return self._extent.place + + def clone(self) -> Cube: + """Make deep copy of this instance.""" + cube = shallow_copy(self) + for name in self.__slots__: + value = getattr(self, name) + if isinstance(value, paddle.Tensor): + setattr(cube, name, value.clone()) + return cube + + def __deepcopy__(self, memo) -> Cube: + """Support copy.deepcopy to clone this cube.""" + if id(self) in memo: + return memo[id(self)] + copy = self.clone() + memo[id(self)] = copy + return copy + + @overload + def center(self) -> paddle.Tensor: + """Get center point in world space.""" + ... + + @overload + def center(self, arg: Union[float, Array], *args: float) -> Cube: + """Get new cube with same orientation and extent, but specified center point.""" + ... + + def center(self, *args) -> Union[paddle.Tensor, Cube]: + """Get center point in world space or new cube with specified center point.""" + if args: + return shallow_copy(self).center_(*args) + return self._center + + def center_(self, arg: Union[Array, float], *args: float) -> Cube: + """Set center point in world space of this cube.""" + self._center = cat_scalars( + arg, *args, num=self.ndim, dtype=self.dtype, device=self.device + ) + return self + + @overload + def origin(self) -> paddle.Tensor: + """Get world coordinates of lower left corner.""" + ... + + @overload + def origin(self, arg: Union[Array, float], *args: float) -> Cube: + """Get new cube with specified world coordinates of lower left corner.""" + ... + + def origin(self, *args) -> Union[paddle.Tensor, Cube]: + """Get origin in world space or new cube with specified origin.""" + if args: + return shallow_copy(self).origin_(*args) + offset = paddle.matmul(x=self.direction(), y=self.spacing()) + origin = self._center.sub(offset) + return origin + + def origin_(self, arg: Union[Array, float], *args: float) -> Cube: + """Set world coordinates of lower left corner.""" + center = cat_scalars( + arg, *args, num=self.ndim, dtype=self.dtype, device=self.device + ) + offset = paddle.matmul(x=self.direction(), y=self.spacing()) + self._center = center.add(offset) + return self + + def spacing(self) -> paddle.Tensor: + """Cube unit spacing in world space.""" + return self._extent.div(2) + + @overload + def direction(self) -> paddle.Tensor: + """Get edge direction cosines matrix.""" + ... + + @overload + def direction(self, arg: Union[Array, float], *args: float) -> Cube: + """Get new cube with specified edge direction cosines.""" + ... + + def direction(self, *args) -> Union[paddle.Tensor, Cube]: + """Get edge direction cosines matrix or new cube with specified orientation.""" + if args: + return shallow_copy(self).direction_(*args) + return self._direction + + def direction_(self, arg: Union[Array, float], *args: float) -> Cube: + """Set edge direction cosines matrix of this cube.""" + D = self.ndim + if args: + direction = paddle.to_tensor(data=(arg,) + args) + else: + direction = as_tensor(arg) + direction = direction.to(dtype=self.dtype, device=self.device) + if direction.ndim == 1: + if tuple(direction.shape)[0] != D * D: + raise ValueError( + f"Cube direction must be array or square matrix with numel={D * D}" + ) + direction = direction.reshape(D, D) + elif ( + direction.ndim != 2 + or tuple(direction.shape)[0] != tuple(direction.shape)[1] + or tuple(direction.shape)[0] != D + ): + raise ValueError( + f"Cube direction must be array or square matrix with numel={D * D}" + ) + with paddle.no_grad(): + if abs(direction.det().abs().item() - 1) > 0.0001: + raise ValueError( + "Cube direction cosines matrix must be valid rotation matrix" + ) + self._direction = direction + return self + + @overload + def extent(self) -> paddle.Tensor: + """Extent of cube in world space.""" + ... + + @overload + def extent(self, arg: Union[float, Array], *args, float) -> Cube: + """Get cube with same center and orientation but different extent.""" + ... + + def extent(self, *args) -> Union[paddle.Tensor, Cube]: + """Get extent of this cube or a new cube with same center and orientation but specified extent.""" + if args: + return shallow_copy(self).extent_(*args) + return self._extent + + def extent_(self, arg: Union[Array, float], *args) -> Cube: + """Set the extent of this cube, keeping center and orientation the same.""" + self._extent = cat_scalars( + arg, *args, num=self.ndim, device=self.device, dtype=self.dtype + ) + return self + + def affine(self) -> paddle.Tensor: + """Affine transformation from cube to world space, excluding translation.""" + return paddle.mm(input=self.direction(), mat2=paddle.diag(x=self.spacing())) + + def inverse_affine(self) -> paddle.Tensor: + """Affine transformation from world to cube space, excluding translation.""" + one = paddle.to_tensor(data=1, dtype=self.dtype, place=self.device) + return paddle.mm( + input=paddle.diag(x=one / self.spacing()), mat2=self.direction().t() + ) + + def transform( + self, + axes: Optional[Union[Axes, str]] = None, + to_axes: Optional[Union[Axes, str]] = None, + to_cube: Optional[Cube] = None, + vectors: bool = False, + ) -> paddle.Tensor: + """Transformation of coordinates from this cube to another cube. + + Args: + axes: Axes with respect to which input coordinates are defined. + If ``None`` and also ``to_axes`` and ``to_cube`` is ``None``, + returns the transform which maps from cube to world space. + to_axes: Axes of cube to which coordinates are mapped. Use ``axes`` if ``None``. + to_cube: Other cube. Use ``self`` if ``None``. + vectors: Whether transformation is used to rescale and reorient vectors. + + Returns: + If ``vectors=False``, a homogeneous coordinate transformation of shape ``(D, D + 1)``. + Otherwise, a square transformation matrix of shape ``(D, D)`` is returned. + + """ + if axes is None and to_axes is None and to_cube is None: + return self.transform(Axes.CUBE, Axes.WORLD, vectors=vectors) + if axes is None: + raise ValueError( + "Cube.transform() 'axes' required when 'to_axes' or 'to_cube' specified" + ) + axes = Axes(axes) + to_axes = axes if to_axes is None else Axes(to_axes) + if axes is Axes.GRID or to_axes is Axes.GRID: + raise ValueError("Cube.transform() Axes.GRID is only valid for a Grid") + if axes == to_axes and axes is Axes.CUBE_CORNERS: + axes = to_axes = Axes.CUBE + elif axes is Axes.CUBE_CORNERS and to_axes is Axes.WORLD: + axes = Axes.CUBE + elif axes is Axes.WORLD and to_axes is Axes.CUBE_CORNERS: + to_axes = Axes.CUBE + if axes is Axes.CUBE_CORNERS or to_axes is Axes.CUBE_CORNERS: + raise ValueError( + "Cube.transform() cannot map between Axes.CUBE and Axes.CUBE_CORNERS. Use Cube.grid().transform() instead." + ) + if axes == to_axes and ( + axes is Axes.WORLD or to_cube is None or to_cube == self + ): + return paddle.eye(num_rows=self.ndim, dtype=self.dtype) + if axes == to_axes: + assert axes is Axes.CUBE + cube_to_world = self.transform(Axes.CUBE, Axes.WORLD, vectors=vectors) + world_to_cube = to_cube.transform(Axes.WORLD, Axes.CUBE, vectors=vectors) + if vectors: + return paddle.mm(input=world_to_cube, mat2=cube_to_world) + return hmm(world_to_cube, cube_to_world) + if axes is Axes.CUBE: + assert to_axes is Axes.WORLD + if vectors: + return self.affine() + return homogeneous_matrix(self.affine(), self.center()) + assert axes is Axes.WORLD + assert to_axes is Axes.CUBE + if vectors: + return self.inverse_affine() + return hmm(self.inverse_affine(), -self.center()) + + def inverse_transform(self, vectors: bool = False) -> paddle.Tensor: + """Transform which maps from world to cube space.""" + return self.transform(Axes.WORLD, Axes.CUBE, vectors=vectors) + + def apply_transform( + self, + arg: Array, + axes: Union[Axes, str], + to_axes: Optional[Union[Axes, str]] = None, + to_cube: Optional[Cube] = None, + vectors: bool = False, + ) -> paddle.Tensor: + """Map point coordinates or displacement vectors from one cube to another. + + Args: + arg: Coordinates of points or displacement vectors as tensor of shape ``(..., D)``. + axes: Axes of this cube with respect to which input coordinates are defined. + to_axes: Axes of cube to which coordinates are mapped. Use ``axes`` if ``None``. + to_cube: Other cube. Use ``self`` if ``None``. + vectors: Whether ``arg`` contains displacements (``True``) or point coordinates (``False``). + + Returns: + Points or displacements with respect to ``to_cube`` and ``to_axes``. + If ``to_cube == self`` and ``to_axes == axes`` or both ``axes`` and ``to_axes`` are + ``Axes.WORLD`` and ``arg`` is a ``paddle.Tensor``, a reference to the unmodified input + tensor is returned. + + """ + axes = Axes(axes) + to_axes = axes if to_axes is None else Axes(to_axes) + if to_cube is None: + to_cube = self + tensor = as_tensor(arg) + if not tensor.is_floating_point(): + tensor = tensor.astype(self.dtype) + if axes is to_axes and axes is Axes.WORLD: + return tensor + if to_cube is not None and to_cube != self or axes is not to_axes: + matrix = self.transform(axes, to_axes, to_cube=to_cube, vectors=vectors) + matrix = matrix.unsqueeze(axis=0).to(device=tensor.place) + tensor = homogeneous_transform(matrix, tensor) + return tensor + + def transform_points( + self, + points: Array, + axes: Union[Axes, str], + to_axes: Optional[Union[Axes, str]] = None, + to_cube: Optional[Cube] = None, + ) -> paddle.Tensor: + """Map point coordinates from one cube to another. + + Args: + points: Coordinates of points to transform as tensor of shape ``(..., D)``. + axes: Axes of this cube with respect to which input coordinates are defined. + to_axes: Axes of cube to which coordinates are mapped. Use ``axes`` if ``None``. + to_cube: Other cube. Use ``self`` if ``None``. + + Returns: + Point coordinates with respect to ``to_cube`` and ``to_axes``. If ``to_cube == self`` + and ``to_axes == axes`` or both ``axes`` and ``to_axes`` are ``Axes.WORLD`` and ``arg`` + is a ``paddle.Tensor``, a reference to the unmodified input tensor is returned. + + """ + return self.apply_transform( + points, axes, to_axes, to_cube=to_cube, vectors=False + ) + + def transform_vectors( + self, + vectors: Array, + axes: Union[Axes, str], + to_axes: Optional[Union[Axes, str]] = None, + to_cube: Optional[Cube] = None, + ) -> paddle.Tensor: + """Rescale and reorient flow vectors. + + Args: + vectors: Displacement vectors as tensor of shape ``(..., D)``. + axes: Axes of this cube with respect to which input coordinates are defined. + to_axes: Axes of cube to which coordinates are mapped. Use ``axes`` if ``None``. + to_cube: Other cube. Use ``self`` if ``None``. + + Returns: + Rescaled and reoriented displacement vectors. If ``to_cube == self`` and + ``to_axes == axes`` or both ``axes`` and ``to_axes`` are ``Axes.WORLD`` and ``arg`` + is a ``paddle.Tensor``, a reference to the unmodified input tensor is returned. + + """ + return self.apply_transform( + vectors, axes, to_axes, to_cube=to_cube, vectors=True + ) + + def cube_to_world(self, coords: Array) -> paddle.Tensor: + """Map point coordinates from cube to world space. + + Args: + coords: Normalized coordinates with respect to this cube as tensor of shape ``(..., D)``. + + Returns: + Coordinates of points in world space. + + """ + return self.apply_transform(coords, Axes.CUBE, Axes.WORLD, vectors=False) + + def world_to_cube(self, points: Array) -> paddle.Tensor: + """Map point coordinates from world to cube space. + + Args: + points: Coordinates of points in world space as tensor of shape ``(..., D)``. + + Returns: + Normalized coordinates of points with respect to this cube. + + """ + return self.apply_transform(points, Axes.WORLD, Axes.CUBE, vectors=False) + + def __eq__(self, other: Any) -> bool: + """Compare this cube to another.""" + if other is self: + return True + if not isinstance(other, self.__class__): + return False + for name in self.__slots__: + value = getattr(self, name) + other_value = getattr(other, name) + if type(value) != type(other_value): + return False + if isinstance(value, paddle.Tensor): + assert isinstance(other_value, paddle.Tensor) + if tuple(value.shape) != tuple(other_value.shape): + return False + other_value = other_value.to(device=value.place) + if not paddle.allclose( + x=value, y=other_value, rtol=1e-05, atol=1e-08 + ).item(): + return False + elif value != other_value: + return False + return True + + def __repr__(self) -> str: + """String representation.""" + origin = ", ".join([f"{v:.5f}" for v in self.origin()]) + center = ", ".join([f"{v:.5f}" for v in self._center]) + direction = ", ".join([f"{v:.5f}" for v in self._direction.flatten()]) + extent = ", ".join([f"{v:.5f}" for v in self._extent]) + return ( + f"{type(self).__name__}(" + + f"origin=({origin})" + + f", center=({center})" + + f", extent=({extent})" + + f", direction=({direction})" + + f", device={repr(str(self.device))}" + + ")" + ) + + +def cube_points_transform( + cube: Cube, axes: Axes, to_cube: Cube, to_axes: Optional[Axes] = None +): + """Get linear transformation of points from one cube to another. + + Args: + cube: Sampling grid with respect to which input points are defined. + axes: Grid axes with respect to which input points are defined. + to_cube: Sampling grid with respect to which output points are defined. + to_axes: Grid axes with respect to which output points are defined. + + Returns: + Homogeneous coordinate transformation matrix as tensor of shape ``(D, D + 1)``. + + """ + return cube.transform(axes=axes, to_axes=to_axes, to_cube=to_cube, vectors=False) + + +def cube_vectors_transform( + cube: Cube, axes: Axes, to_cube: Cube, to_axes: Optional[Axes] = None +): + """Get affine transformation which maps vectors with respect to one cube to another. + + Args: + cube: Cube with respect to which (normalized) input vectors are defined. + axes: Cube axes with respect to which input vectors are defined. + to_cube: Cube with respect to which (normalized) output vectors are defined. + to_axes: Cube axes with respect to which output vectors are defined. + + Returns: + Affine transformation matrix as tensor of shape ``(D, D)``. + + """ + return cube.transform(axes=axes, to_axes=to_axes, to_cube=to_cube, vectors=True) + + +def cube_transform_points( + points: paddle.Tensor, + cube: Cube, + axes: Axes, + to_cube: Cube, + to_axes: Optional[Axes] = None, +): + return cube.transform_points(points, axes=axes, to_axes=to_axes, to_cube=to_cube) + + +def cube_transform_vectors( + vectors: paddle.Tensor, + cube: Cube, + axes: Axes, + to_cube: Cube, + to_axes: Optional[Axes] = None, +): + return cube.transform_vectors(vectors, axes=axes, to_axes=to_axes, to_cube=to_cube) diff --git a/jointContribution/HighResolution/deepali/core/enum.py b/jointContribution/HighResolution/deepali/core/enum.py new file mode 100644 index 0000000000..08a1a99bc9 --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/enum.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +import itertools +import re +from enum import Enum +from enum import IntEnum +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Union + + +class Sampling(Enum): + """Enumeration of image interpolation modes.""" + + AREA = "area" + BICUBIC = "bicubic" + BSPLINE = "bspline" + LINEAR = "linear" + NEAREST = "nearest" + + @classmethod + def from_arg(cls, arg: Union[Sampling, str, None]) -> Sampling: + """Create enumeration value from function argument.""" + if isinstance(arg, str): + arg = arg.lower() + if arg is None or arg in ("default", "bilinear", "trilinear"): + return cls.LINEAR + if arg == "nn": + return cls.NEAREST + return cls(arg) + + def grid_sample_mode(self, num_spatial_dim: int) -> str: + """Interpolation mode argument for paddle.nn.functional.grid_sample() for given number of spatial dimensions.""" + if self == self.LINEAR: + return "bilinear" + if self == self.NEAREST: + return "nearest" + raise ValueError( + f"paddle.nn.functional.grid_sample() does not support padding mode '{self.value}' for {num_spatial_dim}-dimensional images" + ) + + def interpolate_mode(self, num_spatial_dim: int) -> str: + """Interpolation mode argument for paddle.nn.functional.interpolate() for given number of spatial dimensions.""" + if self == self.AREA: + return "area" + if self == self.BICUBIC: + return "bicubic" + if self == self.LINEAR: + if num_spatial_dim == 1: + return "linear" + if num_spatial_dim == 2: + return "bilinear" + if num_spatial_dim == 3: + return "trilinear" + if self == self.NEAREST: + return "nearest" + raise ValueError( + f"paddle.nn.functional.interpolate() does not support padding mode '{self.value}' for {num_spatial_dim}-dimensional images" + ) + + +class PaddingMode(Enum): + """Enumeration of image extrapolation modes.""" + + NONE = "none" + CONSTANT = "constant" + BORDER = "border" + REFLECT = "reflect" + REPLICATE = "replicate" + ZEROS = "zeros" + + @classmethod + def from_arg(cls, arg: Union[PaddingMode, str, None]) -> PaddingMode: + """Create enumeration value from function argument.""" + if isinstance(arg, str): + arg = arg.lower() + if arg is None or arg == "default": + return cls.ZEROS + if arg in ("mirror", "reflection"): + return cls.REFLECT + if arg == "circular": + return cls.REPLICATE + return cls(arg) + + def conv_mode(self, num_spatial_dim: int = 3) -> str: + """Padding mode argument for paddle.nn.ConvNd().""" + if self in (self.CONSTANT, self.ZEROS): + return "zeros" + elif self == self.REFLECT: + return "reflect" + elif self == self.REPLICATE: + return "replicate" + raise ValueError( + f"paddle.nn.Conv{num_spatial_dim}d() does not support padding mode '{self.value}'" + ) + + def grid_sample_mode(self, num_spatial_dim: int) -> str: + """Padding mode argument for paddle.nn.functional.grid_sample().""" + if 2 <= num_spatial_dim <= 3: + if self in (self.CONSTANT, self.ZEROS): + return "zeros" + if self == self.BORDER: + return "border" + if self == self.REFLECT: + return "reflection" + raise ValueError( + f"paddle.nn.functional.grid_sample() does not support padding mode '{self.value}' for {num_spatial_dim}-dimensional images" + ) + + def pad_mode(self, num_spatial_dim: int) -> str: + """Padding mode argument for paddle.nn.functional.pad() for given number of spatial dimensions.""" + if self == self.CONSTANT: + return "constant" + elif self == self.REFLECT: + if 1 <= num_spatial_dim <= 2: + return "reflect" + elif self == self.REPLICATE: + if 1 <= num_spatial_dim <= 3: + return "replicate" + raise ValueError( + f"paddle.nn.functional.pad() does not support padding mode '{self.value}' for {num_spatial_dim}-dimensional images" + ) + + +class SpatialDim(IntEnum): + """Spatial image dimension selector.""" + + X = 0 + Y = 1 + Z = 2 + T = 3 + + @classmethod + def from_arg(cls, arg: Union[int, str, SpatialDim]) -> SpatialDim: + """Get enumeration value from function argument.""" + if arg in ("x", "X"): + return cls.X + if arg in ("y", "Y"): + return cls.Y + if arg in ("z", "Z"): + return cls.Z + if arg in ("t", "T"): + return cls.T + return cls(arg) + + def tensor_dim(self, ndim: int, channels_last: bool = False) -> int: + """Map spatial dimension identifier to image data tensor dimension.""" + dim = ndim - (2 if channels_last else 1) - self.value + if ( + channels_last + and (dim < 1 or dim > ndim - 2) + or not channels_last + and (dim < 2 or dim > ndim - 1) + ): + raise ValueError("SpatialDim.tensor_dim() is out-of-bounds") + return dim + + def __str__(self) -> str: + """Letter of spatial dimension.""" + return ("x", "y", "z", "t")[self.value] + + +SpatialDimArg = Union[int, str, SpatialDim] + + +class SpatialDerivativeKeys(object): + """Auxiliary functions for identifying and enumerating spatial derivatives. + + Spatial derivatives are encoded by a sequence of letters, where each letter + identifies the spatial dimension (cf. ``SpatialDim``) along which to take the + derivative. The length of the string encoding determines the order of the + derivative, i.e., how many times the input image is being derived with + respect to one or more spatial dimensions. + + """ + + @staticmethod + def check(arg: Union[str, Sequence[str]]): + """Check if given derivatives key is valid.""" + if isinstance(arg, str): + arg = (arg,) + for key in arg: + if not isinstance(key, str): + raise TypeError("Spatial derivatives key must be str") + if re.search("[^xyzt]", key): + raise ValueError( + "Spatial derivatives key must only contain letters 'x', 'y', 'z', or 't'" + ) + + @classmethod + def is_valid(cls, arg: Union[str, Sequence[str]]) -> bool: + """Check if given derivatives key is valid.""" + try: + cls.check(arg) + except (TypeError, ValueError): + return False + return True + + @staticmethod + def is_mixed(key: str) -> bool: + """Whether derivative contains mixed terms.""" + return len(set(key)) > 1 + + @staticmethod + def all(ndim: int, order: Union[int, Sequence[int]]) -> Tuple[str, ...]: + """Unmixed spatial derivatives of specified order.""" + if isinstance(order, int): + order = (order,) + keys = [] + dims = [str(SpatialDim(d)) for d in range(ndim)] + for n in order: + if n > 0: + codes = dims + for _ in range(1, n): + codes = [ + (code + letter) + for code, letter in itertools.product(codes, dims) + ] + keys.extend(codes) + return keys + + @staticmethod + def unmixed(ndim: int, order: int) -> Tuple[str, ...]: + """Unmixed spatial derivatives of specified order.""" + if order <= 0: + return () + return tuple(str(SpatialDim(d)) * order for d in range(ndim)) + + @classmethod + def unique(cls, keys: Sequence[str]) -> Set[str]: + """Unique spatial derivatives.""" + return set(cls.sorted(key) for key in keys) + + @classmethod + def sorted(cls, key: str) -> str: + """Sort letters of spatial dimensions in spatial derivative key.""" + return cls.join(sorted(cls.split(key))) + + @staticmethod + def order(arg: str) -> int: + """Order of the spatial derivative.""" + return len(arg) + + @classmethod + def max_order(cls, keys: Sequence[str]) -> int: + if not keys: + return 0 + return max(cls.order(key) for key in keys) + + @staticmethod + def split(arg: str) -> Tuple[SpatialDim, ...]: + """Split spatial derivative key into spatial dimensions enum values.""" + return tuple(SpatialDim.from_arg(letter) for letter in arg) + + @staticmethod + def join(arg: Sequence[SpatialDim]) -> str: + """Join spatial dimensions to spatial derivative key.""" + return "".join(str(x) for x in arg) diff --git a/jointContribution/HighResolution/deepali/core/flow.py b/jointContribution/HighResolution/deepali/core/flow.py new file mode 100644 index 0000000000..1994eeac5f --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/flow.py @@ -0,0 +1,482 @@ +from typing import Optional +from typing import Union + +import paddle + +from . import affine as A +from .enum import PaddingMode +from .enum import Sampling +from .grid import ALIGN_CORNERS +from .grid import Grid +from .image import _image_size +from .image import check_sample_grid +from .image import grid_reshape +from .image import grid_sample +from .image import spatial_derivatives +from .image import zeros_image +from .tensor import move_dim +from .types import Array +from .types import Device +from .types import Scalar +from .types import Shape +from .types import Size + + +def affine_flow( + matrix: paddle.Tensor, grid: Union[Grid, paddle.Tensor], channels_last: bool = False +) -> paddle.Tensor: + """Compute dense flow field from homogeneous transformation. + + Args: + matrix: Homogeneous coordinate transformation matrices of shape ``(N, D, 1)`` (translation), + ``(N, D, D)`` (affine), or ``(N, D, D + 1)`` (homogeneous), respectively. + grid: Image sampling ``Grid`` or tensor of shape ``(N, ..., X, D)`` of points at + which to sample flow fields. If an object of type ``Grid`` is given, the value + of ``grid.align_corners()`` determines if output flow vectors are with respect to + ``Axes.CUBE`` (False) or ``Axes.CUBE_CORNERS`` (True), respectively. + channels_last: If ``True``, flow vector components are stored in the last dimension + of the output tensor, and first dimension otherwise. + + Returns: + paddle.Tensor of shape ``(N, C, ..., X)`` if ``channels_last=False`` and ``(N, ..., X, C)``, otherwise. + + """ + if matrix.ndim != 3: + raise ValueError( + f"affine_flow() 'matrix' must be tensor of shape (N, D, 1|D|D+1), not {tuple(matrix.shape)}" + ) + device = matrix.place + if isinstance(grid, Grid): + grid = grid.coords(device=device) + grid = grid.unsqueeze(axis=0) + elif grid.ndim < 3: + raise ValueError( + f"affine_flow() 'grid' must be tensor of shape (N, ...X, D), not {tuple(grid.shape)}" + ) + assert grid.place == device + flow = A.transform_points(matrix, grid) - grid + if not channels_last: + flow = move_dim(flow, -1, 1) + assert flow.place == device + return flow + + +def compose_flows( + a: paddle.Tensor, b: paddle.Tensor, align_corners: bool = True +) -> paddle.Tensor: + """Compute composite flow field ``c = b o a``.""" + a = move_dim(b, 1, -1) + c = paddle.nn.functional.grid_sample( + x=b, grid=a, mode="bilinear", padding_mode="border", align_corners=align_corners + ) + return c + + +def curl( + flow: paddle.Tensor, + spacing: Optional[Union[Scalar, Array]] = None, + mode: str = "central", +) -> paddle.Tensor: + """Calculate curl of vector field. + + TODO: Implement curl for 2D vector field. + + Args: + flow: Vector field as tensor of shape ``(N, 3, Z, Y, X)``. + spacing: Physical size of image voxels used to compute ``spatial_derivatives()``. + mode: Mode of ``spatial_derivatives()`` approximation. + + Returns: + In case of a 3D input vector field, output is another 3D vector field of rotation vectors, + where axis of rotation corresponds to the unit vector and rotation angle to the magnitude + of the rotation vector, as tensor of shape ``(N, 3, Z, Y, X)``. + + """ + if flow.ndim == 4: + if tuple(flow.shape)[1] != 2: + raise ValueError("curl() 'flow' must have shape (N, 2, Y, X)") + raise NotImplementedError("curl() of 2-dimensional vector field") + if flow.ndim == 5: + if tuple(flow.shape)[1] != 3: + raise ValueError("curl() 'flow' must have shape (N, 3, Z, Y, X)") + start_7 = flow.shape[1] + 0 if 0 < 0 else 0 + dx = spatial_derivatives( + paddle.slice(flow, [1], [start_7], [start_7 + 1]), + mode=mode, + which=("y", "z"), + spacing=spacing, + ) + start_8 = flow.shape[1] + 1 if 1 < 0 else 1 + dy = spatial_derivatives( + paddle.slice(flow, [1], [start_8], [start_8 + 1]), + mode=mode, + which=("x", "z"), + spacing=spacing, + ) + start_9 = flow.shape[1] + 2 if 2 < 0 else 2 + dz = spatial_derivatives( + paddle.slice(flow, [1], [start_9], [start_9 + 1]), + mode=mode, + which=("x", "y"), + spacing=spacing, + ) + rotvec = paddle.concat( + x=[dz["y"] - dy["z"], dx["z"] - dz["x"], dy["x"] - dx["y"]], axis=1 + ) + return rotvec + raise ValueError("curl() 'flow' must be 2- or 3-dimensional vector field") + + +def expv( + flow: paddle.Tensor, + scale: Optional[float] = None, + steps: Optional[int] = None, + sampling: Union[Sampling, str] = Sampling.LINEAR, + padding: Union[PaddingMode, str] = PaddingMode.BORDER, + align_corners: bool = ALIGN_CORNERS, +) -> paddle.Tensor: + """Group exponential maps of flow fields computed using scaling and squaring. + + Args: + flow: Batch of flow fields as tensor of shape ``(N, D, ..., X)``. + scale: Constant flow field scaling factor. + steps: Number of scaling and squaring steps. + sampling: Flow field interpolation mode. + padding: Flow field extrapolation mode. + align_corners: Whether ``flow`` vectors are defined with respect to + ``Axes.CUBE`` (False) or ``Axes.CUBE_CORNERS`` (True). + + Returns: + Exponential map of input flow field. If ``steps=0``, a reference to ``flow`` is returned. + + """ + if scale is None: + scale = 1 + if steps is None: + steps = 5 + if not isinstance(steps, int): + raise TypeError("expv() 'steps' must be of type int") + if steps < 0: + raise ValueError("expv() 'steps' must be positive value") + if steps == 0: + return flow + device = flow.place + grid = Grid(shape=tuple(flow.shape)[2:], align_corners=align_corners) + grid = grid.coords(dtype=flow.dtype, device=device) + assert grid.place == device + disp = flow * (scale / 2**steps) + assert disp.place == device + for _ in range(steps): + disp = disp + warp_image( + disp, + grid, + flow=move_dim(disp, 1, -1), + mode=sampling, + padding=padding, + align_corners=align_corners, + ) + assert disp.place == device + return disp + + +def jacobian_det( + u: paddle.Tensor, mode: str = "central", channels_last: bool = False +) -> paddle.Tensor: + """Evaluate Jacobian determinant of given flow field using finite difference approximations. + + Note that for differentiable parametric spatial transformation models, an accurate Jacobian could + be calculated instead from the given transformation parameters. See for example ``cubic_bspline_jacobian_det()`` + which is specialized for a free-form deformation (FFD) determined by a continuous cubic B-spline function. + + Args: + u: Input vector field as tensor of shape ``(N, D, ..., X)`` when ``channels_last=False`` and + shape ``(N, ..., X, D)`` when ``channels_last=True``. + mode: Mode of ``spatial_derivatives()`` to use for approximating spatial partial derivatives. + channels_last: Whether input vector field has vector (channels) dimension at second or last index. + + Returns: + Scalar field of approximate Jacobian determinant values as tensor of shape ``(N, 1, ..., X)`` when + ``channels_last=False`` and ``(N, ..., X, 1)`` when ``channels_last=True``. + + """ + if u.ndim < 4: + shape_str = "(N, ..., X, D)" if channels_last else "(N, D, ..., X)" + raise ValueError( + f"jacobian_det() 'u' must be dense vector field of shape {shape_str}" + ) + shape = tuple(u.shape)[1:-1] if channels_last else tuple(u.shape)[2:] + mat = paddle.empty(shape=(tuple(u.shape)[0],) + shape + (3, 3), dtype=u.dtype) + for i, which in enumerate(("x", "y", "z")): + deriv = spatial_derivatives(u, mode=mode, which=which)[which] + if not channels_last: + deriv = move_dim(deriv, 1, -1) + mat[..., i] = deriv + for i in range(tuple(mat.shape)[-1]): + mat[..., i, i].add_(y=paddle.to_tensor(1)) + jac = mat.det().unsqueeze_(axis=-1 if channels_last else 1) + return jac + + +def normalize_flow( + data: paddle.Tensor, + size: Optional[Union[paddle.Tensor, list]] = None, + side_length: float = 2, + align_corners: bool = ALIGN_CORNERS, + channels_last: bool = False, +) -> paddle.Tensor: + """Map vectors with respect to unnormalized grid to vectors with respect to normalized grid.""" + if not isinstance(data, paddle.Tensor): + raise TypeError("normalize_flow() 'data' must be tensor") + if not data.is_floating_point(): + data = data.astype(dtype="float32") + if size is None: + if data.ndim < 4 or tuple(data.shape)[1] != data.ndim - 2: + raise ValueError( + "normalize_flow() 'data' must have shape (N, D, ..., X) when 'size' not given" + ) + size = tuple(reversed(tuple(data.shape)[2:])) + zero = paddle.to_tensor(data=0, dtype=data.dtype, place=data.place) + size = paddle.to_tensor(data=size, dtype=data.dtype, place=data.place) + size_ = size.sub(1) if align_corners else size + if not channels_last: + data = move_dim(data, 1, -1) + if side_length != 1: + data = data.mul(side_length) + data = paddle.where(condition=size > 1, x=data.div(size_), y=zero) + if not channels_last: + data = move_dim(data, -1, 1) + return data + + +def denormalize_flow( + data: paddle.Tensor, + size: Optional[Union[paddle.Tensor, list]] = None, + side_length: float = 2, + align_corners: bool = ALIGN_CORNERS, + channels_last: bool = False, +) -> paddle.Tensor: + """Map vectors with respect to normalized grid to vectors with respect to unnormalized grid.""" + if not isinstance(data, paddle.Tensor): + raise TypeError("denormalize_flow() 'data' must be tensors") + if size is None: + if data.ndim < 4 or tuple(data.shape)[1] != data.ndim - 2: + raise ValueError( + "denormalize_flow() 'data' must have shape (N, D, ..., X) when 'size' not given" + ) + size = tuple(reversed(tuple(data.shape)[2:])) + zero = paddle.to_tensor(data=0, dtype=data.dtype, place=data.place) + size = paddle.to_tensor(data=size, dtype=data.dtype, place=data.place) + size_ = size.sub(1) if align_corners else size + if not channels_last: + data = move_dim(data, 1, -1) + data = paddle.where(condition=size > 1, x=data.mul(size_), y=zero) + if side_length != 1: + data = data.div(side_length) + if not channels_last: + data = move_dim(data, -1, 1) + return data + + +def sample_flow( + flow: paddle.Tensor, coords: paddle.Tensor, align_corners: bool = ALIGN_CORNERS +) -> paddle.Tensor: + """Sample non-rigid flow fields at given points. + + This function samples a vector field at spatial points. The ``coords`` tensor can be of any shape, + including ``(N, M, D)``, i.e., a batch of N point sets with cardianality M. It can also be applied to + a tensor of grid points of shape ``(N, ..., X, D)`` regardless if the grid points are located at the + undeformed grid positions or an already deformed grid. The given non-rigid flow field is interpolated + at the input points ``x`` using linear interpolation. These flow vectors ``u(x)`` are returned. + + Args: + flow: Flow fields of non-rigid transformations given as tensor of shape ``(N, D, ..., X)`` + or ``(1, D, ..., X)``. If batch size is one, but the batch size of ``coords`` is greater + than one, this single flow fields is sampled at the different sets of points. + coords: Normalized coordinates of points given as tensor of shape ``(N, ..., D)`` + or ``(1, ..., D)``. If batch size is one, all flow fields are sampled at the same points. + align_corners: Whether point coordinates are with respect to ``Axes.CUBE`` (False) + or ``Axes.CUBE_CORNERS`` (True). This option is in particular passed on to the + ``grid_sample()`` function used to sample the flow vectors at the input points. + + Returns: + paddle.Tensor of shape ``(N, ..., D)``. + + """ + if not isinstance(flow, paddle.Tensor): + raise TypeError("sample_flow() 'flow' must be of type paddle.Tensor") + if flow.ndim < 4: + raise ValueError("sample_flow() 'flow' must be at least 4-dimensional tensor") + if not isinstance(coords, paddle.Tensor): + raise TypeError("sample_flow() 'coords' must be of type paddle.Tensor") + if coords.ndim < 2: + raise ValueError("sample_flow() 'coords' must be at least 2-dimensional tensor") + G = tuple(flow.shape)[0] + N = tuple(coords.shape)[0] if G == 1 else G + D = tuple(flow.shape)[1] + if tuple(coords.shape)[0] not in (1, N): + raise ValueError(f"sample_flow() 'coords' must be batch of length 1 or {N}") + if tuple(coords.shape)[-1] != D: + raise ValueError( + f"sample_flow() 'coords' must be tensor of {D}-dimensional points" + ) + x = coords.expand(shape=(N,) + tuple(coords.shape)[1:]) + t = flow.expand(shape=(N,) + tuple(flow.shape)[1:]) + g = x.reshape((N,) + (1,) * (t.ndim - 3) + (-1, D)) + u = grid_sample(t, g, padding=PaddingMode.BORDER, align_corners=align_corners) + u = move_dim(u, 1, -1) + u = u.reshape(tuple(x.shape)) + return u + + +def warp_grid( + flow: paddle.Tensor, grid: paddle.Tensor, align_corners: bool = ALIGN_CORNERS +) -> paddle.Tensor: + """Transform undeformed grid by a tensor of non-rigid flow fields. + + This function applies a non-rigid transformation to map a tensor of undeformed grid points to a + tensor of deformed grid points with the same shape as the input tensor. The input points must be + the positions of undeformed spatial grid points, because this function uses interpolation to + resize the vector fields to the size of the input ``grid``. This assumes that input points ``x`` + are the coordinates of points located on a regularly spaced undeformed grid which is aligned with + the borders of the grid domain on which the vector fields of the non-rigid transformations are + sampled, i.e., ``y = x + u``. + + If in doubt whether the input points will be sampled regularly at grid points of the domain of + the spatial transformation, use ``warp_points()`` instead. + + Args: + flow: Flow fields of non-rigid transformations given as tensor of shape ``(N, D, ..., X)`` + or ``(1, D, ..., X)``. If batch size is one, but the batch size of ``points`` is greater + than one, all point sets are transformed by the same non-rigid transformation. + grid: Coordinates of points given as tensor of shape ``(N, ..., D)`` or ``(1, ..., D)``. + If batch size is one, but multiple flow fields are given, this single point set is + transformed by each non-rigid transformation to produce ``N`` output point sets. + align_corners: Whether grid points and flow vectors are with respect to ``Axes.CUBE`` + (False) or ``Axes.CUBE_CORNERS`` (True). This option is in particular passed on to + the ``grid_reshape()`` function used to resize the flow fields to the ``grid`` shape. + + Returns: + paddle.Tensor of shape ``(N, ..., D)`` with coordinates of spatially transformed points. + + """ + if not isinstance(flow, paddle.Tensor): + raise TypeError("warp_grid() 'flow' must be of type paddle.Tensor") + if flow.ndim < 4: + raise ValueError("warp_grid() 'flow' must be at least 4-dimensional tensor") + if not isinstance(grid, paddle.Tensor): + raise TypeError("warp_grid() 'grid' must be of type paddle.Tensor") + if grid.ndim < 4: + raise ValueError("warp_grid() 'grid' must be at least 4-dimensional tensor") + G = tuple(flow.shape)[0] + N = tuple(grid.shape)[0] if G == 1 else G + D = tuple(flow.shape)[1] + if tuple(grid.shape)[0] not in (1, N): + raise ValueError(f"warp_grid() 'grid' must be batch of length 1 or {N}") + if tuple(grid.shape)[-1] != D: + raise ValueError(f"warp_grid() 'grid' must be tensor of {D}-dimensional points") + x = grid.expand(shape=(N,) + tuple(grid.shape)[1:]) + t = flow.expand(shape=(N,) + tuple(flow.shape)[1:]) + u = grid_reshape(t, tuple(grid.shape)[1:-1], align_corners=align_corners) + u = move_dim(u, 1, -1).reshape(x.shape) + y = x + u + return y + + +def warp_points( + flow: paddle.Tensor, coords: paddle.Tensor, align_corners: bool = ALIGN_CORNERS +) -> paddle.Tensor: + """Transform set of points by a tensor of non-rigid flow fields. + + This function applies a non-rigid transformation to map a tensor of spatial points to another tensor + of spatial points of the same shape as the input tensor. Unlike ``warp_grid()``, it can be used + to spatially transform any set of points which are defined with respect to the grid domain of the + non-rigid transformation, including a tensor of shape ``(N, M, D)``, i.e., a batch of N point sets with + cardianality M. It can also be applied to a tensor of grid points of shape ``(N, ..., X, D)`` regardless + if the grid points are located at the undeformed grid positions or an already deformed grid. The given + non-rigid flow field is interpolated at the input points ``x`` using linear interpolation. These flow + vectors ``u(x)`` are then added to the input points, i.e., ``y = x + u(x)``. + + Args: + flow: Flow fields of non-rigid transformations given as tensor of shape ``(N, D, ..., X)`` + or ``(1, D, ..., X)``. If batch size is one, but the batch size of ``points`` is greater + than one, all point sets are transformed by the same non-rigid transformation. + coords: Normalized coordinates of points given as tensor of shape ``(N, ..., D)`` + or ``(1, ..., D)``. If batch size is one, this single point set is deformed by each + flow field to produce ``N`` output point sets. + align_corners: Whether points and flow vectors are with respect to ``Axes.CUBE`` (False) + or ``Axes.CUBE_CORNERS`` (True). This option is in particular passed on to the + ``grid_sample()`` function used to sample the flow vectors at the input points. + + Returns: + paddle.Tensor of shape ``(N, ..., D)`` with coordinates of spatially transformed points. + + """ + x = coords + u = sample_flow(flow, coords, align_corners=align_corners) + y = x + u + return y + + +def warp_image( + data: paddle.Tensor, + grid: paddle.Tensor, + flow: Optional[paddle.Tensor] = None, + mode: Optional[Union[Sampling, str]] = None, + padding: Optional[Union[PaddingMode, str, Scalar]] = None, + align_corners: bool = ALIGN_CORNERS, +) -> paddle.Tensor: + """Sample data at optionally displaced grid points. + + Args: + data: Image batch tensor of shape ``(1, C, ..., X)`` or ``(N, C, ..., X)``. + grid: Grid points tensor of shape ``(..., X, D)``, ``(1, ..., X, D)``, or``(N, ..., X, D)``. + Coordinates of points at which to sample ``data`` must be with respect to ``Axes.CUBE``. + flow: Batch of flow fields of shape ``(..., X, D)``, ``(1, ..., X, D)``, or``(N, ..., X, D)``. + If specified, the flow field(s) are added to ``grid`` in order to displace the grid points. + mode: Image interpolate mode. + padding: Image extrapolation mode or constant by which to pad input ``data``. + align_corners: Whether ``grid`` extrema ``(-1, 1)`` refer to the grid boundary + edges (``align_corners=False``) or corner points (``align_corners=True``). + + Returns: + Image batch tensor of sampled data with shape determined by ``grid``. + + """ + if data.ndim < 4: + raise ValueError("warp_image() expected tensor 'data' of shape (N, C, ..., X)") + grid = check_sample_grid("warp", data, grid) + N = tuple(grid.shape)[0] + D = tuple(grid.shape)[-1] + if flow is not None: + if flow.ndim == data.ndim - 1: + flow = flow.unsqueeze(axis=0) + elif flow.ndim != data.ndim: + raise ValueError( + f"warp_image() expected 'flow' tensor with {data.ndim - 1} or {data.ndim} dimensions" + ) + if tuple(flow.shape)[0] != N: + flow = flow.expand(shape=[N, *tuple(flow.shape)[1:]]) + if tuple(flow.shape)[0] != N or tuple(flow.shape)[-1] != D: + msg = f"warp_image() expected tensor 'flow' of shape (..., X, {D})" + msg += f" or (1, ..., X, {D})" if N == 1 else f" or (1|{N}, ..., X, {D})" + raise ValueError(msg) + grid = grid + flow + assert data.place == grid.place + return grid_sample( + data, grid, mode=mode, padding=padding, align_corners=align_corners + ) + + +def zeros_flow( + size: Optional[Union[int, Size, Grid]] = None, + shape: Optional[Shape] = None, + num: int = 1, + named: bool = False, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Create batch of flow fields filled with zeros for given image batch size or grid.""" + size = _image_size("zeros_flow", size, shape) + return zeros_image( + size, num=num, channels=len(size), named=named, dtype=dtype, device=device + ) diff --git a/jointContribution/HighResolution/deepali/core/functional.py b/jointContribution/HighResolution/deepali/core/functional.py new file mode 100644 index 0000000000..60e2c4090c --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/functional.py @@ -0,0 +1,224 @@ +"""Access most core library functions with a single import. + +.. code-block:: python + + import deepali.core.functional as U + +""" +from .affine import affine_rotation_matrix +from .affine import apply_transform as apply_affine_transform +from .affine import euler_rotation_angles +from .affine import euler_rotation_matrix +from .affine import euler_rotation_order +from .affine import identity_transform +from .affine import rotation_matrix +from .affine import scaling_transform +from .affine import shear_matrix +from .affine import transform_points as affine_transform_points +from .affine import transform_vectors as affine_transform_vectors +from .affine import translation +from .bspline import bspline_interpolation_weights +from .bspline import cubic_bspline_control_point_grid +from .bspline import cubic_bspline_control_point_grid_size +from .bspline import cubic_bspline_jacobian_det +from .bspline import cubic_bspline_jacobian_dict +from .bspline import cubic_bspline_jacobian_matrix +from .bspline import cubic_bspline_jacobian_triu +from .bspline import evaluate_cubic_bspline +from .bspline import subdivide_cubic_bspline +from .flow import affine_flow +from .flow import compose_flows +from .flow import denormalize_flow +from .flow import expv +from .flow import jacobian_det +from .flow import normalize_flow +from .flow import sample_flow +from .flow import warp_grid +from .flow import warp_image +from .flow import warp_points +from .flow import zeros_flow +from .image import avg_pool +from .image import center_crop +from .image import center_pad +from .image import circle_image +from .image import conv +from .image import conv1d +from .image import crop +from .image import cshape_image +from .image import dot_batch +from .image import dot_channels +from .image import downsample +from .image import empty_image +from .image import fill_border +from .image import finite_differences +from .image import flatten_channels +from .image import gaussian_pyramid +from .image import grid_image +from .image import grid_resample +from .image import grid_reshape +from .image import grid_resize +from .image import grid_sample +from .image import grid_sample_mask +from .image import image_slice +from .image import max_pool +from .image import min_pool +from .image import normalize_image +from .image import ones_image +from .image import pad +from .image import rand_sample +from .image import rescale +from .image import sample_image +from .image import spatial_derivatives +from .image import upsample +from .image import zeros_image +from .linalg import angle_axis_to_quaternion +from .linalg import angle_axis_to_rotation_matrix +from .linalg import as_homogeneous_matrix +from .linalg import as_homogeneous_tensor +from .linalg import hmm +from .linalg import homogeneous_matmul +from .linalg import homogeneous_matrix +from .linalg import homogeneous_transform +from .linalg import normalize_quaternion +from .linalg import quaternion_exp_to_log +from .linalg import quaternion_log_to_exp +from .linalg import quaternion_to_angle_axis +from .linalg import quaternion_to_rotation_matrix +from .linalg import rotation_matrix_to_angle_axis +from .linalg import rotation_matrix_to_quaternion +from .linalg import tensordot +from .linalg import vector_rotation +from .linalg import vectordot +from .math import abspow +from .math import atanh +from .math import max_difference +from .math import round_decimals +from .math import threshold +from .pointset import bounding_box +from .pointset import closest_point_distances +from .pointset import closest_point_indices +from .pointset import denormalize_grid +from .pointset import distance_matrix +from .pointset import normalize_grid +from .pointset import polyline_directions +from .pointset import polyline_tangents +from .pointset import transform_grid +from .pointset import transform_points +from .tensor import as_float_tensor +from .tensor import as_one_hot_tensor +from .tensor import as_tensor +from .tensor import atleast_1d +from .tensor import batched_index_select +from .tensor import move_dim +from .tensor import unravel_coords +from .tensor import unravel_index + +__all__ = ( + "abspow", + "as_tensor", + "as_float_tensor", + "as_one_hot_tensor", + "atanh", + "atleast_1d", + "batched_index_select", + "max_difference", + "move_dim", + "round_decimals", + "threshold", + "unravel_coords", + "unravel_index", + "affine_flow", + "affine_rotation_matrix", + "affine_transform_points", + "affine_transform_vectors", + "angle_axis_to_rotation_matrix", + "angle_axis_to_quaternion", + "apply_affine_transform", + "as_homogeneous_matrix", + "as_homogeneous_tensor", + "euler_rotation_matrix", + "euler_rotation_angles", + "euler_rotation_order", + "hmm", + "homogeneous_matmul", + "homogeneous_matrix", + "homogeneous_transform", + "identity_transform", + "normalize_quaternion", + "quaternion_to_angle_axis", + "quaternion_to_rotation_matrix", + "quaternion_log_to_exp", + "quaternion_exp_to_log", + "rotation_matrix", + "rotation_matrix_to_angle_axis", + "rotation_matrix_to_quaternion", + "scaling_transform", + "shear_matrix", + "tensordot", + "translation", + "vectordot", + "vector_rotation", + "avg_pool", + "bounding_box", + "bspline_interpolation_weights", + "center_crop", + "center_pad", + "circle_image", + "closest_point_distances", + "closest_point_indices", + "compose_flows", + "conv", + "conv1d", + "crop", + "cshape_image", + "cubic_bspline_control_point_grid", + "cubic_bspline_control_point_grid_size", + "cubic_bspline_jacobian_det", + "cubic_bspline_jacobian_dict", + "cubic_bspline_jacobian_matrix", + "cubic_bspline_jacobian_triu", + "denormalize_flow", + "denormalize_grid", + "distance_matrix", + "dot_batch", + "dot_channels", + "downsample", + "empty_image", + "evaluate_cubic_bspline", + "expv", + "flatten_channels", + "finite_differences", + "gaussian_pyramid", + "grid_image", + "image_slice", + "pad", + "fill_border", + "grid_resample", + "grid_reshape", + "grid_resize", + "grid_sample", + "grid_sample_mask", + "jacobian_det", + "max_pool", + "min_pool", + "normalize_flow", + "normalize_grid", + "normalize_image", + "ones_image", + "polyline_directions", + "polyline_tangents", + "rand_sample", + "rescale", + "sample_flow", + "sample_image", + "spatial_derivatives", + "subdivide_cubic_bspline", + "transform_grid", + "transform_points", + "upsample", + "warp_grid", + "warp_image", + "warp_points", + "zeros_flow", + "zeros_image", +) diff --git a/jointContribution/HighResolution/deepali/core/grid.py b/jointContribution/HighResolution/deepali/core/grid.py new file mode 100644 index 0000000000..24cfd6e82c --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/grid.py @@ -0,0 +1,1777 @@ +from __future__ import annotations + +from copy import copy as shallow_copy +from enum import Enum +from typing import TYPE_CHECKING +from typing import Any +from typing import Dict +from typing import Optional +from typing import Sequence +from typing import Union +from typing import overload + +import numpy as np +import paddle +from pkg_resources import parse_version + +try: + import SimpleITK as _sitk +except ImportError: + _sitk = None +if TYPE_CHECKING: + from .cube import Cube + +from .enum import SpatialDim +from .enum import SpatialDimArg +from .linalg import hmm +from .linalg import homogeneous_matrix +from .linalg import homogeneous_transform +from .math import round_decimals +from .tensor import as_tensor +from .tensor import cat_scalars +from .types import Array +from .types import Device +from .types import PathStr +from .types import ScalarOrTuple +from .types import Shape +from .types import Size +from .types import is_int_dtype + +ALIGN_CORNERS = True +"""By default, grid corner points define the domain within which the data is defined. + +The normalized coordinates with respect to the grid cube are thus in [-1, 1] between +the first and last grid points (grid cell center points). This is such that when a +grid is downsampled or upsampled, respectively, the grid points at the boundary of +the domain remain unchanged. Points are only inserted or removed between these. + +""" + + +class Axes(Enum): + """Enumeration of grid axes with respect to which grid coordinates are defined.""" + + GRID = "grid" + """Oriented along grid axes with units corresponding to voxel units with origin + with respect to world space at grid/image point with zero indices.""" + CUBE = "cube" + """Oriented along grid axes with units corresponding to unit cube :math:`[-1, 1]^D` + with origin with respect to world space at grid/image center and extrema -1/1 + coinciding with the grid border (``align_corners=False``).""" + CUBE_CORNERS = "cube_corners" + """Oriented along grid axes with units corresponding to unit cube :math:`[-1, 1]^D` + with origin with respect to world space at grid/image center and extrema -1/1 + coinciding with the grid corner points (``align_corners=True``).""" + WORLD = "world" + """Oriented along world axes (physical space) with units corresponding to grid spacing (mm).""" + + @classmethod + def from_arg(cls, arg: Union[Axes, str, None]) -> Axes: + """Create enumeration value from function argument.""" + if arg is None or arg == "default": + return cls.CUBE_CORNERS if ALIGN_CORNERS else cls.CUBE + return cls(arg) + + @classmethod + def from_grid(cls, grid: Grid) -> Axes: + """Create enumeration value from sampling grid object.""" + return cls.from_align_corners(grid.align_corners()) + + @classmethod + def from_align_corners(cls, align_corners: bool) -> Axes: + """Create enumeration value from sampling grid object.""" + return cls.CUBE_CORNERS if align_corners else cls.CUBE + + +class Grid(object): + """Oriented regularly spaced data sampling grid. + + The dimensions of :attr:`.Grid.shape` are in reverse order of the dimensions of :meth:`.Grid.size`. + The latter is consistent with SimpleITK, and the order of coordinates of :meth:`.Grid.origin`, + :meth:`.Grid.spacing`, and :meth:`.Grid.direction`. Property :attr:`.Grid.shape`, on the other hand, + is consistent with the order of dimensions of an image data tensor of type ``paddle.Tensor``. + + To not confuse :meth:`.Grid.size` with ``paddle.Tensor.size()``, it is recommended to prefer + property ``paddle.Tensor.shape``. The ``shape`` property is also known from ``numpy.ndarray.shape``. + + A :class:`.Grid` instance stores the grid center point instead of the origin corresponding to the grid + point with zero index along each grid dimension. This simplifies resizing and resampling operations, + which do not need to modify the origin explicitly, but keep the center point fixed. To get the + coordinates of the grid origin, use :meth:`.Grid.origin`. For convenience, the :class:`.Grid` + initialization function also accepts an ``origin`` instead of a ``center`` point as keyword argument. + Conversion between center point and origin are taken care internally. When both ``origin`` and ``center`` + are specified, an error is raised if these are inconsistent with one another. + + In addition, :meth:`.Grid.points`, :meth:`.Grid.transform`, and :meth`.Grid.apply_transform` support + coordinates with respect to different axes: 1) (continuous) grid indices, 2) world space, and + 3) grid-aligned cube with side length 2. The latter, i.e., :attr:`.Axes.CUBE` or :attr:`.Axes.CUBE_CORNERS` + makes coordinates independent of :meth:`.Grid.size` and :meth:`.Grid.spacing`. These normalized coordinates + are furthermore compatible with the ``grid`` argument of ``paddle.nn.functional.grid_sample()``. Use + :meth:`.Grid.cube` to obtain a :class:`.Cube` defining the data domain without spatial sampling attributes. + + """ + + __slots__ = "_size", "_center", "_spacing", "_direction", "_align_corners" + + def __init__( + self, + size: Optional[Union[Size, Array]] = None, + shape: Optional[Union[Shape, Array]] = None, + center: Optional[Union[Array, float]] = None, + origin: Optional[Union[Array, float]] = None, + spacing: Optional[Union[Array, float]] = None, + direction: Optional[Array] = None, + device: Optional[Device] = None, + align_corners: bool = ALIGN_CORNERS, + ): + """Initialize sampling grid attributes. + + Args: + size: Size of spatial grid dimensions in the order ``(X, ...)``. + shape: Size of spatial grid dimensions in the order ``(..., X)``. + Must be ``None`` or ``size`` reversed, if ``size is not None``. + Either ``size`` or ``shape`` must be specified. + center: Grid center point in world space. + origin: World coordinates of grid point with zero indices. + spacing: Size of each grid square along each dimension. + direction: Direction cosines defining orientation of grid in world space. + The direction cosines are vectors that point from one pixel to the next. + Each column of the matrix indicates the direction cosines of the unit vector + that is parallel to the lines of the grid corresponding to that dimension. + device: Device on which to store grid attributes. If ``None``, use ``"cpu"``. + align_corners: Whether position of grid corner points are preserved by grid + resampling and resizing operations by default. If ``True``, the grid + origin remains unchanged by grid resizing operations, but the grid extent + may in total change by one times the spacing between grid points. If ``False``, + the extent of the grid remains constant, but the grid origin may shift. + + """ + if device is None: + device = str("cpu").replace("cuda", "gpu") + if size is not None: + size = as_tensor(size, device=device) + if size.ndim != 1: + raise ValueError("Grid() 'size' must be 1-dimensional array") + if len(size) == 0: + raise ValueError("Grid() 'size' must be non-empty array") + if shape is None: + if size is None: + raise ValueError("Grid() 'size' or 'shape' required") + else: + shape = as_tensor(shape, device=device) + if size is None: + if shape.ndim != 1: + raise ValueError("Grid() 'shape' must be 1-dimensional array") + if len(shape) == 0: + raise ValueError("Grid() 'shape' must be non-empty array") + size = shape.flip(axis=0) + else: + with paddle.no_grad(): + if ( + len(size) != len(shape) + or shape.flip(axis=0) + .not_equal(y=paddle.to_tensor(size)) + .astype("bool") + .any() + ): + raise ValueError("Grid() 'size' and 'shape' are not compatible") + self._size = paddle.clip(x=size.astype(dtype="float32"), min=0) + self.spacing_(1 if spacing is None else spacing) + if direction is None: + direction = paddle.eye(num_rows=self.ndim) + self.direction_(direction) + if origin is None: + self.center_(0 if center is None else center) + elif center is None: + self.origin_(origin) + else: + self.center_(center) + with paddle.no_grad(): + origin = cat_scalars( + origin, num=self.ndim, dtype=self.dtype, device=self.device + ) + if not paddle.allclose(x=origin, y=self.origin()).item(): + raise ValueError("Grid() 'center' and 'origin' are inconsistent") + self._align_corners = bool(align_corners) + + def numpy(self) -> np.ndarray: + """Get grid attributes as 1-dimensional NumPy array.""" + return np.concatenate( + [ + self._size.numpy(), + self._spacing.numpy(), + self._center.numpy(), + self._direction.flatten().numpy(), + ], + axis=0, + ) + + @classmethod + def from_numpy( + cls, + attrs: Union[Sequence[float], np.ndarray], + origin: bool = False, + align_corners: bool = ALIGN_CORNERS, + ) -> Grid: + """Create Grid from 1-dimensional NumPy array.""" + if isinstance(attrs, np.ndarray): + seq = attrs.astype(float).tolist() + else: + seq = attrs + return cls.from_seq(seq, origin=origin, align_corners=align_corners) + + @classmethod + def from_seq( + cls, + attrs: Sequence[float], + origin: bool = False, + align_corners: bool = ALIGN_CORNERS, + ) -> Grid: + """Create Grid from sequence of attribute values. + + Args: + attrs: Array of length (D + 3) * D, where ``D=2`` or ``D=3`` is the number + of spatial grid dimensions and array items are given as + ``(nx, ..., sx, ..., cx, ..., d11, ..., d21, ....)``, + where ``(nx, ...)`` is the grid size, ``(sx, ...)`` the grid spacing, + ``(cx, ...)`` the grid center coordinates, and ``(d11, ...)`` + are the grid direction cosines. The argument can be a Python + list or tuple, NumPy array, or tensor. + origin: Whether ``(cx, ...)`` specifies Grid origin rather than center. + + Returns: + Grid instance. + + """ + if len(attrs) == 10: + d = 2 + elif len(attrs) == 18: + d = 3 + else: + raise ValueError( + f"{cls.__name__}.from_seq() expected array of length 10 (D=2) or 18 (D=3)" + ) + kwargs = dict( + size=attrs[0:d], + spacing=attrs[d : 2 * d], + direction=attrs[3 * d :], + align_corners=align_corners, + ) + if origin: + kwargs["origin"] = attrs[2 * d : 3 * d] + else: + kwargs["center"] = attrs[2 * d : 3 * d] + return Grid(**kwargs) + + @classmethod + def from_batch(cls, tensor: paddle.Tensor) -> Grid: + """Create default grid from (image) batch tensor. + + Args: + tensor: Batch tensor of shape ``(N, C, ..., X)``. + + Returns: + New default grid with size ``(X, ...)``. + + """ + return cls(shape=tuple(tensor.shape)[2:]) + + @classmethod + def from_file(cls, path: PathStr, align_corners: bool = ALIGN_CORNERS) -> Grid: + """Create sampling grid from image file header information.""" + if _sitk is None: + raise RuntimeError(f"{cls.__name__}.from_file() requires SimpleITK") + reader = _sitk.ImageFileReader() + reader.SetFileName(str(path)) + reader.ReadImageInformation() + return cls.from_reader(reader, align_corners=align_corners) + + @classmethod + def from_reader( + cls, reader: "_sitk.ImageFileReader", align_corners: bool = ALIGN_CORNERS + ) -> Grid: + """Create sampling grid from image file reader attributes.""" + return cls( + size=reader.GetSize(), + origin=reader.GetOrigin(), + spacing=reader.GetSpacing(), + direction=reader.GetDirection(), + align_corners=align_corners, + ) + + @classmethod + def from_sitk( + cls, image: "_sitk.Image", align_corners: bool = ALIGN_CORNERS + ) -> Grid: + """Create sampling grid from ``SimpleITK.Image`` attributes.""" + return cls( + size=image.GetSize(), + origin=image.GetOrigin(), + spacing=image.GetSpacing(), + direction=image.GetDirection(), + align_corners=align_corners, + ) + + def cube(self) -> "Cube": + """Get oriented cube defining the space of normalized coordinates.""" + from .cube import Cube + + return Cube( + extent=self.cube_extent(), + center=self.center(), + direction=self.direction(), + device=self.device, + ) + + def domain(self) -> "Cube": + """Get oriented bounding box defining the sampling grid domain in world space.""" + return self.cube() + + def dim(self) -> int: + """Number of spatial grid dimensions.""" + return len(self._size) + + @property + def ndim(self) -> int: + """Number of grid dimensions.""" + return len(self._size) + + @property + def dtype(self) -> paddle.dtype: + """Get data type of grid attribute tensors.""" + return self._size.dtype + + @property + def device(self) -> str: + """Get device on which grid attribute tensors are stored.""" + return self._size.place + + def clone(self) -> Grid: + """Make deep copy of this ``Grid`` instance.""" + grid = shallow_copy(self) + for name in self.__slots__: + value = getattr(self, name) + if isinstance(value, paddle.Tensor): + setattr(grid, name, value.clone()) + return grid + + def __deepcopy__(self, memo) -> Grid: + """Support copy.deepcopy to clone this grid.""" + if id(self) in memo: + return memo[id(self)] + copy = self.clone() + memo[id(self)] = copy + return copy + + @overload + def align_corners(self) -> bool: + """Whether resizing operations preserve grid extent (False) or corner points (True).""" + ... + + @overload + def align_corners(self, arg: bool) -> Grid: + """Set if resizing operations preserve grid extent (False) or corner points (True).""" + ... + + def align_corners(self, arg: Optional[bool] = None) -> Union[bool, Grid]: + """Whether resizing operations preserve grid extent (False) or corner points (True).""" + if arg is None: + return self._align_corners + return shallow_copy(self).align_corners_(arg) + + def align_corners_(self, arg: bool) -> Grid: + """Set if resizing operations preserve grid extent (False) or corner points (True).""" + self._align_corners = bool(arg) + return self + + def axes(self) -> Axes: + """Grid axes.""" + return Axes.from_grid(self) + + def numel(self) -> int: + """Number of grid points.""" + return self.size().size + + @staticmethod + def _round_size(size: paddle.Tensor) -> paddle.Tensor: + """Round sampling grid size attribute.""" + zero = paddle.to_tensor(data=0, dtype=size.dtype, place=size.place) + return paddle.where(condition=size.equal(y=zero), x=zero, y=size.ceil()) + + def size_tensor(self) -> paddle.Tensor: + """Sampling grid size as floating point tensor.""" + return self._round_size(self._size) + + @overload + def size(self, i: int) -> int: + """Sampleing grid size along the specified spatial dimension.""" + ... + + @overload + def size(self) -> list: + """Sampling grid size for dimensions ordered as ``(X, ...)``.""" + ... + + def size(self, i: Optional[int] = None) -> Union[int, list]: + """Sampling grid size.""" + size = self.size_tensor() + if i is None: + return tuple(int(n) for n in size) + return int(size[i]) + + @property + def shape(self) -> list: + """Sampling grid size for dimensions ordered as ``(..., X)``.""" + return tuple(int(n) for n in self.size_tensor().flip(axis=0)) + + def extent(self, i: Optional[int] = None) -> paddle.Tensor: + """Extent of sampling grid in physical world units.""" + if i is None: + return self.spacing() * self.size_tensor() + return self._spacing[i] * self.size(i) + + def cube_extent(self, i: Optional[int] = None) -> paddle.Tensor: + """Extent of sampling grid cube in physical world units.""" + if i is None: + n = self.size_tensor() + if self._align_corners: + n = n.sub(1.0) + return self.spacing().mul(n) + n = self.size(i) + if self._align_corners: + n -= 1 + return self._spacing[i] * n + + @overload + def center(self) -> paddle.Tensor: + """Get grid center point in world space.""" + ... + + @overload + def center(self, arg: Union[float, Array], *args: float) -> Grid: + """Get new grid with specified center point in world space.""" + ... + + def center( + self, arg: Union[float, Array, None] = None, *args: float + ) -> Union[paddle.Tensor, Grid]: + """Get center point in world space or new grid with specified center point, respectively.""" + if arg is None: + if args: + raise ValueError( + f"{type(self).__name__}.center() 'args' cannot be given when first 'arg' is None" + ) + return self._center + return shallow_copy(self).center_(arg, *args) + + def center_(self, arg: Union[float, Array], *args: float) -> Grid: + """Set grid center point in world space.""" + self._center = cat_scalars( + arg, *args, num=self.ndim, dtype=self.dtype, device=self.device + ) + return self + + @overload + def origin(self) -> paddle.Tensor: + """Get world coordinates of grid point with index zero.""" + ... + + @overload + def origin(self, arg: Union[float, Array], *args: float) -> Grid: + """Get new grid with specified world coordinates of grid point at index zero.""" + ... + + def origin( + self, arg: Union[float, Array, None] = None, *args: float + ) -> Union[paddle.Tensor, Grid]: + """Get grid origin in world space or new grid with specified origin, respectively.""" + if arg is None: + if args: + raise ValueError( + f"{type(self).__name__}.origin() 'args' cannot be given when first 'arg' is None" + ) + size = self.size_tensor() + offset = paddle.where( + condition=size.greater_than(y=paddle.to_tensor(0.0)), + x=size.sub(1.0), + y=size, + ).div(2.0) + offset = paddle.matmul(x=self.affine(), y=offset) + return self._center.sub(offset) + return shallow_copy(self).origin_(arg, *args) + + def origin_(self, arg: Union[float, Array], *args: float) -> Grid: + """Set world coordinates of grid point with zero index.""" + origin = cat_scalars( + arg, *args, num=self.ndim, dtype=self.dtype, device=self.device + ) + size = self.size_tensor() + offset = paddle.where( + condition=size.greater_than(y=paddle.to_tensor(0.0)), + x=size.sub(1.0), + y=size, + ).div(2.0) + offset = paddle.matmul(x=self.affine(), y=offset) + self._center = origin.add(offset) + return self + + @overload + def spacing(self) -> paddle.Tensor: + """Get spacing between grid points in world units.""" + ... + + @overload + def spacing(self, arg: Union[float, Array], *args: float) -> Grid: + """Get new grid with specified spacing between grid points in world units.""" + ... + + def spacing( + self, arg: Union[float, Array, None] = None, *args: float + ) -> Union[paddle.Tensor, Grid]: + """Get spacing between grid points in world units or new grid with specified spacing, respectively.""" + if arg is None: + if args: + raise ValueError( + f"{type(self).__name__}.spacing() 'args' cannot be given when first 'arg' is None" + ) + return self._spacing + return shallow_copy(self).spacing_(arg, *args) + + def spacing_(self, arg: Union[float, Array], *args: float) -> Grid: + """Set spacing between grid points in physical world units.""" + spacing = cat_scalars( + arg, *args, num=self.ndim, dtype=self.dtype, device=self.device + ) + if spacing.less_equal(y=paddle.to_tensor(0.0)).astype("bool").any(): + raise ValueError("Grid spacing must be positive") + self._spacing = spacing + return self + + @overload + def direction(self) -> paddle.Tensor: + """Get grid axes direction cosines matrix.""" + ... + + @overload + def direction(self, arg: Union[float, Array], *args: float) -> Grid: + """Get new grid with specified axes direction cosines.""" + ... + + def direction( + self, arg: Union[float, Array, None] = None, *args: float + ) -> Union[paddle.Tensor, Grid]: + """Get grid axes direction cosines matrix or new grid with specified direction, respectively.""" + if arg is None: + if args: + raise ValueError( + f"{type(self).__name__}.direction() 'args' cannot be given when first 'arg' is None" + ) + return self._direction + return shallow_copy(self).direction_(arg, *args) + + def direction_(self, arg: Union[float, Array], *args: float) -> Grid: + """Set grid axes direction cosines matrix of this grid.""" + D = self.ndim + direction = paddle.to_tensor(data=(arg,) + args) if args else as_tensor(arg) + direction = direction.astype(self.dtype) + if direction.ndim == 1: + if tuple(direction.shape)[0] != D * D: + raise ValueError( + f"Grid direction must be array or square matrix with numel={D * D}" + ) + direction = direction.reshape(D, D) + elif ( + direction.ndim != 2 + or tuple(direction.shape)[0] != tuple(direction.shape)[1] + or tuple(direction.shape)[0] != D + ): + raise ValueError( + f"Grid direction must be array or square matrix with numel={D * D}" + ) + with paddle.no_grad(): + if abs(paddle.linalg.det(direction).abs().item() - 1) > 0.0001: + raise ValueError( + "Grid direction cosines matrix must be valid rotation matrix" + ) + self._direction = direction + return self + + def affine(self) -> paddle.Tensor: + """Affine transformation from ``Axes.GRID`` to ``Axes.WORLD``, excluding translation of origin.""" + return paddle.mm(input=self.direction(), mat2=paddle.diag(x=self.spacing())) + + def inverse_affine(self) -> paddle.Tensor: + """Affine transformation from ``Axes.WORLD`` to ``Axes.GRID``, excluding translation of origin.""" + one = paddle.to_tensor(data=1, dtype=self.dtype, place=self.device) + return paddle.mm( + input=paddle.diag(x=one / self.spacing()), mat2=self.direction().t() + ) + + def transform( + self, + axes: Optional[Union[Axes, str]] = None, + to_axes: Optional[Union[Axes, str]] = None, + to_grid: Optional[Grid] = None, + vectors: bool = False, + ) -> paddle.Tensor: + """Transformation from one grid domain to another. + + Args: + axes: Axes with respect to which input coordinates are defined. + If ``None`` and also ``to_axes`` and ``to_cube`` is ``None``, + returns the transform which maps from cube to world space. + to_axes: Axes of grid to which coordinates are mapped. Use ``axes`` if ``None``. + to_grid: Other grid. Use ``self`` if ``None``. + vectors: Whether transformation is used to rescale and reorient vectors. + + Returns: + If ``vectors=False``, a homogeneous coordinate transformation of shape ``(D, D + 1)``. + Otherwise, a square transformation matrix of shape ``(D, D)`` is returned. + + """ + if axes is None and to_axes is None and to_grid is None: + cube_axes = Axes.CUBE_CORNERS if self._align_corners else Axes.CUBE + return self.transform(cube_axes, Axes.WORLD, vectors=vectors) + if axes is None: + raise ValueError( + "Grid.transform() 'axes' required when 'to_axes' or 'to_grid' specified" + ) + matrix = None + axes = Axes(axes) + to_axes = axes if to_axes is None else Axes(to_axes) + if to_grid is None or to_grid == self: + if axes is to_axes: + matrix = paddle.eye(num_rows=self.ndim, dtype=self.dtype) + if not vectors: + offset = paddle.zeros(shape=self.ndim, dtype=self.dtype) + matrix = homogeneous_matrix(matrix, offset=offset) + elif axes is Axes.GRID: + if to_axes is Axes.CUBE: + size = self.size_tensor() + matrix = paddle.diag(x=2 / size) + if not vectors: + one = paddle.to_tensor( + data=1, dtype=size.dtype, place=size.place + ) + matrix = homogeneous_matrix(matrix, offset=one / size - one) + elif to_axes is Axes.CUBE_CORNERS: + size = self.size_tensor() + matrix = paddle.diag(x=2 / (size - 1)) + if not vectors: + offset = paddle.to_tensor( + data=-1, dtype=size.dtype, place=size.place + ) + matrix = homogeneous_matrix(matrix, offset=offset) + elif to_axes is Axes.WORLD: + matrix = self.affine() + if not vectors: + matrix = homogeneous_matrix(matrix, offset=self.origin()) + elif axes is Axes.CUBE: + if to_axes is Axes.CUBE_CORNERS: + size = self.size_tensor() + matrix = paddle.diag(x=size / (size - 1)) + elif to_axes is Axes.GRID: + half_size = 0.5 * self.size_tensor() + matrix = paddle.diag(x=half_size) + if not vectors: + matrix = homogeneous_matrix(matrix, offset=half_size - 0.5) + elif to_axes is Axes.WORLD: + cube_to_grid = self.transform(axes, Axes.GRID, vectors=vectors) + grid_to_world = self.transform( + Axes.GRID, Axes.WORLD, vectors=vectors + ) + if vectors: + matrix = paddle.mm(input=grid_to_world, mat2=cube_to_grid) + else: + matrix = hmm(grid_to_world, cube_to_grid) + elif axes is Axes.CUBE_CORNERS: + if to_axes is Axes.CUBE: + size = self.size_tensor() + matrix = paddle.diag(x=(size - 1) / size) + elif to_axes is Axes.GRID: + scales = 0.5 * (self.size_tensor() - 1) + matrix = paddle.diag(x=scales) + if not vectors: + matrix = homogeneous_matrix(matrix, offset=scales) + elif to_axes is Axes.WORLD: + interim = Axes.GRID + cube_to_grid = self.transform(axes, interim, vectors=vectors) + grid_to_world = self.transform(interim, to_axes, vectors=vectors) + if vectors: + matrix = paddle.mm(input=grid_to_world, mat2=cube_to_grid) + else: + matrix = hmm(grid_to_world, cube_to_grid) + elif axes is Axes.WORLD: + if to_axes is Axes.CUBE or to_axes is Axes.CUBE_CORNERS: + interim = Axes.GRID + world_to_grid = self.transform(axes, interim, vectors=vectors) + grid_to_cube = self.transform(interim, to_axes, vectors=vectors) + if vectors: + matrix = paddle.mm(input=grid_to_cube, mat2=world_to_grid) + else: + matrix = hmm(grid_to_cube, world_to_grid) + elif to_axes is Axes.GRID: + matrix = self.inverse_affine() + if not vectors: + matrix = hmm(matrix, -self.origin()) + elif to_grid.ndim != self.ndim: + raise ValueError( + f"Grid.transform() 'to_grid' must have {self.ndim} spatial dimensions" + ) + else: + target_to_world = self.transform(axes, Axes.WORLD, vectors=vectors) + world_to_source = to_grid.transform(Axes.WORLD, to_axes, vectors=vectors) + if vectors: + matrix = paddle.mm(input=world_to_source, mat2=target_to_world) + else: + matrix = hmm(world_to_source, target_to_world) + if matrix is None: + raise NotImplementedError( + f"Grid.transform() for axes={axes} and to_axes={to_axes}" + ) + return matrix + + def inverse_transform(self, vectors: bool = False) -> paddle.Tensor: + """Transform which maps from world to grid cube space.""" + cube_axes = Axes.CUBE_CORNERS if self._align_corners else Axes.CUBE + return self.transform(Axes.WORLD, cube_axes, vectors=vectors) + + def apply_transform( + self, + input: Array, + axes: Axes, + to_axes: Optional[Axes] = None, + to_grid: Optional[Grid] = None, + vectors: bool = False, + decimals: Optional[int] = -1, + ) -> paddle.Tensor: + """Map point coordinates or displacement vectors from one grid to another. + + Args: + input: Points or vectors to transform as tensor of shape ``(..., D)``. + axes: Axes with respect to which input coordinates are defined. + to_axes: Axes of cube to which coordinates are mapped. Use ``axes`` if ``None``. + to_cube: Other cube. Use ``self`` if ``None``. + vectors: Whether transformation is used to rescale and reorient vectors. + decimals: If positive or zero, number of digits right of the decimal point to round to. + When mapping points to ``Axes.GRID``, ``Axes.CUBE``, or ``Axes.CUBE_CORNERS``, + this function by default (``decimals=-1``) rounds the transformed coordinates. + Explicitly set ``decimals=None`` to suppress this default rounding. + + Returns: + If ``vectors=False``, a homogeneous coordinate transformation of shape ``(D, D + 1)``. + Otherwise, a square transformation matrix of shape ``(D, D)`` is returned. + + """ + axes = Axes(axes) + to_axes = axes if to_axes is None else Axes(to_axes) + input = as_tensor(input) + if not input.is_floating_point(): + input = input.astype(self.dtype) + if to_grid is not None and to_grid != self or axes is not to_axes: + matrix = self.transform(axes, to_axes, to_grid=to_grid, vectors=vectors) + matrix = matrix.unsqueeze(axis=0).to(device=input.place) + result = homogeneous_transform(matrix, input) + else: + result = input + if decimals == -1: + if to_axes is Axes.CUBE or to_axes is Axes.CUBE_CORNERS: + decimals = 12 + elif to_axes is Axes.GRID: + decimals = 6 + if decimals is not None and decimals >= 0: + result = round_decimals(result, decimals=decimals) + return result + + def transform_points( + self, + points: Array, + axes: Axes, + to_axes: Optional[Axes] = None, + to_grid: Optional[Grid] = None, + decimals: Optional[int] = -1, + ) -> paddle.Tensor: + """Map point coordinates from one grid domain to another. + + Args: + points: Coordinates of points to transform as tensor of shape ``(..., D)``. + axes: Coordinate axes with respect to which ``points`` are defined. + to_axes: Coordinate axes to which ``points`` should be mapped to. If ``None``, use ``axes``. + to_grid: Grid with respect to which the codomain is defined. If ``None``, the target + and source sampling grids are assumed to be identical. + decimals: If positive or zero, number of digits right of the decimal point to round to. + When mapping points to codomain ``Axes.GRID``, ``Axes.CUBE``, or ``Axes.CUBE_CORNERS``, + this function by default (``decimals=-1``) rounds the transformed coordinates. + Explicitly set ``decimals=None`` to suppress this default rounding. + + Returns: + Point coordinates in ``axes`` mapped to coordinates with respect to ``to_axes``. + + """ + return self.apply_transform( + points, axes, to_axes, to_grid=to_grid, decimals=decimals + ) + + def transform_vectors( + self, + vectors: Array, + axes: Axes, + to_axes: Optional[Axes] = None, + to_grid: Optional[Grid] = None, + ) -> paddle.Tensor: + """Rescale and reorient flow vectors. + + Args: + vectors: Displacement vectors to transform, e.g., as tensor of shape ``(..., D)``. + axes: Coordinate axes with respect to which ``vectors`` are defined. + to_axes: Coordinate axes to which ``vectors`` should be mapped to. If ``None``, use ``axes``. + to_grid: Grid with respect to which ``to_axes`` is defined. If ``None``, the target + and source sampling grids are assumed to be identical. + + Returns: + Rescaled and reoriented vectors. If ``to_grid == self`` and ``to_axes == axes``, + a reference to the unmodified input ``vectors`` tensor is returned. + + """ + axes = Axes(axes) + to_axes = axes if to_axes is None else Axes(to_axes) + vectors = as_tensor(vectors) + if not vectors.is_floating_point(): + vectors = vectors.astype(self.dtype) + if axes is Axes.WORLD and to_axes is Axes.WORLD: + return vectors + if to_grid is None or to_grid == self: + if axes is not to_axes: + affine = None + scales = None + if axes is Axes.WORLD: + affine = self.inverse_affine() + elif axes is Axes.CUBE: + scales = self.size_tensor() / 2 + elif axes is Axes.CUBE_CORNERS: + scales = (self.size_tensor() - 1) / 2 + elif axes is not Axes.GRID: + raise NotImplementedError( + f"Grid.transform_vectors() for axes={axes} and to_axes={to_axes}" + ) + if to_axes is Axes.WORLD: + grid_to_world = self.affine() + if scales is None: + assert affine is None + affine = grid_to_world + else: + affine = paddle.mm( + input=grid_to_world, mat2=paddle.diag(x=scales) + ) + elif to_axes is Axes.CUBE or to_axes is Axes.CUBE_CORNERS: + num = self.size_tensor() + if to_axes is Axes.CUBE_CORNERS: + num -= 1 + grid_to_cube = 2 / num + if affine is None: + if scales is None: + scales = grid_to_cube + else: + scales *= grid_to_cube + else: + affine = paddle.mm( + input=paddle.diag(x=grid_to_cube), mat2=affine + ) + elif to_axes is not Axes.GRID: + raise NotImplementedError( + f"Grid.transform_vectors() for axes={axes} and to_axes={to_axes}" + ) + if affine is None: + assert scales is not None + scales = scales.to(vectors) + vectors = vectors * scales + else: + affine = affine.to(vectors) + tensor = vectors.reshape(-1, tuple(vectors.shape)[-1]) + vectors = paddle.nn.functional.linear( + weight=affine.T, x=tensor + ).reshape(tuple(vectors.shape)) + else: + matrix = self.transform(axes, to_axes, to_grid=to_grid, vectors=True) + matrix = matrix.to(vectors) + vectors = homogeneous_transform(matrix, vectors, vectors=True) + return vectors + + def index_to_cube( + self, indices: Array, decimals: int = -1, align_corners: Optional[bool] = None + ) -> paddle.Tensor: + """Map points from grid indices to grid-aligned cube with side length 2. + + Args: + indices: Grid point indices to transform as tensor of shape ``(..., D)``. + decimals: If positive or zero, number of digits right of the decimal point to round to. + align_corners: Whether output cube coordinates should be with respect to + ``Axes.CUBE_CORNERS`` (True) or ``Axes.CUBE`` (False), respectively. + If ``None``, use default ``self.align_corners()``. + + Returns: + Grid point indices transformed to points with respect to cube. + + """ + if align_corners is None: + align_corners = self._align_corners + to_axes = Axes.from_align_corners(align_corners) + return self.transform_points( + indices, axes=Axes.GRID, to_axes=to_axes, decimals=decimals + ) + + def cube_to_index( + self, coords: Array, decimals: int = -1, align_corners: Optional[bool] = None + ) -> paddle.Tensor: + """Map points from grid-aligned cube to grid point indices. + + Args: + coords: Normalized grid points to transform as tensor of shape ``(..., D)``. + decimals: If positive or zero, number of digits right of the decimal point to round to. + align_corners: Whether ``coords`` are with respect to ``Axes.CUBE_CORNERS`` (True) + or ``Axes.CUBE`` (False), respectively. If ``None``, use default ``self.align_corners()``. + + Returns: + Points in grid-aligned cube transformed to grid indices. + + """ + if align_corners is None: + align_corners = self._align_corners + axes = Axes.from_align_corners(align_corners) + return self.transform_points( + coords, axes=axes, to_axes=Axes.GRID, decimals=decimals + ) + + def index_to_world(self, indices: Array, decimals: int = -1) -> paddle.Tensor: + """Map points from grid indices to world coordinates. + + Args: + indices: Grid point indices to transform as tensor of shape ``(..., D)``. + decimals: If positive or zero, number of digits right of the decimal point to round to. + + Returns: + Grid point indices transformed to points in world space. + + """ + return self.transform_points( + indices, axes=Axes.GRID, to_axes=Axes.WORLD, decimals=decimals + ) + + def world_to_index(self, points: Array, decimals: int = -1) -> paddle.Tensor: + """Map points from world coordinates to grid point indices. + + Args: + points: World coordinates of points to transform as tensor of shape ``(..., D)``. + decimals: If positive or zero, number of digits right of the decimal point to round to. + + Returns: + Points in world space transformed to grid indices. + + """ + return self.transform_points( + points, axes=Axes.WORLD, to_axes=Axes.GRID, decimals=decimals + ) + + def cube_to_world( + self, coords: Array, decimals: int = -1, align_corners: Optional[bool] = None + ) -> paddle.Tensor: + """Map point coordinates from grid-aligned cube with side length 2 to world space. + + Args: + coords: Normalized grid points to transform as tensor of shape ``(..., D)``. + decimals: If positive or zero, number of digits right of the decimal point to round to. + align_corners: Whether ``coords`` are with respect to ``Axes.CUBE_CORNERS`` (True) + or ``Axes.CUBE`` (False), respectively. If ``None``, use default ``self.align_corners()``. + + Returns: + Normalized grid coordinates transformed to world space coordinates. + + """ + if align_corners is None: + align_corners = self._align_corners + axes = Axes.from_align_corners(align_corners) + return self.transform_points( + coords, axes=axes, to_axes=Axes.WORLD, decimals=decimals + ) + + def world_to_cube( + self, points: Array, decimals: int = -1, align_corners: Optional[bool] = None + ) -> paddle.Tensor: + """Map point coordinates from world space to grid-aligned cube with side length 2. + + Args: + points: World coordinates of points to transform as tensor of shape ``(..., D)``. + decimals: If positive or zero, number of digits right of the decimal point to round to. + align_corners: Whether output cube coordinates should be with respect to + ``Axes.CUBE_CORNERS`` (True) or ``Axes.CUBE`` (False), respectively. + If ``None``, use default ``self.align_corners()``. + + Returns: + Points in world space transformed to normalized grid coordinates. + + """ + if align_corners is None: + align_corners = self._align_corners + to_axes = Axes.from_align_corners(align_corners) + return self.transform_points( + points, axes=Axes.WORLD, to_axes=to_axes, decimals=decimals + ) + + def coords( + self, + dim: Optional[int] = None, + center: bool = False, + normalize: bool = True, + align_corners: Optional[bool] = None, + channels_last: bool = True, + flip: bool = False, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, + ) -> paddle.Tensor: + """Get tensor of grid point coordinates. + + Args: + dim: Return coordinates for specified dimension, where ``dim=0`` refers to + the first grid dimension, i.e., the ``x`` axis. + center: Whether to center ``Axes.GRID`` coordinates when ``normalize=False``. + normalize: Normalize coordinates to grid-aligned cube with side length 2. + The normalized grid point coordinates are in the closed interval ``[-1, +1]``. + align_corners: If ``normalize=True``, specifies whether the extrema ``(-1, 1)`` + should refer to the centers of the grid corner squares or their boundary. + Note that in both cases the returned normalized coordinates are associated + with the center points of the grid squares. When used as ``grid`` argument + of ``paddle.nn.functional.grid_sample()``, the same ``align_corners`` value + should be used for both ``Grid.coords()`` and ``grid_sample()`` calls. + If ``None``, use default ``self.align_corners()``. + channels_last: Whether to place stacked coordinates at last (``True``) or first (``False``) + output tensor dimension. Vector fields are represented with channels first after the + batch dimension, whereas point sets such as the sampling points passed to ``grid_sample`` + are represented with point coordinates in the last tensor dimension. + flip: Whether to return coordinates in the order (..., x) instead of (x, ...). + dtype: Data type of coordinates. If ``None``, uses ``paddle.int32`` as data type + for returned tensor if ``normalize=False``, and ``self.dtype`` otherwise. + device: Device on which to create tensor. If ``None``, use ``self.device``. + + Returns: + If ``dim`` is ``None``, returns a tensor of shape (...X, C) if ``channels_last=True`` (default) + or ``(C, ..., X)`` if ``channels_last=False``, where C is the number of spatial grid dimensions. + If ``normalize=Falze`` and ``center=False``, the tensor values are the multi-dimensional indices + in the closed-open interval [0, n) for each grid dimension, where n is the number of points in the + respective dimension. The first channel with index 0 is the ``x`` coordinate. If ``normalize=False`` + and ``center=True``, the indices are shifted such that index 0 corresponds to the grid center point. + If ``normalize=True``, the centered coordinates are normalized to ``(-1, 1)``, where the extrema + either correspond to the corner points of the grid (``align_corners=True``) or the grid boundary + edges (``align_cornes=False``). If ``dim`` is not ``None``, a 1D tensor with the coordinates for + this grid axis is returned. + + """ + if align_corners is None: + align_corners = self._align_corners + if dtype is None: + if normalize or center: + dtype = self.dtype + else: + dtype = "int32" + if device is None: + device = self.device + if dim is None: + shape = self.shape + else: + shape = tuple((self.size()[dim],)) + if len(shape) == 0: + return paddle.empty(shape=shape, dtype=dtype) + coords = [] + for n in shape: + if n == 1: + coord = paddle.to_tensor(data=[0], dtype=dtype, place=device) + elif normalize: + if align_corners: + spacing = 2 / (n - 1) + extrema = -1, 1 + 0.1 * spacing + else: + spacing = 2 / n + extrema = -1 + 0.5 * spacing, 1 + coord = paddle.arange(*extrema, spacing, dtype=dtype) + elif center: + radius = (n - 1) / 2 + coord = paddle.linspace(start=-radius, stop=radius, num=n, dtype=dtype) + else: + coord = paddle.arange(dtype=dtype, end=n) + coords.append(coord) + channels_dim = len(coords) if channels_last else 0 + if parse_version(paddle.__version__) < parse_version("1.10"): + coords = paddle.stack(x=paddle.meshgrid(*coords), axis=channels_dim) + else: + coords = paddle.stack(x=paddle.meshgrid(*coords), axis=channels_dim) + if not flip: + coords = paddle.flip(x=coords, axis=(channels_dim,)) + return coords + + def points( + self, + axes: Axes = Axes.WORLD, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, + ) -> paddle.Tensor: + """paddle.Tensor of grid point coordinates with respect to specified coordinate axes.""" + axes = Axes(axes) + coords = self.coords( + normalize=axes is Axes.CUBE, align_corners=False, dtype=dtype, device=device + ) + if axes is not Axes.CUBE and axes is not Axes.GRID: + coords = self.apply_transform(coords, Axes.GRID, to_axes=axes) + return coords + + def _resize( + self, size: paddle.Tensor, align_corners: Optional[bool] = None + ) -> Grid: + """Get new grid of specified size. + + The resulting ``grid``` MUST preserve floating point values of given ``size`` argument such + as in particular those passed by ``downsample()`` and ``upsample()``. Otherwise, a sequence + of resampling steps which should produce the original grid may result in a different size. + + Args: + size: Size of new grid. + align_corners: Whether to preserve positions of corner points (``True``) or grid extent (``False``). + + """ + if align_corners is None: + align_corners = self._align_corners + size = size.to(dtype=self.dtype, device=self.device) + if size.equal(y=self._size).astype("bool").all(): + return self + grid = shallow_copy(self) + grid._size = size + size = self._round_size(size) + if align_corners: + spacing = (self.extent() - self.spacing()) / (size - 1) + grid._spacing = paddle.where( + condition=self._size.greater_than(y=paddle.to_tensor(0.0)), + x=spacing, + y=self._spacing, + ) + assert paddle.allclose(x=grid.origin(), y=self.origin()).item() + else: + spacing = self.extent() / size + grid._spacing = paddle.where( + condition=self._size.greater_than(y=paddle.to_tensor(0.0)), + x=spacing, + y=self._spacing, + ) + assert paddle.allclose(x=grid.extent(), y=self.extent()).item() + return grid + + def resize( + self, + size: Union[int, Size, Array], + *args: int, + align_corners: Optional[bool] = None, + ) -> Grid: + """Create new grid of same extent with specified size. + + Specify new grid size for grid dimensions in the order (X...). Note that this is in reverse order + of ``paddle.Tensor.size()``! To set the grid size given a data tensor shape, use ``Grid.reshape()``. + + Args: + size: Size of new grid in the order ``(X, ...)``. + align_corners: Whether to preserve positions of corner points (``True``) or grid extent (``False``). + + Returns: + New grid with given ``size`` and adjusted ``Grid.spacing()``. + + """ + size = cat_scalars(size, *args, num=self.ndim, device=self.device) + if not is_int_dtype(size.dtype): + raise TypeError( + f"Grid.resize() 'size' must be integer values, got dtype={size.dtype}" + ) + if size.less_than(y=paddle.to_tensor(0)).astype("bool").any(): + raise ValueError("Grid.resize() 'size' must be all non-negative numbers") + return self._resize(size, align_corners=align_corners) + + def reshape( + self, + shape: Union[int, Shape, Array], + *args: int, + align_corners: Optional[bool] = None, + ) -> Grid: + """Create new grid of same extent with specified data tensor shape. + + The data tensor shape specifies the size of data dimensions in the order ``(..., X)``, + whereas the ``Grid.size()`` is given in reverse order ``(X, ...)``. This function is a + convenience function to change the grid size given the ``paddle.Tensor.shape`` of a data tensor. + + Args: + shape: Size of new grid in the order ``(..., X)``. + align_corners: Whether to preserve positions of corner points (``True``) or grid extent (``False``). + + Returns: + New grid with given ``shape`` and adjusted ``Grid.spacing()``. + + """ + shape = cat_scalars(shape, *args, num=self.ndim, device=self.device) + if not is_int_dtype(shape.dtype): + raise TypeError( + f"Grid.reshape() 'shape' must be integer values, got dtype={shape.dtype}" + ) + if shape.less_than(y=paddle.to_tensor(0)).astype("bool").any(): + raise ValueError("Grid.reshape() 'shape' must be all non-negative numbers") + return self._resize(shape.flip(axis=0), align_corners=align_corners) + + def resample( + self, spacing: Union[float, Array, str], *args: float, min_size: int = 1 + ) -> Grid: + """Create new grid with specified spacing. + + Args: + spacing: Desired spacing between grid points. Uses minimum or maximum grid spacing + for isotropic resampling when argument is string "min" or "max", respectively. + min_size: Minimum grid size. + + Returns: + New grid with specified spacing. The extent of the grid may be greater + than before, if the original extent is not divisible by the desired spacing. + + """ + if spacing == "min": + assert not args + spacing = self._spacing.min() + elif spacing == "max": + assert not args + spacing = self._spacing.max() + elif isinstance(spacing, str): + raise ValueError( + f"{type(self).__name__}.resample() 'spacing' str must be 'min' or 'max'" + ) + spacing = cat_scalars( + spacing, *args, num=self.ndim, dtype=self.dtype, device=self.device + ) + if paddle.allclose(x=spacing, y=self._spacing).item(): + return self + if spacing.less_equal(y=paddle.to_tensor(0)).astype("bool").any(): + raise ValueError("Grid.resample() 'spacing' must be all positive numbers") + size = self.extent().div(spacing) + size = paddle.where( + condition=self._size.greater_than(y=paddle.to_tensor(0)), + x=size.clip(min=min_size), + y=size, + ) + grid = shallow_copy(self) + grid._size = size + grid._spacing = spacing + return grid + + def pool( + self, + kernel_size: ScalarOrTuple[int], + stride: Optional[ScalarOrTuple[int]] = None, + padding: ScalarOrTuple[int] = 0, + dilation: ScalarOrTuple[int] = 1, + ceil_mode: bool = False, + ) -> Grid: + """Output grid after applying pooling operation. + + Args: + kernel_size: Size of the pooling region. + stride: Stride of the pooling operation. + padding: Implicit zero paddings on both sides of the input. + dilation: Spacing between pooling kernel elements. + ceil_mode: When True, will use `ceil` instead of `floor` to compute the output size. + + Returns: + New grid corresponding to output data tensor after pooling operation. + + """ + if stride is not None: + raise NotImplementedError("Grid.pool() 'stride' currently not supported") + if padding != 0: + raise NotImplementedError("Grid.pool() 'padding' currently not supported") + if dilation != 1: + raise NotImplementedError("Grid.pool() 'dilation' currently not supported") + ks = cat_scalars( + kernel_size, num=self.ndim, dtype=self.dtype, device=self.device + ) + size = self.size_tensor() / ks + size = size.ceil() if ceil_mode else size.floor() + size = size.astype("int32") + grid = Grid( + size=size, + origin=self.index_to_world(ks.sub(1).div(2)), + spacing=self.spacing().mul(ks), + direction=self.direction(), + align_corners=self.align_corners(), + device=self.device, + ) + return grid + + def avg_pool( + self, + kernel_size: ScalarOrTuple[int], + stride: Optional[ScalarOrTuple[int]] = None, + padding: ScalarOrTuple[int] = 0, + ceil_mode: bool = False, + ) -> Grid: + """Output grid after applying average pooling.""" + return self.pool( + kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + + def downsample( + self, + levels: int = 1, + dims: Optional[Sequence[SpatialDimArg]] = None, + min_size: int = 1, + align_corners: Optional[bool] = None, + ) -> Grid: + """Create new grid with size halved the specified number of times. + + Args: + levels: Number of times the grid size is halved (>0) or doubled (<0). + dims: Spatial dimensions along which to downsample. If not specified, consider all spatial dimensions. + min_size: Minimum grid size. If downsampling the grid along a spatial dimension would reduce its + size below the given minimum value, the grid is not further downsampled along this dimension. + align_corners: Whether to preserve positions of corner points (``True``) or grid extent (``False``). + + Returns: + New grid. + + """ + if not isinstance(levels, int): + raise TypeError("Grid.downsample() 'levels' must be of type int") + if not dims: + dims = tuple(dim for dim in range(self.ndim)) + dims = tuple(SpatialDim.from_arg(dim) for dim in dims) + size = self._size.clone() + scale = 2**levels + for dim in dims: + size[dim] /= scale + size = paddle.where( + condition=size.greater_equal( + y=paddle.to_tensor(min_size, dtype=size.dtype) + ), + x=size, + y=self._size, + ) + return self._resize(size, align_corners=align_corners) + + def upsample( + self, + levels: int = 1, + dims: Optional[Sequence[SpatialDimArg]] = None, + align_corners: Optional[bool] = None, + ) -> Grid: + """Create new grid of same extent with size doubled the specified number of times. + + Args: + levels: Number of times the grid size is doubled (>0) or halved (<0). + dims: Spatial dimensions along which to upsample. If not specified, consider all spatial dimensions. + min_size: Minimum grid size. If downsampling the grid along a spatial dimension would reduce its + size below the given minimum value, the grid is not further downsampled along this dimension. + align_corners: Whether to preserve positions of corner points (``True``) or grid extent (``False``). + + Returns: + New grid. + + """ + if not isinstance(levels, int): + raise TypeError("Grid.upsample() 'levels' must be of type int") + if not dims: + dims = tuple(dim for dim in range(self.ndim)) + dims = tuple(SpatialDim.from_arg(dim) for dim in dims) + size = self._size.clone() + scale = 2**levels + for dim in dims: + size[dim] *= scale + return self._resize(size, align_corners=align_corners) + + def pyramid( + self, + levels: int, + dims: Optional[Sequence[SpatialDimArg]] = None, + min_size: int = 0, + ) -> Dict[int, Grid]: + """Compute image size at each image resolution pyramid level. + + This function computes suitable image sizes for each level of a multi-resolution + image pyramid depending on original image size and spacing, a minimum grid size for + every level, and whether grid corners (``align_corners=True``) or grid borders + (``align_corners=False``) should be aligned. + + Args: + levels: Number of resolution levels. + dims: Spatial dimensions along which to downsample. If not specified, consider all spatial dimensions. + min_size: Minimum grid size at each level. If the grid size after downsampling + would be smaller than the specified minimum size, the grid size is not reduced + for this spatial dimension. The number of resolution levels is not affected. + + Returns: + Dictionary mapping resolution level to sampling grid. The sampling grid at the + finest resolution level has index 0. The cube extent, i.e., the physical length + between grid points corresponding to cube interval ``(-1, 1)`` will be the same + for all resolution levels. + + """ + if not isinstance(levels, int): + raise TypeError("Grid.pyramid() 'levels' must be int") + if not isinstance(min_size, int): + raise TypeError("Grid.pyramid() 'min_size' must be int") + if not dims: + dims = tuple(dim for dim in range(self.ndim)) + dims = tuple(SpatialDim.from_arg(dim) for dim in dims) + m = sum([(2**i) for i in range(levels)]) if self._align_corners else 0 + sizes = {level: list(self.size()) for level in range(levels + 1)} + for dim in dims: + sizes[levels][dim] = int(0.5 + (sizes[levels][dim] + m) / 2**levels) + for level in range(levels - 1, -1, -1): + sizes[level][dim] = 2 * sizes[level + 1][dim] - 1 + for level in range(1, levels + 1): + sizes[level][dim] = (sizes[level - 1][dim] + 1) // 2 + if sizes[level][dim] < min_size: + sizes[level][dim] = sizes[level - 1][dim] + return {level: self.resize(size) for level, size in sizes.items()} + + def crop( + self, + *args: int, + margin: Optional[Union[int, Array]] = None, + num: Optional[Union[int, Array]] = None, + ) -> Grid: + """Create new grid with a margin along each axis removed. + + Args: + args: Crop ``margin`` specified as int arguments. + margin: Number of spatial grid points to remove (positive) or add (negative) at each border. + Use instead of ``num`` in order to symmetrically crop the input ``data`` tensor, e.g., + ``(nx, ny, nz)`` is equivalent to ``num=(nx, nx, ny, ny, nz, nz)``. + num: Number of spatial grid points to remove (positive) or add (negative) at each border, + where margin of the last dimension of the ``data`` tensor must be given first, e.g., + ``(nx, nx, ny, ny)``. If a scalar is given, the input is cropped equally at all borders. + Otherwise, the given sequence must have an even length. + + Returns: + New grid with modified ``Grid.size()``, but unchanged ``Grid.spacing``. + Hence, the ``Grid.extent()`` of the new grid will be different from ``self.extent()``. + + """ + if ( + sum([1 if args else 0, 0 if num is None else 1, 0 if margin is None else 1]) + != 1 + ): + raise AssertionError( + "Grid.pad() 'args', 'margin', and 'num' are mutually exclusive" + ) + if len(args) == 1 and not isinstance(args[0], int): + margin = args[0] + elif args: + margin = args + if isinstance(margin, int): + num = margin + elif margin is not None: + num = tuple(n for n_n in ((n, n) for n in margin) for n in n_n) + assert num is not None + if isinstance(num, int): + num = (num,) * (2 * self.ndim) + else: + num = tuple(int(n) for n in num) + if len(num) % 2 != 0: + raise ValueError("Grid.crop() 'num' must be int or have even length") + if all(n == 0 for n in num): + return self + num = num + (0,) * (2 * self.ndim - len(num)) + num_ = paddle.to_tensor(data=num, dtype=self.dtype, place=self.device) + size = paddle.clip(x=self._size - num_[::2] - num_[1::2], min=1) + size = paddle.where( + condition=self._size.greater_than(y=paddle.to_tensor(0)), + x=size, + y=self._size, + ) + origin = self.index_to_world(num_[::2]) + return Grid( + size=size, + origin=origin, + spacing=self.spacing(), + direction=self.direction(), + align_corners=self.align_corners(), + device=self.device, + ) + + def pad( + self, + *args: int, + margin: Optional[Union[int, Array]] = None, + num: Optional[Union[int, Array]] = None, + ) -> Grid: + """Create new grid with an additional margin along each axis. + + Args: + args: Pad ``margin`` specified as int arguments. + margin: Number of spatial grid points to add (positive) or remove (negative) at each border. + Use instead of ``num`` in order to symmetrically pad the input ``data`` tensor, e.g., + ``(nx, ny, nz)`` is equivalent to ``num=(nx, nx, ny, ny, nz, nz)``. + num: Number of spatial grid points to remove (positive) or add (negative) at each border, + where margin of the last dimension of the ``data`` tensor must be given first, e.g., + ``(nx, nx, ny, ny)``. If a scalar is given, the input is cropped equally at all borders. + Otherwise, the given sequence must have an even length. + + Returns: + New grid with modified ``Grid.size()``, but unchanged ``Grid.spacing``. + Hence, the ``Grid.extent()`` of the new grid will be different from ``self.extent()``. + + """ + if ( + sum([1 if args else 0, 0 if num is None else 1, 0 if margin is None else 1]) + != 1 + ): + raise AssertionError( + "Grid.pad() 'args', 'margin', and 'num' are mutually exclusive" + ) + if len(args) == 1 and not isinstance(args[0], int): + margin = args[0] + elif args: + margin = args + if isinstance(margin, int): + num = margin + elif margin is not None: + num = tuple(n for n_n in ((n, n) for n in margin) for n in n_n) + assert num is not None + if isinstance(num, int): + num = (num,) * (2 * self.ndim) + else: + num = tuple(int(n) for n in num) + if len(num) % 2 != 0: + raise ValueError("Grid.pad() 'num' must be int or have even length") + if all(n == 0 for n in num): + return self + num = num + (0,) * (2 * self.ndim - len(num)) + num_ = paddle.to_tensor(data=num, dtype=self.dtype, place=self.device) + size = paddle.clip(x=self._size + num_[::2] + num_[1::2], min=1) + size = paddle.where( + condition=self._size.greater_than(y=paddle.to_tensor(0)), + x=size, + y=self._size, + ) + origin = self.index_to_world(-num_[::2]) + return Grid( + size=size, + origin=origin, + spacing=self.spacing(), + direction=self.direction(), + align_corners=self.align_corners(), + device=self.device, + ) + + def center_crop(self, size: Union[int, Array], *args: int) -> Grid: + """Crop grid to specified maximum size.""" + size = cat_scalars(size, *args, num=self.ndim, device=self.device) + if not is_int_dtype(size.dtype): + raise TypeError( + f"Grid.center_crop() expected scalar or array of integer values, got dtype={size.dtype}" + ) + size = [min(m, n) for m, n in zip(self.size(), size.tolist())] + origin = [((m - n) // 2) for m, n in zip(self.size(), size)] + return Grid( + size=size, + origin=self.index_to_world(origin), + spacing=self.spacing(), + direction=self.direction(), + align_corners=self.align_corners(), + device=self.device, + ) + + def center_pad(self, size: Union[int, Array], *args: int) -> Grid: + """Pad grid to specified minimum size.""" + size = cat_scalars(size, *args, num=self.ndim, device=self.device) + if not is_int_dtype(size.dtype): + raise TypeError( + f"Grid.center_crop() expected scalar or array of integer values, got dtype={size.dtype}" + ) + size = [max(m, n) for m, n in zip(self.size(), size.tolist())] + origin = [(-((n - m) // 2)) for m, n in zip(self.size(), size)] + return Grid( + size=size, + origin=self.index_to_world(origin), + spacing=self.spacing(), + direction=self.direction(), + align_corners=self.align_corners(), + device=self.device, + ) + + def narrow(self, dim: int, start: int, length: int) -> Grid: + """Narrow grid along specified dimension.""" + if dim < 0 or dim > self.ndim: + raise IndexError("Grid.narrow() 'dim' is out of bounds") + size = tuple(length if d == dim else n for d, n in enumerate(self.size())) + origin = tuple(start if d == dim else 0 for d in range(self.ndim)) + return Grid( + size=size, + origin=self.index_to_world(origin), + spacing=self.spacing(), + direction=self.direction(), + align_corners=self.align_corners(), + device=self.device, + ) + + def region_of_interest( + self, start: Union[int, Array], size: Union[int, Array] + ) -> Grid: + """Get region of interest grid.""" + start = cat_scalars(start, num=self.ndim, device=self.device) + if not is_int_dtype(start.dtype): + raise TypeError( + f"Grid.region_of_interest() 'start' must be scalar or array of integer values, got dtype={start.dtype}" + ) + size = cat_scalars(size, num=self.ndim, device=self.device) + if not is_int_dtype(size.dtype): + raise TypeError( + f"Grid.region_of_interest() 'size' must be scalar or array of integer values, got dtype={size.dtype}" + ) + grid_size = self.size() + num = [ + [start[i], grid_size[i] - (start[i] + size[i])] for i in range(self.ndim) + ] + num = [n for nn in num for n in nn] + return self.crop(num=num) + + def same_domain_as(self, other: Grid) -> bool: + """Check if this and another grid cover the same cube domain.""" + if other is self: + return True + return self.domain() == other.domain() + + def __eq__(self, other: Any) -> bool: + """Compare this grid to another.""" + if other is self: + return True + if not isinstance(other, self.__class__): + return False + for name in self.__slots__: + if name == "_align_corners": + continue + value = getattr(self, name) + other_value = getattr(other, name) + if type(value) != type(other_value): + return False + if isinstance(value, paddle.Tensor): + assert isinstance(other_value, paddle.Tensor) + if tuple(value.shape) != tuple(other_value.shape): + return False + other_value = other_value.to(device=value.place) + if not paddle.allclose( + x=value, y=other_value, rtol=1e-05, atol=1e-08 + ).item(): + return False + elif value != other_value: + return False + return True + + def __repr__(self) -> str: + """String representation.""" + size = ", ".join([f"{v:>6.2f}" for v in self._size]) + center = ", ".join([f"{v:.5f}" for v in self._center]) + origin = ", ".join([f"{v:.5f}" for v in self.origin()]) + spacing = ", ".join([f"{v:.5f}" for v in self._spacing]) + direction = ", ".join([f"{v:.5f}" for v in self._direction.flatten()]) + return ( + f"{type(self).__name__}(" + + f"size=({size})" + + f", origin=({origin})" + + f", center=({center})" + + f", spacing=({spacing})" + + f", direction=({direction})" + + f", device={repr(str(self.device))}" + + f", align_corners={repr(self._align_corners)}" + + ")" + ) + + +def grid_points_transform( + grid: Grid, axes: Axes, to_grid: Grid, to_axes: Optional[Axes] = None +): + """Get linear transformation of points from one grid domain to another. + + Args: + grid: Sampling grid with respect to which input points are defined. + axes: Grid axes with respect to which input points are defined. + to_grid: Sampling grid with respect to which output points are defined. + to_axes: Grid axes with respect to which output points are defined. + + Returns: + Homogeneous coordinate transformation matrix as tensor of shape ``(D, D + 1)``. + + """ + return grid.transform(axes=axes, to_axes=to_axes, to_grid=to_grid, vectors=False) + + +def grid_vectors_transform( + grid: Grid, axes: Axes, to_grid: Grid, to_axes: Optional[Axes] = None +): + """Get affine transformation which maps vectors with respect to one grid domain to another. + + Args: + grid: Sampling grid with respect to which input vectors are defined. + axes: Grid axes with respect to which input vectors are defined. + to_grid: Sampling grid with respect to which output vectors are defined. + to_axes: Grid axes with respect to which output vectors are defined. + + Returns: + Affine transformation matrix as tensor of shape ``(D, D)``. + + """ + return grid.transform(axes=axes, to_axes=to_axes, to_grid=to_grid, vectors=True) + + +def grid_transform_points( + points: paddle.Tensor, + grid: Grid, + axes: Axes, + to_grid: Grid, + to_axes: Optional[Axes] = None, + decimals: Optional[int] = -1, +): + """Map point coordinates from one grid domain to another. + + Args: + points: Coordinates of points to transform as tensor of shape ``(..., D)``. + grid: Grid with respect to which input ``points`` are defined. + axes: Coordinate axes with respect to which ``points`` are defined. + to_grid: Grid with respect to which the codomain is defined. If ``None``, the target + and source sampling grids are assumed to be identical. + to_axes: Coordinate axes to which ``points`` should be mapped to. If ``None``, use ``axes``. + decimals: If positive or zero, number of digits right of the decimal point to round to. + When mapping points to codomain ``Axes.GRID``, ``Axes.CUBE``, or ``Axes.CUBE_CORNERS``, + this function by default (``decimals=-1``) rounds the transformed coordinates. + Explicitly set ``decimals=None`` to suppress this default rounding. + + Returns: + Point coordinates in ``axes`` mapped to coordinates with respect to ``to_axes``. + + """ + return grid.transform_points( + points, axes=axes, to_axes=to_axes, to_grid=to_grid, decimals=decimals + ) + + +def grid_transform_vectors( + vectors: paddle.Tensor, + grid: Grid, + axes: Axes, + to_grid: Grid, + to_axes: Optional[Axes] = None, +): + """Rescale and reorient flow vectors. + + Args: + vectors: Displacement vectors to transform, e.g., as tensor of shape ``(..., D)``. + grid: Grid with respect to which input ``vectors`` are defined. + axes: Coordinate axes with respect to which input ``vectors`` are defined. + to_grid: Grid with respect to which ``to_axes`` is defined. If ``None``, the target + and source sampling grids are assumed to be identical. + to_axes: Coordinate axes to which ``vectors`` should be mapped to. If ``None``, use ``axes``. + + Returns: + Rescaled and reoriented vectors. If ``to_grid == grid`` and ``to_axes == axes``, + a reference to the unmodified input ``vectors`` tensor is returned. + + """ + return grid.transform_vectors(vectors, axes=axes, to_axes=to_axes, to_grid=to_grid) diff --git a/jointContribution/HighResolution/deepali/core/image.py b/jointContribution/HighResolution/deepali/core/image.py new file mode 100644 index 0000000000..63917e5fd1 --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/image.py @@ -0,0 +1,2055 @@ +import math +from typing import Dict +from typing import Optional +from typing import Sequence +from typing import Union + +import paddle + +from ..utils import paddle_aux +from .enum import PaddingMode +from .enum import Sampling +from .enum import SpatialDerivativeKeys +from .enum import SpatialDim +from .enum import SpatialDimArg +from .grid import ALIGN_CORNERS +from .grid import Axes +from .grid import Grid +from .grid import grid_transform_points +from .kernels import gaussian1d +from .kernels import gaussian1d_I +from .nnutils import same_padding +from .nnutils import stride_minus_kernel_padding +from .random import multinomial +from .tensor import as_tensor +from .tensor import cat_scalars +from .tensor import move_dim +from .types import Array +from .types import Device +from .types import Scalar +from .types import ScalarOrTuple +from .types import Shape +from .types import Size +from .types import is_float_dtype + + +def avg_pool( + data: paddle.Tensor, + kernel_size: ScalarOrTuple[int], + stride: Optional[ScalarOrTuple[int]] = None, + padding: Optional[ScalarOrTuple[int]] = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + divisor_override: Optional[Scalar] = None, +) -> paddle.Tensor: + """Average pooling of image data.""" + if not isinstance(data, paddle.Tensor): + raise TypeError("avg_pool() 'data' must be paddle.Tensor") + if data.ndim < 4: + raise ValueError("avg_pool() 'data' must have shape (N, C, ..., X)") + D = data.ndim - 2 + if D == 1: + avg_pool_fn = paddle.nn.functional.avg_pool1d + elif D == 2: + avg_pool_fn = paddle.nn.functional.avg_pool2d + elif D == 3: + avg_pool_fn = paddle.nn.functional.avg_pool3d + else: + raise ValueError( + "avg_pool() number of spatial 'data' dimensions must be 1, 2, or 3" + ) + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * D + elif len(kernel_size) != D: + raise ValueError(f"avg_pool() 'kernel_size' must be scalar or {D}-tuple") + if stride is None: + stride = kernel_size + if padding is None: + padding = tuple(n // 2 for n in kernel_size) + return avg_pool_fn( + data, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + divisor_override=divisor_override, + ) + + +def max_pool( + data: paddle.Tensor, + kernel_size: ScalarOrTuple[int], + stride: Optional[ScalarOrTuple[int]] = None, + padding: Optional[ScalarOrTuple[int]] = 0, + dilation: ScalarOrTuple[int] = 1, + ceil_mode: bool = False, +) -> paddle.Tensor: + """Max pooling of image data.""" + if not isinstance(data, paddle.Tensor): + raise TypeError("max_pool() 'data' must be paddle.Tensor") + if data.ndim < 4: + raise ValueError("max_pool() 'data' must have shape (N, C, ..., X)") + D = data.ndim - 2 + if D == 1: + max_pool_fn = paddle.nn.functional.max_pool1d + elif D == 2: + max_pool_fn = paddle.nn.functional.max_pool2d + elif D == 3: + max_pool_fn = paddle.nn.functional.max_pool3d + else: + raise ValueError( + "max_pool() number of spatial 'data' dimensions must be 1, 2, or 3" + ) + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * D + elif len(kernel_size) != D: + raise ValueError(f"max_pool() 'kernel_size' must be scalar or {D}-tuple") + if stride is None: + stride = kernel_size + if padding is None: + padding = tuple(n // 2 for n in kernel_size) + return max_pool_fn( + data, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + ) + + +def min_pool( + data: paddle.Tensor, + kernel_size: ScalarOrTuple[int], + stride: Optional[ScalarOrTuple[int]] = None, + padding: Optional[ScalarOrTuple[int]] = 0, + dilation: ScalarOrTuple[int] = 1, + ceil_mode: bool = False, +) -> paddle.Tensor: + """Min pooling of image data, i.e., negate max_pool() result of negated input data.""" + return -max_pool( + -data, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + ceil_mode=ceil_mode, + ) + + +def conv( + data: paddle.Tensor, + kernel: Union[paddle.Tensor, Sequence[Optional[paddle.Tensor]]], + stride: ScalarOrTuple[int] = 1, + dilation: ScalarOrTuple[int] = 1, + padding: Union[PaddingMode, str, ScalarOrTuple[int]] = None, + output_padding: Optional[ScalarOrTuple[int]] = None, + transpose: bool = False, +) -> paddle.Tensor: + """Convolve images in batch with a given (separable) kernel. + + Args: + data: Image batch tensor of shape ``(N, C, ..., X)``. + kernel: paddle.Tensor of shape ``(..., X)`` with weights of kernel used to filter the images + in this batch by. If the input ``data`` tensor is of non-floating point type, the + dtype of the kernel defines the intermediate data type used for convolutions. + If a 1-dimensional kernel is given, it is used as separable convolution kernel in + all spatial image dimensions. Otherwise, the kernel is applied to the last spatial + image dimensions. For example, a 2D kernel applied to a batch of 3D image volumes + is applied slice-by-slice by convolving along the y and x image axes. + In order to anisotropically convolve the input data with 1-dimensional kernels of + different sizes, a sequence of at most ``D`` 1-dimensional kernel tensors can be given, + where ``D`` is the number of spatial dimensions. If the sequence contains ``None``, + no convolution is performed along the corresponding spatial dimension. The first kernel + in the sequence is applied to the last spatial grid dimension, which corresponds to + the ``data`` tensor dimension ``X``, e.g., ``(kz, ky, kx)``. + stride: Stride by which convolution kernel is advanced. + dilation: Spacing between kernel elements. + padding: Image padding mode. If ``int``, pad with zeros the specified margin at each + side. Otherwise, use ``same_padding()`` calculated from kernel size and dilation such + that output size is equal to input size, unless ``PaddingMode.NONE`` is given. If ``None``, + use default mode ``PaddingMode.ZEROS`` with "same" padding. + output_padding: Output padding for transposed convolution. + transpose: Whether to compute transposed convolution. + + Returns: + Result of filtering operation with data type set to the image data type before convolution. + If dtype is not a floating point data type, the filtered data is being rounded and clamped. + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("conv() 'data' must be paddle.Tensor") + if data.ndim < 3: + raise ValueError("conv() 'data' must have shape (N, C, ..., X)") + D = data.ndim - 2 + if isinstance(kernel, paddle.Tensor) and kernel.ndim == 1: + kernel = [kernel] * D + dtype = data.dtype + if isinstance(kernel, paddle.Tensor): + if isinstance(padding, int): + margin = (padding,) * kernel.ndim + padding = PaddingMode.ZEROS + elif isinstance(padding, Sequence): + margin = padding + padding = PaddingMode.ZEROS + else: + margin = same_padding(tuple(kernel.shape), dilation) + padding = PaddingMode.from_arg(padding) + kernel_dtype = kernel.dtype + else: + kernel = list(kernel) + kernel_dtype = None + kernel_size = [] + for k in kernel: + if k is None: + kernel_size.append(1) + else: + if kernel_dtype is None: + kernel_dtype = k.dtype + if k.ndim != 1: + raise ValueError( + "conv() 'kernel' must be n-dimensional tensor or sequence of 1-dimensional tensors" + ) + kernel_size.append(len(k)) + if kernel_dtype is None: + return data + if isinstance(padding, int): + margin = (padding,) * len(kernel) + padding = PaddingMode.ZEROS + elif isinstance(padding, Sequence): + margin = padding + padding = PaddingMode.ZEROS + else: + margin = same_padding(kernel_size, dilation) + padding = PaddingMode.from_arg(padding) + if sum(margin) != 0 and padding not in (PaddingMode.NONE, PaddingMode.ZEROS): + if transpose: + raise NotImplementedError( + f"conv() 'transpose=True' with padding {padding.value}" + ) + margin = tuple(reversed(margin)) + tensor = pad(data, margin=margin, mode=padding) + return conv( + tensor, kernel, stride=stride, dilation=dilation, padding=PaddingMode.NONE + ) + if not is_float_dtype(dtype): + dtype = kernel_dtype if is_float_dtype(kernel_dtype) else "float32" + tensor = data.astype(dtype) + device = tensor.place + if isinstance(kernel, paddle.Tensor): + K = kernel.ndim + if K > D: + raise ValueError("conv() 'kernel' has too many dimensions") + if K == 2: + conv_fn = ( + paddle.nn.functional.conv2d_transpose + if transpose + else paddle.nn.functional.conv2d + ) + elif K == 3: + conv_fn = ( + paddle.nn.functional.conv3d_transpose + if transpose + else paddle.nn.functional.conv3d + ) + else: + raise ValueError("conv() 'kernel' must have 1, 2, or 3 spatial dimensions") + shape_ = tuple(tensor.shape) + kernel = kernel.to(dtype=dtype, device=device) + kernel = kernel.reshape(1, 1, *tuple(kernel.shape)) + if tensor.ndim > kernel.ndim: + groups = tuple(tensor.shape)[1:-K].size + tensor = tensor.reshape( + tuple(tensor.shape)[0], groups, *tuple(tensor.shape)[-K:] + ) + else: + groups = tuple(tensor.shape)[1] + weight = kernel.expand(shape=[groups, 1, *tuple(kernel.shape)[-K:]]) + kwargs = dict( + tensor, + weight, + stride=stride, + dilation=dilation, + padding=0 if padding == PaddingMode.NONE else margin, + groups=groups, + ) + if transpose: + if output_padding is None: + output_padding = stride_minus_kernel_padding(1, stride) + kwargs["output_padding"] = output_padding + result = conv_fn(tensor, weight, **kwargs) + result = result.reshape(shape_[:-K] + tuple(result.shape)[-K:]) + else: + result = tensor + kernels = list(kernel) + for i, k in enumerate(kernels): + if k is not None: + kernels[i] = k.to(dtype=dtype, device=device) + K = len(kernels) + if isinstance(stride, int): + stride = (stride,) * K + elif len(stride) != K: + raise ValueError(f"conv() 'stride' must be int or sequence of {K} ints") + if isinstance(dilation, int): + dilation = (dilation,) * K + elif len(dilation) != K: + raise ValueError(f"conv() 'dilation' must be int or sequence of {K} ints") + if output_padding is None: + output_padding = stride_minus_kernel_padding(1, stride) + elif isinstance(output_padding, int): + output_padding = (output_padding,) * K + elif len(output_padding) != K: + raise ValueError( + f"conv() 'output_padding' must be None, int, or sequence of {K} ints" + ) + args = zip(kernels, stride, dilation, margin, output_padding) + for i, (k, s, d, p, op) in enumerate(args): + if k is None: + continue + result = conv1d( + result, + k, + dim=result.ndim - K + i, + stride=s, + dilation=d, + padding=0 if padding == PaddingMode.NONE else p, + output_padding=op, + transpose=transpose, + ) + if not paddle.is_floating_point(x=data): + result = result.round_() + result = result.clip_(min=float(data.min()), max=float(data.max())) + result = result.astype(dtype=data.dtype) + return result + + +def conv1d( + data: paddle.Tensor, + kernel: paddle.Tensor, + dim: int = -1, + stride: int = 1, + dilation: int = 1, + padding: Union[PaddingMode, str, int] = None, + output_padding: Optional[int] = None, + transpose: bool = False, + dtype: Optional[paddle.dtype] = None, +) -> paddle.Tensor: + """Convolve data with 1-dimensional kernel along specified dimension.""" + if not isinstance(data, paddle.Tensor): + raise TypeError("conv1d() 'data' must be paddle.Tensor") + if data.ndim < 3: + raise ValueError("conv1d() 'data' must have shape (N, C, ..., X)") + if not isinstance(kernel, paddle.Tensor): + raise TypeError("conv1d() 'kernel' must be of type paddle.Tensor") + if kernel.ndim != 1: + raise ValueError("conv1d() 'kernel' must be 1-dimensional") + if dtype is None: + dtype = data.dtype + if is_float_dtype(dtype): + kernel = kernel.astype(dtype) + elif not is_float_dtype(kernel.dtype): + kernel = kernel.astype("float32") + if isinstance(padding, int): + margin = padding + padding = PaddingMode.ZEROS + else: + padding = PaddingMode.from_arg(padding) + if padding is PaddingMode.NONE: + margin = 0 + else: + margin = same_padding(tuple(kernel.shape), dilation) + result = data.astype(kernel.dtype) + result = move_dim(result, dim, -1) + shape_ = result.shape + result = result.reshape([shape_[0], -1, shape_[-1]]) + groups = result.shape[1] + weight = kernel.expand([groups, 1, kernel.shape[-1]]) + result = result.reshape(shape_[0], groups, shape_[-1]) + if margin and padding is not PaddingMode.ZEROS: + result = paddle_aux._FUNCTIONAL_PAD( + pad=(margin, margin), mode=padding.pad_mode(1), x=result + ) + margin = 0 + conv_fn = ( + paddle.nn.functional.conv1d_transpose + if transpose + else paddle.nn.functional.conv1d + ) + kwargs = dict(stride=stride, dilation=dilation, padding=margin, groups=groups) + if transpose: + if output_padding is None: + output_padding = stride_minus_kernel_padding(1, stride) + kwargs["output_padding"] = output_padding + result = conv_fn(result, weight, **kwargs) + result = result.reshape(shape_[0:-1] + list(result.shape)[-1:]) + result = move_dim(result, -1, dim) + if not is_float_dtype(dtype): + result = result.round_() + result = result.clip_(min=float(data.min()), max=float(data.max())) + result = result.astype(dtype) + return result + + +def dot_batch( + a: paddle.Tensor, b: paddle.Tensor, weight: Optional[paddle.paddle.Tensor] = None +) -> paddle.Tensor: + """Weighted dot product between batches of image batch tensors. + + Args: + a: Image data tensor of shape ``(N, C, ..., X)``. + b: Image data tensor of shape ``(N, C, ..., X)``. + + Returns: + paddle.Tensor of shape ``(N,)`` containing batchwise dot products. + + """ + return dot_channels(a, b, weight=weight).sum(axis=1) + + +def dot_channels( + a: paddle.Tensor, b: paddle.Tensor, weight: Optional[paddle.Tensor] = None +) -> paddle.Tensor: + """Weighted dot product between channels of image batch tensors. + + Args: + a: Image data tensor of shape ``(N, C, ..., X)``. + b: Image data tensor of shape ``(N, C, ..., X)``. + + Returns: + paddle.Tensor of shape ``(N, C)`` containing channelwise dot products. + + """ + if not isinstance(a, paddle.Tensor): + raise TypeError("dot_channels() 'a' and 'b' must be tensors") + if a.ndim < 4: + raise ValueError("dot_channels() 'a' and 'b' must have shape (N, C, ..., X)") + if tuple(a.shape) != tuple(b.shape): + raise ValueError("dot_channels() 'a' and 'b' must have identical shape") + c = a * b + if weight is not None: + c *= weight + return c.view(tuple(c.shape)[0], tuple(c.shape)[1], -1).sum(axis=2) + + +def downsample( + data: paddle.Tensor, + levels: int = 1, + dims: Optional[Sequence[SpatialDimArg]] = None, + sigma: Optional[Union[Scalar, Array]] = None, + mode: Optional[Union[Sampling, str]] = None, + min_size: int = 0, + align_corners: bool = ALIGN_CORNERS, +) -> paddle.Tensor: + """Downsample images after optional convolution with truncated Gaussian kernel. + + Args: + data: Image batch tensor of shape ``(N, C, ..., X)``. + levels: Number of times the image size is halved. If zero, a reference to the + unmodified input ``data`` tensor is returned. If negative, the images are + upsampled instead. + dims: Spatial dimensions along which to downsample. If not specified, consider all spatial dimensions. + sigma: Standard deviation of Gaussian used for each downsampling step. + If a scalar or 1-element sequence is given, an isotropic Gaussian + kernel is used. Otherwise, the first value is the standard deviation + of the 1-dimensional Gaussian applied to the first grid dimension, which + is the last ``data`` tensor dimension, e.g., ``(sx, sy)``. If ``sigma=0``, + no low-pass filter is applied. If ``None``, a default value is used. + min_size: Required minimum grid size. + align_corners: Whether to preserve corner points (True) or grid extent (False). + + Returns: + Downsampled image data. + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("downsample() 'data' must be paddle.Tensor") + if data.ndim < 4: + raise ValueError("downsample() 'data' must have shape (N, C, ..., X)") + if not paddle.is_floating_point(x=data): + raise TypeError("downsample() 'data' must have floating point dtype") + try: + levels = int(levels) + except TypeError: + raise TypeError("downsample() 'levels' must be scalar of type int") + if levels == 0: + return data + if levels < 0: + return upsample( + data, + levels=-levels, + dims=dims, + sigma=sigma, + mode=mode, + align_corners=align_corners, + ) + grid = Grid(shape=tuple(data.shape)[2:]) + if not dims: + dims = tuple(dim for dim in range(grid.ndim)) + dims = tuple(SpatialDim.from_arg(dim) for dim in dims) + grid = grid.downsample( + levels, dims=dims, min_size=min_size, align_corners=align_corners + ) + if sigma is None: + sigma = 0.7355 + sigma: paddle.Tensor = paddle.atleast_1d(as_tensor(sigma, dtype="float32")) + if sigma.ndim > 1: + raise ValueError("downsample() 'sigma' must be scalar or 1-dimensional tensor") + if sigma.not_equal(y=paddle.to_tensor(0.0)).astype("bool").any(): + if levels > 1: + var = paddle.zeros(shape=tuple(sigma.shape), dtype=sigma.dtype) + for level in range(levels): + var += sigma.mul(2**level).pow(y=2) + sigma = var.sqrt() + if tuple(sigma.shape)[0] == 1: + sigma = sigma.repeat(grid.ndim) + for i in range(grid.ndim): + if i in dims: + continue + sigma[i] = 0 + kernels = [] + kernels_ = {} + for i in range(grid.ndim): + std = float(sigma[i] if i < len(sigma) else 0) + if std > 0 and ( + tuple(grid.size())[i] != tuple(data.shape)[grid.ndim - i + 1] + ): + kernel = kernels_.get(std) + if kernel is None: + kernel = gaussian1d(std, dtype="float32", device=data.place) + kernels_[std] = kernel + kernels.append(kernel) + else: + kernels.append(None) + data = conv(data, reversed(kernels)) + mode = Sampling.from_arg(mode).interpolate_mode(grid.ndim) + return paddle.nn.functional.interpolate( + x=data, size=tuple(grid.shape), mode=mode, align_corners=align_corners + ) + + +def upsample( + data: paddle.Tensor, + levels: int = 1, + dims: Optional[Sequence[SpatialDimArg]] = None, + sigma: Optional[Union[Scalar, Array]] = None, + mode: Optional[Union[Sampling, str]] = None, + align_corners: bool = ALIGN_CORNERS, +) -> paddle.Tensor: + """Upsample images and opitonally deconvolve with truncated Gaussian kernel. + + Args: + data: Image batch tensor of shape ``(N, C, ..., X)``. + levels: Number of times the image size is doubled. If zero, a reference to the + unmodified input ``data`` tensor is returned. If negative, the images are + downsampled instead. + dims: Spatial dimensions along which to upsample. If not specified, consider all spatial dimensions. + sigma: Standard deviation of Gaussian used for each upsampling step. + If a scalar or 1-element sequence is given, an isotropic Gaussian + kernel is used. Otherwise, the first value is the standard deviation + of the 1-dimensional Gaussian applied to the first grid dimension, which + is the last ``data`` tensor dimension, e.g., ``(sx, sy)``. If ``sigma=0`` + or ``None``, no transposed convolution is applied. + align_corners: Whether to preserve corner points (True) or grid extent (False). + + Returns: + Upsampled image data. + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("upsample() 'data' must be paddle.Tensor") + if data.ndim < 4: + raise ValueError("upsample() 'data' must have shape (N, C, ..., X)") + if not paddle.is_floating_point(x=data): + raise TypeError("upsample() 'data' must have floating point dtype") + try: + levels = int(levels) + except TypeError: + raise TypeError("upsample() 'levels' must be scalar of type int") + if levels == 0: + return data + if levels < 0: + return downsample( + data, + levels=-levels, + dims=dims, + sigma=sigma, + mode=mode, + align_corners=align_corners, + ) + grid = Grid(shape=tuple(data.shape)[2:], align_corners=align_corners) + if not dims: + dims = tuple(dim for dim in range(grid.ndim)) + dims = tuple(SpatialDim.from_arg(dim) for dim in dims) + grid = grid.upsample(levels, dims=dims) + mode = Sampling.from_arg(mode).interpolate_mode(grid.ndim) + result: paddle.Tensor = paddle.nn.functional.interpolate( + x=data, size=tuple(grid.shape), mode=mode, align_corners=align_corners + ) + if sigma is not None: + sigma: paddle.Tensor = paddle.atleast_1d(as_tensor(sigma, dtype="float32")) + if sigma.ndim > 1: + raise ValueError( + "upsample() 'sigma' must be scalar or 1-dimensional tensor" + ) + if levels > 1: + var = paddle.zeros(shape=tuple(sigma.shape), dtype=sigma.dtype) + for level in range(levels): + var += sigma.mul(2**level).pow(y=2) + sigma = var.sqrt() + if tuple(sigma.shape)[0] == 1: + sigma = sigma.repeat(grid.ndim) + for i in range(grid.ndim): + if i in dims: + continue + sigma[i] = 0 + kernels = [] + kernels_ = {} + for i in range(grid.ndim): + std = float(sigma[i] if i < len(sigma) else 0) + if std > 0: + kernel = kernels_.get(std) + if kernel is None: + kernel = gaussian1d(std, dtype="float32", device=result.place) + kernels_[std] = kernel + kernels.append(kernel) + else: + kernels.append(None) + result = conv(result, reversed(kernels), transpose=True) + return result + + +def gaussian_pyramid( + data: paddle.Tensor, + levels: int, + start: int = 0, + dims: Optional[Sequence[SpatialDimArg]] = None, + sigma: Optional[Union[Scalar, Array]] = None, + mode: Optional[Union[Sampling, str]] = None, + min_size: int = 0, + align_corners: bool = ALIGN_CORNERS, +) -> Dict[int, paddle.Tensor]: + """Create Gaussian image resolution pyramid. + + Args: + data: Image data tensor of shape ``(N, C, ..., X)``. + levels: Coarsest resolution level. + start: Finest resolution level, where 0 corresponds to the original resolution. + dims: Spatial dimensions along which to downsample. If not specified, consider all spatial dimensions. + sigma: Standard deviation of Gaussian filter applied at each downsampling level. + mode: Interpolation mode for resampling image data on downsampled grid. + min_size: Minimum grid size. + align_corners: Whether to preserve corner points (True) or grid extent (False). + + Returns: + Dictionary of downsampled image tensors with keys corresponding to level indices. + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("gaussian_pyramid() 'data' must be paddle.Tensor") + if data.ndim < 4: + raise ValueError("gaussian_pyramid() 'data' must have shape (N, C, ..., X)") + if not isinstance(levels, int): + raise TypeError("gaussian_pyramid() 'levels' must be of type int") + if levels < 0: + raise ValueError("gaussian_pyramid() 'levels' must be positive") + if not isinstance(start, int): + raise TypeError("gaussian_pyramid() 'start' must be of type int") + if start < 0: + raise ValueError("gaussian_pyramid() 'start' must not be negative") + pyramid = {(0): data} + if start == 0: + start = 1 + levels -= 1 + for i, level in enumerate(range(start, start + levels)): + data = downsample( + data, + levels=level if i == 0 else 1, + dims=dims, + sigma=sigma, + mode=mode, + min_size=min_size, + align_corners=align_corners, + ) + pyramid[level] = data + return pyramid + + +def crop( + data: paddle.Tensor, + margin: Optional[Union[int, Array]] = None, + num: Optional[Union[int, Array]] = None, + mode: Union[PaddingMode, str] = PaddingMode.CONSTANT, + value: Scalar = 0, +) -> paddle.Tensor: + """Crop or pad images at border. + + Args: + data: Image batch tensor of shape ``(N, C, ..., X)``. + margin: Number of spatial grid points to remove (positive) or add (negative) at each border. + Use instead of ``num`` in order to symmetrically crop the input ``data`` tensor, e.g., + ``(nx, ny, nz)`` is equivalent to ``num=(nx, nx, ny, ny, nz, nz)``. + num: Number of spatial gird points to remove (positive) or add (negative) at each border, + where margin of the last dimension of the ``data`` tensor must be given first, e.g., + ``(nx, nx, ny, ny)``. If a scalar is given, the input is cropped equally at all borders. + Otherwise, the given sequence must have an even length. + mode: Image extrapolation mode. + value: Constant value used for extrapolation if ``mode=PaddingMode.CONSTANT``. + + Returns: + Cropped or padded image batch data of shape ``(N, C, ..., X)``. + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("crop() 'data' must be paddle.Tensor") + if data.ndim < 4: + raise ValueError("crop() 'data' must have shape (N, C, ..., X)") + D = data.ndim - 2 + if num is not None and margin is not None: + raise AssertionError("crop() 'margin' and 'num' are mutually exclusive") + if isinstance(num, int): + pad_ = (-num,) * 2 * D + elif num is not None: + pad_ = tuple(-int(n) for n in num) + if len(pad_) % 2 == 1: + raise ValueError("crop() 'num' must be int or have even length") + elif isinstance(margin, int): + pad_ = (-margin,) * 2 * D + elif margin is not None: + pad_ = tuple(-int(n) for nn in ((n, n) for n in margin) for n in nn) + else: + raise AssertionError("crop() either 'margin' or 'num' is required") + if all(n == 0 for n in pad_): + return data + mode = PaddingMode.from_arg(mode) + if mode == PaddingMode.ZEROS: + mode = PaddingMode.CONSTANT + value = 0 + else: + value = float(value) + mode = mode.pad_mode(D) + return paddle_aux._FUNCTIONAL_PAD(pad=pad_, mode=mode, value=value, x=data) + + +def pad( + data: paddle.Tensor, + margin: Optional[Union[int, Array]] = None, + num: Optional[Union[int, Array]] = None, + mode: Union[PaddingMode, str] = PaddingMode.CONSTANT, + value: Scalar = 0, +) -> paddle.Tensor: + """Pad or crop images at border. + + Args: + data: Image batch tensor of shape ``(N, C, ..., X)``. + margin: Number of spatial grid points to add (positive) or remove (negative) at each border, + Use instead of ``num`` in order to symmetrically pad the input ``data`` tensor. + num: Number of spatial gird points to add (positive) or remove (negative) at each border, + where margin of the last dimension of the ``data`` tensor must be given first, e.g., + ``(nx, ny, nz)``. If a scalar is given, the input is padded equally at all borders. + Otherwise, the given sequence must have an even length. + mode: Image extrapolation mode. + value: Constant value used for extrapolation if ``mode=PaddingMode.CONSTANT``. + + Returns: + Padded or cropped image batch data of shape ``(N, C, ..., X)``. + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("pad() 'data' must be paddle.Tensor") + if data.ndim < 4: + raise ValueError("pad() 'data' must have shape (N, C, ..., X)") + D = data.ndim - 2 + if num is not None and margin is not None: + raise AssertionError("pad() 'margin' and 'num' are mutually exclusive") + if isinstance(num, int): + pad_ = (num,) * 2 * D + elif num is not None: + pad_ = tuple(int(n) for n in num) + if len(pad_) % 2 == 1: + raise ValueError("pad() 'num' must be int or have even length") + elif isinstance(margin, int): + pad_ = (margin,) * 2 * D + elif margin is not None: + pad_ = tuple(int(n) for nn in ((n, n) for n in margin) for n in nn) + else: + raise AssertionError("pad() either 'pad' or 'margin' is required") + if all(n == 0 for n in pad_): + return data + mode = PaddingMode.from_arg(mode) + if mode == PaddingMode.ZEROS: + mode = PaddingMode.CONSTANT + value = 0 + else: + value = float(value) + mode = mode.pad_mode(D) + return paddle_aux._FUNCTIONAL_PAD(pad=pad_, mode=mode, value=value, x=data) + + +def center_crop(data: paddle.Tensor, size: Union[int, Sequence[int]]) -> paddle.Tensor: + """Crop image tensor to specified maximum size. + + Args: + data: Input tensor of shape ``(N, C, ..., X)``. + size: Maximum output size, where the size of the last tensor + dimension must be given first, i.e., ``(X, ...)``. + If an ``int`` is given, all spatial output dimensions + are cropped to this maximum size. If the length of size + is less than the spatial dimensions of the ``data`` tensor, + then only the last ``len(size)`` dimensions are modified. + + Returns: + Output tensor of shape ``(N, C, ..., X)``. + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("center_crop() 'data' must be paddle.Tensor") + if data.dim() < 4: + raise ValueError("center_crop() 'data' must be tensor of shape (N, C, ..., X)") + sdim = data.dim() - 2 + if isinstance(size, int): + shape = (size,) * sdim + elif len(size) == 0: + return data + elif len(size) > sdim: + raise ValueError( + "center_crop() 'data' has fewer spatial dimensions than output 'size'" + ) + else: + shape = tuple(data.shape)[2:][: sdim - len(size)] + tuple(reversed(size)) + crop = [max(0, m - n) for m, n in zip(tuple(data.shape)[2:], shape)] + if sum(crop) == 0: + return data + crop = (n // 2 for n in crop) + crop = (slice(0, tuple(data.shape)[0]), slice(0, tuple(data.shape)[1])) + tuple( + slice(i, i + n) for i, n in zip(crop, shape) + ) + return data[crop] + + +def center_pad( + data: paddle.Tensor, + size: Union[int, Sequence[int]], + mode: Union[PaddingMode, str] = PaddingMode.CONSTANT, + value: Scalar = 0, +) -> paddle.Tensor: + """Pad image tensor to specified minimum size. + + Args: + data: Input tensor of shape ``(N, C, ..., X)``. + size: Minimum output size, where the size of the last tensor + dimension must be given first, i.e., ``(X, ...)``. + If an ``int`` is given, all spatial output dimensions + are cropped or padded to this size. If the length of size + is less than the spatial dimensions of the ``data`` tensor, + then only the last ``len(size)`` dimensions are modified. + mode: PaddingMode mode (cf. ``paddle.nn.functional.pad()``). + value: Value for padding mode "constant". + + Returns: + Output tensor of shape ``(N, C, ..., X)``. + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("center_pad() 'data' must be paddle.Tensor") + if data.dim() < 4: + raise ValueError("center_pad() 'data' must be tensor of shape (N, C, ..., X)") + sdim = data.dim() - 2 + if isinstance(size, int): + shape = (size,) * sdim + elif len(size) == 0: + return data + elif len(size) > sdim: + raise ValueError( + "center_pad() 'data' has fewer spatial dimensions than output 'size'" + ) + else: + shape = tuple(data.shape)[2:][: sdim - len(size)] + tuple(reversed(size)) + pad = [max(0, n - m) for m, n in zip(tuple(data.shape)[2:], shape)] + pad = [(n // 2, (n + 1) // 2) for n in reversed(pad)] + pad = [n for x in pad for n in x] + if sum(pad) == 0: + return data + mode = PaddingMode.from_arg(mode) + if mode == PaddingMode.ZEROS: + mode = PaddingMode.CONSTANT + value = 0 + mode = mode.pad_mode(sdim) + return paddle_aux._FUNCTIONAL_PAD(pad=pad, mode=mode, value=value, x=data) + + +def region_of_interest( + data: paddle.Tensor, + start: ScalarOrTuple[int], + size: ScalarOrTuple[int], + padding: Union[PaddingMode, str, float] = PaddingMode.CONSTANT, + value: float = 0, +) -> paddle.Tensor: + """Extract region of interest from image tensor. + + Args: + data: Input tensor of shape ``(N, C, ..., X)``. + start: Indices of lower left corner of region of interest, e.g., ``(x, y, z)``. + size: Size of region of interest, e.g., ``(nx, ny, nz)``. + padding: Padding mode to use when extrapolating input image or constant fill value. + value: Fill value to use when ``padding=Padding.CONSTANT``. + + Returns: + paddle.Tensor of shape ``(N, C, ..., X)`` with spatial size equal to ``size``. + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("region_of_interest() 'data' must be paddle.Tensor") + if data.dim() < 4: + raise ValueError( + "region_of_interest() 'data' must be tensor of shape (N, C, ..., X)" + ) + sdim = data.dim() - 2 + num = [] + if isinstance(start, int): + start = (start,) * sdim + elif not isinstance(start, Sequence) or not all(isinstance(n, int) for n in start): + raise TypeError("region_of_interest() 'start' must be int or sequence of ints") + elif len(start) != 3: + raise ValueError( + f"region_of_interest() 'start' must be int or sequence of length {sdim}" + ) + if isinstance(size, int): + size = (size,) * sdim + elif not isinstance(size, Sequence) or not all(isinstance(n, int) for n in size): + raise TypeError("region_of_interest() 'size' must be int or sequence of ints") + elif len(size) != 3: + raise ValueError( + f"region_of_interest() 'size' must be int or sequence of length {sdim}" + ) + if isinstance(padding, (PaddingMode, str)): + mode = PaddingMode.from_arg(padding) + value = value + elif isinstance(padding, (int, float)): + mode = PaddingMode.CONSTANT + value = padding + else: + raise TypeError( + "region_of_interest() 'padding' must be str, Padding, or fill value" + ) + num = [ + [start[i], tuple(data.shape)[data.ndim - 1 - i] - (start[i] + size[i])] + for i in range(sdim) + ] + num = [n for nn in num for n in nn] + return crop(data, num=num, mode=mode, value=value) + + +def fill_border( + data: paddle.Tensor, + margin: ScalarOrTuple[int], + value: float = 0, + inplace: bool = False, +) -> paddle.Tensor: + """Fill image border with specified value.""" + if not isinstance(data, paddle.Tensor): + raise TypeError("fill_border() 'data' must be paddle.Tensor") + if data.ndim < 4: + raise ValueError("fill_border() 'data' must have shape (N, C, ..., X)") + if isinstance(margin, int): + margin = (margin,) * (data.ndim - 2) + if len(margin) > data.ndim - 2: + raise ValueError( + f"fill_border() 'margin' must be at most {data.ndim - 2}-dimensional" + ) + if not inplace: + data = data.clone() + for i, m in enumerate(margin): + dim = data.ndim - i - 1 + idx = slice(0, m) + idx = tuple(idx if j == dim else slice(None) for j in range(data.ndim)) + data[idx] = value + idx = slice(tuple(data.shape)[dim] - m, tuple(data.shape)[dim]) + idx = tuple(idx if j == dim else slice(None) for j in range(data.ndim)) + data[idx] = value + return data + + +def flatten_channels(data: paddle.Tensor) -> paddle.Tensor: + """Flatten image tensor channels. + + + Args: + data: Input tensor of shape ``(N, C, ..., X)``. + + Returns: + paddle.Tensor of shape ``(C, N * ... * X)``. + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("flatten_channels() 'data' must be paddle.Tensor") + if data.ndim < 4: + raise ValueError("flatten_channels() 'data' must have shape (N, C, ..., X)") + C = data.shape[1] + axis_order = (1, 0) + tuple(range(2, data.ndim)) + transposed = data.transpose(perm=axis_order) + return transposed.view(C, -1) + + +def grid_resample( + data: paddle.Tensor, + in_spacing: Union[float, Array], + out_spacing: Union[float, Array], + *args: float, + mode: Union[Sampling, str] = None, + padding: Union[PaddingMode, str, Scalar] = None, +) -> paddle.Tensor: + """Interpolate image on minimum bounding grid with specified spacing. + + Args: + data: Image batch tensor with shape ``(N, C, ..., X)``. + in_spacing: Current grid spacing. + out_spacing: Spacing of grid on which to sample images, where the spacing + of the first grid dimension, which is the last ``data`` dimension + must be given first, e.g., ``(sx, sy, sz)``. If a scalar value is + given, the images are resampled to this isotropic spacing. + mode: Image data interpolation mode. + padding: Image data extrapolation mode. + + Returns: + This image batch with given spacing and interpolated image tensor. + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("grid_resample() 'data' must be paddle.Tensor") + if data.ndim < 4: + raise ValueError("grid_resample() 'data' must have shape (N, C, ..., X)") + D = data.ndim - 2 + in_spacing = cat_scalars(in_spacing, *args, num=D, device=data.place) + out_spacing = cat_scalars(out_spacing, *args, num=D, device=data.place) + mode = Sampling.from_arg(mode).interpolate_mode(D) + input_grid = Grid(shape=tuple(data.shape)[2:], spacing=in_spacing) + output_grid = input_grid.resample(out_spacing) + if tuple(output_grid.shape) == tuple(input_grid.shape): + return data + align_corners = input_grid.align_corners() + axes = Axes.from_align_corners(align_corners) + coords = output_grid.coords(align_corners=align_corners, device=data.place) + coords = grid_transform_points(coords, output_grid, axes, input_grid, axes) + return grid_sample( + data, coords, mode=mode, padding=padding, align_corners=align_corners + ) + + +def grid_reshape( + data: paddle.Tensor, + shape: Union[int, Array, Shape], + *args: int, + mode: Union[Sampling, str] = Sampling.LINEAR, + align_corners: bool = ALIGN_CORNERS, +) -> paddle.Tensor: + """Interpolate image with specified spatial image tensor shape. + + Args: + data: Image batch tensor with shape ``(N, C, ..., X)``. + shape: Size of spatial image dimensions, where the size of the first grid + dimension, which is the last ``data`` dimension, must be given last, + e.g., ``(nz, ny, nx)``. + mode: Image data interpolation mode. + align_corners: Whether to preserve corner points (True) or grid extent (False). + + Returns: + Interpolated image data of specified ``size`` of spatial dimensions. + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("grid_resample() 'data' must be paddle.Tensor") + if data.ndim < 4: + raise ValueError("grid_reshape() 'data' must have shape (N, C, ..., X)") + D = data.ndim - 2 + shape_ = cat_scalars(shape, *args, num=D) + return grid_resize( + data, shape_.flip(axis=0), mode=mode, align_corners=align_corners + ) + + +def grid_resize( + data: paddle.Tensor, + size: Union[int, Array, Shape], + *args: int, + mode: Union[Sampling, str] = Sampling.LINEAR, + align_corners: bool = ALIGN_CORNERS, +) -> paddle.Tensor: + """Interpolate image with specified spatial image tensor shape. + + Args: + data: Image batch tensor with shape ``(N, C, ..., X)``. + size: Size of spatial image dimensions, where size of first grid dimension, which + is the last ``data`` dimension, must be given first, e.g., ``(nx, ny, nz)``. + mode: Image data interpolation mode. + align_corners: Whether to preserve corner points (True) or grid extent (False). + + Returns: + Interpolated image data of specified ``size`` of spatial dimensions. + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("grid_resample() 'data' must be paddle.Tensor") + if data.ndim < 4: + raise ValueError("grid_resize() 'data' must have shape (N, C, ..., X)") + D = data.ndim - 2 + size_ = cat_scalars(size, *args, num=D) + mode_ = Sampling.from_arg(mode).interpolate_mode(D) + grid = Grid(shape=tuple(data.shape)[2:], align_corners=align_corners) + grid = grid.resize(size_) + if tuple(grid.shape) == tuple(data.shape)[2:]: + return data + if mode_ in ("area", "nearest", "nearest-exact"): + align_corners = None + res_shape = (grid.shape[-1],) + grid.shape[:-1] + return paddle.nn.functional.interpolate( + x=data, size=res_shape, mode=mode_, align_corners=align_corners + ) + + +def check_sample_grid( + func: str, data: paddle.Tensor, grid: paddle.Tensor +) -> paddle.Tensor: + """Normalize sample grid tensor.""" + if not isinstance(data, paddle.Tensor): + raise TypeError(f"{func}() 'data' must be paddle.Tensor") + if data.ndim < 4: + raise ValueError(f"{func}() 'data' must have shape (N, C, ..., X)") + if not isinstance(grid, paddle.Tensor): + raise TypeError(f"{func}() 'grid' must be paddle.Tensor") + N = tuple(data.shape)[0] + D = data.ndim - 2 + if not paddle.is_floating_point(x=grid): + raise TypeError("{func}() 'grid' must have floating point dtype") + if tuple(grid.shape)[-1] != D: + raise ValueError( + f"Last {func}() 'grid' dimension size must match number of spatial image dimensions" + ) + if grid.ndim == data.ndim - 1: + grid = grid.unsqueeze(axis=0) + elif grid.ndim != data.ndim: + raise ValueError( + f"{func}() expected 'grid' tensor with {data.ndim - 1} or {data.ndim} dimensions" + ) + if N == 1 and tuple(grid.shape)[0] > 1: + N = tuple(grid.shape)[0] + elif N > 1 and tuple(grid.shape)[0] == 1: + grid = grid.expand(shape=[N, *tuple(grid.shape)[1:]]) + if tuple(grid.shape)[0] != N: + msg = f"{func}() expected tensor 'grid' of shape (..., X, {D})" + msg += f" or (1, ..., X, {D})" if N == 1 else f" or (1|{N}, ..., X, {D})" + raise ValueError(msg) + return grid + + +def grid_sample( + data: paddle.Tensor, + grid: paddle.Tensor, + mode: Optional[Union[Sampling, str]] = None, + padding: Optional[Union[PaddingMode, str, Scalar]] = None, + align_corners: bool = ALIGN_CORNERS, +) -> paddle.Tensor: + """Sample data at grid points. + + Args: + data: Image batch tensor of shape ``(1, C, ..., X)`` or ``(N, C, ..., X)``. + grid: Grid points tensor of shape ``(..., X, D)``, ``(1, ..., X, D)``, or ``(N, ..., X, D)``. + Coordinates of points at which to sample ``data`` must be with respect to ``Axes.CUBE``. + mode: Image interpolate mode. + padding: Image extrapolation mode or constant by which to pad input ``data``. + align_corners: Whether ``grid`` extrema ``(-1, 1)`` refer to the grid boundary + edges (``align_corners=False``) or corner points (``align_corners=True``). + + Returns: + Image batch tensor of sampled data with spatial shape determined by ``grid``, and batch + size ``N`` based on ``data.shape[0]`` or ``grid.shape[0]``, respectively. The data type + of the returned tensor is ``data.dtype`` if it is a floating point type or ``mode="nearest"``. + Otherwise, the output data type matches ``grid.dtype``, which must be a floating point type. + + """ + grid = check_sample_grid("grid_sample", data, grid) + if str(data.place) != str(grid.place): + raise ValueError( + "grid_sample() 'data' and 'grid' tensors must be on same device" + ) + N = tuple(grid.shape)[0] + D = tuple(grid.shape)[-1] + if tuple(data.shape)[0] != N: + data = data.expand(shape=[N, *tuple(data.shape)[1:]]) + mode = Sampling.from_arg(mode).grid_sample_mode(D) + if isinstance(padding, (PaddingMode, str)): + padding_mode = PaddingMode.from_arg(padding).grid_sample_mode(D) + padding_value = 0 + else: + padding_mode = "zeros" + padding_value = float(padding or 0) + out = data.astype(dtype=grid.dtype) + if padding_value != 0: + if out.data_ptr() == data.data_ptr(): + out = out.sub(padding_value) + else: + out = out.subtract_(y=paddle.to_tensor(padding_value)) + out = paddle.nn.functional.grid_sample( + x=out, + grid=grid, + mode=mode, + padding_mode=padding_mode, + align_corners=align_corners, + ) + if padding_mode == "zeros" and padding_value != 0: + out = out.add_(y=paddle.to_tensor(padding_value)) + if mode == "nearest" or data.is_floating_point(): + out = out.astype(dtype=data.dtype) + return out + + +def grid_sample_mask( + data: paddle.Tensor, + grid: paddle.Tensor, + threshold: float = 0, + align_corners: bool = ALIGN_CORNERS, +) -> paddle.Tensor: + """Sample binary mask at grid points. + + Args: + data: Image batch tensor of shape ``(N, 1, ..., X)``. + grid: Grid points tensor of shape ``(..., X, D)``, ``(1, ..., X, D)``, or``(N, ..., X, D)``. + Coordinates of points at which to sample ``data`` must be with respect to ``Axes.CUBE``. + threshold: Scalar value used to binarize input mask. Values above this threshold are assigned + value 1, and values below this threshold are assigned value 0. + align_corners: Whether ``grid`` extrema ``(-1, 1)`` refer to the grid boundary edges (``False``) + or corner points (``True``), respectively. + + Returns: + Batch tensor of sampled mask with spatial shape determined by ``grid``, and floating point values + in the closed interval ``[0, 1]`` as obtained by linear interpolation of the binarized input mask. + + """ + grid = check_sample_grid("grid_sample_mask", data, grid) + if tuple(data.shape)[1] != 1: + raise ValueError("grid_sample_mask() 'data' must have single channel") + mask = data if data.dtype == "bool" else data > threshold + mask = mask.to(dtype="float32", device=data.place) + return grid_sample( + mask, + grid, + mode=Sampling.LINEAR, + padding=PaddingMode.ZEROS, + align_corners=align_corners, + ) + + +def rand_sample( + data: Union[paddle.Tensor, Sequence[paddle.Tensor]], + num_samples: int, + mask: Optional[paddle.Tensor] = None, + replacement: bool = False, + generator: Optional[paddle.Generator] = None, +) -> Union[paddle.Tensor, Sequence[paddle.Tensor]]: + """Random sampling of voxels within an image. + + Args: + data: One or more image batch tensors with shape ``(N, C, ... X)`` to sample + from at the same random image grid positions, i.e., voxel indices. Note that + all input tensors must have the same number and size of spatial dimensions. + num_samples: Number of spatial samples to draw. + mask: Optional mask to use for spatially weighted sampling. + replacement: Whether to sample with or without replacement. + generator: Random number generator to use. + + Returns: + paddle.Tensor of shape ``(N, C, num_samples)`` with input ``data`` values at randomly + sampled spatial grid points. When ``data`` is a sequence of tensors, a list + of tensors with order matching the input data is returned. + + """ + input: Sequence[paddle.Tensor] + if isinstance(data, paddle.Tensor): + input = [data] + elif isinstance(data, Sequence) and all(isinstance(x, paddle.Tensor) for x in data): + input = data + else: + raise TypeError( + "rand_sample() 'data' must be paddle.Tensor or Sequence[paddle.Tensor]" + ) + if not input: + return [] + if any(x.ndim < 3 for x in input): + raise ValueError( + "rand_sample() 'data' must be one or more tensors of shape (N, C, ..., X)" + ) + shape = tuple(input[0].shape) + if any(tuple(x.shape)[2:] != shape[2:] for x in input): + raise ValueError( + "rand_sample() 'data' tensors must have identical spatial shape" + ) + numel = shape[2:].size + if not replacement and num_samples > numel: + raise ValueError( + "rand_sample() 'num_samples' is greater than number of spatial points" + ) + input = [x.flatten(start_axis=2) for x in input] + if mask is None: + if replacement: + index = paddle.randint( + low=0, high=numel, shape=(shape[0], num_samples), dtype="int64" + ) + else: + perm = paddle.empty(shape=numel, dtype="int64") + index = paddle.empty(shape=(shape[0], num_samples), dtype="int64") + for row in index: + paddle.assign(paddle.randperm(n=numel), output=perm) + start_16 = perm.shape[0] + 0 if 0 < 0 else 0 + paddle.assign( + paddle.slice(perm, [0], [start_16], [start_16 + num_samples]), + output=row, + ) + else: + if mask.ndim < 3: + raise ValueError( + "rand_sample() 'mask' must be tensor of shape (N, C, ..., X)" + ) + if tuple(mask.shape)[2:] != shape[2:]: + raise ValueError( + "rand_sample() 'mask' has different spatial shape than 'data'" + ) + if tuple(mask.shape)[1] != 1: + raise ValueError("rand_sample() 'mask' must be scalar image tensor") + mask = ( + mask.flatten(start_axis=2).squeeze(axis=1).expand(shape=[shape[0], numel]) + ) + index = multinomial( + mask, num_samples, replacement=replacement, generator=generator + ) + index = index.unsqueeze(axis=1).repeat(1, shape[1], 1) + out = [x.take_along_axis(axis=2, indices=index) for x in input] + if len(out) == 1 and isinstance(data, paddle.Tensor): + return out[0] + return out + + +def image_slice(data: paddle.Tensor, offset: Optional[int] = None) -> paddle.Tensor: + """Get slice from image tensor. + + Args: + data: Image data tensor of shape ``(N, C, ..., Y, X)``. + offset: Slice offset. If ``None``, use ``Z // 2``. This argument is + ignored when the input image is 2-dimensional. + + Returns: + View of image tensor slice with shape ``(N, C, Y, X)``. + + """ + if data.ndim == 4: + return data + if data.ndim < 4 or data.ndim > 5: + raise ValueError("image_slice() 'data' must be 4- or 5-dimensional") + if offset is None: + offset = tuple(data.shape)[2] // 2 + start_17 = data.shape[2] + offset if offset < 0 else offset + return paddle.slice(data, [2], [start_17], [start_17 + 1]).squeeze(axis=2) + + +def normalize_image( + data: paddle.Tensor, + mode: str = "unit", + min: Optional[float] = None, + max: Optional[float] = None, + inplace: bool = False, +) -> paddle.Tensor: + """Normalize image intensities in [min, max]. + + Args: + data: Input image data. + mode: How to normalize image values: + - ``center``: Linearly rescale to [-0.5, 0.5] + - ``unit``: Linearly rescale to [0, 1]. + - ``zscore``: Linearly rescale to zero mean and unit variance. + min: Minimum intensity at which to clamp input. + max: Maximum intensity at which to clamp input. + inplace: Whether to modify ``data`` in place. + + Returns: + Normalized image data. + + Raises: + TypeError: When ``inplace=True`` and ``data.dtype`` is not a floating point data type. + + """ + if inplace: + if not data.is_floating_point(): + raise AssertionError( + "normalize_image() 'data.dtype' must be float when inplace=True" + ) + + def add_fn(data: paddle.Tensor, a: float) -> paddle.Tensor: + return data.add_(y=paddle.to_tensor(a)) + + def sub_fn(data: paddle.Tensor, a: float) -> paddle.Tensor: + return data.subtract_(y=paddle.to_tensor(a)) + + def mul_fn(data: paddle.Tensor, a: float) -> paddle.Tensor: + return data.multiply_(y=paddle.to_tensor(a)) + + def clamp_fn(data: paddle.Tensor, a: float, b: float) -> paddle.Tensor: + return data.clip_(min=a, max=b) + + else: + data = data.astype(dtype="float32") + + def add_fn(data: paddle.Tensor, a: float) -> paddle.Tensor: + return data.add(a) + + def sub_fn(data: paddle.Tensor, a: float) -> paddle.Tensor: + return data.sub(a) + + def mul_fn(data: paddle.Tensor, a: float) -> paddle.Tensor: + return data.mul(a) + + def clamp_fn(data: paddle.Tensor, a: float, b: float) -> paddle.Tensor: + return data.clip(min=a, max=b) + + if mode in ("zscore", "z-score"): + data = clamp_fn(data, min, max) + stdev, mean = tuple( + [ + paddle.std(data, axis=None, unbiased=True, keepdim=False), + paddle.mean(data, axis=None, keepdim=False), + ] + ) + data = sub_fn(data, mean) + if stdev > 1e-15: + data = mul_fn(data, 1 / stdev) + elif mode in ("unit", "center"): + if min is None: + min = float(data.min()) + if max is None: + max = float(data.max()) + dif = max - min + mul = 1 if abs(dif) < 1e-09 else 1 / dif + add = -mul * min + if mode == "center": + add -= 0.5 + if mul != 1: + data = mul_fn(data, mul) + if add != 0: + data = add_fn(data, add) + if mode == "center": + data = clamp_fn(data, -0.5, 0.5) + else: + data = clamp_fn(data, 0, 1) + return data + + +def rescale( + data: paddle.Tensor, + min: Optional[Scalar] = None, + max: Optional[Scalar] = None, + data_min: Optional[Scalar] = None, + data_max: Optional[Scalar] = None, + dtype: Optional[paddle.dtype] = None, +) -> paddle.Tensor: + """Linearly rescale values to specified output interval. + + Args: + data: Input tensor. + min: Minimum value of output tensor. Use ``data_min`` if ``None``. + max: Maximum value of output tensor. Use ``data_max`` if ``None``. + data_min: Minimum value of input ``data``. Use ``data.min()`` if ``None``. + data_max: Maximum value of input ``data``. Use ``data.max()`` if ``None``. + dtype: Cast rescaled values to specified output data type. If ``None``, + use ``data.dtype`` if it is a floating point type, otherwise ``paddle.float``. + + Returns: + paddle.Tensor of same shape as ``data`` with specified ``dtype`` and values in closed interval ``[min, max]``. + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("rescale() 'data' must be paddle.Tensor") + if dtype is None: + dtype = data.dtype + if not is_float_dtype(dtype): + dtype = "float32" + if dtype.is_floating_point: + interim_dtype = dtype + elif data.dtype.is_floating_point: + interim_dtype = data.dtype + else: + interim_dtype = "float32" + data = data.astype(interim_dtype) + if data_min is None: + data_min = data.min() + data_min = float(data_min) + if data_max is None: + data_max = data.max() + data_max = float(data_max) + min = data_min if min is None else float(min) + max = data_max if max is None else float(max) + norm = data_max - data_min + if norm < 1e-15: + result = paddle.empty(shape=tuple(data.shape), dtype=data.dtype).fill_( + value=min + ) + else: + scale = (max - min) / norm + result = min + scale * (data - data_min) + if not dtype.is_floating_point: + result = result.round_() + result = result.clip_(min=min, max=max) + result = result.astype(dtype) + return result + + +def sample_image( + data: paddle.Tensor, + coords: paddle.Tensor, + mode: Optional[Union[Sampling, str]] = None, + padding: Optional[Union[PaddingMode, str, Scalar]] = None, + align_corners: bool = ALIGN_CORNERS, +) -> paddle.Tensor: + """Sample images at given points. + + This function samples a batch of images at spatial points. The ``coords`` tensor can be of any shape, + including ``(N, M, D)``, i.e., a batch of N point sets with cardianality M, and ``(N, ..., X, D)`` , + i.e., a (deformed) regular sampling grid (cf. ``grid_sample()``). + + Args: + data: Batch of images as tensor of shape ``(N, C, ..., X)``. If batch size is one, + but the batch size of ``coords`` is greater than one, this single image is sampled + at the different sets of points. + coords: Normalized coordinates of points given as tensor of shape ``(N, ..., D)`` + or ``(1, ..., D)``. If batch size is one, all images are sampled at the same points. + align_corners: Whether point coordinates are with respect to ``Axes.CUBE`` (False) + or ``Axes.CUBE_CORNERS`` (True). This option is in particular passed on to the + ``grid_sample()`` function used to sample the images at the given points. + + Returns: + Sampled image data as tensor of shape ``(N, C, *coords.shape[1:-1])``. + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("sample_image() 'data' must be of type paddle.Tensor") + data = data.as_subclass(paddle.Tensor) + if data.ndim < 4: + raise ValueError("sample_image() 'data' must be at least 4-dimensional tensor") + if not isinstance(coords, paddle.Tensor): + raise TypeError("sample_image() 'coords' must be of type paddle.Tensor") + if coords.ndim < 2: + raise ValueError( + "sample_image() 'coords' must be at least 2-dimensional tensor" + ) + G = tuple(data.shape)[0] + N = tuple(coords.shape)[0] if G == 1 else G + D = data.ndim - 2 + if tuple(coords.shape)[0] not in (1, N): + raise ValueError(f"sample_image() 'coords' must be batch of length 1 or {N}") + if tuple(coords.shape)[-1] != D: + raise ValueError( + f"sample_image() 'coords' must be tensor of {D}-dimensional points" + ) + x = coords.expand(shape=(N,) + tuple(coords.shape)[1:]) + data = data.expand(shape=(N,) + tuple(data.shape)[1:]) + grid = x.reshape((N,) + (1,) * (data.ndim - 3) + (-1, D)) + data = grid_sample( + data, grid, mode=mode, padding=padding, align_corners=align_corners + ) + return data.reshape(tuple(data.shape)[:2] + tuple(coords.shape)[1:-1]) + + +def spatial_derivatives( + data: paddle.Tensor, + mode: str = "central", + which: Optional[Union[str, Sequence[str]]] = None, + order: Optional[int] = None, + sigma: Optional[float] = None, + spacing: Optional[Union[Scalar, Array]] = None, +) -> Dict[str, paddle.Tensor]: + """Calculate spatial image derivatives. + + Args: + data: Image data tensor of shape ``(N, C, ..., X)``. + mode: Method to use for approximating spatial image derivative. + If ``forward``, ``backward``, or ``central``, the respective finite difference + scheme is used to approximate the image derivative, optionally after smoothing + the input image with a Gaussian kernel. If ``gaussian``, the image derivative + is computed by convolving the image with a derivative of Gaussian kernel. + If ``None``, a central difference scheme is used by default. + which: String codes of spatial deriviatives to compute. See ``SpatialDerivativeKeys``. + order: Order of spatial derivative. If zero, the input ``data`` is returned. + sigma: Standard deviation of Gaussian kernel in grid units. If ``None`` or zero, + no Gaussian smoothing is used for calculation of finite differences, and a + default standard deviation of 0.4 is used when ``mode="gaussian"``. + spacing: Physical spacing between image grid points, e.g., ``(sx, sy, sz)``. + When a scalar is given, the same spacing is used for each image and spatial dimension. + If a sequence is given, it must be of length equal to the number of spatial dimensions ``D``, + and specify a separate spacing for each dimension in the order ``(x, ...)``. In order to + specify a different spacing for each image in the input ``data`` batch, a 2-dimensional + tensor must be given, where the size of the first dimension is equal to ``N``. The second + dimension can have either size 1 for an isotropic spacing, or ``D`` in case of an + anisotropic grid spacing. + + Returns: + Mapping from spatial derivative keys to corresponding tensors of the respective spatial + image derivatives of shape ``(N, C, ..., X)``. The keys are sequences of letters identifying + the spatial dimensions along which a derivative was taken (cf. ``SpatialDerivativeKeys``). + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("spatial_derivatives() 'data' must be paddle.Tensor") + if data.ndim < 4: + raise ValueError("spatial_derivatives() 'data' must have shape (N, C, ..., X)") + N = tuple(data.shape)[0] + D = data.ndim - 2 + if spacing is None: + spacing = paddle.ones(shape=(N, D), dtype="float32") + else: + spacing = as_tensor( + spacing, dtype="float32", device=str("cpu").replace("cuda", "gpu") + ) + spacing = paddle.atleast_1d(spacing) + if spacing.ndim == 1: + spacing = spacing.unsqueeze(axis=0) + if ( + spacing.ndim != 2 + or tuple(spacing.shape)[0] not in (1, N) + or tuple(spacing.shape)[1] not in (1, D) + ): + raise ValueError( + f"spatial_derivatives() 'spacing' must be scalar, {D}-dimensional vector, or 2-dimensional array of shape (1, {D}), ({N}, 1), or ({N}, {D})" + ) + spacing = spacing.expand(shape=[N, D]) + if isinstance(which, str): + which = (which,) + if which is None: + if order is None: + order = 1 + which = SpatialDerivativeKeys.all(ndim=D, order=order) + elif order is not None: + which = [arg for arg in which if len(arg) == order] + unique_keys = SpatialDerivativeKeys.unique(which) + max_order = SpatialDerivativeKeys.max_order(which) + derivs = {} + if not which: + return derivs + if not data.is_floating_point(): + data = data.astype("float32") + if mode is None: + mode = "central" + if mode in ("forward", "backward", "central", "prewitt", "sobel"): + if sigma and sigma > 0: + blur = gaussian1d(sigma, dtype="float32", device=data.place) + data = conv(data, blur, padding=PaddingMode.ZEROS) + if mode in ("prewitt", "sobel"): + avg_kernel = paddle.to_tensor( + data=[1, 1 if mode == "prewitt" else 2, 1], dtype=data.dtype + ) + avg_kernel /= avg_kernel.sum() + avg_kernel = avg_kernel.to(data.place) + fd_mode = "central" + else: + avg_kernel = None + fd_mode = mode + for i in range(max_order): + for code in unique_keys: + key = code[: i + 1] + if i < len(code) and key not in derivs: + sdim = SpatialDim.from_arg(code[i]) + result = data if i == 0 else derivs[code[:i]] + if avg_kernel is not None: + for d in (d for d in range(D) if d != sdim): + dim = SpatialDim(d).tensor_dim(result.ndim) + result = conv1d( + result, + avg_kernel, + dim=dim, + padding=len(avg_kernel) // 2, + ) + fd_spacing = spacing[:, (sdim)] + result = finite_differences( + result, sdim, mode=fd_mode, spacing=fd_spacing + ) + derivs[key] = result + derivs = {key: derivs[SpatialDerivativeKeys.sorted(key)] for key in which} + elif mode == "gaussian": + if not sigma: + sigma = 0.4 + kernel_0 = gaussian1d(sigma, normalize=False, dtype="float32") + kernel_1 = gaussian1d_I(sigma, normalize=False, dtype="float32") + norm = kernel_0.sum() + kernel_0 = kernel_0.divide_(y=paddle.to_tensor(norm)).to(data.place) + kernel_1 = kernel_1.divide_(y=paddle.to_tensor(norm)).to(data.place) + for i in range(max_order): + for code in unique_keys: + key = code[: i + 1] + if i < len(code) and key not in derivs: + sdim = SpatialDim.from_arg(code[i]) + result = data if i == 0 else derivs[code[:i]] + for d in range(D): + dim = SpatialDim(d).tensor_dim(result.ndim) + kernel = kernel_1 if sdim == d else kernel_0 + result = conv1d( + result, kernel, dim=dim, padding=len(kernel) // 2 + ) + derivs[key] = result + derivs = {key: derivs[SpatialDerivativeKeys.sorted(key)] for key in which} + else: + raise ValueError( + "spatial_derivatives() 'mode' must be 'forward', 'backward', 'central', or 'gaussian'" + ) + return derivs + + +def finite_differences( + data: paddle.Tensor, + sdim: SpatialDimArg, + mode: str = "central", + order: int = 1, + dilation: int = 1, + spacing: Union[float, Sequence[float]] = 1, +) -> paddle.Tensor: + """Calculate spatial image derivative using finite differences. + + Args: + data: Image data tensor of shape ``(N, C, ..., X)``. + sdim: Spatial dimension along which to compute spatial derivative. + mode: Finite differences to use for approximating spatial derivative. + order: Order of spatial derivative. If zero, the input ``data`` is returned. + dilation: Step size for finite differences. + spacing: Physical spacing between image grid points along dimension ``sdim``. + When a scalar is given, the same spacing is used for all images in the + input ``data`` batch. Otherwise, a separate spacing must be specified for + each image as sequence of float values. + + Returns: + paddle.Tensor of spatial derivative with respect to specified spatial dimension. + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("finite_differences() 'data' must be paddle.Tensor") + if data.ndim < 4: + raise ValueError("finite_differences() 'data' must have shape (N, C, ..., X)") + if not isinstance(order, int): + raise TypeError("finite_differences() 'order' must be int") + if order < 0: + raise ValueError("finite_differences() 'order' must be non-negative") + if not isinstance(dilation, int): + raise TypeError("finite_differences() 'dilation' must be int") + if dilation < 1: + raise ValueError("finite_differences() 'dilation' must be positive") + spatial_dim = SpatialDim.from_arg(sdim) + dim = spatial_dim.tensor_dim(data.ndim) + if mode not in ("forward", "backward", "central"): + raise ValueError( + "finite_differences() 'mode' must be 'forward', 'backward', or 'central'" + ) + if order == 0: + return data + if order == 1: + if mode == "forward": + i = slice(0, tuple(data.shape)[dim] - dilation, 1) + j = slice(dilation, tuple(data.shape)[dim], 1) + p = 0, dilation + elif mode == "backward": + i = slice(dilation, tuple(data.shape)[dim], 1) + j = slice(0, tuple(data.shape)[dim] - dilation, 1) + p = dilation, 0 + else: + i = slice(0, tuple(data.shape)[dim] - 2 * dilation, 1) + j = slice(2 * dilation, tuple(data.shape)[dim], 1) + p = dilation, dilation + else: + raise NotImplementedError(f"finite_differences(..., order={order})") + i = tuple( + i if d == dim else slice(0, n, 1) for d, n in enumerate(tuple(data.shape)) + ) + j = tuple( + j if d == dim else slice(0, n, 1) for d, n in enumerate(tuple(data.shape)) + ) + N = tuple(data.shape)[0] + data = data.astype(dtype="float32") + deriv = data[j].sub(data[i]) + denom: paddle.Tensor = paddle.atleast_1d( + as_tensor(spacing, dtype=data.dtype, device=data.place) + ) + if denom.ndim > 1 or tuple(denom.shape)[0] not in (1, N): + raise ValueError( + f"finite_differences() 'spacing' must be scalar or sequence of length {N}" + ) + denom = denom.mul((2 if mode == "central" else 1) * dilation) + denom = denom.reshape((tuple(denom.shape)[0],) + (1,) * (deriv.ndim - 1)) + deriv = deriv.div(denom) + pad = [(p if d == spatial_dim else (0, 0)) for d in range(data.ndim - 2)] + pad = [n for v in pad for n in v] + return paddle_aux._FUNCTIONAL_PAD(pad=pad, mode="constant", value=0, x=deriv) + + +def _image_size( + fn_name: str, + size: Optional[Union[int, Size, Grid]] = None, + shape: Optional[Shape] = None, + ndim: Optional[int] = None, +) -> list: + """Parse 'size' and/or 'shape' argument of image creation function.""" + if size is None and shape is None: + raise AssertionError(f"{fn_name}() 'size' or 'shape' required") + if isinstance(size, Grid): + size = tuple(size.shape) + if size is not None and shape is not None and size != tuple(reversed(shape)): + raise AssertionError(f"{fn_name}() mismatch between 'size' and 'shape'") + if size is None: + if ndim and len(shape) != ndim: + raise ValueError(f"{fn_name}() 'shape' must be tuple of length {ndim}") + size = tuple(reversed(shape)) + elif isinstance(size, int): + size = size, size + elif ndim and len(size) != 2: + raise ValueError(f"{fn_name}() 'size' must be tuple of length {ndim}") + return tuple(size) + + +def circle_image( + size: Optional[Union[int, Size, Grid]] = None, + shape: Optional[Shape] = None, + num: Optional[int] = None, + center: Optional[Sequence[int]] = None, + radius: Optional[float] = None, + sigma: float = 0, + x_max: Optional[Union[float, Sequence[float]]] = None, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Synthetic image of a circle. + + Args: + size: Spatial size in the order ``(X, Y)``. + shape: Spatial size in the order ``(Y, X)``. + num: Number ``N`` of images in batch. + center: Coordinates of center pixel in the order ``(x, y)``. + radius: Radius of circle in pixel units. + sigma: Standard deviation of isotropic Gaussian blurring kernel in pixel units. + x_max: Maximum ``x`` pixel index at which to clamp image to zero. + This can be used to create partial circles such as a half circle. + dtype: Data type of output tensor. Use ``paddle.uint8`` if ``None``. + When the output data type is a floating point type, the output tensor + values are in the interval ``[0, 1]``. Otherwise, the output values are + in the interval ``[0, 255]``. + device: Device on which to create image tensor. + + Returns: + Image tensor of shape ``(N, 1, Y, X)``. + + """ + size = _image_size("circle_image", size, shape, ndim=2) + if center is None: + center = tuple((n - 1) / 2 for n in size) + center = tuple(float(x) for x in center) + grid = Grid(size=size) + if radius is None: + radius = max(0, min(center) - 1 - math.ceil(2 * sigma)) + _dtype = "float32" + _device = str("cpu").replace("cuda", "gpu") + c = as_tensor(center, dtype=_dtype, device=_device) + x = grid.coords(normalize=False, dtype=_dtype, device=_device) + x = x.reshape(num or 1, 1, *tuple(x.shape)) - c + data = paddle.linalg.norm(x=x, axis=-1) <= radius + if x_max: + x_threshold = as_tensor(x_max, dtype=_dtype, device=_device) + if x_threshold.ndim == 0: + data &= x[..., 0] <= x_threshold + else: + data &= (x <= x_threshold).astype("bool").all(axis=-1) + data = data.astype(_dtype) + if sigma > 0: + kernel = gaussian1d(sigma, dtype=data.dtype, device=_device) + data = conv(data, kernel / kernel.sum()) + if dtype is None: + dtype = "uint8" + if not dtype.is_floating_point: + data = 255 * data / data.max() + return data.to(dtype=dtype, device=device) + + +def cshape_image( + size: Optional[Union[int, Size, Grid]] = None, + shape: Optional[Shape] = None, + num: Optional[int] = None, + center: Optional[Sequence[float]] = None, + radius: Optional[float] = None, + width: Optional[float] = None, + sigma: float = 0, + x_max: Union[float, Sequence[float]] = 5, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Synthetic C-shaped image. + + Args: + size: Spatial size in the order ``(X, Y)``. + shape: Spatial size in the order ``(Y, X)``. + num: Number ``N`` of images in batch. + center: Coordinates of center pixel in the order ``(y, x)``. + radius: Radius of circle in pixel units. + sigma: Standard deviation of isotropic Gaussian blurring kernel in pixel units. + x_max: Maximum ``x`` pixel index at which to clamp image to zero. + This can be used to create partial circles such as a half circle. + dtype: Data type of output tensor. Use ``paddle.uint8`` if ``None``. + When the output data type is a floating point type, the output tensor + values are in the interval ``[0, 1]``. Otherwise, the output values are + in the interval ``[0, 255]``. + device: Device on which to create image tensor. + + Returns: + Image tensor of shape ``(N, 1, Y, X)``. + + """ + size = _image_size("cshape_image", size, shape, ndim=2) + if dtype is None: + dtype = "uint8" + if radius is None: + center = tuple(float(x) for x in center) + radius = max(0, min(center) - 1 - math.ceil(2 * sigma)) + if width is None: + width = radius // 2 + outer = circle_image(size, center=center, radius=radius, x_max=x_max, sigma=0) + inner = circle_image(size, center=center, radius=radius - width, sigma=0) + image = (outer - inner).astype("float32") + if sigma > 0: + kernel = gaussian1d(sigma, dtype="float32", device=image.place) + image = conv(image, kernel) + if dtype.is_floating_point: + image /= image.max() + image.clip_(min=0, max=1) + else: + image *= 255 / image.max() + image.clip_(min=0, max=255) + if num and num > 1: + image = image.expand(shape=(1,) + tuple(image.shape)[1:]) + return image.to(dtype=dtype, device=device) + + +def empty_image( + size: Optional[Union[int, Size, Grid]] = None, + shape: Optional[Shape] = None, + num: Optional[int] = None, + channels: Optional[int] = None, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Create new batch of uninitalized image data. + + Args: + size: Spatial size in the order ``(X, ...)``. + shape: Spatial size in the order ``(..., X)``. + num: Number of images in batch. + channels: Number of channels per image. + dtype: Data type of image tensor. + device: Device on which to store image data. + + Returns: + Uninitialized image batch tensor. + + """ + size = _image_size("empty_image", size, shape) + shape = (num or 1, channels or 1) + tuple(reversed(size)) + return paddle.empty(shape=shape, dtype=dtype) + + +def grid_image( + size: Optional[Union[int, Size, Grid]] = None, + shape: Optional[Shape] = None, + num: Optional[int] = None, + stride: Optional[Union[int, Sequence[int]]] = None, + inverted: bool = False, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Create batch of regularly spaced grid images. + + Args: + size: Spatial size in the order ``(X, ...)``. + shape: Spatial size in the order ``(..., X)``. + num: Number of images in batch. When ``shape`` is not a ``Grid``, must + match the size of the first dimension in ``shape`` if not ``None``. + stride: Spacing between grid lines. To draw in-plane grid lines on a + D-dimensional image where ``D>2``, specify a sequence of two stride + values, where the first stride applies to the last tensor dimension, + which corresponds to the first spatial grid dimension. + inverted: Whether to draw grid lines in black (0) over white (1) background. + dtype: Data type of image tensor. + device: Device on which to store image data. + + Returns: + Image tensor of shape ``(N, 1, ..., X)``. The default number of channels is 1. + + """ + size = _image_size("grid_image", size, shape) + data = empty_image(size, num=1, channels=1, dtype=dtype, device=device) + data.fill_(value=1 if inverted else 0) + if stride is None: + stride = 4 + if isinstance(stride, int): + stride = (stride,) * (data.ndim - 2) + if len(stride) > data.ndim - 2: + raise ValueError( + "grid_image() 'stride' length must not be greater than number of spatial dimensions" + ) + start = data.ndim - len(stride) + for dim, step in zip(range(start, data.ndim), reversed(stride)): + n = tuple(data.shape)[dim] + index = paddle.arange(start=n % step // 2, end=n, step=step, dtype="int64") + data.index_fill_(axis=dim, index=index, value=0 if inverted else 1) + return data.expand(shape=[num or 1, *tuple(data.shape)[1:]]) + + +def ones_image( + size: Optional[Union[int, Size, Grid]] = None, + shape: Optional[Shape] = None, + num: Optional[int] = None, + channels: Optional[int] = None, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Create new batch of image data filled with ones. + + Args: + size: Spatial size in the order ``(X, ...)``. + shape: Spatial size in the order ``(..., X)``. + num: Number of images in batch. + channels: Number of channels per image. + dtype: Data type of image tensor. + device: Device on which to store image data. + + Returns: + Image batch tensor filled with ones. + + """ + size = _image_size("ones_image", size, shape) + data = empty_image(size, num=num, channels=channels, dtype=dtype, device=device) + return data.fill_(value=1) + + +def zeros_image( + size: Optional[Union[int, Size, Grid]] = None, + shape: Optional[Shape] = None, + num: Optional[int] = None, + channels: Optional[int] = None, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Create new batch of image data filled with zeros. + + Args: + size: Spatial size in the order ``(X, ...)``. + shape: Spatial size in the order ``(..., X)``. + num: Number of images in batch. + channels: Number of channels per image. + dtype: Data type of image tensor. + device: Device on which to store image data. + + Returns: + Image batch tensor filled with zeros. + + """ + size = _image_size("zeros_image", size, shape) + data = empty_image(size, num=num, channels=channels, dtype=dtype, device=device) + return data.fill_(value=0) diff --git a/jointContribution/HighResolution/deepali/core/itertools.py b/jointContribution/HighResolution/deepali/core/itertools.py new file mode 100644 index 0000000000..e7dd8b9635 --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/itertools.py @@ -0,0 +1,62 @@ +"""Custom itertools functions.""" +from itertools import repeat +from typing import Any +from typing import Iterable +from typing import Sequence +from typing import Union + + +def is_even_permutation(permutation: Sequence[int]) -> bool: + """Checks if given permutation is even. + + Example: + >>> is_even_permutation(range(10)) + True + >>> is_even_permutation(range(10)[::-1]) + False + + """ + if len(permutation) == 1: + return True + transitions_count = 0 + for index, element in enumerate(permutation): + for next_element in permutation[index + 1 :]: + if element > next_element: + transitions_count += 1 + return not transitions_count % 2 + + +def repeat_last(arg: Union[Any, Sequence[Any]], length: int) -> Sequence[Any]: + """Repeat last element in sequence to extend it to the specified length.""" + if not isinstance(arg, str) and not isinstance(arg, Sequence): + arg = (arg,) + if not arg: + raise ValueError("repeat_last() 'arg' must have at least one value to repeat") + if len(arg) > length: + raise ValueError( + "repeat_last() 'arg' sequence length must be at most '{length}'" + ) + arg = tuple(arg) + (arg[-1],) * (length - len(arg)) + return arg + + +def zip_longest_repeat_last(*args: Iterable): + iterators = [iter(it) for it in args] + num_active = len(iterators) + if not num_active: + return + prev = None + while True: + values = [] + for i, it in enumerate(iterators): + try: + value = next(it) + except StopIteration: + num_active -= 1 + if not num_active: + return + value = prev[i] + iterators[i] = repeat(value) + values.append(value) + prev = tuple(values) + yield prev diff --git a/jointContribution/HighResolution/deepali/core/kernels.py b/jointContribution/HighResolution/deepali/core/kernels.py new file mode 100644 index 0000000000..652dd1a26d --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/kernels.py @@ -0,0 +1,352 @@ +import math +from typing import Optional +from typing import Sequence +from typing import Union + +import paddle + +from .linalg import tensordot +from .tensor import as_tensor +from .tensor import cat_scalars +from .types import Array +from .types import Device +from .types import Scalar + + +def bspline1d(stride: int, order: int = 4) -> paddle.Tensor: + """B-spline kernel of given order for specified control point spacing. + + Implementation adopted from AirLab: + https://github.com/airlab-unibas/airlab/blob/80c9d487c012892c395d63c6d937a67303c321d1/airlab/utils/kernelFunction.py#L218 + + This function computes the kernel recursively by convolving with a box filter (cf. Cox-de Boor's recursion formula). + The resulting kernel differs from the analytic B-spline function. This may be due to the box filter having extend + to the borders of the pixels, where it should drop to zero at pixel centers rather. + + The exact B-spline kernel of order 4 is computed by ``cubic_bspline1d()``. + + Args: + stride: Spacing between control points with respect to original (upsampled) image grid. + order: Order of B-spline kernel, where the degree of the spline polynomials is order minus 1. + + Returns: + B-spline convolution kernel. + + """ + kernel = kernel_ones = paddle.ones(shape=[1, 1, stride], dtype="float32") + for _ in range(1, order + 1): + kernel = ( + paddle.nn.functional.conv1d( + x=kernel, weight=kernel_ones, padding=stride - 1 + ) + / stride + ) + return kernel.reshape(-1) + + +def cubic_bspline_value(x: float, derivative: int = 0) -> float: + """Evaluate 1-dimensional cubic B-spline.""" + t = abs(x) + if t >= 2: + return 0 + if derivative == 0: + if t < 1: + return 2 / 3 + (0.5 * t - 1) * t**2 + return -((t - 2) ** 3) / 6 + if derivative == 1: + if t < 1: + return (1.5 * t - 2.0) * x + if x < 0: + return 0.5 * (t - 2) ** 2 + return -0.5 * (t - 2) ** 2 + if derivative == 2: + if t < 1: + return 3 * t - 2 + return -t + 2 + + +def cubic_bspline( + stride: Union[int, Sequence[int]], + *args: int, + derivative: int = 0, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +): + """Get n-dimensional cubic B-spline kernel. + + Args: + stride: Spacing between control points with respect to original (upsampled) image grid. + derivative: Order of cubic B-spline derivative. + + Returns: + Cubic B-spline convolution kernel. + + """ + stride_ = cat_scalars( + stride, + *args, + derivative=derivative, + dtype="int32", + device=str("cpu").replace("cuda", "gpu"), + ).tolist() + D = len(stride_) + if D == 1: + return cubic_bspline1d( + stride_, derivative=derivative, dtype=dtype, device=device + ) + if D == 2: + return cubic_bspline2d( + stride_, derivative=derivative, dtype=dtype, device=device + ) + if D == 3: + return cubic_bspline3d( + stride_, derivative=derivative, dtype=dtype, device=device + ) + raise NotImplementedError(f"cubic_bspline() {D}-dimensional kernel") + + +def cubic_bspline1d( + stride: Union[int, Sequence[int]], + derivative: int = 0, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Cubic B-spline kernel for specified control point spacing. + + Args: + stride: Spacing between control points with respect to original (upsampled) image grid. + derivative: Order of cubic B-spline derivative. + + Returns: + Cubic B-spline convolution kernel. + + """ + if dtype is None: + dtype = "float32" + if not isinstance(stride, int): + (stride,) = stride + kernel = paddle.ones(shape=4 * stride - 1, dtype="float32") + radius = tuple(kernel.shape)[0] // 2 + for i in range(tuple(kernel.shape)[0]): + kernel[i] = cubic_bspline_value((i - radius) / stride, derivative=derivative) + if device is None: + device = kernel.place + return kernel.to(device) + + +def cubic_bspline2d( + stride: Union[int, Sequence[int]], + *args: int, + derivative: int = 0, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Cubic B-spline kernel for specified control point spacing. + + Args: + stride: Spacing between control points with respect to original (upsampled) image grid. + derivative: Order of cubic B-spline derivative. + + Returns: + Cubic B-spline convolution kernel. + + """ + if dtype is None: + dtype = "float32" + stride_ = cat_scalars( + stride, *args, num=2, dtype="int32", device=str("cpu").replace("cuda", "gpu") + ) + kernel = paddle.ones(shape=(4 * stride_ - 1).tolist(), dtype=dtype) + radius = [(n // 2) for n in tuple(kernel.shape)] + for j in range(tuple(kernel.shape)[1]): + w_j = cubic_bspline_value((j - radius[1]) / stride[1], derivative=derivative) + for i in range(tuple(kernel.shape)[0]): + w_i = cubic_bspline_value( + (i - radius[0]) / stride[0], derivative=derivative + ) + kernel[j, i] = w_i * w_j + if device is None: + device = kernel.place + return kernel.to(device) + + +def cubic_bspline3d( + stride: Union[int, Sequence[int]], + *args: int, + derivative: int = 0, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Cubic B-spline kernel for specified control point spacing. + + Args: + stride: Spacing between control points with respect to original (upsampled) image grid. + derivative: Order of cubic B-spline derivative. + + Returns: + Cubic B-spline convolution kernel. + + """ + if dtype is None: + dtype = "float32" + stride_ = cat_scalars( + stride, *args, num=3, dtype="int32", device=str("cpu").replace("cuda", "gpu") + ) + kernel = paddle.ones(shape=(4 * stride_ - 1).tolist(), dtype="float32") + radius = [(n // 2) for n in tuple(kernel.shape)] + for k in range(tuple(kernel.shape)[2]): + w_k = cubic_bspline_value((k - radius[2]) / stride[2], derivative=derivative) + for j in range(tuple(kernel.shape)[1]): + w_j = cubic_bspline_value( + (j - radius[1]) / stride[1], derivative=derivative + ) + for i in range(tuple(kernel.shape)[0]): + w_i = cubic_bspline_value( + (i - radius[0]) / stride[0], derivative=derivative + ) + kernel[k, j, i] = w_i * w_j * w_k + if device is None: + device = kernel.place + return kernel.to(device) + + +def gaussian_kernel_radius( + sigma: Union[Scalar, Array], factor: Scalar = 3 +) -> paddle.Tensor: + """Radius of truncated Gaussian kernel. + + Args: + sigma: Standard deviation in grid units. + factor: Number of standard deviations at which to truncate. + + Returns: + Radius of truncated Gaussian kernel in grid units. + + """ + sigma = as_tensor(sigma, dtype="float32", device="cpu") + is_scalar = sigma.ndim == 0 + if is_scalar: + sigma = sigma.unsqueeze(axis=0) + if sigma.ndim != 1: + raise ValueError("gaussian() 'sigma' must be scalar or sequence") + if tuple(sigma.shape)[0] == 0: + raise ValueError("gaussian() 'sigma' must be scalar or non-empty sequence") + if sigma.less_than(y=paddle.to_tensor(0.0)).astype("bool").any(): + raise ValueError("Gaussian standard deviation must be non-negative") + factor = as_tensor(factor, dtype=sigma.dtype, device=sigma.place) + radius = sigma.mul(factor).floor().astype("int64") + if is_scalar: + radius = radius + return radius + + +def gaussian( + sigma: Union[Scalar, Array], + *args: Scalar, + normalize: bool = True, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +): + """Get n-dimensional Gaussian kernel.""" + sigma_ = cat_scalars( + sigma, *args, dtype=dtype, device=str("cpu").replace("cuda", "gpu") + ) + if not paddle.is_floating_point(x=sigma_): + if dtype is not None: + raise TypeError("Gaussian kernel dtype must be floating point type") + sigma_ = sigma_.astype("float32") + if sigma_.ndim == 0: + sigma_ = sigma_.unsqueeze(axis=0) + if sigma_.ndim != 1: + raise ValueError("gaussian() 'sigma' must be scalar or sequence") + if tuple(sigma_.shape)[0] == 0: + raise ValueError("gaussian() 'sigma' must be scalar or non-empty sequence") + kernel = gaussian1d(sigma_[0], normalize=False, dtype="float64") + for std in sigma_[1:]: + other = gaussian1d(std, normalize=False, dtype="float64") + kernel = tensordot(kernel, other, dims=0) + if normalize: + kernel /= kernel.sum() + return kernel.to(dtype=sigma_.dtype, device=device) + + +def gaussian1d( + sigma: Scalar, + radius: Optional[Union[int, paddle.Tensor]] = None, + scale: Optional[Scalar] = None, + normalize: bool = True, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Get 1-dimensional Gaussian kernel.""" + sigma = as_tensor(sigma, device="cpu") + if sigma.ndim != 0: + raise ValueError("gaussian1d() 'sigma' must be scalar") + if sigma.less_than(y=paddle.to_tensor(0.0)): + raise ValueError("Gaussian standard deviation must be non-negative") + if dtype is not None and dtype not in ("float16", "float32", "float64"): + raise TypeError("Gaussian kernel dtype must be floating point type") + if radius is None: + radius = gaussian_kernel_radius(sigma) + radius = int(radius) + if radius > 0: + size = 2 * radius + 1 + x = paddle.linspace(start=-radius, stop=radius, num=size, dtype=dtype) + sigma = sigma.to(dtype=dtype, device=device) + kernel = paddle.exp(x=-0.5 * (x / sigma) ** 2) + if scale is None: + scale = 1 / sigma.mul(math.sqrt(2 * math.pi)) + else: + scale = as_tensor(scale, dtype=dtype, device=device) + kernel *= scale + if normalize: + kernel /= kernel.sum() + else: + if scale is None: + scale = 1 + kernel = as_tensor(scale, dtype=dtype, device=device) + return kernel + + +def gaussian1d_I( + sigma: Scalar, + normalize: bool = True, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Get 1st order derivative of 1-dimensional Gaussian kernel.""" + if paddle.is_tensor(x=sigma): + sigma_ = as_tensor(sigma) + if sigma_.ndim != 0: + raise ValueError("gaussian1d() 'sigma' must be scalar") + sigma = sigma_.item() + if sigma < 0: + raise ValueError("Gaussian standard deviation must be non-negative") + if dtype is not None and not dtype.is_floating_point: + raise TypeError("Gaussian kernel dtype must be floating point type") + radius = int(gaussian_kernel_radius(sigma).item()) + if radius > 0: + size = 2 * radius + 1 + x = paddle.linspace(start=-radius, stop=radius, num=size, dtype=dtype) + norm = paddle.to_tensor( + data=1 / (sigma * math.sqrt(2 * math.pi)), dtype=dtype, place=device + ) + var = sigma**2 + kernel = norm * paddle.exp(x=-0.5 * x**2 / var) * (x / var) + if normalize: + kernel /= (norm * paddle.exp(x=-0.5 * x**2 / var)).sum() + else: + kernel = paddle.to_tensor(data=[1], dtype=dtype, place=device) + return kernel + + +def gaussian2d(sigma: Union[Scalar, Array], *args: Scalar, **kwargs) -> paddle.Tensor: + """Get 2-dimensional Gaussian kernel.""" + sigma = cat_scalars(sigma, *args, num=2, device=str("cpu").replace("cuda", "gpu")) + return gaussian(sigma, **kwargs) + + +def gaussian3d(sigma: Union[Scalar, Array], *args: Scalar, **kwargs) -> paddle.Tensor: + """Get 3-dimensional Gaussian kernel.""" + sigma = cat_scalars(sigma, *args, num=3, device=str("cpu").replace("cuda", "gpu")) + return gaussian(sigma, **kwargs) diff --git a/jointContribution/HighResolution/deepali/core/linalg.py b/jointContribution/HighResolution/deepali/core/linalg.py new file mode 100644 index 0000000000..c2aecf8c9f --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/linalg.py @@ -0,0 +1,482 @@ +from enum import Enum +from functools import reduce +from operator import mul +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +import paddle + +from ._kornia import angle_axis_to_quaternion +from ._kornia import angle_axis_to_rotation_matrix +from ._kornia import normalize_quaternion +from ._kornia import quaternion_exp_to_log +from ._kornia import quaternion_log_to_exp +from ._kornia import quaternion_to_angle_axis +from ._kornia import quaternion_to_rotation_matrix +from ._kornia import rotation_matrix_to_angle_axis +from ._kornia import rotation_matrix_to_quaternion +from .tensor import as_tensor +from .types import Device + +__all__ = ( + "as_homogeneous_matrix", + "as_homogeneous_tensor", + "hmm", + "homogeneous_matmul", + "homogeneous_matrix", + "homogeneous_transform", + "tensordot", + "vectordot", + "vector_rotation", + "angle_axis_to_rotation_matrix", + "angle_axis_to_quaternion", + "rotation_matrix_to_angle_axis", + "rotation_matrix_to_quaternion", + "quaternion_to_angle_axis", + "quaternion_to_rotation_matrix", + "quaternion_log_to_exp", + "quaternion_exp_to_log", + "normalize_quaternion", +) + + +class HomogeneousTensorType(Enum): + """Type of homogeneous transformation tensor.""" + + AFFINE = "affine" + HOMOGENEOUS = "homogeneous" + TRANSLATION = "translation" + + +def as_homogeneous_tensor( + tensor: paddle.Tensor, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> Tuple[paddle.Tensor, HomogeneousTensorType]: + """Convert tensor to homogeneous coordinate transformation.""" + tensor_ = as_tensor(tensor, dtype=dtype, device=device) + if tensor_.ndim == 0: + raise ValueError("Expected at least 1-dimensional 'tensor'") + if tensor_.ndim == 1: + tensor_ = tensor_.unsqueeze(axis=1) + if tuple(tensor_.shape)[-1] == 1: + type_ = HomogeneousTensorType.TRANSLATION + elif tuple(tensor_.shape)[-1] == tuple(tensor_.shape)[-2]: + type_ = HomogeneousTensorType.AFFINE + elif tuple(tensor_.shape)[-1] == tuple(tensor_.shape)[-2] + 1: + type_ = HomogeneousTensorType.HOMOGENEOUS + else: + raise ValueError(f"Invalid homogeneous 'tensor' shape {tuple(tensor_.shape)}") + return tensor_, type_ + + +def as_homogeneous_matrix( + tensor: paddle.Tensor, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Convert tensor to homogeneous coordinate transformation matrix. + + Args: + tensor: paddle.Tensor of translations of shape ``(D,)`` or ``(..., D, 1)``, tensor of square affine + matrices of shape ``(..., D, D)``, or tensor of homogeneous transformation matrices of + shape ``(..., D, D + 1)``. + dtype: Data type of output matrix. If ``None``, use ``tensor.dtype`` or default. + device: Device on which to create matrix. If ``None``, use ``tensor.device`` or default. + + Returns: + Homogeneous coordinate transformation matrices of shape ``(..., D, D + 1)``. If ``tensor`` has already + shape ``(..., D, D + 1)``, a reference to this tensor is returned without making a copy, unless requested + ``dtype`` and ``device`` differ from ``tensor`` (cf. ``as_tensor()``). Use ``homogeneous_matrix()`` + if a copy of the input ``tensor`` should always be made. + + + """ + tensor_, type_ = as_homogeneous_tensor(tensor, dtype=dtype, device=device) + if type_ == HomogeneousTensorType.TRANSLATION: + A = paddle.eye(num_rows=tuple(tensor_.shape)[-2], dtype=tensor_.dtype) + tensor_ = paddle.concat(x=[A, tensor_], axis=-1) + elif type_ == HomogeneousTensorType.AFFINE: + t = paddle.to_tensor(data=0, dtype=tensor_.dtype, place=tensor_.place).expand( + shape=[*tuple(tensor_.shape)[:-1], 1] + ) + tensor_ = paddle.concat(x=[tensor_, t], axis=-1) + elif type_ != HomogeneousTensorType.HOMOGENEOUS: + raise ValueError( + "Expected 'tensor' to have shape (D,), (..., D, 1), (..., D, D) or (..., D, D + 1)" + ) + return tensor_ + + +def homogeneous_transform( + transform: paddle.Tensor, points: paddle.Tensor, vectors: bool = False +) -> paddle.Tensor: + """Transform points or vectors by given homogeneous transformations. + + The data type used for matrix-vector products, as well as the data type of + the resulting tensor, is by default set to ``points.dtype``. If ``points.dtype`` + is not a floating point data type, ``transforms.dtype`` is used instead. + + Args: + transform: paddle.Tensor of translations of shape ``(D,)``, ``(D, 1)`` or ``(N, D, 1)``, tensor of + affine transformation matrices of shape ``(D, D)`` or ``(N, D, D)``, or tensor of homogeneous + matrices of shape ``(D, D + 1)`` or ``(N, D, D + 1)``, respectively. When 3-dimensional + batch of transformation matrices is given, the size of leading dimension N must be 1 + for applying the same transformation to all points, or be equal to the leading dimension + of ``points``, otherwise. All points within a given batch dimension are transformed by + the matrix of matching leading index. If size of ``points`` batch dimension is one, + the size of the leading output batch dimension is equal to the number of transforms, + each applied to the same set of input points. + points: Either 1-dimensional tensor of single point coordinates, or multi-dimensional tensor + of shape ``(N, ..., D)``, where last dimension contains the spatial coordinates in the + order ``(x, y)`` (2D) or ``(x, y, z)`` (3D), respectively. + vectors: Whether ``points`` is tensor of vectors. If ``True``, only the affine + component of the given ``transforms`` is applied without any translation offset. + If ``transforms`` is a 2-dimensional tensor of translation offsets, a tensor sharing + the data memory of the input ``points`` is returned. + + Returns: + paddle.Tensor of transformed points/vectors with the same shape as the input ``points``, except + for the size of the leading batch dimension if the size of the input ``points`` batch dimension + is one, but the ``transform`` batch contains multiple transformations. + + """ + if transform.ndim == 0: + raise TypeError("homogeneous_transform() 'transform' must be non-scalar tensor") + if transform.ndim == 1: + transform = transform.unsqueeze(axis=1) + if transform.ndim == 2: + transform = transform.unsqueeze(axis=0) + N = tuple(transform.shape)[0] + D = tuple(transform.shape)[1] + if N < 1: + raise ValueError( + "homogeneous_transform() 'transform' size of leading dimension must not be zero" + ) + if ( + transform.ndim != 3 + or tuple(transform.shape)[2] != 1 + and (1 < tuple(transform.shape)[2] < D or tuple(transform.shape)[2] > D + 1) + ): + raise ValueError( + "homogeneous_transform() 'transform' must be tensor of shape" + + " (D,), (D, 1), (D, D), (D, D + 1) or (N, D, 1), (N, D, D), (N, D, D + 1)" + ) + if points.ndim == 0: + raise TypeError("'points' must be non-scalar tensor") + if ( + points.ndim == 1 + and len(points) != D + or points.ndim > 1 + and tuple(points.shape)[-1] != D + ): + raise ValueError( + "homogeneous_transform() 'points' number of spatial dimensions does not match 'transform'" + ) + if points.ndim == 1: + output_shape = (N,) + tuple(points.shape) if N > 1 else tuple(points.shape) + points = points.expand(shape=(N,) + tuple(points.shape)) + elif N == 1: + output_shape = tuple(points.shape) + elif tuple(points.shape)[0] == 1 or tuple(points.shape)[0] == N: + output_shape = (N,) + tuple(points.shape)[1:] + points = points.expand(shape=(N,) + tuple(points.shape)[1:]) + else: + raise ValueError( + "homogeneous_transform() expected size of leading dimension of 'transform' and 'points' to be either 1 or equal" + ) + points = points.reshape(N, -1, D) + if paddle.is_floating_point(x=points): + transform = transform.astype(points.dtype) + else: + points = points.astype(transform.dtype) + if tuple(transform.shape)[2] == 1: + if not vectors: + points = points + transform[..., 0].unsqueeze(axis=1) + else: + x = transform[:, :D, :D] + perm_7 = list(range(x.ndim)) + perm_7[1] = 2 + perm_7[2] = 1 + points = paddle.bmm(x=points, y=x.transpose(perm=perm_7)) + if not vectors and tuple(transform.shape)[2] == D + 1: + points += transform[..., D].unsqueeze(axis=1) + return points.reshape(output_shape) + + +def hmm(a: paddle.Tensor, b: paddle.Tensor) -> paddle.Tensor: + """Compose two homogeneous coordinate transformations. + + Args: + a: paddle.Tensor of second homogeneous transformation. + b: paddle.Tensor of first homogeneous transformation. + + Returns: + Composite homogeneous transformation given by a tensor of shape ``(..., D, D + 1)``. + + See also: + ``homogeneous_matmul()`` + + """ + c = homogeneous_matmul(a, b) + return as_homogeneous_matrix(c) + + +def homogeneous_matmul(*args: paddle.Tensor) -> paddle.Tensor: + """Compose homogeneous coordinate transformations. + + This function performs the equivalent of a matrix-matrix product for homogeneous coordinate transformations + given as either a translation vector (tensor of shape ``(D,)`` or ``(..., D, 1)``), a tensor of square affine + matrices of shape ``(..., D, D)``, or a tensor of homogeneous coordinate transformation matrices of shape + ``(..., D, D + 1)``. The size of leading dimensions must either match, or be all 1 for one of the input tensors. + In the latter case, the same homogeneous transformation is composed with each individual trannsformation of the + tensor with leading dimension size greater than 1. + + For example, if the shape of tensor ``a`` is either ``(D,)``, ``(D, 1)``, or ``(1, D, 1)``, and the shape of tensor + ``b`` is ``(N, D, D)``, the translation given by ``a`` is applied after each affine transformation given by each + matrix in grouped batch tensor ``b``, and the shape of the composite transformation tensor is ``(N, D, D + 1)``. + + Args: + args: Tensors of homogeneous coordinate transformations, where the transformation corresponding to the first + argument is applied last, and the transformation corresponding to the last argument is applied first. + + Returns: + Composite homogeneous transformation given by tensor of shape ``(..., D, 1)``, ``(..., D, D)``, or ``(..., D, D + 1)``, + respectively, where the shape of leading dimensions is determined by input tensors. + + """ + if not args: + raise ValueError("homogeneous_matmul() at least one argument is required") + a = args[0] + dtype = a.dtype + device = a.place + if dtype not in (paddle.float16, paddle.float32, paddle.float64): + # if not dtype.is_floating_point: + for b in args[1:]: + if b.is_floating_point(): + dtype = b.dtype + break + if not dtype.is_floating_point: + dtype = "float32" + a, a_type = as_homogeneous_tensor(a, dtype=dtype) + D = tuple(a.shape)[-2] + for b in args[1:]: + b, b_type = as_homogeneous_tensor(b, dtype=dtype) + if str(b.place) != str(device): + raise RuntimeError( + "homogeneous_matmul() tensors must be on the same 'device'" + ) + if tuple(b.shape)[-2] != D: + raise ValueError( + "homogeneous_matmul() tensors have mismatching number of spatial dimensions" + + f" ({tuple(a.shape)[-2]} != {tuple(b.shape)[-2]})" + ) + leading_shape = None + a_numel = len(tuple(a.shape)[:-2]) + b_numel = len(tuple(b.shape)[:-2]) + if a_numel > 1: + if b_numel > 1 and tuple(a.shape)[:-2] != tuple(b.shape)[:-2]: + raise ValueError( + "Expected homogeneous tensors to have matching leading dimensions:" + + f" {tuple(a.shape)[:-2]} != {tuple(b.shape)[:-2]}" + ) + if b.ndim > a.ndim: + raise ValueError( + "Homogeneous tensors have different number of leading dimensions" + ) + leading_shape = tuple(a.shape)[:-2] + b = b.expand(shape=leading_shape + tuple(b.shape)[-2:]) + elif b_numel > 1: + if a.ndim > b.ndim: + raise ValueError( + "Homogeneous tensors have different number of leading dimensions" + ) + leading_shape = tuple(b.shape)[:-2] + a = a.expand(shape=leading_shape + tuple(a.shape)[-2:]) + elif a.ndim > b.ndim: + leading_shape = tuple(a.shape)[:-2] + b = b.expand(shape=tuple(a.shape)[:-2] + tuple(b.shape)[-2:]) + else: + leading_shape = tuple(b.shape)[:-2] + a = a.expand(shape=tuple(b.shape)[:-2] + tuple(a.shape)[-2:]) + assert leading_shape is not None + a = a.reshape(-1, *tuple(a.shape)[-2:]) + b = b.reshape(-1, *tuple(b.shape)[-2:]) + c, c_type = None, None + if a_type == HomogeneousTensorType.TRANSLATION: + if b_type == HomogeneousTensorType.TRANSLATION: + c = a + b + c_type = HomogeneousTensorType.TRANSLATION + elif b_type == HomogeneousTensorType.AFFINE: + c = paddle.concat(x=[b, a], axis=-1) + c_type = HomogeneousTensorType.HOMOGENEOUS + elif b_type == HomogeneousTensorType.HOMOGENEOUS: + c = b.clone() + c[..., D] += a[(...), :, (0)] + c_type = HomogeneousTensorType.HOMOGENEOUS + elif a_type == HomogeneousTensorType.AFFINE: + if b_type == HomogeneousTensorType.TRANSLATION: + t = paddle.bmm(x=a, y=b) + c = paddle.concat(x=[a, t], axis=-1) + c_type = HomogeneousTensorType.HOMOGENEOUS + elif b_type == HomogeneousTensorType.AFFINE: + c = paddle.bmm(x=a, y=b) + c_type = HomogeneousTensorType.AFFINE + elif b_type == HomogeneousTensorType.HOMOGENEOUS: + A = paddle.bmm(x=a, y=b[(...), :D]) + t = paddle.bmm(x=a[(...), :D], y=b[(...), D:]) + c = paddle.concat(x=[A, t], axis=-1) + c_type = HomogeneousTensorType.HOMOGENEOUS + elif a_type == HomogeneousTensorType.HOMOGENEOUS: + if b_type == HomogeneousTensorType.TRANSLATION: + t = paddle.bmm(x=a[(...), :D], y=b) + c = a.clone() + c[..., D] += t[(...), :, (0)] + elif b_type == HomogeneousTensorType.AFFINE: + A = paddle.bmm(x=a[(...), :D], y=b) + t = a[(...), D:] + c = paddle.concat(x=[A, t], axis=-1) + elif b_type == HomogeneousTensorType.HOMOGENEOUS: + A = paddle.bmm(x=a[(...), :D], y=b[(...), :D]) + t = a[(...), D:] + paddle.bmm(x=a[(...), :D], y=b[(...), D:]) + c = paddle.concat(x=[A, t], axis=-1) + c_type = HomogeneousTensorType.HOMOGENEOUS + assert ( + c is not None + ), "as_homogeneous_tensor() returned invalid 'type' enumeration value" + assert c_type is not None + c = c.reshape(leading_shape + tuple(c.shape)[-2:]) + assert str(c.place) == str(device) + a, a_type = c, c_type + return a + + +def homogeneous_matrix( + tensor: paddle.Tensor, + offset: Optional[paddle.paddle.Tensor] = None, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Convert square matrix or vector to homogeneous coordinate transformation matrix. + + Args: + tensor: paddle.Tensor of translations of shape ``(D,)`` or ``(..., D, 1)``, tensor of square affine + matrices of shape ``(..., D, D)``, or tensor of homogeneous transformation matrices of + shape ``(..., D, D + 1)``. + offset: Translation offset to add to homogeneous transformations of shape ``(..., D)``. + If a scalar is given, this offset is used as translation along each spatial dimension. + dtype: Data type of output matrix. If ``None``, use ``offset.dtype``. + device: Device on which to create matrix. If ``None``, use ``offset.device``. + + Returns: + Homogeneous coordinate transformation matrices of shape (..., D, D + 1). Always makes a copy + of ``tensor`` even if it has already the shape of homogeneous coordinate transformation matrices. + + """ + matrix = as_homogeneous_matrix(tensor, dtype=dtype, device=device) + if matrix is tensor: + matrix = tensor.clone() + if offset is not None: + D = tuple(matrix.shape)[-2] + if offset.ndim == 0: + offset = offset.repeat(D) + if tuple(offset.shape)[-1] != D: + raise ValueError( + f"Expected homogeneous_matrix() 'offset' to be scalar or have last dimension of size {D}" + ) + matrix[..., D] += offset + return matrix + + +def tensordot( + a: paddle.Tensor, + b: paddle.Tensor, + dims: Union[int, Sequence[int], Tuple[Sequence[int], Sequence[int]]] = 2, +) -> paddle.Tensor: + """Implements ``numpy.tensordot()`` for ``paddle.Tensor``. + + Based on https://gist.github.com/deanmark/9aec75b7dc9fa71c93c4bc85c5438777. + + """ + if isinstance(dims, int): + axes_a = list(range(-dims, 0)) + axes_b = list(range(0, dims)) + else: + axes_a, axes_b = dims + if isinstance(axes_a, int): + axes_a = [axes_a] + na = 1 + else: + na = len(axes_a) + axes_a = list(axes_a) + if isinstance(axes_b, int): + axes_b = [axes_b] + nb = 1 + else: + nb = len(axes_b) + axes_b = list(axes_b) + a = as_tensor(a) + b = as_tensor(b) + as_ = tuple(a.shape) + nda = a.ndim + bs = tuple(b.shape) + ndb = b.ndim + equal = True + if na != nb: + equal = False + else: + for k in range(na): + if as_[axes_a[k]] != bs[axes_b[k]]: + equal = False + break + if axes_a[k] < 0: + axes_a[k] += nda + if axes_b[k] < 0: + axes_b[k] += ndb + if not equal: + raise ValueError("shape-mismatch for sum") + notin = [k for k in range(nda) if k not in axes_a] + newaxes_a = notin + axes_a + N2 = 1 + for axis in axes_a: + N2 *= as_[axis] + newshape_a = int(reduce(mul, [as_[ax] for ax in notin])), N2 + olda = [as_[axis] for axis in notin] + notin = [k for k in range(ndb) if k not in axes_b] + newaxes_b = axes_b + notin + N2 = 1 + for axis in axes_b: + N2 *= bs[axis] + newshape_b = N2, int(reduce(mul, [bs[ax] for ax in notin])) + oldb = [bs[axis] for axis in notin] + at = a.transpose(perm=newaxes_a).reshape(newshape_a) + bt = b.transpose(perm=newaxes_b).reshape(newshape_b) + res = at.matmul(y=bt) + return res.reshape(olda + oldb) + + +def vectordot( + a: paddle.Tensor, b: paddle.Tensor, w: Optional[paddle.Tensor] = None, dim: int = -1 +) -> paddle.Tensor: + """Inner product of vectors over specified input tensor dimension.""" + c = a.mul(b) + if w is not None: + c.mul(w) + return c.sum(axis=dim) + + +def vector_rotation(a: paddle.Tensor, b: paddle.Tensor) -> paddle.Tensor: + """Calculate rotation matrix which aligns two 3D vectors.""" + if not isinstance(a, paddle.Tensor) or not isinstance(b, paddle.Tensor): + raise TypeError("vector_rotation() 'a' and 'b' must be of type paddle.Tensor") + if tuple(a.shape) != tuple(b.shape): + raise ValueError("vector_rotation() 'a' and 'b' must have identical shape") + a = paddle.nn.functional.normalize(x=a, p=2, axis=-1) + b = paddle.nn.functional.normalize(x=b, p=2, axis=-1) + axis = a.cross(y=b, axis=-1) + norm: paddle.Tensor = axis.norm(p=2, axis=-1, keepdim=True) + angle_axis = axis.div(norm).mul(norm.asin()) + rotation_matrix = angle_axis_to_rotation_matrix(angle_axis) + return rotation_matrix diff --git a/jointContribution/HighResolution/deepali/core/math.py b/jointContribution/HighResolution/deepali/core/math.py new file mode 100644 index 0000000000..934761251f --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/math.py @@ -0,0 +1,90 @@ +from typing import Optional +from typing import Union + +import paddle + +from ..utils import paddle_aux +from .types import Scalar + + +def abspow(x: paddle.Tensor, exponent: Union[int, float]) -> paddle.Tensor: + """Compute ``abs(x)**exponent``.""" + if exponent == 1: + return x.abs() + return x.abs().pow(y=exponent) + + +def atanh(x: paddle.Tensor) -> paddle.Tensor: + """Inverse of tanh function. + + Args: + x: Function argument. + + Returns: + Inverse of tanh function, i.e., ``y`` where ``x = tanh(y)``. + + + """ + return paddle.log1p(x=2 * x / (1 - x)) / 2 + + +def max_difference(source: paddle.Tensor, target: paddle.Tensor) -> paddle.Tensor: + """Maximum possible intensity difference. + + Note that the two input images need not be sampled on the same grid. + + Args: + source: Source image. + target: Reference image. + + Returns: + Maximum possible intensity difference. + + """ + smin, smax = source.min(), source.max() + if target is source: + tmin, tmax = smin, smax + else: + tmin, tmax = target.min(), target.max() + return paddle_aux.max(paddle.abs(x=smax - tmin), paddle.abs(x=tmax - smin)) + + +def round_decimals( + tensor: paddle.Tensor, decimals: int = 0, out: Optional[paddle.Tensor] = None +) -> paddle.Tensor: + """Round tensor values to specified number of decimals.""" + if not decimals: + result = paddle.assign(paddle.round(tensor), output=out) + else: + scale = 10**decimals + if out is tensor: + tensor *= scale + else: + tensor = tensor * scale + result = paddle.assign(paddle.round(tensor), output=out) + result /= scale + return result + + +def threshold( + data: paddle.Tensor, min: Optional[Scalar], max: Optional[Scalar] = None +) -> paddle.Tensor: + """Get mask for given lower and upper thresholds. + + Args: + data: Input data tensor. + min: Lower threshold. If ``None``, use ``data.min()``. + max: Upper threshold. If ``None``, use ``data.max()``. + + Returns: + Boolean tensor with same shape as ``data``, where only elements with a value + greater than or equal ``min`` and less than or equal ``max`` are ``True``. + + """ + if min is None and max is None: + return paddle.ones_like(x=data, dtype="bool") + if min is None: + return data <= max + if max is None: + return data >= min + return (min <= data) & (data <= max) diff --git a/jointContribution/HighResolution/deepali/core/nnutils.py b/jointContribution/HighResolution/deepali/core/nnutils.py new file mode 100644 index 0000000000..36fbb2be27 --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/nnutils.py @@ -0,0 +1,526 @@ +from collections import namedtuple +from typing import Any +from typing import Iterable +from typing import Mapping +from typing import NamedTuple +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union +from typing import overload + +import paddle + +from .types import ScalarOrTuple + + +def get_namedtuple_item(self: NamedTuple, arg: Union[int, str]) -> Any: + if isinstance(arg, str): + return getattr(self, arg) + return self[arg] + + +def namedtuple_keys(self: NamedTuple) -> Iterable[str]: + return self._fields + + +def namedtuple_values(self: NamedTuple) -> Iterable[Any]: + return self + + +def namedtuple_items(self: NamedTuple) -> Iterable[Tuple[str, Any]]: + return zip(self._fields, self) + + +def as_immutable_container( + arg: Union[paddle.Tensor, Sequence, Mapping], recursive: bool = True +) -> Union[paddle.Tensor, tuple]: + """Convert mutable container such as dict or list to an immutable container type. + + For use with ``paddle.utils.tensorboard.SummaryWriter.add_graph`` when model output is list or dict. + See error message: "Encountering a dict at the output of the tracer might cause the trace to be incorrect, + this is only valid if the container structure does not change based on the module's inputs. Consider using + a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). + If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior." + + """ + if recursive: + if isinstance(arg, Mapping): + arg = {key: as_immutable_container(value) for key, value in arg.items()} + elif isinstance(arg, Sequence): + arg = [as_immutable_container(value) for value in arg] + if isinstance(arg, Mapping): + output_type = namedtuple("Dict", sorted(arg.keys())) + output_type.__getitem__ = get_namedtuple_item + output_type.keys = namedtuple_keys + output_type.values = namedtuple_values + output_type.items = namedtuple_items + return output_type(**arg) + if isinstance(arg, list): + return tuple(arg) + return arg + + +def conv_output_size( + in_size: ScalarOrTuple[int], + kernel_size: ScalarOrTuple[int], + stride: ScalarOrTuple[int] = 1, + padding: ScalarOrTuple[int] = 0, + dilation: ScalarOrTuple[int] = 1, +) -> ScalarOrTuple[int]: + """Calculate spatial size of output tensor after convolution.""" + device = str("cpu").replace("cuda", "gpu") + m: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=in_size, dtype="int32", place=device) + ) + k: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=kernel_size, dtype="int32", place=device) + ) + s: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=stride, dtype="int32", place=device) + ) + d: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=dilation, dtype="int32", place=device) + ) + if m.ndim != 1: + raise ValueError("conv_output_size() 'in_size' must be scalar or sequence") + ndim = tuple(m.shape)[0] + if ndim == 1 and tuple(k.shape)[0] > 1: + ndim = tuple(k.shape)[0] + for arg, name in zip([k, s, d], ["kernel_size", "stride", "dilation"]): + if arg.ndim != 1 or arg.shape[0] not in (1, ndim): + raise ValueError( + f"conv_output_size() {name!r} must be scalar or sequence of length {ndim}" + ) + k = k.expand(shape=ndim) + s = s.expand(shape=ndim) + d = d.expand(shape=ndim) + if padding == "valid": + padding = 0 + elif padding == "same": + if not s.equal(y=1).astype("bool").all(): + raise ValueError("conv_output_size() padding='same' requires stride=1") + padding = same_padding(kernel_size=kernel_size, dilation=dilation) + elif isinstance(padding, str): + raise ValueError( + "conv_output_size() 'padding' string must be 'valid' or 'same'" + ) + p: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=padding, dtype="int32", place=device) + ) + if p.ndim != 1 or tuple(p.shape)[0] not in (1, ndim): + raise ValueError( + f"conv_output_size() 'padding' must be scalar or sequence of length {ndim}" + ) + p = p.expand(shape=ndim) + n = ( + p.mul(2) + .add_(y=paddle.to_tensor(m)) + .subtract_(y=paddle.to_tensor(k.sub(1).multiply_(y=paddle.to_tensor(d)))) + .subtract_(y=paddle.to_tensor(1)) + .astype(dtype="float32") + .divide_(y=paddle.to_tensor(s)) + .add_(y=paddle.to_tensor(1)) + .floor_() + .astype(dtype="int64") + ) + if isinstance(in_size, int): + return n[0].item() + return tuple(n.tolist()) + + +def conv_transposed_output_size( + in_size: ScalarOrTuple[int], + kernel_size: ScalarOrTuple[int], + stride: ScalarOrTuple[int] = 1, + padding: ScalarOrTuple[int] = 0, + output_padding: ScalarOrTuple[int] = 0, + dilation: ScalarOrTuple[int] = 1, +) -> ScalarOrTuple[int]: + """Calculate spatial size of output tensor after transposed convolution.""" + device = str("cpu").replace("cuda", "gpu") + m: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=in_size, dtype="int32", place=device) + ) + k: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=kernel_size, dtype="int32", place=device) + ) + s: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=stride, dtype="int32", place=device) + ) + p: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=padding, dtype="int32", place=device) + ) + o: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=output_padding, dtype="int32", place=device) + ) + d: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=dilation, dtype="int32", place=device) + ) + if m.ndim != 1: + raise ValueError( + "conv_transposed_output_size() 'in_size' must be scalar or sequence" + ) + ndim = tuple(m.shape)[0] + if ndim == 1 and tuple(k.shape)[0] > 1: + ndim = tuple(k.shape)[0] + for arg, name in zip( + [k, s, p, o, d], + ["kernel_size", "stride", "padding", "output_padding", "dilation"], + ): + if arg.ndim != 1 or arg.shape[0] not in (1, ndim): + raise ValueError( + f"conv_transposed_output_size() {name!r} must be scalar or sequence of length {ndim}" + ) + k = k.expand(shape=ndim) + s = s.expand(shape=ndim) + p = p.expand(shape=ndim) + o = o.expand(shape=ndim) + d = d.expand(shape=ndim) + n = ( + m.sub(1) + .multiply_(y=paddle.to_tensor(s)) + .subtract_(y=paddle.to_tensor(p.mul(2))) + .add_(y=paddle.to_tensor(k.sub(1).multiply_(y=paddle.to_tensor(d)))) + .add_(y=paddle.to_tensor(o)) + .add_(y=paddle.to_tensor(1)) + ) + if isinstance(in_size, int): + return n.item() + return tuple(n.tolist()) + + +def pad_output_size( + in_size: ScalarOrTuple[int], padding: ScalarOrTuple[int] = 0 +) -> ScalarOrTuple[int]: + """Calculate spatial size of output tensor after padding.""" + device = str("cpu").replace("cuda", "gpu") + m: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=in_size, dtype="int32", place=device) + ) + p: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=padding, dtype="int32", place=device) + ) + if m.ndim != 1: + raise ValueError("pad_output_size() 'in_size' must be scalar or sequence") + ndim = tuple(m.shape)[0] + if ndim == 1 and tuple(p.shape)[0] > 1 and tuple(p.shape)[0] % 2: + ndim = tuple(p.shape)[0] // 2 + if p.ndim != 1 or tuple(p.shape)[0] not in (1, 2 * ndim): + raise ValueError( + f"pad_output_size() 'padding' must be scalar or sequence of length {2 * ndim}" + ) + p = p.expand(shape=2 * ndim) + n = p.reshape(ndim, 2).sum(axis=1).add(m) + if isinstance(in_size, int): + return n[0].item() + return tuple(n.tolist()) + + +def pool_output_size( + in_size: ScalarOrTuple[int], + kernel_size: ScalarOrTuple[int], + stride: ScalarOrTuple[int] = 1, + padding: ScalarOrTuple[int] = 0, + dilation: ScalarOrTuple[int] = 1, + ceil_mode: bool = False, +) -> ScalarOrTuple[int]: + """Calculate spatial size of output tensor after pooling.""" + device = str("cpu").replace("cuda", "gpu") + m: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=in_size, dtype="int32", place=device) + ) + k: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=kernel_size, dtype="int32", place=device) + ) + s: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=stride, dtype="int32", place=device) + ) + p: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=padding, dtype="int32", place=device) + ) + d: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=dilation, dtype="int32", place=device) + ) + if m.ndim != 1: + raise ValueError("pool_output_size() 'in_size' must be scalar or sequence") + ndim = tuple(m.shape)[0] + if ndim == 1 and tuple(k.shape)[0] > 1: + ndim = tuple(k.shape)[0] + for arg, name in zip( + [k, s, p, d], ["kernel_size", "stride", "padding", "dilation"] + ): + if arg.ndim != 1 or arg.shape[0] not in (1, ndim): + raise ValueError( + f"pool_output_size() {name!r} must be scalar or sequence of length {ndim}" + ) + k = k.expand(shape=ndim) + s = s.expand(shape=ndim) + p = p.expand(shape=ndim) + d = d.expand(shape=ndim) + n = ( + p.mul(2) + .add_(y=paddle.to_tensor(m)) + .subtract_(y=paddle.to_tensor(k.sub(1).multiply_(y=paddle.to_tensor(d)))) + .subtract_(y=paddle.to_tensor(1)) + .astype(dtype="float32") + .divide_(y=paddle.to_tensor(s)) + .add_(y=paddle.to_tensor(1)) + ) + n = n.ceil() if ceil_mode else n.floor() + n = n.astype(dtype="int64") + if isinstance(in_size, int): + return n[0].item() + return tuple(n.tolist()) + + +def unpool_output_size( + in_size: ScalarOrTuple[int], + kernel_size: ScalarOrTuple[int], + stride: ScalarOrTuple[int] = 1, + padding: ScalarOrTuple[int] = 0, +) -> ScalarOrTuple[int]: + """Calculate spatial size of output tensor after unpooling.""" + device = str("cpu").replace("cuda", "gpu") + m: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=in_size, dtype="int32", place=device) + ) + k: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=kernel_size, dtype="int32", place=device) + ) + s: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=stride, dtype="int32", place=device) + ) + p: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=padding, dtype="int32", place=device) + ) + if m.ndim != 1: + raise ValueError("unpool_output_size() 'in_size' must be scalar or sequence") + ndim = tuple(m.shape)[0] + if ndim == 1 and tuple(k.shape)[0] > 1: + ndim = tuple(k.shape)[0] + for arg, name in zip([k, s, p], ["kernel_size", "stride", "padding"]): + if arg.ndim != 1 or arg.shape[0] not in (1, ndim): + raise ValueError( + f"unpool_output_size() {name!r} must be scalar or sequence of length {ndim}" + ) + k = k.expand(shape=ndim) + s = s.expand(shape=ndim) + p = p.expand(shape=ndim) + n = ( + m.sub(1) + .multiply_(y=paddle.to_tensor(s)) + .subtract_(y=paddle.to_tensor(p.mul(2))) + .add(k) + ) + if isinstance(in_size, int): + return n[0].item() + return tuple(n.tolist()) + + +@overload +def same_padding(kernel_size: int, dilation: int = 1) -> int: + ... + + +@overload +def same_padding(kernel_size: Tuple[int, ...], dilation: int = 1) -> Tuple[int, ...]: + ... + + +@overload +def same_padding(kernel_size: int, dilation: Tuple[int, ...]) -> Tuple[int, ...]: + ... + + +@overload +def same_padding( + kernel_size: Tuple[int, ...], dilation: Tuple[int, ...] +) -> Tuple[int, ...]: + ... + + +def same_padding( + kernel_size: ScalarOrTuple[int], dilation: ScalarOrTuple[int] = 1 +) -> ScalarOrTuple[int]: + """Padding value needed to ensure convolution preserves input tensor shape. + + Return the padding value needed to ensure a convolution using the given kernel size produces an output of the same + shape as the input for a stride of 1, otherwise ensure a shape of the input divided by the stride rounded down. + + Raises: + NotImplementedError: When ``(kernel_size - 1) * dilation`` is an odd number. + + """ + device = str("cpu").replace("cuda", "gpu") + k: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=kernel_size, dtype="int64", place=device) + ) + d: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=dilation, dtype="int64", place=device) + ) + if k.ndim != 1: + raise ValueError("same_padding() 'kernel_size' must be scalar or sequence") + ndim = tuple(k.shape)[0] + if ndim == 1 and tuple(d.shape)[0] > 1: + ndim = tuple(d.shape)[0] + for arg, name in zip([k, d], ["kernel_size", "dilation"]): + if arg.ndim != 1 or arg.shape[0] not in (1, ndim): + raise ValueError( + f"same_padding() {name!r} must be scalar or sequence of length {ndim}" + ) + if k.sub(1).mul(d).mod(y=paddle.to_tensor(2)).equal(y=1).astype("bool").any(): + raise NotImplementedError( + f"Same padding not available for kernel_size={tuple(k.tolist())} and dilation={tuple(d.tolist())}." + ) + p = k.sub(1).div(2).mul(d).astype("int32") + if isinstance(kernel_size, int) and isinstance(dilation, int): + return p[0].item() + return tuple(p.tolist()) + + +@overload +def stride_minus_kernel_padding(kernel_size: int, stride: int) -> int: + ... + + +@overload +def stride_minus_kernel_padding( + kernel_size: Sequence[int], stride: int +) -> Tuple[int, ...]: + ... + + +@overload +def stride_minus_kernel_padding( + kernel_size: int, stride: Sequence[int] +) -> Tuple[int, ...]: + ... + + +def stride_minus_kernel_padding( + kernel_size: ScalarOrTuple[int], stride: ScalarOrTuple[int] +) -> ScalarOrTuple[int]: + device = str("cpu").replace("cuda", "gpu") + k: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=kernel_size, dtype="int32", place=device) + ) + s: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=stride, dtype="int32", place=device) + ) + if k.ndim != 1: + raise ValueError( + "stride_minus_kernel_padding() 'kernel_size' must be scalar or sequence" + ) + ndim = tuple(k.shape)[0] + if ndim == 1 and tuple(s.shape)[0] > 1: + ndim = tuple(s.shape)[0] + for arg, name in zip([k, s], ["kernel_size", "stride"]): + if arg.ndim != 1 or arg.shape[0] not in (1, ndim): + raise ValueError( + f"stride_minus_kernel_padding() {name!r} must be scalar or sequence of length {ndim}" + ) + assert ( + k.ndim == 1 + ), "stride_minus_kernel_padding() 'kernel_size' must be scalar or sequence" + assert ( + s.ndim == 1 + ), "stride_minus_kernel_padding() 'stride' must be scalar or sequence" + p = s.sub(k).astype("int32") + if isinstance(kernel_size, int) and isinstance(stride, int): + return p[0].item() + return tuple(p.tolist()) + + +def upsample_padding( + kernel_size: ScalarOrTuple[int], scale_factor: ScalarOrTuple[int] +) -> Tuple[int, ...]: + """Padding on both sides for transposed convolution.""" + device = str("cpu").replace("cuda", "gpu") + k: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=kernel_size, dtype="int32", place=device) + ) + s: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=scale_factor, dtype="int32", place=device) + ) + assert k.ndim == 1, "upsample_padding() 'kernel_size' must be scalar or sequence" + assert s.ndim == 1, "upsample_padding() 'scale_factor' must be scalar or sequence" + p = k.sub(s).add(1).div(2).astype("int32") + if p.less_than(y=paddle.to_tensor(0)).astype("bool").any(): + raise ValueError( + "upsample_padding() 'kernel_size' must be greater than or equal to 'scale_factor'" + ) + return tuple(p.tolist()) + + +def upsample_output_padding( + kernel_size: ScalarOrTuple[int], + scale_factor: ScalarOrTuple[int], + padding: ScalarOrTuple[int], +) -> Tuple[int, ...]: + """Output padding on one side for transposed convolution.""" + device = str("cpu").replace("cuda", "gpu") + k: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=kernel_size, dtype="int32", place=device) + ) + s: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=scale_factor, dtype="int32", place=device) + ) + p: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=padding, dtype="int32", place=device) + ) + assert ( + k.ndim == 1 + ), "upsample_output_padding() 'kernel_size' must be scalar or sequence" + assert ( + s.ndim == 1 + ), "upsample_output_padding() 'scale_factor' must be scalar or sequence" + assert p.ndim == 1, "upsample_output_padding() 'padding' must be scalar or sequence" + op = p.mul(2).sub(k).add(s).astype("int32") + if op.less_than(y=paddle.to_tensor(0)).astype("bool").any(): + raise ValueError( + "upsample_output_padding() 'output_padding' must be greater than or equal to zero" + ) + return tuple(op.tolist()) + + +def upsample_output_size( + in_size: ScalarOrTuple[int], + size: Optional[ScalarOrTuple[int]] = None, + scale_factor: Optional[ScalarOrTuple[float]] = None, +) -> ScalarOrTuple[int]: + """Calculate spatial size of output tensor after unpooling.""" + if size is not None and scale_factor is not None: + raise ValueError( + "upsample_output_size() 'size' and 'scale_factor' are mutually exclusive" + ) + device = str("cpu").replace("cuda", "gpu") + m: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=in_size, dtype="int32", place=device) + ) + if m.ndim != 1: + raise ValueError("upsample_output_size() 'in_size' must be scalar or sequence") + ndim = tuple(m.shape)[0] + if size is not None: + s: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=size, dtype="int32", place=device) + ) + if s.ndim != 1 or tuple(s.shape)[0] not in (1, ndim): + raise ValueError( + f"upsample_output_size() 'size' must be scalar or sequence of length {ndim}" + ) + n = s.expand(shape=ndim) + elif scale_factor is not None: + s: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=scale_factor, dtype="int32", place=device) + ) + if s.ndim != 1 or tuple(s.shape)[0] not in (1, ndim): + raise ValueError( + f"upsample_output_size() 'scale_factor' must be scalar or sequence of length {ndim}" + ) + n = m.astype(dtype="float32").mul(s).floor().astype(dtype="int64") + else: + n = m + if isinstance(in_size, int): + return n.item() + return tuple(n.tolist()) diff --git a/jointContribution/HighResolution/deepali/core/path.py b/jointContribution/HighResolution/deepali/core/path.py new file mode 100644 index 0000000000..262b589823 --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/path.py @@ -0,0 +1,135 @@ +import os +import re +import shutil +from contextlib import contextmanager +from pathlib import Path +from tempfile import mkdtemp +from tempfile import mkstemp +from typing import Generator +from typing import Optional +from typing import Union +from typing import overload + +from .types import PathStr + + +@overload +def abspath(path: str, parent: Optional[PathStr] = None) -> str: + ... + + +@overload +def abspath(path: Path, parent: Optional[PathStr] = None) -> Path: + ... + + +def abspath(path: PathStr, parent: Optional[PathStr] = None) -> PathStr: + """Make path absolute.""" + path_type = type(path) + if path_type not in (Path, str): + raise TypeError("abspath() 'path' must be pathlib.Path or str") + path = (Path(parent or ".") / Path(path)).absolute() + if path_type is str: + path = path.as_posix() + return path + + +@overload +def abspath_template(path: str, parent: Optional[Path] = None) -> str: + ... + + +@overload +def abspath_template(path: Path, parent: Optional[Path] = None) -> str: + ... + + +def abspath_template(path: PathStr, parent: Optional[Path] = None) -> PathStr: + """Make path format string absolute.""" + if not isinstance(path, (Path, str)): + raise TypeError("abspath_template() 'path' must be pathlib.Path or str") + if str(path).startswith("{"): + return path + return abspath(path, parent=parent) + + +def delete(path: PathStr) -> bool: + """Remove file or (non-empty) directory.""" + try: + shutil.rmtree(path) + except NotADirectoryError: + os.remove(path) + + +def filename_suffix(path: PathStr) -> str: + """Get filename suffix, including .gz, .bz2 if present.""" + m = re.search("(\\.[a-zA-Z0-9]+(\\.gz|\\.GZ|\\.bz2|\\.BZ2)?)$", str(path)) + return m.group(1) if m else "" + + +def make_parent_dir(path: PathStr, parents: bool = True, exist_ok: bool = True) -> Path: + """Make parent directory of file path.""" + parent = Path(path).absolute().parent + parent.mkdir(parents=parents, exist_ok=exist_ok) + return parent + + +def make_temp_file( + suffix: Optional[str] = None, + prefix: Optional[str] = None, + dir: Optional[PathStr] = None, + text: bool = False, +) -> Path: + """Make temporary file with mkstemp, but close open file handle immediately.""" + fp, path = mkstemp(suffix=suffix, prefix=prefix, dir=dir, text=text) + os.close(fp) + return Path(path) + + +@contextmanager +def temp_dir( + suffix: str = None, prefix: str = None, dir: Union[Path, str] = None +) -> Generator[Path, None, None]: + """Create temporary directory within context.""" + path = mkdtemp(suffix=suffix, prefix=prefix, dir=dir) + try: + yield Path(path) + finally: + shutil.rmtree(path) + + +@contextmanager +def temp_file( + suffix: Optional[str] = None, + prefix: Optional[str] = None, + dir: PathStr = None, + text: bool = False, +) -> Generator[Path, None, None]: + """Create temporary file within context.""" + path = make_temp_file(suffix=suffix, prefix=prefix, dir=dir, text=text) + try: + yield path + finally: + os.remove(path) + + +def unlink_or_mkdir(path: PathStr) -> Path: + """Unlink existing file or make parent directory if non-existent. + + This function is useful when a script output file is managed by `DVC ` using + protected symbolic or hard links. Call this function before writing the new output file. It will + remove any existing output file, and ensure that the output directory exists. + + Args: + path: File path. + + Returns: + Absolute file path. + + """ + path = Path(path).absolute() + try: + path.unlink() + except FileNotFoundError: + path.parent.mkdir(parents=True, exist_ok=True) + return path diff --git a/jointContribution/HighResolution/deepali/core/pointset.py b/jointContribution/HighResolution/deepali/core/pointset.py new file mode 100644 index 0000000000..e7bf7a6bec --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/pointset.py @@ -0,0 +1,360 @@ +from typing import Optional +from typing import Tuple +from typing import Union + +import paddle + +from . import affine as A +from .flow import warp_grid +from .flow import warp_points +from .grid import ALIGN_CORNERS +from .tensor import move_dim + + +def bounding_box(points: paddle.Tensor) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Compute corners of minimum axes-aligned bounding box of given points.""" + return points.amin(axis=0), points.amax(axis=0) + + +def distance_matrix(x: paddle.Tensor, y: paddle.Tensor) -> paddle.Tensor: + """Compute squared Euclidean distances between all pairs of points. + + Args: + x: Point set of shape ``(N, X, D)``. + y: Point set of shape ``(N, Y, D)``. + + Returns: + paddle.Tensor of distance matrices of shape ``(N, X, Y)``. + + """ + if not isinstance(x, paddle.Tensor) or not isinstance(y, paddle.Tensor): + raise TypeError("distance_matrix() 'x' and 'y' must be paddle.Tensor") + if x.ndim != 3 or y.ndim != 3: + raise ValueError("distance_matrix() 'x' and 'y' must have shape (N, X, D)") + N, _, D = tuple(x.shape) + if tuple(y.shape)[0] != N: + raise ValueError("distance_matrix() 'x' and 'y' must have same batch size N") + if tuple(y.shape)[2] != D: + raise ValueError( + "distance_matrix() 'x' and 'y' must have same point dimension D" + ) + out_dtype = x.dtype + if not out_dtype.is_floating_point: + out_dtype = "float32" + x = x.astype("float64") + y = y.astype("float64") + x_norm = x.pow(y=2).sum(axis=2).view(N, -1, 1) + y_norm = y.pow(y=2).sum(axis=2).view(N, 1, -1) + x = y + perm_8 = list(range(x.ndim)) + perm_8[1] = 2 + perm_8[2] = 1 + dist = x_norm + y_norm - 2.0 * paddle.bmm(x=x, y=paddle.transpose(x=x, perm=perm_8)) + return dist.astype(out_dtype) + + +def closest_point_distances( + x: paddle.Tensor, y: paddle.Tensor, split_size: int = 10000 +) -> paddle.Tensor: + """Compute minimum Euclidean distance from each point in ``x`` to point set ``y``. + + Args: + x: Point set of shape ``(N, X, D)``. + y: Point set of shape ``(N, Y, D)``. + split_size: Maximum number of points in ``x`` to consider each time when computing + the full distance matrix between these points in ``x`` and every point in ``y``. + This is required to limit the size of the distance matrix. + + Returns: + paddle.Tensor of shape ``(N, X)`` with minimum distances from points in ``x`` to points in ``y``. + + """ + if not isinstance(x, paddle.Tensor) or not isinstance(y, paddle.Tensor): + raise TypeError("closest_point_distances() 'x' and 'y' must be paddle.Tensor") + if x.ndim != 3 or y.ndim != 3: + raise ValueError( + "closest_point_distances() 'x' and 'y' must have shape (N, X, D)" + ) + N, _, D = tuple(x.shape) + if tuple(y.shape)[0] != N: + raise ValueError( + "closest_point_distances() 'x' and 'y' must have same batch size N" + ) + if tuple(y.shape)[2] != D: + raise ValueError( + "closest_point_distances() 'x' and 'y' must have same point dimension D" + ) + x = x.astype(dtype="float32") + y = y.astype(x.dtype) + min_dists = paddle.empty(shape=tuple(x.shape)[0:2], dtype=x.dtype) + for i, points in enumerate(x.split(split_size, dim=1)): + dists = distance_matrix(points, y) + j = slice(i * split_size, i * split_size + tuple(points.shape)[1]) + min_dists[:, (j)] = ( + paddle.min(x=dists, axis=2), + paddle.argmin(x=dists, axis=2), + ).values + return min_dists + + +def closest_point_indices( + x: paddle.Tensor, y: paddle.Tensor, split_size: int = 10000 +) -> paddle.Tensor: + """Determine indices of points in ``y`` with minimum Euclidean distance from each point in ``x``. + + Args: + x: Point set of shape ``(N, X, D)``. + y: Point set of shape ``(N, Y, D)``. + split_size: Maximum number of points in ``x`` to consider each time when computing + the full distance matrix between these points in ``x`` and every point in ``y``. + This is required to limit the size of the distance matrix. + + Returns: + paddle.Tensor of shape ``(N, X)`` with indices of closest points in ``y``. + + """ + if not isinstance(x, paddle.Tensor) or not isinstance(y, paddle.Tensor): + raise TypeError("closest_point_indices() 'x' and 'y' must be paddle.Tensor") + if x.ndim != 3 or y.ndim != 3: + raise ValueError( + "closest_point_indices() 'x' and 'y' must have shape (N, X, D)" + ) + N, _, D = tuple(x.shape) + if tuple(y.shape)[0] != N: + raise ValueError( + "closest_point_indices() 'x' and 'y' must have same batch size N" + ) + if tuple(y.shape)[2] != D: + raise ValueError( + "closest_point_indices() 'x' and 'y' must have same point dimension D" + ) + x = x.astype(dtype="float32") + y = y.astype(x.dtype) + indices = paddle.empty(shape=tuple(x.shape)[0:2], dtype="int64") + for i, points in enumerate(x.split(split_size, dim=1)): + dists = distance_matrix(points, y) + j = slice(i * split_size, i * split_size + tuple(points.shape)[1]) + indices[:, (j)] = ( + paddle.min(x=dists, axis=2), + paddle.argmin(x=dists, axis=2), + ).indices + return indices + + +def normalize_grid( + grid: paddle.Tensor, + size: Optional[Union[paddle.Tensor, list]] = None, + side_length: float = 2, + align_corners: bool = ALIGN_CORNERS, + channels_last: bool = True, +) -> paddle.Tensor: + """Map unnormalized grid coordinates to normalized grid coordinates.""" + if not isinstance(grid, paddle.Tensor): + raise TypeError("normalize_grid() 'grid' must be tensors") + if not grid.is_floating_point(): + grid = grid.astype(dtype="float32") + if size is None: + if channels_last: + if grid.ndim < 4 or tuple(grid.shape)[-1] != grid.ndim - 2: + raise ValueError( + "normalize_grid() 'grid' must have shape (N, ..., X, D) when 'size' not given" + ) + size = tuple(reversed(tuple(grid.shape)[1:-1])) + else: + if grid.ndim < 4 or tuple(grid.shape)[1] != grid.ndim - 2: + raise ValueError( + "normalize_grid() 'grid' must have shape (N, D, ..., X) when 'size' not given" + ) + size = tuple(reversed(tuple(grid.shape)[2:])) + zero = paddle.to_tensor(data=0, dtype=grid.dtype, place=grid.place) + size = paddle.to_tensor(data=size, dtype=grid.dtype, place=grid.place) + size_ = size.sub(1) if align_corners else size + if not channels_last: + grid = move_dim(grid, 1, -1) + if side_length != 1: + grid = grid.mul(side_length) + grid = paddle.where(condition=size > 1, x=grid.div(size_).sub(1), y=zero) + if not channels_last: + grid = move_dim(grid, -1, 1) + return grid + + +def denormalize_grid( + grid: paddle.Tensor, + size: Optional[Union[paddle.Tensor, list]] = None, + side_length: float = 2, + align_corners: bool = ALIGN_CORNERS, + channels_last: bool = True, +) -> paddle.Tensor: + """Map normalized grid coordinates to unnormalized grid coordinates.""" + if not isinstance(grid, paddle.Tensor): + raise TypeError("denormalize_grid() 'grid' must be tensors") + if size is None: + if grid.ndim < 4 or tuple(grid.shape)[-1] != grid.ndim - 2: + raise ValueError( + "normalize_grid() 'grid' must have shape (N, ..., X, D) when 'size' not given" + ) + size = tuple(reversed(tuple(grid.shape)[1:-1])) + zero = paddle.to_tensor(data=0, dtype=grid.dtype, place=grid.place) + size = paddle.to_tensor(data=size, dtype=grid.dtype, place=grid.place) + size_ = size.sub(1) if align_corners else size + if not channels_last: + grid = move_dim(grid, 1, -1) + grid = paddle.where(condition=size > 1, x=grid.add(1).mul(size_), y=zero) + if side_length != 1: + grid = grid.div(side_length) + if not channels_last: + grid = move_dim(grid, -1, 1) + return grid + + +def polyline_directions( + points: paddle.Tensor, normalize: bool = False, repeat_last: bool = True +) -> paddle.Tensor: + """Compute proximal to distal facing tangent vectors.""" + if not isinstance(points, paddle.Tensor): + raise TypeError("polyline_directions() 'points' must be paddle.Tensor") + if points.ndim < 2: + raise ValueError("polyline_directions() 'points' must have shape (..., N, 3)") + dim = points.ndim - 2 + n = tuple(points.shape)[dim] + start_10 = points.shape[dim] + 1 if 1 < 0 else 1 + start_11 = points.shape[dim] + 0 if 0 < 0 else 0 + d = paddle.slice(points, [dim], [start_10], [start_10 + n - 1]).sub( + paddle.slice(points, [dim], [start_11], [start_11 + (n - 1)]) + ) + if normalize: + d = paddle.nn.functional.normalize(x=d, p=2, axis=dim) + if repeat_last: + start_12 = d.shape[dim] + (n - 2) if n - 2 < 0 else n - 2 + d = paddle.concat( + x=[d, paddle.slice(d, [dim], [start_12], [start_12 + 1])], axis=dim + ) + return d + + +def polyline_tangents( + points: paddle.Tensor, normalize: bool = False, repeat_first: bool = True +) -> paddle.Tensor: + """Compute distal to proximal facing tangent vectors.""" + if not isinstance(points, paddle.Tensor): + raise TypeError("polyline_tangents() 'points' must be paddle.Tensor") + if points.ndim < 2: + raise ValueError("polyline_tangents() 'points' must have shape (..., N, 3)") + dim = points.ndim - 2 + n = tuple(points.shape)[dim] + start_13 = points.shape[dim] + 0 if 0 < 0 else 0 + start_14 = points.shape[dim] + 1 if 1 < 0 else 1 + d = paddle.slice(points, [dim], [start_13], [start_13 + n - 1]).sub( + paddle.slice(points, [dim], [start_14], [start_14 + (n - 1)]) + ) + if normalize: + d = paddle.nn.functional.normalize(x=d, p=2, axis=dim) + if repeat_first: + start_15 = d.shape[dim] + 0 if 0 < 0 else 0 + d = paddle.concat( + x=[paddle.slice(d, [dim], [start_15], [start_15 + 1]), d], axis=dim + ) + return d + + +def transform_grid( + transform: paddle.Tensor, grid: paddle.Tensor, align_corners: bool = ALIGN_CORNERS +) -> paddle.Tensor: + """Transform undeformed grid by a spatial transformation. + + This function applies a spatial transformation to map a tensor of undeformed grid points to a + tensor of deformed grid points with the same shape as the input tensor. The input points must be + the positions of undeformed spatial grid points, because in case of a non-rigid transformation, + this function uses interpolation to resize the vector fields to the size of the input ``grid``. + This assumes that input points ``x`` are the coordinates of points located on a regularly spaced + undeformed grid which is aligned with the borders of the grid domain on which the vector fields + of the non-rigid transformations are sampled, i.e., ``y = x + u``. + + In case of a linear transformation ``y = Ax + t``. + + If in doubt whether the input points will be sampled regularly at grid points of the domain of + the spatial transformation, use ``transform_points()`` instead. + + Args: + transform: paddle.Tensor representation of spatial transformation, where the shape of the tensor + determines the type of transformation. A translation-only transformation must be given + as tensor of shape ``(N, D, 1)``. An affine-only transformation without translation can + be given as tensor of shape ``(N, D, D)``, and an affine transformation with translation + as tensor of shape ``(N, D, D + 1)``. Flow fields of non-rigid transformations, on the + other hand, are tensors of shape ``(N, D, ..., X)``, i.e., linear transformations are + represented by 3-dimensional tensors, and non-rigid transformations by tensors of at least + 4 dimensions. If batch size is one, but the batch size of ``points`` is greater than one, + all point sets are transformed by the same non-rigid transformation. + grid: Coordinates of undeformed grid points as tensor of shape ``(N, ..., D)`` or ``(1, ..., D)``. + If batch size is one, but multiple flow fields are given, this single point set is + transformed by each non-rigid transformation to produce ``N`` output point sets. + align_corners: Whether flow vectors in case of a non-rigid transformation are with respect to + ``Axes.CUBE`` (False) or ``Axes.CUBE_CORNERS`` (True). The input ``grid`` points must be + with respect to the same spatial grid domain as the input flow fields. This option is in + particular passed on to the ``grid_reshape()`` function used to resize the flow fields to + the shape of the input grid. + + Returns: + paddle.Tensor of shape ``(N, ..., D)`` with coordinates of spatially transformed points. + + """ + if not isinstance(transform, paddle.Tensor): + raise TypeError("transform_grid() 'transform' must be paddle.Tensor") + if transform.ndim < 3: + raise ValueError( + "transform_grid() 'transform' must be at least 3-dimensional tensor" + ) + if transform.ndim == 3: + return A.transform_points(transform, grid) + return warp_grid(transform, grid, align_corners=align_corners) + + +def transform_points( + transform: paddle.Tensor, points: paddle.Tensor, align_corners: bool = ALIGN_CORNERS +) -> paddle.Tensor: + """Transform set of points by a tensor of non-rigid flow fields. + + This function applies a spatial transformation to map a tensor of points to a tensor of transformed + points of the same shape as the input tensor. Unlike ``transform_grid()``, it can be used to spatially + transform any set of points which are defined with respect to the grid domain of the spatial transformation, + including a tensor of shape ``(N, M, D)``, i.e., a batch of N point sets with cardianality M. It can also + be applied to a tensor of grid points of shape ``(N, ..., X, D)`` regardless if the grid points are located + at the undeformed grid positions or an already deformed grid. Therefore, in case of a non-rigid transformation, + the given flow fields are sampled at the input points ``x`` using linear interpolation. The flow vectors ``u(x)`` + are then added to the input points, i.e., ``y = x + u(x)``. + + In case of a linear transformation ``y = Ax + t``. + + Args: + transform: paddle.Tensor representation of spatial transformation, where the shape of the tensor + determines the type of transformation. A translation-only transformation must be given + as tensor of shape ``(N, D, 1)``. An affine-only transformation without translation can + be given as tensor of shape ``(N, D, D)``, and an affine transformation with translation + as tensor of shape ``(N, D, D + 1)``. Flow fields of non-rigid transformations, on the + other hand, are tensors of shape ``(N, D, ..., X)``, i.e., linear transformations are + represented by 3-dimensional tensors, and non-rigid transformations by tensors of at least + 4 dimensions. If batch size is one, but the batch size of ``points`` is greater than one, + all point sets are transformed by the same non-rigid transformation. + points: Coordinates of points given as tensor of shape ``(N, ..., D)`` or ``(1, ..., D)``. + If batch size is one, but multiple flow fields are given, this single point set is + transformed by each non-rigid transformation to produce ``N`` output point sets. + align_corners: Whether flow vectors in case of a non-rigid transformation are with respect to + ``Axes.CUBE`` (False) or ``Axes.CUBE_CORNERS`` (True). The input ``points`` must be + with respect to the same spatial grid domain as the input flow fields. This option is in + particular passed on to the ``grid_sample()`` function used to sample the flow vectors at + the input points. + + Returns: + paddle.Tensor of shape ``(N, ..., D)`` with coordinates of spatially transformed points. + + """ + if not isinstance(transform, paddle.Tensor): + raise TypeError("transform_points() 'transform' must be paddle.Tensor") + if transform.ndim < 3: + raise ValueError( + "transform_points() 'transform' must be at least 3-dimensional tensor" + ) + if transform.ndim == 3: + return A.transform_points(transform, points) + return warp_points(transform, points, align_corners=align_corners) diff --git a/jointContribution/HighResolution/deepali/core/random.py b/jointContribution/HighResolution/deepali/core/random.py new file mode 100644 index 0000000000..2cf08a9d57 --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/random.py @@ -0,0 +1,114 @@ +from typing import Optional +from typing import Union + +import paddle +from pkg_resources import parse_version + +paddle.Generator = Union[ + paddle.framework.core.default_cuda_generator, + paddle.framework.core.default_cpu_generator, +] + + +def multinomial( + input: paddle.Tensor, + num_samples: int, + replacement: bool = False, + generator: Optional[paddle.Generator] = None, + out: Optional = None, +) -> paddle.int64: + """Sample from a multinomial probability distribution. + + Args: + input: Input vector of shape ``(N,)`` or matrix ``(M, N)``. + num_samples: Number of random samples to draw. + replacement: Whether to sample with or without replacement. + generator: Random number generator to use. + out: Pre-allocated output tensor. + + Returns: + Indices of random samples. When ``input`` is a vector, a vector of ``num_samples`` indices + is returned. Otherwise, a matrix of shape ``(M, num_samples)`` is returned. When ``out`` + is given, the returned tensor is a reference to ``out``. + + """ + if input.ndim == 0 or input.ndim > 2: + raise ValueError("multinomial() 'input' must be vector or matrix") + num_candidates = input.shape[-1] + if not replacement and num_candidates < num_samples: + raise ValueError( + "multinomial() 'num_samples' cannot be greater than number of categories" + ) + if num_candidates > 2**24: + impl = _multinomial + else: + impl = paddle.multinomial + input = input.astype(dtype="float32") + return impl( + input, num_samples, replacement=replacement, generator=generator, out=out + ) + + +def _multinomial( + input: paddle.Tensor, + num_samples: int, + replacement: bool = False, + generator: Optional[paddle.Generator] = None, + out: Optional = None, +) -> paddle.int64: + """Sample from a multinomial probability distribution. + + This function can be used for inputs of any size and is unlike ``paddle.multinomial`` not limited + to 2**24 categories at the expense of a less efficient implementation. + + Args: + input: Input vector of shape ``(N,)`` or matrix ``(M, N)``. + num_samples: Number of random samples to draw. + replacement: Whether to sample with or without replacement. + generator: Random number generator to use. + out: Pre-allocated output tensor. + + Returns: + Indices of random samples. When ``input`` is a vector, a vector of ``num_samples`` indices + is returned. Otherwise, a matrix of shape ``(M, num_samples)`` is returned. When ``out`` + is given, the returned tensor is a reference to ``out``. + + """ + if input.ndim == 0 or input.ndim > 2: + raise ValueError("_multinomial() 'input' must be vector or matrix") + num_candidates = input.shape[-1] + out_shape = tuple(input.shape)[:-1] + (num_samples,) + if out is not None: + if not isinstance(out, paddle.Tensor): + raise TypeError("_multinomial() 'out' must be paddle.Tensor") + if out.dtype != "int64": + raise TypeError("_multinomial() 'out' must be int64 tensor") + if tuple(out.shape) != out_shape: + raise ValueError(f"_multinomial() 'out' must have shape {out_shape}") + if replacement: + cdf = input.astype("float64").cumsum(axis=-1) + cdf = cdf.divide_(y=paddle.to_tensor(cdf[(...), -1:].clone())) + val = paddle.rand(shape=out_shape, dtype=cdf.dtype) + out = paddle.assign( + paddle.searchsorted(sorted_sequence=cdf, values=val), output=out + ).clip_(min=0, max=num_candidates - 1) + else: + if num_samples > num_candidates: + raise ValueError( + "_multinomial() 'num_samples' cannot be greater than number of categories" + ) + logit = input.log() + value = paddle.rand( + shape=tuple(input.shape)[:-1] + (num_candidates,), dtype=logit.dtype + ) + value = value.log_().neg_().log_().neg_().add_(y=paddle.to_tensor(logit)) + if parse_version(paddle.__version__) < parse_version("1.12"): + _, index = paddle.topk(k=num_samples, sorted=False, x=value, axis=-1) + out = index if out is None else paddle.assign(index, output=out) + else: + if out is None: + out = paddle.empty(shape=out_shape, dtype="int64") + _ = paddle.empty(shape=out_shape, dtype=value.dtype) + out1, out2 = paddle.topk(k=num_samples, sorted=False, x=value, axis=-1) + _, out = paddle.assign(out1, (_, out)[0]), paddle.assign(out2, (_, out)[1]) + return out diff --git a/jointContribution/HighResolution/deepali/core/tensor.py b/jointContribution/HighResolution/deepali/core/tensor.py new file mode 100644 index 0000000000..e3c9b87e70 --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/tensor.py @@ -0,0 +1,278 @@ +from logging import Logger +from typing import Callable +from typing import Optional +from typing import Tuple +from typing import Union + +import paddle + +from .types import Array +from .types import Device +from .types import Scalar + + +def as_tensor( + arg: Union[Scalar, Array], + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Create tensor from array if argument is not of type paddle.Tensor. + + Unlike ``paddle.as_tensor()``, this function preserves the tensor device if ``device=None``. + + """ + if isinstance(arg, paddle.Tensor): + return arg + if device is None and isinstance(arg, paddle.Tensor): + device = arg.place + return paddle.to_tensor(data=arg, dtype=dtype, place=device) + + +def as_float_tensor(arr: Array) -> paddle.Tensor: + """Create tensor with floating point type from argument if it is not yet.""" + arr_ = as_tensor(arr) + if not paddle.is_floating_point(x=arr_): + return arr_.astype("float32") + return arr_ + + +def as_one_hot_tensor( + tensor: paddle.Tensor, + num_classes: int, + ignore_index: Optional[int] = None, + dtype: Optional[paddle.dtype] = None, +) -> paddle.Tensor: + """Converts label image to one-hot encoding of multi-class segmentation. + + Adapted from: https://github.com/wolny/paddle-3dunet + + Args: + tensor: Input tensor of shape ``(N, 1, ..., X)`` or ``(N, C, ..., X)``. + When a tensor with ``C == num_classes`` is given, it is converted to the specified + ``dtype`` but not modified otherwise. Otherwise, the input tensor must contain + class labels in a single channel. + num_classes: Number of channels/labels. + ignore_index: Ignore index to be kept during the expansion. The locations of the index + value in the GT image is stored in the corresponding locations across all channels so + that this location can be ignored across all channels later e.g. in Dice computation. + This argument must be ``None`` if ``tensor`` has ``C == num_channels``. + dtype: Data type of output tensor. Default is ``paddle.float``. + + Returns: + Output tensor of shape ``(N, C, ..., X)``. + + """ + if dtype is None: + dtype = "float32" + if not isinstance(tensor, paddle.Tensor): + raise TypeError("as_one_hot_tensor() 'tensor' must be paddle.Tensor") + if tensor.dim() < 3: + raise ValueError("as_one_hot_tensor() 'tensor' must have shape (N, C, ..., X)") + if tuple(tensor.shape)[1] == num_classes: + return tensor.to(dtype=dtype) + elif tuple(tensor.shape)[1] != 1: + raise ValueError( + f"as_one_hot_tensor() 'tensor' must have shape (N, 1|{num_classes}, ..., X)" + ) + shape = list(tuple(tensor.shape)) + shape[1] = num_classes + if ignore_index is None: + return ( + paddle.zeros(shape=shape, dtype=dtype) + .to(tensor.place) + .put_along_axis_(axis=1, indices=tensor, values=1.0) + ) + mask = tensor.expand(shape=shape) == ignore_index + inputs = tensor.clone() + inputs[inputs == ignore_index] = 0 + result = ( + paddle.zeros(shape=shape, dtype=dtype) + .to(inputs.place) + .put_along_axis_(axis=1, indices=inputs, values=1.0) + ) + result[mask] = ignore_index + return result + + +def atleast_1d( + arr: Array, dtype: Optional[paddle.dtype] = None, device: Optional[Device] = None +) -> paddle.Tensor: + """Convert array-like argument to 1- or more-dimensional tensor.""" + arr_ = as_tensor(arr, dtype=dtype, device=device) + return arr_.unsqueeze(axis=0) if arr_.ndim == 0 else arr_ + + +def cat_scalars( + arg: Union[Scalar, Array], + *args: Scalar, + num: int = 0, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, +) -> paddle.Tensor: + """Join arguments into single 1-dimensional tensor. + + This auxiliary function is used by ``Grid``, ``Image``, and ``ImageBatch`` to support + method arguments for different spatial dimensions as either scalar constant, list + of scalar ``*args``, or single ``Array`` argument. If a single argument of type ``Array`` + is given, it must be a sequence of scalar values. + + Args: + arg: Either a single scalar or sequence of scalars. If the argument is a ``paddle.Tensor``, + it is cloned and detached in order to avoid unintended side effects. + args: Additional scalars. If ``arg`` is a sequence, ``args`` must be empty. + num: Number of expected scalar values. If a single scalar ``arg`` is given, + it is repeated ``num`` times to create a 1-dimensional array. If ``num=0``, + the length of the returned array corresponds to the number of given scalars. + dtype: Data type of output tensor. + device: Device on which to store tensor. + + Returns: + Scalar arguments joined into a 1-dimensional tensor. + + """ + if args: + if isinstance(arg, (tuple, list)) or isinstance(arg, paddle.Tensor): + raise ValueError("arg and args must either be all scalars, or args empty") + arg = paddle.to_tensor(data=(arg,) + args, dtype=dtype, place=device) + else: + arg = as_tensor(arg, dtype=dtype, device=device) + if arg.ndim == 0: + arg = arg.unsqueeze(0) + if arg.ndim != 1: + if num > 0: + raise ValueError( + f"Expected one scalar, a sequence of length {num}, or {num} args" + ) + raise ValueError( + "Expected one scalar, a sequence of scalars, or multiple scalars" + ) + if num > 0: + if len(arg) == 1: + arg = arg.repeat(num) + elif len(arg) != num: + raise ValueError( + f"Expected one scalar, a sequence of length {num}, or {num} args" + ) + return arg + + +def batched_index_select( + input: paddle.Tensor, dim: int, index: paddle.Tensor +) -> paddle.Tensor: + """Batched version of paddle.index_select(). + + See https://discuss.paddle.org/t/batched-index-select/9115/9. + + """ + for i in range(1, len(tuple(input.shape))): + if i != dim: + index = index.unsqueeze(axis=i) + shape = list(tuple(input.shape)) + shape[0] = -1 + shape[dim] = -1 + index = index.expand(shape=shape) + return paddle.take_along_axis(arr=input, axis=dim, indices=index) + + +def move_dim(tensor: paddle.Tensor, dim: int, pos: int) -> paddle.Tensor: + """Move the specified tensor dimension to another position.""" + if dim < 0: + dim = tensor.ndim + dim + if pos < 0: + pos = tensor.ndim + pos + if pos == dim: + return tensor + if dim < pos: + pos += 1 + tensor = tensor.unsqueeze(axis=pos) + if pos <= dim: + dim += 1 + x = tensor + perm_6 = list(range(x.ndim)) + perm_6[dim] = pos + perm_6[pos] = dim + tensor = x.transpose(perm=perm_6).squeeze(axis=dim) + return tensor + + +def unravel_coords(indices: paddle.Tensor, size: Tuple[int, ...]) -> paddle.Tensor: + """Converts flat indices into unraveled grid coordinates. + + Args: + indices: A tensor of flat indices with shape ``(..., N)``. + size: Sampling grid size with order ``(X, ...)``. + + Returns: + Grid coordinates of corresponding grid points. + + """ + size = tuple(size) + numel = size.size + if indices.greater_equal(y=paddle.to_tensor(numel)).astype("bool").any(): + raise ValueError(f"unravel_coords() indices must be smaller than {numel}") + coords = paddle.zeros( + shape=tuple(indices.shape) + (len(size),), dtype=indices.dtype + ) + for i, n in enumerate(size): + coords[..., i] = indices % n + indices = indices // n + return coords + + +def unravel_index(indices: paddle.Tensor, shape: Tuple[int, ...]) -> paddle.Tensor: + """Converts flat indices into unraveled coordinates in a target shape. + + This is a `paddle` implementation of `numpy.unravel_index`, but returning a + tensor of shape (..., N, D) rather than a D-dimensional tuple. See also + + Args: + indices: A tensor of indices with shape (..., N). + shape: The targeted tensor shape of length D. + + Returns: + Unraveled coordinates as tensor of shape (..., N, D) with coordinates + in the same order as the input ``shape`` dimensions. + + """ + shape = tuple(shape) + numel = shape.size + if indices.greater_equal(y=paddle.to_tensor(numel)).astype("bool").any(): + raise ValueError(f"unravel_coords() indices must be smaller than {numel}") + coords = paddle.zeros( + shape=tuple(indices.shape) + (len(shape),), dtype=indices.dtype + ) + for i, n in enumerate(reversed(shape)): + coords[..., i] = indices % n + indices = indices // n + return coords.flip(axis=-1) + + +def log_grad_hook( + name: str, logger: Optional[Logger] = None +) -> Callable[[paddle.Tensor], None]: + """Backward hook to print tensor gradient information for debugging.""" + + def printer(grad: paddle.Tensor) -> None: + if grad.size == 1: + msg = f"{name}.grad: value={grad}" + else: + msg = f"{name}.grad: shape={tuple(tuple(grad.shape))}, max={grad.max()}, min={grad.min()}, mean={grad.mean()}" + if logger is None: + print(msg) + else: + logger.debug(msg) + + return printer + + +def register_backward_hook( + tensor: paddle.Tensor, + hook: Callable[[paddle.Tensor], None], + retain_grad: bool = False, +) -> paddle.Tensor: + """Register backward hook and optionally enable retaining gradient.""" + if not tensor.stop_gradient: + if retain_grad: + tensor.retain_grads() + tensor.register_hook(hook=hook) + return tensor diff --git a/jointContribution/HighResolution/deepali/core/types.py b/jointContribution/HighResolution/deepali/core/types.py new file mode 100644 index 0000000000..b07d4653ba --- /dev/null +++ b/jointContribution/HighResolution/deepali/core/types.py @@ -0,0 +1,169 @@ +import re +from dataclasses import Field +from pathlib import Path +from typing import Any +from typing import Dict +from typing import List +from typing import Mapping +from typing import NamedTuple +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import TypeVar +from typing import Union + +import paddle +from typing_extensions import Protocol + +RE_OUTPUT_KEY_INDEX = re.compile("\\[([0-9]+)\\]") +EllipsisType = type(...) +T = TypeVar("T") +ScalarOrTuple = Union[T, Tuple[T, ...]] +ScalarOrTuple1d = Union[T, Tuple[T]] +ScalarOrTuple2d = Union[T, Tuple[T, T]] +ScalarOrTuple3d = Union[T, Tuple[T, T, T]] +ScalarOrTuple4d = Union[T, Tuple[T, T, T, T]] +ScalarOrTuple5d = Union[T, Tuple[T, T, T, T, T]] +ScalarOrTuple6d = Union[T, Tuple[T, T, T, T, T, T]] +ListOrTuple = Union[List[T], Tuple[T, ...]] + +paddle.Tensor = paddle.Tensor +Device = str +DType = paddle.dtype +Name = Optional[str] +Size = ScalarOrTuple[int] +Shape = ScalarOrTuple[int] +Scalar = Union[int, float, paddle.Tensor] +Array = Union[Sequence[Scalar], paddle.Tensor] +PathStr = Union[Path, str] + + +class Dataclass(Protocol): + """Type annotation for any dataclass.""" + + __dataclass_fields__: Dict[str, Any] + + +Batch = Union[Dataclass, Dict[str, Any], NamedTuple] +Sample = Union[Dataclass, Dict[str, Any], NamedTuple] +TensorMapOrSequence = Union[Mapping[str, paddle.Tensor], Sequence[paddle.Tensor]] +TensorCollection = Union[ + TensorMapOrSequence, + Mapping[str, TensorMapOrSequence], + Sequence[TensorMapOrSequence], +] + + +def tensor_collection_entry( + output: TensorCollection, key: str +) -> Union[TensorCollection, paddle.Tensor]: + """Get specified output entry.""" + key = RE_OUTPUT_KEY_INDEX.sub(".\\1", key) + for index in key.split("."): + if isinstance(output, (list, tuple)): + try: + index = int(index) + except TypeError: + raise KeyError(f"invalid output key {key}") + elif not index or not isinstance(output, dict): + raise KeyError(f"invalid output key {key}") + output = output[index] + return output + + +def get_tensor(output: TensorCollection, key: str) -> paddle.Tensor: + """Get tensor at specified output entry.""" + item = tensor_collection_entry(output, key) + if not isinstance(item, paddle.Tensor): + raise TypeError(f"get_output_tensor() entry {key} must be paddle.Tensor") + return item + + +def is_bool_dtype(dtype: DType) -> bool: + """Checks if ``dtype`` of given NumPy array or tensor is boolean type.""" + return dtype in ("bool",) + + +def is_float_dtype(dtype: DType) -> bool: + """Checks if ``dtype`` of given tensor is a floating point type.""" + return dtype in ("float16", "float32", "float64") or dtype in ( + paddle.float16, + paddle.float32, + paddle.float64, + ) + + +def is_int_dtype(dtype: DType) -> bool: + """Checks if ``dtype`` of given tensor is a signed integer type.""" + return dtype in ("int8", "int16", "int32", "int64") or dtype in ( + paddle.int8, + paddle.int16, + paddle.int32, + paddle.int64, + ) + + +def is_uint_dtype(dtype: DType) -> bool: + """Checks if ``dtype`` of given tensor is an unsigned integer type.""" + return dtype in ("uint8",) + + +def is_namedtuple(arg: Any) -> bool: + """Check if given object is a named tuple.""" + return isinstance(arg, tuple) and hasattr(arg, "_fields") + + +def is_optional_field(field: Field) -> bool: + """Whether given dataclass field type is ``Optional[T] = Union[T, NoneType]``.""" + return is_optional_type_hint(field.type) + + +def is_optional_type_hint(type_hint: Any) -> bool: + """Whether given type hint is ``Optional[T] = Union[T, NoneType]``.""" + type_origin = getattr(type_hint, "__origin__", None) + if type_origin is Union: + return type(None) in type_hint.__args__ + + +def is_path_str(arg: Any) -> bool: + """Whether given object is of type ``pathlib.Path`` or ``str``.""" + return isinstance(arg, (Path, str)) + + +def is_path_str_type_hint(type_hint: Any, required: bool = False) -> bool: + """Check if given type annotation is ``pathlib.Path``, ``PathStr = Union[pathlib.Path, str]``. + + Args: + type_hint: Type annotation, e.g., ``dataclasses.Field.type``. + required: Whether path argument is required. If ``False``, ``type(None)`` in the + type hint is ignore, i.e., also ``Optional[T]`` is considered valid. + + Returns: + Whether type hint is ``pathlib.Path``, ``Union[pathlib.Path, str]``, or + ``Union[str, pathlib.Path]``. When ``required=False``, type annotations + ``Optional[T] = Union[T, None]`` where ``T`` is one of the aforementioned + path string types also results in a return value of ``True``. + + """ + if type_hint in (Path, "Path", "Optional[Path]", "PathStr", "Optional[PathStr]"): + return True + type_origin = getattr(type_hint, "__origin__", None) + if type_origin is Union: + type_args = set(type_hint.__args__) + if not required: + type_args.discard(type(None)) + type_args.discard("None") + type_args.discard(str) + type_args.discard("str") + if not type_args: + return False + for type_arg in type_args: + if type_arg not in (Path, "Path", "PathStr"): + return False + return True + return False + + +def is_path_str_field(field: Field, required: bool = False) -> bool: + """Check if given dataclass field type is ``pathlib.Path``, ``PathStr = Union[pathlib.Path, str]``.""" + return is_path_str_type_hint(field.type, required=required) diff --git a/jointContribution/HighResolution/deepali/data/__init__.py b/jointContribution/HighResolution/deepali/data/__init__.py new file mode 100644 index 0000000000..192fff7ee2 --- /dev/null +++ b/jointContribution/HighResolution/deepali/data/__init__.py @@ -0,0 +1,32 @@ +"""Specialized subtypes of ``paddle.Tensor``, datasets thereof, and data transforms.""" +from .collate import collate_samples +from .dataset import Dataset +from .dataset import GroupDataset +from .dataset import ImageDataset +from .dataset import ImageDatasetConfig +from .dataset import JoinDataset +from .dataset import MetaDataset +from .flow import FlowField +from .flow import FlowFields +from .image import Image +from .image import ImageBatch +from .partition import Partition +from .partition import dataset_split_lengths +from .prepare import prepare_batch + +__all__ = ( + "FlowField", + "FlowFields", + "Image", + "ImageBatch", + "Dataset", + "GroupDataset", + "ImageDataset", + "ImageDatasetConfig", + "JoinDataset", + "MetaDataset", + "dataset_split_lengths", + "Partition", + "collate_samples", + "prepare_batch", +) diff --git a/jointContribution/HighResolution/deepali/data/collate.py b/jointContribution/HighResolution/deepali/data/collate.py new file mode 100644 index 0000000000..be8b21f602 --- /dev/null +++ b/jointContribution/HighResolution/deepali/data/collate.py @@ -0,0 +1,119 @@ +from collections import abc +from dataclasses import is_dataclass +from pathlib import Path +from typing import Any +from typing import List +from typing import Mapping +from typing import NamedTuple +from typing import Sequence +from typing import overload + +import paddle + +from ..core.types import Batch +from ..core.types import Dataclass +from ..core.types import Sample +from ..core.types import is_namedtuple +from .flow import FlowField +from .flow import FlowFields +from .image import Image +from .image import ImageBatch +from .sample import replace_all_sample_field_values +from .sample import sample_field_names +from .sample import sample_field_value + +__all__ = ("collate_samples",) + + +@overload +def collate_samples(batch: Sequence[Mapping[str, Any]]) -> Mapping[str, Any]: + ... + + +@overload +def collate_samples(batch: Sequence[Dataclass]) -> Dataclass: + ... + + +@overload +def collate_samples(batch: Sequence[NamedTuple]) -> NamedTuple: + ... + + +def collate_samples(batch: Sequence[Sample]) -> Batch: + """Collate individual samples into a batch.""" + if not batch: + raise ValueError("collate_samples() 'batch' must have at least one element") + item0 = batch[0] + names = sample_field_names(item0) + values = [] + for name in names: + elem0 = sample_field_value(item0, name) + samples = [elem0] + is_none = elem0 is None + for item in batch[1:]: + value = sample_field_value(item, name) + if is_none and value is not None or not is_none and value is None: + raise ValueError( + f"collate_samples() 'batch' has some samples with '{name}' set and others without" + ) + if not isinstance(value, type(elem0)): + raise TypeError( + f"collate_samples() all 'batch' samples for field '{name}' must be of the same type" + ) + samples.append(value) + if is_none: + values.append(None) + elif isinstance(elem0, FlowField): + FlowFieldsType = type(elem0.batch()) + flow_fields: List[FlowField] = samples + if any(flow_field.axes() != elem0.axes() for flow_field in flow_fields): + raise ValueError( + f"collate_samples() 'batch' contains '{name}' flow fields with mixed axes" + ) + data = paddle.io.dataloader.collate.default_collate_fn( + batch=[flow_field.tensor() for flow_field in flow_fields] + ) + grid = tuple(flow_field.grid() for flow_field in flow_fields) + values.append(FlowFieldsType(data, grid, elem0.axes())) + elif isinstance(elem0, FlowFields): + FlowFieldsType = type(elem0) + flow_fields: List[FlowFields] = samples + if any(flow_field.axes() != elem0.axes() for flow_field in flow_fields): + raise ValueError( + f"collate_samples() 'batch' contains '{name}' flow fields with mixed axes" + ) + data = paddle.concat( + x=[flow_field.tensor() for flow_field in flow_fields], axis=0 + ) + grid = tuple( + grid for flow_field in flow_fields for grid in flow_field.grids() + ) + values.append(FlowFieldsType(data, grid, elem0.axes())) + elif isinstance(elem0, Image): + ImageBatchType = type(elem0.batch()) + images: List[Image] = samples + data = paddle.io.dataloader.collate.default_collate_fn( + batch=[image.tensor() for image in images] + ) + grid = tuple(image.grid() for image in images) + values.append(ImageBatchType(data, grid)) + elif isinstance(elem0, ImageBatch): + ImageBatchType = type(elem0) + images: List[ImageBatch] = samples + data = paddle.concat(x=[image.tensor() for image in images], axis=0) + grid = tuple(grid for image in images for grid in image.grids()) + values.append(ImageBatchType(data, grid)) + elif isinstance(elem0, (Path, str)): + values.append(samples) + elif ( + isinstance(elem0, abc.Mapping) + or is_dataclass(elem0) + or is_namedtuple(elem0) + ): + values.append(collate_samples(samples)) + else: + values.append( + paddle.io.dataloader.collate.default_collate_fn(batch=samples) + ) + return replace_all_sample_field_values(item0, values) diff --git a/jointContribution/HighResolution/deepali/data/dataset.py b/jointContribution/HighResolution/deepali/data/dataset.py new file mode 100644 index 0000000000..ff9cd659d0 --- /dev/null +++ b/jointContribution/HighResolution/deepali/data/dataset.py @@ -0,0 +1,503 @@ +from __future__ import annotations + +from abc import ABCMeta +from abc import abstractmethod +from copy import copy as shallowcopy +from dataclasses import dataclass +from dataclasses import field +from pathlib import Path +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import TypeVar +from typing import Union +from typing import overload + +import paddle +import pandas as pd +from paddle.io import Subset +from paddle.nn import Sequential + +from ..core.config import DataclassConfig +from ..core.types import PathStr +from ..core.types import Sample +from ..core.types import is_namedtuple +from ..core.types import is_path_str +from .transforms import Transform +from .transforms.image import ImageTransformConfig +from .transforms.image import image_transforms +from .transforms.image import prepend_read_image_transform + +__all__ = ( + "Dataset", + "MetaDataset", + "ImageDataset", + "ImageDatasetConfig", + "GroupDataset", + "JoinDataset", + "read_table", +) +TDataset = TypeVar("TDataset", bound="Dataset") + + +class Dataset(paddle.io.Dataset, metaclass=ABCMeta): + """Base class of datasets with optionally on-the-fly pre-processed samples. + + This map-style dataset base class is convenient for attaching data transformations to + a given dataset. Otherwise, datasets may also derive directly from the respective + ``paddle.utils.data`` dataset classes or simply implement the expected interfaces. + + + """ + + def __init__( + self, transforms: Optional[Union[Transform, Sequence[Transform]]] = None + ): + """Initialize dataset. + + If a dataset produces samples (i.e., a dictionary, named tuple, or custom dataclass) + which contain fields with ``None`` values, ``collate_fn=collate_samples`` must be + passed to ``paddle.utils.data.DataLoader``. This custom collate function will ignore + ``None`` values and pass these on to the respective batch entry. Auxiliary function + ``prepare_batch()`` can be used to transfer the batch data retrieved by the data + loader to the execution device. + + Args: + transforms: Data preprocessing and augmentation transforms. + If more than one transformation is given, these will be composed + in the given order, where the first transformation in the sequence + is applied first. When the data samples are passed directly to + ``paddle.utils.data.DataLoader``, the transformed sample data must + be of type ``np.ndarray``, ``paddle.Tensor``, or ``None``. + + """ + super().__init__() + if transforms is None: + transform = Sequential() + elif isinstance(transforms, Sequential): + transform = transforms + else: + if not isinstance(transforms, (list, tuple)): + transforms = [transforms] + transform = Sequential(*transforms) + self._transform: Sequential = transform + + @abstractmethod + def __len__(self) -> int: + """Number of samples in dataset.""" + raise NotImplementedError + + def __getitem__(self, index: int) -> Sample: + """Processed data of i-th dataset sample. + + Args: + index: Index of dataset sample. + + Returns: + Sample data. + + """ + sample = self.sample(index) + sample = self._transform(sample) + return sample + + @abstractmethod + def sample(self, index: int) -> Sample: + """Data of i-th dataset sample.""" + raise NotImplementedError + + def samples(self) -> Iterable[Sample]: + """Get iterable over untransformed dataset samples.""" + + class SampleIterator(object): + def __init__(self, dataset: Dataset): + self.dataset = dataset + self.index = -1 + + def __iter__(self) -> Iterator[Sample]: + self.index = 0 + return self + + def __next__(self) -> Sample: + if self.index >= len(self.dataset): + raise StopIteration + sample = self.dataset.sample(self.index) + self.index += 1 + return sample + + return SampleIterator(self) + + @overload + def transform(self) -> Sequential: + ... + + @overload + def transform( + self: TDataset, + arg0: Union[Transform, Sequence[Transform], None], + *args: Union[Transform, Sequence[Transform], None], + ) -> TDataset: + ... + + def transform( + self: TDataset, *args: Union[Transform, Sequence[Transform], None] + ) -> Union[Sequential, TDataset]: + """Get composite data preprocessing and augmentation transform, or new dataset with specified transform.""" + if not args: + return self._transform + return shallowcopy(self).transform_(*args) + + def transform_( + self: TDataset, + arg0: Union[Transform, Sequence[Transform], None], + *args: Union[Transform, Sequence[Transform], None], + ) -> TDataset: + """Set data preprocessing and augmentation transform of this dataset.""" + transforms = [] + for arg in [arg0, *args]: + if arg is None: + continue + if isinstance(arg, (list, tuple)): + transforms.extend(arg) + else: + transforms.append(arg) + if not transforms: + self._transform = None + elif len(transforms) == 1 and isinstance(transforms[0], Sequential): + self._transform = transforms[0] + else: + self._transform = Sequential(*transforms) + return self + + @overload + def transforms(self) -> List[Transform]: + ... + + @overload + def transforms( + self: TDataset, + arg0: Union[Transform, Sequence[Transform], None], + *args: Union[Transform, Sequence[Transform], None], + ) -> TDataset: + ... + + def transforms( + self: TDataset, *args: Union[Transform, Sequence[Transform], None] + ) -> Union[List[Transform], TDataset]: + """Get or set dataset transforms.""" + if not args: + return [transform for transform in self._transform] + return shallowcopy(self).transform_(*args) + + def transforms_( + self, + arg0: Union[Transform, Sequence[Transform], None], + *args: Union[Transform, Sequence[Transform], None], + ) -> Dataset: + """Set data transforms of this dataset.""" + return self.transform_(arg0, *args) + + +class MetaDataset(Dataset): + """Dataset of file path template strings and sample meta-data given by Pandas DataFrame. + + This dataset can be used in conjunction with data reader transforms to load the data from + configured input file paths. For example, use the :class:`deepali.data.transforms.ReadImage` + transform followed by image data preprocessing and augmentation functions for image data. + The specified file path strings are Python format strings, where keywords are replaced by the + respective column entries for the sample in the dataset index table (`pandas.DataFrame`). + + """ + + def __init__( + self, + table: Union[Path, str, pd.DataFrame], + paths: Optional[Mapping[str, Union[PathStr, Callable[..., PathStr]]]] = None, + prefix: Optional[PathStr] = None, + transforms: Optional[Union[Transform, Sequence[Transform]]] = None, + **kwargs, + ): + """Initialize dataset. + + Args: + table: Table with sample IDs, optionally sample specific input file path template + strings (cf. ``paths``), and additional sample meta data. + paths: File path template strings of input data files. The format string may contain keys ``prefix``, + when a ``prefix`` path has been specified, and ``table`` column names. The dictionary keys of this + argument are used as sample data dictionary keys for the respective file paths. When the path value + is a string which matches exactly the name of a ``table`` column, the value of this column is used + without configuring a file path template string. This is useful when the input ``table`` already + specifies the file paths for each sample. Instead of a string, the dictionary value can be a + callable function instead, which takes the ``table`` row values as keyword arguments, and must return + the respectively formatted input file path string. When no ``paths`` are given, the dataset samples + only contain the meta-data from the input ``table`` columns. + prefix: Root directory of input file paths starting with ``"{prefix}/"``. + If ``None`` and ``table`` is a file path, it is set to the directory containing the index table. + Otherwise, template file path strings may not contain a ``{prefix}`` key if ``None``. + transforms: Data preprocessing and augmentation transforms. + kwargs: Additional format arguments used in addition to ``prefix`` and ``table`` column values. + + """ + if isinstance(table, (str, Path)): + if prefix is None: + path = Path(table).absolute() + prefix = path.parent + elif prefix: + prefix = Path(prefix).absolute() + path = prefix / Path(table) + else: + path = Path(table).absolute() + table = read_table(path) + if not isinstance(table, pd.DataFrame): + raise TypeError( + f"{type(self).__name__}() 'table' must be pandas.DataFrame or file path" + ) + df: pd.DataFrame = table + paths = {} if paths is None else dict(paths) + if "meta" in df.columns: + raise ValueError( + f"{type(self).__name__} 'table' contains column with reserved name 'meta'" + ) + if "meta" in paths: + raise ValueError( + f"{type(self).__name__} 'paths' contains reserved 'meta' key" + ) + prefix = Path(prefix).absolute() if prefix else None + self.table = df + self.paths = paths + self.prefix = prefix + self.kwargs = kwargs + super().__init__(transforms=transforms) + + def __len__(self) -> int: + """Number of samples in dataset.""" + return len(self.table) + + def row(self, index: int) -> Dict[str, Any]: + """Get i-th table row values.""" + return self.table.iloc[index].to_dict() + + def sample(self, index: int) -> Dict[str, Any]: + """Input file paths and/or meta-data of i-th sample in dataset.""" + meta = self.row(index) + if not self.paths: + return meta + data = {} + args = {"prefix": str(self.prefix)} if self.prefix else {} + args["index"] = index + args["index+1"] = index + 1 + args["index + 1"] = index + 1 + args.update(self.kwargs) + args.update(meta) + for name, path in self.paths.items(): + if callable(path): + path = path(**args) + elif path in meta: + path = meta[path] + else: + path = path.format(**args) + if not path: + continue + path = str(path) + data[name] = path + meta[name] = path + data["meta"] = meta + return data + + def samples(self) -> Iterable[Dict[str, Any]]: + """Get iterable over untransformed dataset samples.""" + + class DatasetSampleIterator(object): + def __init__(self, dataset: MetaDataset): + self.dataset = dataset + self.index = -1 + + def __iter__(self) -> Iterator[Dict[str, Any]]: + self.index = 0 + return self + + def __next__(self) -> Dict[str, Any]: + if self.index >= len(self.dataset): + raise StopIteration + sample = self.dataset.sample(self.index) + self.index += 1 + return sample + + return DatasetSampleIterator(self) + + +@dataclass +class ImageDatasetConfig(DataclassConfig): + """Configuration of image dataset.""" + + table: PathStr + prefix: Optional[PathStr] = None + images: Mapping[str, PathStr] = field(default_factory=dict) + transforms: Mapping[str, ImageTransformConfig] = field(default_factory=dict) + + @classmethod + def _from_dict( + cls, arg: Mapping[str, Any], parent: Optional[Path] = None + ) -> ImageDatasetConfig: + """Create configuration from dictionary. + + This function optionally re-organizes the dictionary entries to conform to the dataclass layout. + It allows the image data transforms to be specified as separate "transforms" entry for each image. + In this case, the image file path template string must given by the "path" dictionary entry. + Additionally, a "read" image transform is added when a "dtype" or "device" is specified on which + the image data is loaded and preprocessed can also be specified alongside the file "path". + Any image "transforms" specified at the top-level are applied after any "transforms" specified + underneath the "images" key. + + """ + arg = dict(arg) + images = arg.pop("images", None) or {} + transforms = arg.pop("transforms", None) or {} + image_paths = {} + for name, value in images.items(): + dtype = None + device = None + if isinstance(value, Mapping): + if "path" not in value: + raise ValueError( + f"{cls.__name__}.from_dict() 'images' key '{name}' dict must contain 'path' entry" + ) + path = value["path"] + dtype = value.get("dtype", dtype) + device = value.get("device", device) + image_transforms = value.get("transforms") or [] + if not isinstance(image_transforms, Sequence): + raise TypeError( + f"{cls.__name__}.from_dict() image 'transforms' value must be Sequence" + ) + elif is_path_str(value): + path = Path(value).as_posix() + image_transforms = [] + else: + raise ValueError( + f"{cls.__name__}.from_dict() 'images' key '{name}' must be PathStr or dict with 'path' entry" + ) + if name in transforms: + item_transforms = transforms[name] + if not isinstance(item_transforms, Sequence): + raise TypeError( + f"{cls.__name__}.from_dict() 'transforms' dict value must be Sequence" + ) + item_transforms = list(item_transforms) + else: + item_transforms = [] + image_transforms = image_transforms + item_transforms + if dtype or device: + image_transforms = prepend_read_image_transform( + image_transforms, dtype=dtype, device=device + ) + transforms[name] = image_transforms + image_paths[name] = path + arg["images"] = {k: v for k, v in image_paths.items() if v} + arg["transforms"] = {k: v for k, v in transforms.items() if v} + return super()._from_dict(arg, parent) + + +class ImageDataset(MetaDataset): + """Configurable image dataset.""" + + @classmethod + def from_config(cls, config: ImageDatasetConfig) -> ImageDataset: + transforms = [] + for image_name in config.images: + image_transforms_config = config.transforms.get(image_name, []) + image_transforms_config = prepend_read_image_transform( + image_transforms_config + ) + item_transforms = image_transforms(image_transforms_config, key=image_name) + transforms.extend(item_transforms) + return cls( + config.table, + paths=config.images, + prefix=config.prefix, + transforms=transforms, + ) + + +class GroupDataset(paddle.io.Dataset): + """Group samples in dataset.""" + + def __init__( + self, + dataset: MetaDataset, + groupby: Union[Sequence[str], str], + sortby: Optional[Union[Sequence[str], str]] = None, + ascending: bool = True, + ) -> None: + super().__init__() + indices = [] + df = dataset.table + if sortby: + df = df.sort_values(sortby, ascending=ascending) + groups = df.groupby(groupby) + for _, group in groups: + assert isinstance(group, pd.DataFrame) + ilocs = [row[0] for row in group.itertuples(index=True)] + indices.append(ilocs) + self.dataset = dataset + self.indices = indices + + def __len__(self) -> int: + return len(self.indices) + + def __getitem__(self, index: int) -> Subset[Dict[str, Any]]: + indices = self.indices[index] + return Subset(dataset=self.dataset, indices=indices) + + +class JoinDataset(Dataset): + """Join dict entries from one or more datasets in a single dict.""" + + def __init__(self, datasets: Iterable[Dataset]) -> None: + super().__init__() + datasets = list(datasets) + if not all(len(dataset) == len(datasets[0]) for dataset in datasets): + raise ValueError("JoinDataset() 'datasets' must have the same size") + self.datasets = datasets + + def __len__(self) -> int: + datasets = self.datasets + return len(datasets[0]) if datasets else 0 + + def sample(self, index: int) -> Sample: + sample = {} + for i, dataset in enumerate(self.datasets): + data = dataset[index] + if not isinstance(data, dict): + if is_namedtuple(data): + data = data._asdict() + else: + data = {str(i): data} + for key, value in data.items(): + current = sample.get(key, None) + if current is not None and current != value: + raise ValueError( + "JoinDataset() encountered ambiguous duplicate key '{key}'" + ) + sample[key] = value + return sample + + +def read_table(path: PathStr) -> pd.DataFrame: + """Read dataset index table.""" + path = Path(path).absolute() + if path.suffix.lower() == ".h5": + return pd.read_hdf(path) + if path.suffix.lower() == ".tsv": + return pd.read_csv(path, comment="#", skip_blank_lines=True, delimiter="\t") + if path.suffix.lower() == ".csv": + return pd.read_csv(path, comment="#", skip_blank_lines=True) + raise NotImplementedError( + f"read_table() does not support {path.suffix} file format" + ) diff --git a/jointContribution/HighResolution/deepali/data/flow.py b/jointContribution/HighResolution/deepali/data/flow.py new file mode 100644 index 0000000000..9d243d4f34 --- /dev/null +++ b/jointContribution/HighResolution/deepali/data/flow.py @@ -0,0 +1,627 @@ +from __future__ import annotations + +from typing import Any +from typing import Optional +from typing import Sequence +from typing import Type +from typing import TypeVar +from typing import Union +from typing import overload + +import paddle + +from ..core import flow as U +from ..core.enum import PaddingMode +from ..core.enum import Sampling +from ..core.grid import Axes +from ..core.grid import Grid +from ..core.grid import grid_transform_vectors +from ..core.tensor import move_dim +from ..core.types import Array +from ..core.types import Device +from ..core.types import DType +from ..core.types import EllipsisType +from ..core.types import PathStr +from ..core.types import Scalar +from .image import Image +from .image import ImageBatch + +TFlowField = TypeVar("TFlowField", bound="FlowField") +TFlowFields = TypeVar("TFlowFields", bound="FlowFields") +__all__ = "FlowField", "FlowFields" + + +class FlowFields(ImageBatch): + """Batch of flow fields.""" + + def __init__( + self: TFlowFields, + data: Union[Array, ImageBatch], + grid: Optional[Union[Grid, Sequence[Grid]]] = None, + axes: Optional[Axes] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, + requires_grad: Optional[bool] = None, + pin_memory: bool = False, + ) -> None: + """Initialize flow fields. + + Args: + data: Batch data tensor of shape (N, D, ...X), where N is the batch size, and D + must be equal the number of spatial dimensions. The order of the image channels + must be such that vector components are in the order ``(x, ...)``. + grid: Flow field sampling grids. If not otherwise specified, this attribute + defines the fixed target image domain on which to resample a moving source image. + axes: Axes with respect to which vectors are defined. By default, it is assumed that + vectors are with respect to the unit ``grid`` cube in ``[-1, 1]^D``, where D are the + number of spatial dimensions. If ``grid.align_corners() == False``, the extrema + ``(-1, 1)`` refer to the boundary of the vector field ``grid``. Otherwise, the + extrema coincide with the corner points of the sampling grid. + dtype: Data type of the image data. A copy of the data is only made when the desired ``dtype`` + is not ``None`` and not the same as ``data.dtype``. + device: Device on which to store image data. A copy of the data is only made when the data + has to be copied to a different device. + requires_grad: If autograd should record operations on the returned image tensor. + pin_memory: If set, returned image tensor would be allocated in the pinned memory. + Works only for CPU tensors. + + """ + if grid is None and isinstance(data, ImageBatch): + grid = data.grid() + data = data.tensor() + super().__init__( + data, + grid, + dtype=dtype, + device=device, + requires_grad=requires_grad, + pin_memory=pin_memory, + ) + if self.shape[1] != self.sdim: + raise ValueError( + f"{type(self).__name__}() 'data' nchannels={self.shape[1]} must be equal spatial ndim={self.sdim}" + ) + if axes is None: + axes = Axes.from_grid(self._grid[0]) + else: + axes = Axes.from_arg(axes) + self._axes = axes + + def _make_instance( + self: TFlowFields, + data: paddle.Tensor, + grid: Optional[Sequence[Grid]] = None, + axes: Optional[Axes] = None, + **kwargs, + ) -> Union[TFlowFields, ImageBatch]: + """Create a new instance while preserving subclass meta-data.""" + if tuple(data.shape)[1] != data.ndim - 2: + return ImageBatch(data, grid) + kwargs["axes"] = axes or self._axes + return super()._make_instance(data, grid, **kwargs) + + def _make_subitem(self, data: paddle.Tensor, grid: Grid) -> Union[FlowField, Image]: + """Create FlowField in __getitem__. Can be overridden by subclasses to return a subtype.""" + if tuple(data.shape)[0] == data.ndim - 1: + return FlowField(data, grid, self._axes) + return super()._make_subitem(data, grid) + + @staticmethod + def _paddle_function_axes(args) -> Optional[Axes]: + """Get flow field Axes from args passed to __paddle_function__.""" + if not args: + return None + if isinstance(args[0], (tuple, list)): + args = args[0] + axes: Sequence[Axes] + axes = [ + ax for ax in (getattr(arg, "_axes", None) for arg in args) if ax is not None + ] + if not axes: + return None + if any(ax != axes[0] for ax in axes[1:]): + raise ValueError( + "Cannot apply __paddle_function__ to flow fields with mismatching axes" + ) + return axes[0] + + @classmethod + def _paddle_function_result( + cls, func, data, grid: Optional[Sequence[Grid]], axes: Optional[Axes] + ) -> Any: + if not isinstance(data, paddle.Tensor): + return data + if ( + grid + and axes is not None + and data.ndim == grid[0].ndim + 2 + and tuple(data.shape)[1] == grid[0].ndim + and tuple(data.shape)[2:] == tuple(grid[0].shape) + or grid is not None + and not grid + and data.ndim >= 4 + and tuple(data.shape)[0] == 0 + ): + if func in (paddle.clone, paddle.Tensor.clone): + grid = [g.clone() for g in grid] + if isinstance(data, cls): + data._grid = grid + data._axes = axes + else: + data = cls(data, grid, axes) + else: + data = ImageBatch._paddle_function_result(func, data, grid) + return data + + @classmethod + def __paddle_function__(cls, func, types, args=(), kwargs=None): + if func == paddle.nn.functional.grid_sample: + raise ValueError( + "Argument of F.grid_sample() must be a batch, not a single image" + ) + if kwargs is None: + kwargs = {} + data = paddle.Tensor.__paddle_function__(func, (paddle.Tensor,), args, kwargs) + grid = cls._paddle_function_grid(func, args, kwargs) + axes = cls._paddle_function_axes(args) + if func in ( + paddle.split, + paddle.Tensor.split, + paddle.split_with_sizes, + paddle.Tensor.split_with_sizes, + paddle.tensor_split, + paddle.Tensor.tensor_split, + ): + return tuple( + cls._paddle_function_result(func, res, grid, axes) for res in data + ) + return cls._paddle_function_result(func, data, grid, axes) + + @overload + def __getitem__(self: TFlowFields, index: int) -> FlowField: + ... + + @overload + def __getitem__(self: TFlowFields, index: EllipsisType) -> TFlowFields: + ... + + @overload + def __getitem__( + self: TFlowFields, index: Union[list, slice, paddle.Tensor] + ) -> TFlowFields: + ... + + def __getitem__( + self: TFlowFields, + index: Union[ + EllipsisType, int, slice, Sequence[Union[EllipsisType, int, slice]] + ], + ) -> Union[FlowField, Image, TFlowFields, ImageBatch, paddle.Tensor]: + """Get flow field at specified batch index, get a sub-batch, or extract region of interest tensor.""" + return super().__getitem__(index) + + @overload + def axes(self: TFlowFields) -> Axes: + """Get axes with respect to which flow vectors are defined.""" + ... + + @overload + def axes(self: TFlowFields, axes: Axes) -> TFlowFields: + """Get new batch of flow fields with flow vectors defined with respect to specified axes.""" + ... + + def axes( + self: TFlowFields, axes: Optional[Axes] = None + ) -> Union[Axes, TFlowFields]: + """Rescale and reorient vectors.""" + if axes is None: + return self._axes + data = self.tensor() + data = move_dim(data, 1, -1) + data = tuple( + grid.transform_vectors(data[i : i + 1], axes=self._axes, to_axes=axes) + for i, grid in enumerate(self._grid) + ) + data = paddle.concat(x=data, axis=0) + data = move_dim(data, -1, 1) + return self._make_instance(data, self._grid, axes) + + def curl(self: TFlowFields, mode: str = "central") -> ImageBatch: + if self.ndim not in (2, 3): + raise RuntimeError( + "Cannot compute curl of {self.ndim}-dimensional flow field" + ) + spacing = self.spacing() + data = self.tensor() + data = U.curl(data, spacing=spacing, mode=mode) + return ImageBatch(data, self._grid) + + def exp( + self: TFlowFields, + scale: Optional[float] = None, + steps: Optional[int] = None, + sampling: Union[Sampling, str] = Sampling.LINEAR, + padding: Union[PaddingMode, str] = PaddingMode.BORDER, + ) -> TFlowFields: + """Group exponential maps of flow fields computed using scaling and squaring.""" + axes = self._axes + align_corners = axes is Axes.CUBE_CORNERS + flow = self.axes(Axes.CUBE_CORNERS if align_corners else Axes.CUBE) + data = self.tensor() + data = U.expv( + data, + scale=scale, + steps=steps, + sampling=sampling, + padding=padding, + align_corners=align_corners, + ) + flow = self._make_instance(data, flow._grid, flow._axes) + flow = flow.axes(axes) + return flow + + def sample( + self: TFlowFields, + arg: Union[Grid, Sequence[Grid], paddle.Tensor], + mode: Optional[Union[Sampling, str]] = None, + padding: Optional[Union[PaddingMode, str, Scalar]] = None, + ) -> Union[TFlowFields, paddle.Tensor]: + """Sample flow fields at optionally deformed unit grid points. + + Args: + arg: Either a single grid which defines the sampling points for all images in the batch, + a different grid for each image in the batch, or a tensor of normalized coordinates + with shape ``(N, ..., D)`` or ``(1, ..., D)``. In the latter case, note that the + shape ``...`` need not correspond to a (deformed) grid as required by ``grid_sample()``. + It can be an arbitrary shape, e.g., ``M`` to sample at ``M`` given points. + mode: Image interpolation mode. + padding: Image extrapolation mode or scalar padding value. + + Returns: + If ``arg`` is of type ``Grid`` or ``Sequence[Grid]``, a ``FlowFields`` batch is returned. + When these grids match the grids of this batch of flow fields, ``self`` is returned. + Otherwise, a ``paddle.Tensor`` of shape (N, C, ...) of sampled flow values is returned. + Note that when ``arg`` is of type ``Grid`` or ``Sequence[Grid]``, flow vectors that are + not expressed with respect to the world coordinate system will be implicitly converted to + flow vectors with respect to the new sampling grids. If this is not desired, use a ``paddle.Tensor`` + type with sampling coordinates instead of ``Grid`` instances. + + """ + flow = super().sample(arg, mode=mode, padding=padding) + if isinstance(flow, FlowFields): + axes = flow.axes() + if axes != Axes.WORLD: + data = flow.tensor() + data = U.move_dim(data, 1, -1) + data = paddle.concat( + x=[ + grid_transform_vectors(v, grid, axes, to_grid, axes).unsqueeze_( + 0 + ) + for v, grid, to_grid in zip(data, self._grid, flow.grids()) + ], + axis=0, + ) + data = U.move_dim(data, -1, 1) + flow = self._make_instance(data, flow.grids()) + return flow + + def warp_image( + self: TFlowFields, + image: Union[Image, ImageBatch], + sampling: Optional[Union[Sampling, str]] = None, + padding: Optional[Union[PaddingMode, str]] = None, + ) -> ImageBatch: + """Deform given image (batch) using this batch of vector fields. + + Args: + image: Single image or image batch. If a single ``Image`` is given, it is deformed by + all the displacement fields in this batch. If an ``ImageBatch`` is given, the number + of images in the batch must match the number of displacement fields in this batch. + sampling: Interpolation mode for sampling values from ``image`` at deformed grid points. + padding: Extrapolation mode for sampling values outside ``image`` domain. + + Returns: + Batch of input images deformed by the vector fields of this batch. + + """ + if isinstance(image, Image): + image = image.batch() + align_corners = self._axes is Axes.CUBE_CORNERS + grid = ( + g.coords(align_corners=align_corners, device=self.device) + for g in self._grid + ) + grid = paddle.concat(x=tuple(g.unsqueeze(axis=0) for g in grid), axis=0) + flow = self.axes(Axes.from_align_corners(align_corners)) + flow = flow.tensor() + flow = move_dim(flow, 1, -1) + data = image.tensor() + data = U.warp_image( + data, + grid, + flow=flow, + mode=sampling, + padding=padding, + align_corners=align_corners, + ) + return image._make_instance(data, self._grid) + + def __repr__(self) -> str: + return ( + type(self).__name__ + + f"(data={self.tensor()!r}, grids={self.grids()!r}, axes={self.axes()!r})" + ) + + def __str__(self) -> str: + return ( + type(self).__name__ + + f"(data={self.tensor()!s}, grids={self.grids()!s}, axes={self.axes()!r})" + ) + + +class FlowField(Image): + """Flow field image. + + A (dense) flow field is a vector image where the number of channels equals the number of spatial dimensions. + The starting points of the vectors are defined on a regular oriented sampling grid positioned in world space. + Orientation and scale of the vectors are defined with respect to a specified regular grid domain, which either + coincides with the sampling grid, the world coordinate system, or the unit cube with side length 2 centered at + the center of the sampling grid with axes parallel to the sampling grid. This unit cube domain is used by the + ``paddle.nn.functional.grid_sample()`` and ``paddle.nn.functional.interpolate()`` functions. + + When a flow field is convert to a ``SimpleITK.Image``, the vectors are by default reoriented and rescaled such + that these are with respect to the world coordinate system, a format common to ITK functions and other toolkits. + + """ + + def __init__( + self: TFlowField, + data: Array, + grid: Optional[Grid] = None, + axes: Optional[Axes] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, + requires_grad: Optional[bool] = None, + pin_memory: bool = False, + ) -> None: + """Initialize flow field. + + Args: + data: Flow field data tensor of shape (C, ...X), where C must be equal the number of spatial dimensions. + The order of the image channels must be such that vector components are in the order X, Y,... + grid: Flow field sampling grid. If not otherwise specified, this attribute often also defines the fixed + target image domain on which to resample a moving source image. + axes: Axes with respect to which vectors are defined. By default, it is assumed that vectors are with + respect to the unit ``grid`` cube in ``[-1, 1]^D``, where D are the number of spatial dimensions. + If ``None`` and ``grid.align_corners() == False``, the extrema ``(-1, 1)`` refer to the boundary of + the vector field ``grid``, and coincide with the grid corner points otherwise. + dtype: Data type of the image data. A copy of the data is only made when the desired ``dtype`` + is not ``None`` and not the same as ``data.dtype``. + device: Device on which to store image data. A copy of the data is only made when the data + has to be copied to a different device. + requires_grad: If autograd should record operations on the returned image tensor. + pin_memory: If set, returned image tensor would be allocated in the pinned memory. + Works only for CPU tensors. + + """ + super().__init__( + data, + grid, + dtype=dtype, + device=device, + requires_grad=requires_grad, + pin_memory=pin_memory, + ) + if self.nchannels != self._grid.ndim: + raise ValueError( + f"{type(self).__name__} nchannels={self.nchannels} must be equal grid.ndim={self._grid.ndim}" + ) + if axes is None: + axes = Axes.from_grid(self._grid) + else: + axes = Axes.from_arg(axes) + self._axes = axes + + def _make_instance( + self: TFlowField, + data: paddle.Tensor, + grid: Optional[Grid] = None, + axes: Optional[Axes] = None, + **kwargs, + ) -> TFlowField: + """Create a new instance while preserving subclass meta-data.""" + kwargs["axes"] = axes or self._axes + return super()._make_instance(data, grid, **kwargs) + + @staticmethod + def _paddle_function_axes(args) -> Optional[Axes]: + """Get flow field Axes from args passed to __paddle_function__.""" + if not args: + return None + if isinstance(args[0], (tuple, list)): + args = args[0] + axes: Sequence[Axes] + axes = [ + ax for ax in (getattr(arg, "_axes", None) for arg in args) if ax is not None + ] + if not axes: + return None + if any(ax != axes[0] for ax in axes[1:]): + raise ValueError( + "Cannot apply __paddle_function__ to flow fields with mismatching axes" + ) + return axes[0] + + @classmethod + def _paddle_function_result( + cls, func, data, grid: Optional[Grid], axes: Optional[Axes] + ) -> Any: + if not isinstance(data, paddle.Tensor): + return data + if ( + grid is not None + and axes is not None + and data.ndim == grid.ndim + 1 + and tuple(data.shape)[0] == grid.ndim + and tuple(data.shape)[1:] == tuple(grid.shape) + ): + if func in (paddle.clone, paddle.Tensor.clone): + grid = grid.clone() + if isinstance(data, cls): + data._grid = grid + data._axes = axes + else: + data = cls(data, grid, axes) + else: + data = Image._paddle_function_result(func, data, grid) + return data + + @classmethod + def __paddle_function__(cls, func, types, args=(), kwargs=None): + if func == paddle.nn.functional.grid_sample: + raise ValueError( + "Argument of F.grid_sample() must be a batch, not a single image" + ) + if kwargs is None: + kwargs = {} + data = paddle.Tensor.__paddle_function__(func, (paddle.Tensor,), args, kwargs) + grid = cls._paddle_function_grid(args) + axes = cls._paddle_function_axes(args) + if func in ( + paddle.split, + paddle.Tensor.split, + paddle.split_with_sizes, + paddle.Tensor.split_with_sizes, + paddle.tensor_split, + paddle.Tensor.tensor_split, + ): + return tuple( + cls._paddle_function_result(func, res, grid, axes) for res in data + ) + return cls._paddle_function_result(func, data, grid, axes) + + @classmethod + def from_image( + cls: Type[TFlowField], image: Image, axes: Optional[Axes] = None + ) -> TFlowField: + """Create flow field from image instance.""" + return cls(image, image._grid, axes) + + def batch(self: TFlowField) -> FlowFields: + """Batch of flow fields containing only this flow field.""" + data = self.unsqueeze(0) + return FlowFields(data, self._grid, self._axes) + + @overload + def axes(self: TFlowField) -> Axes: + """Get axes with respect to which flow vectors are defined.""" + ... + + @overload + def axes(self: TFlowField, axes: Axes) -> TFlowField: + """Get new flow field with flow vectors defined with respect to specified axes.""" + ... + + def axes(self: TFlowField, axes: Optional[Axes] = None) -> Union[Axes, TFlowField]: + """Rescale and reorient vectors with respect to specified axes.""" + if axes is None: + return self._axes + batch = self.batch() + batch = batch.axes(axes) + return batch[0] + + @classmethod + def from_sitk( + cls: Type[TFlowField], image: "sitk.Image", axes: Optional[Axes] = None + ) -> TFlowField: + """Create vector field from ``SimpleITK.Image``.""" + image = super().from_sitk(image) + return cls.from_image(image, axes=axes or Axes.WORLD) + + def sitk(self: TFlowField, axes: Optional[Axes] = None) -> "sitk.Image": + """Create ``SimpleITK.Image`` from this vector field.""" + disp: TFlowField = self.detach() + disp = disp.axes(axes or Axes.WORLD) + return Image.sitk(disp) + + @classmethod + def read( + cls: Type[TFlowField], path: PathStr, axes: Optional[Axes] = None + ) -> TFlowField: + """Read image data from file.""" + image = cls._read_sitk(path) + return cls.from_sitk(image, axes) + + def curl(self: TFlowField, mode: str = "central") -> Image: + """Compute curl of vector field.""" + batch = self.batch() + rotvec = batch.curl(mode=mode) + return rotvec[0] + + def exp( + self: TFlowField, + scale: Optional[float] = None, + steps: Optional[int] = None, + sampling: Union[Sampling, str] = Sampling.LINEAR, + padding: Union[PaddingMode, str] = PaddingMode.BORDER, + ) -> TFlowField: + """Group exponential map of vector field computed using scaling and squaring.""" + batch = self.batch() + batch = batch.exp(scale=scale, steps=steps, sampling=sampling, padding=padding) + return batch[0] + + @overload + def warp_image( + self: TFlowField, + image: Image, + sampling: Optional[Union[Sampling, str]] = None, + padding: Optional[Union[PaddingMode, str]] = None, + ) -> Image: + """Deform given image using this displacement field.""" + ... + + @overload + def warp_image( + self: TFlowField, + image: ImageBatch, + sampling: Optional[Union[Sampling, str]] = None, + padding: Optional[Union[PaddingMode, str]] = None, + ) -> ImageBatch: + """Deform images in batch using this displacement field.""" + ... + + def warp_image( + self: TFlowField, + image: Union[Image, ImageBatch], + sampling: Optional[Union[Sampling, str]] = None, + padding: Optional[Union[PaddingMode, str]] = None, + ) -> Union[Image, ImageBatch]: + """Deform given image (batch) using this displacement field. + + Args: + image: Single image or batch of images. + kwargs: Keyword arguments to pass on to ``ImageBatch.warp()``. + + Returns: + If ``image`` is an ``ImageBatch``, each image in the batch is deformed by this flow field + and a batch of deformed images is returned. Otherwise, a single deformed image is returned. + + """ + batch = self.batch() + result = batch.warp_image(image, sampling=sampling, padding=padding) + if isinstance(image, Image) and len(result) == 1: + return result[0] + return result + + def __repr__(self) -> str: + return ( + type(self).__name__ + + f"(data={self.tensor()!r}, grid={self.grid()!r}, axes={self.axes()!r})" + ) + + def __str__(self) -> str: + return ( + type(self).__name__ + + f"(data={self.tensor()!s}, grid={self.grid()!s}, axes={self.axes()!r})" + ) diff --git a/jointContribution/HighResolution/deepali/data/image.py b/jointContribution/HighResolution/deepali/data/image.py new file mode 100644 index 0000000000..52cce11575 --- /dev/null +++ b/jointContribution/HighResolution/deepali/data/image.py @@ -0,0 +1,1544 @@ +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union +from typing import overload + +import numpy as np +import paddle + +try: + import SimpleITK as sitk + + from ..utils.sitk.imageio import read_image + from ..utils.sitk.paddle import image_from_tensor + from ..utils.sitk.paddle import tensor_from_image +except ImportError: + sitk = None +from ..core import image as U +from ..core.cube import Cube +from ..core.enum import PaddingMode +from ..core.enum import Sampling +from ..core.enum import SpatialDimArg +from ..core.grid import ALIGN_CORNERS +from ..core.grid import Axes +from ..core.grid import Grid +from ..core.grid import grid_transform_points +from ..core.itertools import zip_longest_repeat_last +from ..core.path import unlink_or_mkdir +from ..core.tensor import cat_scalars +from ..core.types import Array +from ..core.types import Device +from ..core.types import DType +from ..core.types import EllipsisType +from ..core.types import PathStr +from ..core.types import Scalar +from ..core.types import ScalarOrTuple +from ..core.types import Size +from .tensor import DataTensor + +Domain = Cube +TImage = TypeVar("TImage", bound="Image") +TImageBatch = TypeVar("TImageBatch", bound="ImageBatch") +__all__ = "Image", "ImageBatch" + + +class ImageBatch(DataTensor): + """Batch of images sampled on regular oriented grids.""" + + def __init__( + self: TImageBatch, + data: Array, + grid: Optional[Union[Grid, Sequence[Grid]]] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, + requires_grad: Optional[bool] = None, + pin_memory: bool = False, + ) -> None: + """Initialize image decorator. + + Args: + data: Image batch data tensor of shape (N, C, ...X). + grid: Sampling grid of image data oriented in world space. Can be either a single shared + sampling grid, or a separate grid for each image in the batch. Note that operations + which would result in differently sized images (e.g., resampling to a certain voxel + size, when images have different resolutions) will raise an exception. All images in + a batch must have the same number of channels and spatial size. If ``None``, a default + grid whose world space coordinate axes are aligned with the image axes, unit spacing, + and origin at the image centers is created. By default, image grid attributes are always + stored in CPU memory, regardless of the ``device`` on which the image data is located. + dtype: Data type of the image data. A copy of the data is only made when the desired ``dtype`` + is not ``None`` and not the same as ``data.dtype``. + device: Device on which to store image data. A copy of the data is only made when the data + has to be copied to a different device. + requires_grad: If autograd should record operations on the returned image tensor. + pin_memory: If set, returned image tensor would be allocated in the pinned memory. + Works only for CPU tensors. + + """ + if self.ndim < 4: + raise ValueError("Image batch tensor must have at least four dimensions") + self.grid_(grid) + + def _make_instance( + self: TImageBatch, + data: Optional[paddle.Tensor] = None, + grid: Optional[Sequence[Grid]] = None, + **kwargs, + ) -> TImageBatch: + """Create a new instance while preserving subclass (meta-)data.""" + kwargs["grid"] = self._grid if grid is None else grid + return super()._make_instance(data, **kwargs) + + def _make_subitem(self, data: paddle.Tensor, grid: Grid) -> Image: + """Create Image in __getitem__. Can be overridden by subclasses to return a subtype.""" + return Image(data, grid) + + def __deepcopy__(self: TImageBatch, memo) -> TImageBatch: + if id(self) in memo: + return memo[id(self)] + result = self._make_instance( + self.data.clone(), + grid=tuple(grid.clone() for grid in self._grid), + requires_grad=self.requires_grad, + pin_memory=self.is_pinned(), + ) + memo[id(self)] = result + return result + + @staticmethod + def _paddle_function_grid( + func, args, kwargs: Dict[str, Any] + ) -> Optional[Union[Sequence[Grid], Sequence[Sequence[Grid]]]]: + """Get spatial sampling grids from args passed to __paddle_function__.""" + if not args: + return None + if isinstance(args[0], (tuple, list)): + args = args[0] + grids: Sequence[Sequence[Grid]] + grids = [ + g for g in (getattr(arg, "_grid", None) for arg in args) if g is not None + ] + if not grids: + return None + if kwargs.get("dim", 0) == 0: + if func == paddle.concat: + return [g for grid in grids for g in grid] + if func in (paddle.split, paddle.Tensor.split): + grids = grids[0] + split_grids = [] + split_size_or_sections = args[1] + if isinstance(split_size_or_sections, int): + for start in range(0, len(grids), split_size_or_sections): + split_grids.append( + grids[start : start + split_size_or_sections] + ) + elif isinstance(split_size_or_sections, Sequence): + start = 0 + for num in split_size_or_sections: + split_grids.append(grids[start : start + num]) + return split_grids + if func in (paddle.split_with_sizes, paddle.Tensor.split_with_sizes): + grids = grids[0] + split_grids = [] + split_sizes = args[1] + start = 0 + for num in split_sizes: + split_grids.append(grids[start : start + num]) + return split_grids + if func in (paddle.tensor_split, paddle.Tensor.tensor_split): + grids = grids[0] + split_grids = [] + tensor_indices_or_sections = args[1] + if isinstance(tensor_indices_or_sections, int): + for start in range(0, len(grids), tensor_indices_or_sections): + split_grids.append( + grids[start : start + tensor_indices_or_sections] + ) + elif isinstance(tensor_indices_or_sections, Sequence): + indices = list(tensor_indices_or_sections) + for start, end in zip([0] + indices, indices + [len(grids)]): + split_grids.append(grids[start:end]) + return split_grids + return grids[0] + + @classmethod + def _paddle_function_result(cls, func, data, grid: Optional[Sequence[Grid]]) -> Any: + if not isinstance(data, paddle.Tensor): + return data + if ( + grid + and data.ndim == grid[0].ndim + 2 + and tuple(data.shape)[0] == len(grid) + and tuple(data.shape)[2:] == tuple(grid[0].shape) + or grid is not None + and len(grid) == 0 + and data.ndim >= 4 + and tuple(data.shape)[0] == 0 + ): + if func in (paddle.clone, paddle.Tensor.clone): + grid = [g.clone() for g in grid] + if isinstance(data, cls): + data._grid = grid + else: + data = cls(data, grid) + elif type(data) != paddle.Tensor: + data = data.as_subclass(paddle.Tensor) + return data + + @classmethod + def __paddle_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + args = tuple(arg.batch() if isinstance(arg, Image) else arg for arg in args) + data = paddle.Tensor.__paddle_function__(func, (paddle.Tensor,), args, kwargs) + grid = cls._paddle_function_grid(func, args, kwargs) + if func in (paddle.nn.functional.grid_sample,): + grid = None + elif func in ( + paddle.split, + paddle.Tensor.split, + paddle.split_with_sizes, + paddle.Tensor.split_with_sizes, + paddle.tensor_split, + paddle.Tensor.tensor_split, + ): + if type(data) not in (tuple, list): + raise AssertionError( + f"expected split 'data' to be tuple or list, got {type(data)}" + ) + if type(grid) not in (tuple, list): + raise AssertionError( + f"expected split 'grid' to be tuple or list, got {type(grid)}" + ) + if len(grid) != len(data): + raise AssertionError( + f"expected 'grid' tuple length to be equal batch size, but {len(grid)} != {len(data)}" + ) + assert all(isinstance(d, paddle.Tensor) for d in data) + assert all(isinstance(g, (tuple, list)) for g in grid) + assert all(len(d) == len(g) for d, g in zip(data, grid)) + return tuple( + cls._paddle_function_result(func, d, g) for d, g in zip(data, grid) + ) + return cls._paddle_function_result(func, data, grid) + + def cube(self: TImageBatch, n: int = 0) -> Cube: + """Get cube of n-th image in batch defining its normalized coordinates space with respect to the world.""" + return self._grid[n].cube() + + def cubes(self: TImageBatch) -> Tuple[Cube, ...]: + """Get cubes of all images which define their normalized coordinates spaces with respect to the world.""" + return tuple(grid.cube() for grid in self._grid) + + def domain(self: TImageBatch, n: int = 0) -> Domain: + """Get oriented bounding box of n-th image in world space which defines the domain within which it is defined.""" + return self._grid[n].domain() + + def domains(self: TImageBatch) -> Tuple[Domain, ...]: + """Get oriented bounding boxes of all images in world space which define the domains within which these are defined.""" + return tuple(grid.domain() for grid in self._grid) + + def same_domains_as(self, other: ImageBatch) -> bool: + """Check if images in this batch and another batch have the same cube domains.""" + if len(self) != len(other): + return False + return all(a.same_domain_as(b) for a, b in zip(self.grids(), other.grids())) + + @overload + def grid(self: TImageBatch, n: int = 0) -> Grid: + """Get sampling grid of n-th image in batch.""" + ... + + @overload + def grid(self: TImageBatch, arg: Union[Grid, Sequence[Grid]]) -> TImageBatch: + """Get new image batch with specified sampling grid(s).""" + ... + + def grid( + self: TImageBatch, arg: Optional[Union[int, Grid, Sequence[Grid]]] = None + ) -> Union[Grid, TImageBatch]: + """Get sampling grid of images in batch or new batch with specified grid, respectively.""" + if arg is None: + arg = 0 + if isinstance(arg, int): + return self._grid[arg] + return self._make_instance(grid=arg) + + def grid_(self: TImageBatch, arg: Union[Grid, Sequence[Grid], None]) -> TImageBatch: + """Change image sampling grid of this image batch.""" + shape = self.shape + if arg is None: + arg = (Grid(shape=shape[2:]),) * shape[0] + elif isinstance(arg, Grid): + grid = arg + if tuple(grid.shape) != tuple(shape[2:]): + raise ValueError( + "Image grid size does not match spatial dimensions of image batch tensor" + ) + arg = (grid,) * shape[0] + else: + arg = tuple(arg) + if any(tuple(grid.shape) != tuple(shape[2:]) for grid in arg): + raise ValueError( + "Image grid sizes must match spatial dimensions of image batch tensor" + ) + self._grid = arg + return self + + def grids(self: TImageBatch) -> Tuple[Grid, ...]: + """Get sampling grids of images in batch.""" + return self._grid + + def align_corners(self: TImageBatch) -> bool: + """Whether image resizing operations by default preserve corner points or grid extent.""" + return self._grid[0].align_corners() + + def center(self: TImageBatch) -> paddle.Tensor: + """Image centers in world space coordinates as tensor of shape (N, D).""" + return paddle.concat( + x=[grid.center().unsqueeze(axis=0) for grid in self.grids()], axis=0 + ) + + def origin(self: TImageBatch) -> paddle.Tensor: + """Image origins in world space coordinates as tensor of shape (N, D).""" + return paddle.concat( + x=[grid.origin().unsqueeze(axis=0) for grid in self.grids()], axis=0 + ) + + def spacing(self: TImageBatch) -> paddle.Tensor: + """Image spacings in world units as tensor of shape (N, D).""" + return paddle.concat( + x=[grid.spacing().unsqueeze(axis=0) for grid in self.grids()], axis=0 + ) + + def direction(self: TImageBatch) -> paddle.Tensor: + """Image direction cosines matrices as tensor of shape (N, D, D).""" + return paddle.concat( + x=[grid.direction().unsqueeze(axis=0) for grid in self.grids()], axis=0 + ) + + def __len__(self: TImageBatch) -> int: + """Number of images in batch.""" + return self.shape[0] + + @overload + def __getitem__(self: TImageBatch, index: int) -> Image: + ... + + @overload + def __getitem__(self: TImageBatch, index: EllipsisType) -> TImageBatch: + ... + + @overload + def __getitem__( + self: TImageBatch, index: Union[list, slice, paddle.Tensor] + ) -> TImageBatch: + ... + + def __getitem__( + self: TImageBatch, + index: Union[ + EllipsisType, int, slice, Sequence[Union[EllipsisType, int, slice]] + ], + ) -> Union[Image, TImageBatch, paddle.Tensor]: + """Get image at specified batch index, get a sub-batch, or a region of interest tensor.""" + if index is ...: + return self._make_instance(self.tensor(), self.grid()) + if type(index) is tuple: + index = [ + j for i, j in enumerate(index) if j is not ... or ... not in index[:i] + ] + if index[-1] is ...: + index = index[:-1] + try: + i = index.index(...) + j = len(index) - i - 1 + index = index[:i] + [slice(None)] * (self.ndim - i - j) + index[-j:] + except ValueError: + pass + index = tuple(index) + is_multi_index = True + elif isinstance(index, (np.ndarray, slice, Sequence, paddle.Tensor)): + index = (index,) + is_multi_index = True + else: + index = int(index) + is_multi_index = False + data = self.tensor()[index] + if is_multi_index and len(index) > 1 and isinstance(index[1], int): + return data + grid_index = index[0] if is_multi_index else index + if isinstance(grid_index, (np.ndarray, Sequence, paddle.Tensor)): + grid = tuple(self._grid[i] for i in grid_index) + else: + grid = self._grid[grid_index] + if is_multi_index and len(index) > 2: + same_grid = True + for i, n in zip(index[2:], self.shape[2:]): + if isinstance(i, int) or not isinstance(i, slice): + same_grid = False + break + if ( + i.start not in (None, 0) + or i.stop not in (None, n) + or i.step not in (None, 1) + ): + same_grid = False + break + if not same_grid: + return data + if isinstance(grid, Grid): + if data.ndim < 3: + return data + return self._make_subitem(data, grid) + elif data.ndim < 4: + return data + return self._make_instance(data, grid) + + @property + def sdim(self: TImageBatch) -> int: + """Number of spatial dimensions.""" + return self.ndim - 2 + + @property + def nchannels(self: TImageBatch) -> int: + """Number of image channels.""" + return self.shape[1] + + def normalize( + self: TImageBatch, + mode: str = "unit", + min: Optional[float] = None, + max: Optional[float] = None, + ) -> TImageBatch: + """Normalize image intensities in [min, max].""" + return U.normalize_image(self, mode=mode, min=min, max=max) + + def normalize_( + self: TImageBatch, + mode: str = "unit", + min: Optional[float] = None, + max: Optional[float] = None, + ) -> TImageBatch: + """Normalize image intensities in [min, max].""" + return U.normalize_image(self, mode=mode, min=min, max=max, inplace=True) + + def rescale( + self: TImageBatch, + min: Optional[Scalar] = None, + max: Optional[Scalar] = None, + data_min: Optional[Scalar] = None, + data_max: Optional[Scalar] = None, + dtype: Optional[DType] = None, + ) -> TImageBatch: + """Clamp and linearly rescale image values.""" + return U.rescale( + self, min, max, data_min=data_min, data_max=data_max, dtype=dtype + ) + + def narrow(self: TImageBatch, dim: int, start: int, length: int) -> TImageBatch: + """Narrow image batch along specified tensor dimension.""" + start_27 = self.tensor().shape[dim] + start if start < 0 else start + data = paddle.slice(self.tensor(), [dim], [start_27], [start_27 + length]) + grid = self.grid() + if dim > 1: + start_28 = grid.shape[self.ndim - dim - 1] + start if start < 0 else start + grid = paddle.slice( + grid, [self.ndim - dim - 1], [start_28], [start_28 + length] + ) + return self._make_instance(data, grid) + + def resize( + self: TImageBatch, + size: Union[int, Array, Size], + *args: int, + mode: Union[Sampling, str] = Sampling.LINEAR, + align_corners: Optional[bool] = None, + ) -> TImageBatch: + """Interpolate images on grid with specified size. + + Args: + size: Size of spatial image dimensions, where the size of the last tensor dimension, + which corresponds to the first grid dimension, must be given first, e.g., ``(nx, ny, nz)``. + mode: Image data interpolation mode. + align_corners: Whether to preserve grid extent (False) or corner points (True). + If ``None``, the default of the image sampling grid is used. + + Returns: + Image batch with specified size of spatial dimensions. + + """ + if align_corners is None: + align_corners = self.align_corners() + size = cat_scalars(size, *args, num=self.sdim, device=self.device) + data = U.grid_resize(self, size, mode=mode, align_corners=align_corners) + grid = tuple( + grid.resize(size, align_corners=align_corners) for grid in self._grid + ) + return self._make_instance(data, grid) + + def resample( + self: TImageBatch, + spacing: Union[float, Array, str], + *args: float, + mode: Union[Sampling, str] = Sampling.LINEAR, + ) -> TImageBatch: + """Interpolate images on grid with specified spacing. + + Args: + spacing: Spacing of grid on which to resample image data, where the spacing corresponding + to first grid dimension, which corresponds to the last tensor dimension, must be given + first, e.g., ``(sx, sy, sz)``. Alternatively, can be string 'min' or 'max' to resample + to the minimum or maximum voxel size, respectively. + mode: Image data interpolation mode. + + Returns: + Image batch with given grid spacing. + + """ + in_spacing = self._grid[0].spacing() + if not all( + paddle.allclose(x=grid.spacing(), y=in_spacing).item() + for grid in self._grid + ): + raise AssertionError( + f"{type(self).__name__}.resample() requires all images in batch to have the same grid spacing" + ) + if spacing == "min": + assert not args + out_spacing = in_spacing.min() + elif spacing == "max": + assert not args + out_spacing = in_spacing.max() + else: + out_spacing = spacing + out_spacing = cat_scalars(out_spacing, *args, num=self.sdim, device=self.device) + data = U.grid_resample( + self, in_spacing=in_spacing, out_spacing=out_spacing, mode=mode + ) + grid = tuple(grid.resample(out_spacing) for grid in self._grid) + return self._make_instance(data, grid) + + def avg_pool( + self: TImageBatch, + kernel_size: ScalarOrTuple[int], + stride: Optional[ScalarOrTuple[int]] = None, + padding: ScalarOrTuple[int] = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + ) -> TImageBatch: + """Average pooling of image data.""" + data = U.avg_pool( + self, + kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + grid = tuple( + grid.avg_pool( + kernel_size, stride=stride, padding=padding, ceil_mode=ceil_mode + ) + for grid in self._grid + ) + return self._make_instance(data, grid) + + def downsample( + self: TImageBatch, + levels: int = 1, + dims: Optional[Sequence[SpatialDimArg]] = None, + sigma: Optional[Union[Scalar, Array]] = None, + mode: Optional[Union[Sampling, str]] = None, + min_size: int = 0, + align_corners: Optional[bool] = None, + ) -> TImageBatch: + """Downsample images in batch by halving their size the specified number of times. + + Args: + levels: Number of times the image size is halved (>0) or doubled (<0). + dims: Spatial dimensions along which to downsample. If not specified, consider all spatial dimensions. + sigma: Standard deviation of Gaussian filter applied at each downsampling level. + mode: Image interpolation mode. + align_corners: Whether to preserve grid extent (False) or corner points (True). + If ``None``, the default of the image sampling grid is used. + + Returns: + Batch of downsampled images. + + """ + if not isinstance(levels, int): + raise TypeError( + f"{type(self).__name__}.downsample() 'levels' must be of type int" + ) + if align_corners is None: + align_corners = self.align_corners() + data = U.downsample( + self, + levels, + dims=dims, + sigma=sigma, + mode=mode, + min_size=min_size, + align_corners=align_corners, + ) + grid = tuple( + grid.downsample( + levels, dims=dims, min_size=min_size, align_corners=align_corners + ) + for grid in self._grid + ) + return self._make_instance(data, grid) + + def upsample( + self: TImageBatch, + levels: int = 1, + dims: Optional[Sequence[SpatialDimArg]] = None, + sigma: Optional[Union[Scalar, Array]] = None, + mode: Optional[Union[Sampling, str]] = None, + align_corners: Optional[bool] = None, + ) -> TImageBatch: + """Upsample image in batch by doubling their size the specified number of times. + + Args: + levels: Number of times the image size is doubled (>0) or halved (<0). + dims: Spatial dimensions along which to upsample. If not specified, consider all spatial dimensions. + sigma: Standard deviation of Gaussian filter applied at each downsampling level. + mode: Image interpolation mode. + align_corners: Whether to preserve grid extent (False) or corner points (True). + If ``None``, the default of the image sampling grid is used. + + Returns: + Batch of upsampled images. + + """ + if not isinstance(levels, int): + raise TypeError( + f"{type(self).__name__}.upsample() 'levels' must be of type int" + ) + if align_corners is None: + align_corners = self.align_corners() + data = U.upsample( + self, levels, dims=dims, sigma=sigma, mode=mode, align_corners=align_corners + ) + grid = tuple( + grid.upsample(levels, dims=dims, align_corners=align_corners) + for grid in self._grid + ) + return self._make_instance(data, grid) + + def pyramid( + self: TImageBatch, + levels: int, + start: int = 0, + end: int = -1, + dims: Optional[Sequence[SpatialDimArg]] = None, + sigma: Optional[Union[Scalar, Array]] = None, + mode: Optional[Union[Sampling, str]] = None, + spacing: Optional[float] = None, + min_size: int = 0, + align_corners: Optional[bool] = None, + ) -> Dict[int, TImageBatch]: + """Create Gaussian resolution pyramid. + + Args: + levels: Number of resolution levels. + start: Highest resolution level to return, where 0 corresponds to the finest resolution. + end: Lowest resolution level to return (inclusive). + dims: Spatial dimensions along which to downsample. If not specified, consider all spatial dimensions. + sigma: Standard deviation of Gaussian filter applied at each downsampling level. + mode: Interpolation mode for resampling image data on downsampled grid. + spacing: Grid spacing at finest resolution level. Note that this option may increase the + cube extent of the multi-resolution pyramid sampling grids. + min_size: Minimum grid size. + align_corners: Whether to preserve grid extent (False) or corner points (True). + If ``None``, the default of the image sampling grid is used. + + Returns: + Dictionary of downsampled image batches with keys corresponding to level indices. + + """ + if not isinstance(levels, int): + raise TypeError(f"{type(self).__name__}.pyramid() 'levels' must be int") + if not isinstance(start, int): + raise TypeError(f"{type(self).__name__}.pyramid() 'start' must be int") + if not isinstance(end, int): + raise TypeError(f"{type(self).__name__}.pyramid() 'end' must be int") + if start < 0: + start = levels + start + if start < 0 or start > levels - 1: + raise ValueError( + f"{type(self).__name__}.pyramid() 'start' must be in [{-levels}, {levels - 1}]" + ) + if end < 0: + end = levels + end + if end < 0 or end > levels - 1: + raise ValueError( + f"{type(self).__name__}.pyramid() 'end' must be in [{-levels}, {levels - 1}]" + ) + if start > end: + return {} + if align_corners is None: + align_corners = self.align_corners() + grids = tuple(grid.align_corners(align_corners) for grid in self._grid) + if spacing is not None: + spacing0 = grids[0].spacing() + if not all( + paddle.allclose(x=grid.spacing(), y=spacing0).item() for grid in grids + ): + raise AssertionError( + f"{type(self).__name__}.pyramid() requires all images to have the same grid spacing when output 'spacing' at finest level is specified" + ) + grids = tuple(grid.resample(spacing) for grid in grids) + grids = tuple( + grid.pyramid(levels, dims=dims, min_size=min_size)[0] for grid in grids + ) + assert all(tuple(grid.shape) == tuple(grids[0].shape) for grid in grids) + if paddle.allclose( + x=grids[0].cube_extent(), y=self._grid[0].cube_extent() + ).item(): + size = tuple(grids[0].shape) + data = U.grid_resize(self, size, mode=mode, align_corners=align_corners) + else: + points = grids[0].coords(device=self.device) + data = U.grid_sample(self, points, mode=mode, align_corners=align_corners) + pyramid = {} + batch = self._make_instance(data, grids) + if start == 0: + pyramid[0] = batch + for level in range(1, end + 1): + batch = batch.downsample( + dims=dims, sigma=sigma, mode=mode, min_size=min_size + ) + if level >= start: + pyramid[level] = batch + return pyramid + + def crop( + self: TImageBatch, + margin: Optional[Union[int, Array]] = None, + num: Optional[Union[int, Array]] = None, + mode: Union[PaddingMode, str] = PaddingMode.CONSTANT, + value: Scalar = 0, + ) -> TImageBatch: + """Crop images at boundary. + + Args: + margin: Number of spatial grid points to remove (positive) or add (negative) at each border. + Use instead of ``num`` in order to symmetrically crop the input ``data`` tensor, e.g., + ``(nx, ny, nz)`` is equivalent to ``num=(nx, nx, ny, ny, nz, nz)``. + num: Number of spatial gird points to remove (positive) or add (negative) at each border, + where margin of the last dimension of the ``data`` tensor must be given first, e.g., + ``(nx, nx, ny, ny)``. If a scalar is given, the input is cropped equally at all borders. + Otherwise, the given sequence must have an even length. + mode: Image extrapolation mode in case of negative crop value. + value: Constant value used for extrapolation if ``mode=PaddingMode.CONSTANT``. + + Returns: + Batch of images with modified size, but unchanged spacing. + + """ + data = U.crop(self, margin=margin, num=num, mode=mode, value=value) + grid = tuple(grid.crop(margin=margin, num=num) for grid in self._grid) + return self._make_instance(data, grid) + + def pad( + self: TImageBatch, + margin: Optional[Union[int, Array]] = None, + num: Optional[Union[int, Array]] = None, + mode: Union[PaddingMode, str] = PaddingMode.CONSTANT, + value: Scalar = 0, + ) -> TImageBatch: + """Pad images at boundary. + + Args: + margin: Number of spatial grid points to add (positive) or remove (negative) at each border, + Use instead of ``num`` in order to symmetrically pad the input ``data`` tensor. + num: Number of spatial gird points to add (positive) or remove (negative) at each border, + where margin of the last dimension of the ``data`` tensor must be given first, e.g., + ``(nx, ny, nz)``. If a scalar is given, the input is padded equally at all borders. + Otherwise, the given sequence must have an even length. + mode: Image extrapolation mode in case of positive pad value. + value: Constant value used for extrapolation if ``mode=PaddingMode.CONSTANT``. + + Returns: + Batch of images with modified size, but unchanged spacing. + + """ + data = U.pad(self, margin=margin, num=num, mode=mode, value=value) + grid = tuple(grid.pad(margin=margin, num=num) for grid in self._grid) + return self._make_instance(data, grid) + + def center_crop( + self: TImageBatch, size: Union[int, Array], *args: int + ) -> TImageBatch: + """Crop image to specified maximum size. + + Args: + size: Maximum output size, where the size of the last tensor + dimension must be given first, i.e., ``(X, ...)``. + If an ``int`` is given, all spatial output dimensions + are cropped to this maximum size. If the length of size + is less than the spatial dimensions of the ``data`` tensor, + then only the last ``len(size)`` dimensions are modified. + + Returns: + Batch of cropped images. + + """ + size = cat_scalars(size, *args, num=self.sdim, device=self.device) + data = U.center_crop(self.tensor(), size) + grid = tuple(grid.center_crop(size) for grid in self._grid) + return self._make_instance(data, grid) + + def center_pad( + self: TImageBatch, + size: Union[int, Array], + *args: int, + mode: Union[PaddingMode, str] = PaddingMode.CONSTANT, + value: Scalar = 0, + ) -> TImageBatch: + """Pad image to specified minimum size. + + Args: + size: Minimum output size, where the size of the last tensor + dimension must be given first, i.e., ``(X, ...)``. + If an ``int`` is given, all spatial output dimensions + are cropped to this maximum size. If the length of size + is less than the spatial dimensions of the ``data`` tensor, + then only the last ``len(size)`` dimensions are modified. + mode: Padding mode (cf. ``paddle.nn.functional.pad()``). + value: Value for padding mode "constant". + + Returns: + Batch of images with modified size, but unchanged spacing. + + """ + size = cat_scalars(size, *args, num=self.sdim, device=self.device) + data = U.center_pad(self, size, mode=mode, value=value) + grid = tuple(grid.center_pad(size) for grid in self._grid) + return self._make_instance(data, grid) + + def region_of_interest( + self: TImageBatch, + start: Union[int, Array], + size: Union[int, Array], + padding: Union[PaddingMode, str, float] = PaddingMode.CONSTANT, + value: float = 0, + ) -> TImageBatch: + """Extract region of interest from images in batch. + + Args: + start: Indices of first spatial point to include in region of interest. + size: Size of region of interest. + padding: Image extrapolation mode or fill value. + value: Fill value to use when ``padding=Padding.CONSTANT``. + + Returns: + Image batch of extracted image region of interest. + + """ + data = U.region_of_interest(self, start, size, padding=padding, value=value) + grid = tuple(grid.region_of_interest(start, size) for grid in self._grid) + return self._make_instance(data, grid) + + def conv( + self: TImageBatch, + kernel: Union[paddle.Tensor, Sequence[Optional[paddle.Tensor]]], + padding: Union[PaddingMode, str, int] = None, + ) -> TImageBatch: + """Filter images in batch with a given convolutional kernel. + + Args: + kernel: Weights of kernel used to filter the images in this batch by. + The dtype of the kernel defines the intermediate data type used for convolutions. + If a 1-dimensional kernel is given, it is used as seperable convolution kernel in + all spatial image dimensions. Otherwise, the kernel is applied to the last spatial + image dimensions. For example, a 2D kernel applied to a batch of 3D image volumes + is applied slice-by-slice by convolving along the X and Y image axes. + padding: Image padding mode or margin size. If ``None``, use default mode ``PaddingMode.ZEROS``. + + Returns: + Result of filtering operation with data type set to the image data type before convolution. + If this data type is not a floating point data type, the filtered data is rounded and clamped + before it is being cast to the original dtype. + + """ + data = U.conv(self, kernel, padding=padding) + crop = tuple( + (m - n) // 2 for m, n in zip(self.shape[2:], tuple(data.shape)[2:]) + ) + crop = tuple(reversed(crop)) + grid = tuple(grid.crop(crop) for grid in self._grid) + return self._make_instance(data, grid) + + @overload + def sample( + self: TImageBatch, + grid: Union[Grid, Sequence[Grid]], + mode: Optional[Union[Sampling, str]] = None, + padding: Optional[Union[PaddingMode, str, Scalar]] = None, + ) -> TImageBatch: + """Sample images at optionally deformed unit grid points. + + Args: + grid: Spatial grid points at which to sample image values. + mode: Image interpolation mode. + padding: Image extrapolation mode or scalar padding value. + + Returns: + A new image batch of the resampled data with the given sampling grids. + + """ + ... + + @overload + def sample( + self: TImageBatch, + coords: paddle.Tensor, + mode: Optional[Union[Sampling, str]] = None, + padding: Optional[Union[PaddingMode, str, Scalar]] = None, + ) -> paddle.Tensor: + """Sample images at optionally deformed unit grid points. + + Args: + coords: Normalized coordinates at which to sample image values as tensor of + shape ``(N, ..., D)`` or ``(1, ..., D)``. Note that the shape ``...`` need + not correspond to a (deformed) grid as required by ``grid_sample()``. + It can be an arbitrary shape, e.g., ``M`` to sample at ``M`` given points. + mode: Image interpolation mode. + padding: Image extrapolation mode or scalar padding value. + + Returns: + A tensor of shape (N, C, ...) of sampled image values. + + """ + ... + + def sample( + self: TImageBatch, + arg: Union[Grid, Sequence[Grid], paddle.Tensor], + mode: Optional[Union[Sampling, str]] = None, + padding: Optional[Union[PaddingMode, str, Scalar]] = None, + ) -> Union[TImageBatch, paddle.Tensor]: + """Sample images at optionally deformed unit grid points. + + Args: + arg: Either a single grid which defines the sampling points for all images in the batch, + a different grid for each image in the batch, or a tensor of normalized coordinates + with shape ``(N, ..., D)`` or ``(1, ..., D)``. In the latter case, note that the + shape ``...`` need not correspond to a (deformed) grid as required by ``grid_sample()``. + It can be an arbitrary shape, e.g., ``M`` to sample at ``M`` given points. + mode: Image interpolation mode. + padding: Image extrapolation mode or scalar padding value. + + Returns: + If ``arg`` is of type ``Grid`` or ``Sequence[Grid]``, an ``ImageBatch`` is returned. + When these grids match the grids of this image batch, ``self`` is returned. + Otherwise, a ``paddle.Tensor`` of shape (N, C, ...) of sampled image values is returned. + + """ + data = self.tensor() + align_corners = self.align_corners() + if isinstance(arg, paddle.Tensor): + return U.sample_image( + data, arg, mode=mode, padding=padding, align_corners=align_corners + ) + if isinstance(arg, Grid): + arg = (arg,) + elif not isinstance(arg, Sequence) or any( + not isinstance(item, Grid) for item in arg + ): + raise TypeError( + f"{type(self).__name__}.sample() 'arg' must be Grid, Sequence[Grid], or paddle.Tensor" + ) + if len(arg) not in (1, len(self)): + raise ValueError( + f"{type(self).__name__}.sample() 'arg' must be one or {len(self)} grids" + ) + if all(grid == g for grid, g in zip_longest_repeat_last(arg, self._grid)): + return self + axes = Axes.from_align_corners(align_corners) + coords = [ + grid.coords(align_corners=align_corners, device=self.device) for grid in arg + ] + coords = paddle.concat( + x=[ + grid_transform_points(p, grid, axes, to_grid, axes).unsqueeze(0) + for p, grid, to_grid in zip_longest_repeat_last(coords, arg, self._grid) + ], + axis=0, + ) + data = U.grid_sample( + data, coords, mode=mode, padding=padding, align_corners=align_corners + ) + return self._make_instance(data, arg) + + def __repr__(self) -> str: + return type(self).__name__ + f"(data={self.tensor()!r}, grid={self.grids()!r})" + + def __str__(self) -> str: + return type(self).__name__ + f"(data={self.tensor()!s}, grid={self.grids()!s})" + + +class Image(DataTensor): + """Image sampled on oriented grid.""" + + def __init__( + self: TImage, + data: Array, + grid: Optional[Grid] = None, + dtype: Optional[DType] = None, + device: Optional[Device] = None, + requires_grad: Optional[bool] = None, + pin_memory: bool = False, + ) -> None: + """Initialize image tensor. + + Args: + data: Image data tensor of shape (C, ...X). To create an ``Image`` instance from an + image of a mini-batch without creating a copy of the data, simply provide the + respective slice of the mini-batch corresponding to this image, e.g., ``batch[i]``. + grid: Sampling grid of image ``data`` oriented in world space. + If ``None``, a default grid whose world space coordinate axes are aligned with the + image axes, unit spacing, and origin at the image center is created on CPU. + dtype: Data type of the image data. A copy of the data is only made when the desired + ``dtype`` is not ``None`` and not the same as ``data.dtype``. + device: Device on which to store image data. A copy of the data is only made when + the data has to be copied to a different device. + requires_grad: If autograd should record operations on the returned image tensor. + pin_memory: If set, returned image tensor would be allocated in the pinned memory. + Works only for CPU tensors. + + """ + if self.ndim < 3: + raise ValueError( + "Image tensor must have at least three dimensions (C, H, W)" + ) + self.grid_(grid) + + def _make_instance( + self: TImage, + data: Optional[paddle.Tensor] = None, + grid: Optional[Grid] = None, + **kwargs, + ) -> TImage: + """Create a new instance while preserving subclass meta-data.""" + kwargs["grid"] = self._grid if grid is None else grid + return super()._make_instance(data, **kwargs) + + def __deepcopy__(self: TImage, memo) -> TImage: + if id(self) in memo: + return memo[id(self)] + result = self._make_instance( + self.data.clone(), + grid=self._grid.clone(), + requires_grad=self.requires_grad, + pin_memory=self.is_pinned(), + ) + memo[id(self)] = result + return result + + @staticmethod + def _paddle_function_grid(args) -> Optional[Grid]: + """Get spatial sampling grid from args passed to __paddle_function__.""" + if not args: + return None + if isinstance(args[0], (tuple, list)): + args = args[0] + grids: Sequence[Grid] + grids = [ + g for g in (getattr(arg, "_grid", None) for arg in args) if g is not None + ] + if not grids: + return None + return grids[0] + + @classmethod + def _paddle_function_result(cls, func, data, grid: Optional[Grid]) -> Any: + if not isinstance(data, paddle.Tensor): + return data + if ( + grid is not None + and data.ndim == grid.ndim + 1 + and tuple(data.shape)[1:] == tuple(grid.shape) + ): + if func in (paddle.clone, paddle.Tensor.clone): + grid = grid.clone() + if isinstance(data, cls): + data._grid = grid + else: + data = cls(data, grid) + elif type(data) != paddle.Tensor: + data = data.as_subclass(paddle.Tensor) + return data + + @classmethod + def __paddle_function__(cls, func, types, args=(), kwargs=None): + if func == paddle.nn.functional.grid_sample: + raise ValueError( + "Argument of F.grid_sample() must be a batch, not a single image" + ) + data = paddle.Tensor.__paddle_function__(func, (paddle.Tensor,), args, kwargs) + grid = cls._paddle_function_grid(args) + if func in ( + paddle.split, + paddle.Tensor.split, + paddle.split_with_sizes, + paddle.Tensor.split_with_sizes, + paddle.tensor_split, + paddle.Tensor.tensor_split, + ): + return tuple(cls._paddle_function_result(func, sub, grid) for sub in data) + return cls._paddle_function_result(func, data, grid) + + def batch(self: TImage) -> ImageBatch: + """Image batch consisting of this image only. + + Because batched operations are generally more efficient, especially when executed on the GPU, + most ``Image`` operations are implemented by ``ImageBatch``. The single-image batch instance + property of this ``Image`` instance is used to execute single-image operations of ``self``. + The ``ImageBatch`` uses a view on the tensor data of this ``Image``, as well as the ``Grid`` + object reference. No copies are made. + + """ + data = self.unsqueeze(0) + grid = self._grid + return ImageBatch(data, grid) + + def cube(self: TImage) -> Cube: + """Get cube which defines the normalized coordinates space of the image with respect to the world.""" + return self._grid.cube() + + def domain(self: TImage) -> Domain: + """Get oriented bounding box in world space which defines the domain within which the image is defined.""" + return self._grid.domain() + + def same_domain_as(self, other: Image) -> bool: + """Check if this and another image have the same cube domain.""" + return self.same_domain_as(other.grid()) + + @overload + def grid(self: TImage) -> Grid: + """Get sampling grid.""" + ... + + @overload + def grid(self: TImage, grid: Grid) -> Image: + """Get new image with given sampling grid.""" + ... + + def grid(self: TImage, grid: Optional[Grid] = None) -> Union[Grid, TImage]: + """Get sampling grid or image with given grid, respectively.""" + if grid is None: + return self._grid + return self._make_instance(grid=grid) + + def grid_(self: TImage, grid: Optional[Grid]) -> TImage: + """Change image sampling grid of this image.""" + if grid is None: + grid = Grid(shape=self.shape[1:]) + elif tuple(grid.shape) != tuple(self.shape[1:]): + raise ValueError( + "Image grid size does not match spatial dimensions of image tensor" + ) + self._grid = grid + return self + + def align_corners(self: TImage) -> bool: + """Whether image resizing operations by default preserve corner points or grid extent.""" + return self._grid.align_corners() + + def center(self: TImage) -> paddle.Tensor: + """Image center in world space coordinates as tensor of shape (D,).""" + return self._grid.center() + + def origin(self: TImage) -> paddle.Tensor: + """Image origin in world space coordinates as tensor of shape (D,).""" + return self._grid.origin() + + def spacing(self: TImage) -> paddle.Tensor: + """Image spacing in world units as tensor of shape (D,).""" + return self._grid.spacing() + + def direction(self: TImage) -> paddle.Tensor: + """Image direction cosines matrix as tensor of shape (D, D).""" + return self._grid.direction() + + @property + def sdim(self: TImage) -> int: + """Number of spatial dimensions.""" + return self._grid.ndim + + @property + def nchannels(self: TImage) -> int: + """Number of image channels.""" + return self.shape[0] + + @classmethod + def from_sitk( + cls: Type[TImage], + image: "sitk.Image", + align_corners: bool = ALIGN_CORNERS, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, + ) -> TImage: + """Create image from ``SimpleITK.Image`` instance.""" + if sitk is None: + raise RuntimeError(f"{cls.__name__}.from_sitk() requires SimpleITK") + data = tensor_from_image(image, dtype=dtype, device=device) + grid = Grid.from_sitk(image, align_corners=align_corners) + return cls(data, grid) + + def sitk(self: TImage) -> "sitk.Image": + """Create ``SimpleITK.Image`` from this image.""" + if sitk is None: + raise RuntimeError(f"{type(self).__name__}.sitk() requires SimpleITK") + grid = self._grid + origin = grid.origin().tolist() + spacing = grid.spacing().tolist() + direction = grid.direction().flatten().tolist() + return image_from_tensor( + self, origin=origin, spacing=spacing, direction=direction + ) + + @classmethod + def read( + cls: Type[TImage], + path: PathStr, + align_corners: bool = ALIGN_CORNERS, + dtype: Optional[paddle.dtype] = None, + device: Optional[Device] = None, + ) -> TImage: + """Read image data from file.""" + image = cls._read_sitk(path) + return cls.from_sitk( + image, align_corners=align_corners, dtype=dtype, device=device + ) + + @classmethod + def _read_sitk(cls, path: PathStr) -> "sitk.Image": + """Read SimpleITK.Image from file path.""" + if sitk is None: + raise RuntimeError(f"{cls.__name__}.read() requires SimpleITK") + return read_image(path) + + def write(self: TImage, path: PathStr, compress: bool = True) -> None: + """Write image data to file.""" + if sitk is None: + raise RuntimeError(f"{type(self).__name__}.write() requires SimpleITK") + image = self.sitk() + path = unlink_or_mkdir(path) + sitk.WriteImage(image, str(path), compress) + + def normalize( + self: TImage, + mode: str = "unit", + min: Optional[float] = None, + max: Optional[float] = None, + ) -> TImage: + """Normalize image intensities in [min, max].""" + batch = self.batch() + batch = batch.normalize(mode=mode, min=min, max=max) + return batch[0] + + def normalize_( + self: TImage, + mode: str = "unit", + min: Optional[float] = None, + max: Optional[float] = None, + ) -> TImage: + """Normalize image intensities in [min, max].""" + batch = self.batch() + batch = batch.normalize_(mode=mode, min=min, max=max) + return batch[0] + + def rescale( + self: TImage, + min: Optional[Scalar] = None, + max: Optional[Scalar] = None, + data_min: Optional[Scalar] = None, + data_max: Optional[Scalar] = None, + dtype: Optional[DType] = None, + ) -> TImage: + """Clamp and linearly rescale image values.""" + batch = self.batch() + batch = batch.rescale(min, max, data_min, data_max, dtype=dtype) + return batch[0] + + def narrow(self: TImage, dim: int, start: int, length: int) -> TImage: + """Narrow image along specified dimension.""" + batch = self.batch() + start_29 = batch.shape[dim + 1] + start if start < 0 else start + batch = paddle.slice(batch, [dim + 1], [start_29], [start_29 + length]) + return batch[0] + + def resize( + self: TImage, + size: Union[int, Array, Size], + *args: int, + mode: Union[Sampling, str] = Sampling.LINEAR, + align_corners: Optional[bool] = None, + ) -> TImage: + """Interpolate image with specified spatial image grid size.""" + batch = self.batch() + batch = batch.resize(size, *args, mode=mode, align_corners=align_corners) + return batch[0] + + def resample( + self: TImage, + spacing: Union[float, Array, str], + *args: float, + mode: Union[Sampling, str] = Sampling.LINEAR, + ) -> TImage: + """Interpolate image with specified spacing.""" + batch = self.batch() + batch = batch.resample(spacing, *args, mode=mode) + return batch[0] + + def avg_pool( + self: TImage, + kernel_size: ScalarOrTuple[int], + stride: Optional[ScalarOrTuple[int]] = None, + padding: ScalarOrTuple[int] = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + ) -> TImage: + """Average pooling of image data.""" + batch = self.batch() + batch = batch.avg_pool( + kernel_size, + stride=stride, + padding=padding, + ceil_mode=ceil_mode, + count_include_pad=count_include_pad, + ) + return batch[0] + + def downsample( + self: TImage, + levels: int = 1, + dims: Optional[Sequence[SpatialDimArg]] = None, + sigma: Optional[Union[Scalar, Array]] = None, + mode: Optional[Union[Sampling, str]] = None, + min_size: int = 0, + align_corners: Optional[bool] = None, + ) -> TImage: + """Downsample image a given number of times.""" + batch = self.batch() + batch = batch.downsample( + levels, + dims=dims, + sigma=sigma, + mode=mode, + min_size=min_size, + align_corners=align_corners, + ) + return batch[0] + + def upsample( + self: TImage, + levels: int = 1, + dims: Optional[Sequence[SpatialDimArg]] = None, + sigma: Optional[Union[Scalar, Array]] = None, + mode: Optional[Union[Sampling, str]] = None, + align_corners: Optional[bool] = None, + ) -> TImage: + """Upsample image a given number of times.""" + batch = self.batch() + batch = batch.upsample( + levels, dims=dims, sigma=sigma, mode=mode, align_corners=align_corners + ) + return batch[0] + + def pyramid( + self: TImage, + levels: int, + start: int = 0, + end: int = -1, + dims: Optional[Sequence[SpatialDimArg]] = None, + sigma: Optional[Union[Scalar, Array]] = None, + mode: Optional[Union[Sampling, str]] = None, + spacing: Optional[float] = None, + min_size: int = 0, + align_corners: Optional[bool] = None, + ) -> Dict[int, TImage]: + """Create Gaussian resolution pyramid.""" + batch = self.batch() + batches = batch.pyramid( + levels, + start, + end, + dims=dims, + sigma=sigma, + mode=mode, + spacing=spacing, + min_size=min_size, + align_corners=align_corners, + ) + return {level: batch[0] for level, batch in batches.items()} + + def crop( + self: TImage, + margin: Optional[Union[int, Array]] = None, + num: Optional[Union[int, Array]] = None, + mode: Union[PaddingMode, str] = PaddingMode.CONSTANT, + value: Scalar = 0, + ) -> TImage: + """Crop image at boundary.""" + batch = self.batch() + batch = batch.crop(margin=margin, num=num, mode=mode, value=value) + return batch[0] + + def pad( + self: TImage, + margin: Optional[Union[int, Array]] = None, + num: Optional[Union[int, Array]] = None, + mode: Union[PaddingMode, str] = PaddingMode.CONSTANT, + value: Scalar = 0, + ) -> TImage: + """Pad image at boundary.""" + batch = self.batch() + batch = batch.pad(margin=margin, num=num, mode=mode, value=value) + return batch[0] + + def center_crop(self: TImage, size: Union[int, Array], *args: int) -> TImage: + """Crop image to specified maximum size.""" + batch = self.batch() + batch = batch.center_crop(size, *args) + return batch[0] + + def center_pad( + self: TImage, + size: Union[int, Array], + *args: int, + mode: Union[PaddingMode, str] = PaddingMode.CONSTANT, + value: Scalar = 0, + ) -> TImage: + """Pad image to specified minimum size.""" + batch = self.batch() + batch = batch.center_pad(size, *args, mode=mode, value=value) + return batch[0] + + def region_of_interest(self: TImage, *args, **kwargs) -> TImage: + """Extract image region of interest.""" + batch = self.batch() + batch = batch.region_of_interest(*args, **kwargs) + return batch[0] + + def conv( + self: TImage, + kernel: Union[paddle.Tensor, Sequence[Optional[paddle.paddle.Tensor]]], + padding: Union[PaddingMode, str, int] = None, + ) -> TImage: + """Filter image with a given (separable) kernel.""" + batch = self.batch() + batch = batch.conv(kernel, padding=padding) + return batch[0] + + @overload + def sample( + self: TImage, + coords: paddle.Tensor, + mode: Optional[Union[Sampling, str]] = None, + padding: Optional[Union[PaddingMode, str, Scalar]] = None, + ) -> paddle.Tensor: + """Sample image at optionally deformed unit grid points. + + Note, to sample a set of 2D patches from a 3D volume, it may be beneficial to use ``ImageBatch.sample()`` + instead such that the output tensor shape is ``(N, C, Y, X)`` instead of ``(C, N, Y, X)`` given an + input ``coords`` shape of ``(N, Y, X, D)``. For example, use ``image.batch().sample()``. + + Args: + coords: Normalized coordinates of points at which to sample image as tensor of shape ``(..., D)``. + Typical tensor shapes are: ``(Y, X, D)`` or ``(Z, Y, X, D)`` to sample an image at (deformed) + 2D or 3D grid points (cf. ``grid_sample()``), ``(N, D)`` to sample an image at a set of ``N`` + points, and ``(N, Y, X, D)`` to sample ``N`` 2D patches of size ``(X, Y)`` from a 3D volume. + mode: Interpolation mode. + padding: Extrapolation mode or scalar padding value. + + Returns: + paddle.Tensor of sampled image values with shape ``(C, ...)``, where ``C`` is the number of channels + of this image and ``...`` are the leading dimensions of ``coords`` (i.e., ``coords.shape[:-1]``). + + """ + ... + + @overload + def sample( + self: TImage, + grid: Grid, + mode: Optional[Union[Sampling, str]] = None, + padding: Optional[Union[PaddingMode, str, Scalar]] = None, + ) -> TImage: + """Sample image at optionally deformed unit grid points. + + Args: + grid: Sample this image at the points of the given sampling grid. + mode: Interpolation mode. + padding: Extrapolation mode or scalar padding value. + + Returns: + Image sampled at grid points. + + """ + ... + + def sample( + self: TImage, + arg: Union[Grid, paddle.Tensor], + mode: Optional[Union[Sampling, str]] = None, + padding: Optional[Union[PaddingMode, str, Scalar]] = None, + ) -> Union[paddle.Tensor, TImage]: + """Sample image at points given as normalized coordinates. + + Note, to sample a set of 2D patches from a 3D volume, it may be beneficial to use ``ImageBatch.sample()`` + instead such that the output tensor shape is ``(N, C, Y, X)`` instead of ``(C, N, Y, X)`` given an + input ``arg`` shape of ``(N, Y, X, D)``. For example, use ``image.batch().sample()``. + + Args: + arg: Sampling grid defining points at which to sample image data, or normalized coordinates of + points at which to sample image as tensor of shape ``(..., D)``. Typical tensor shapes are: + ``(Y, X, D)`` or ``(Z, Y, X, D)`` to sample an image at (deformed) 2D or 3D grid points + (cf. ``grid_sample()``), ``(N, D)`` to sample an image at a set of ``N`` points, and + ``(N, Y, X, D)`` to sample ``N`` 2D patches of size ``(X, Y)`` from a 3D volume. + mode: Interpolation mode. + padding: Extrapolation mode or scalar padding value. + + Returns: + If ``arg`` is of type ``Grid``, an ``Image`` with the sampled values and given sampling grid is returend. + When ``arg == self.grid()``, a reference to ``self`` is returned. Otherwise, a ``paddle.Tensor`` of sampled image + values with shape ``(C, ...)`` is returned, where ``C`` is the number of channels of this image and ``...`` + are the leading dimensions of ``grid`` (i.e., ``grid.shape[:-1]``). + + """ + batch = self.batch() + if isinstance(arg, Grid): + if arg == self.grid(): + return self + batch = batch.sample(arg, mode=mode, padding=padding) + assert isinstance(batch, ImageBatch) + return batch[0] + if isinstance(arg, paddle.Tensor): + grid = arg.unsqueeze(0) + data = batch.sample(grid, mode=mode, padding=padding) + assert type(data) is paddle.Tensor + assert tuple(data.shape)[0] == 1 + data = data.squeeze(axis=0) + return data + raise TypeError( + f"{type(self).__name__}.sample() 'arg' must be Grid or paddle.Tensor" + ) + + def __repr__(self) -> str: + return type(self).__name__ + f"(data={self.tensor()!r}, grid={self.grid()!r})" + + def __str__(self) -> str: + return type(self).__name__ + f"(data={self.tensor()!s}, grid={self.grid()!s})" diff --git a/jointContribution/HighResolution/deepali/data/partition.py b/jointContribution/HighResolution/deepali/data/partition.py new file mode 100644 index 0000000000..13816da7c4 --- /dev/null +++ b/jointContribution/HighResolution/deepali/data/partition.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +from enum import Enum +from itertools import accumulate +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +import paddle + +__all__ = "Partition", "dataset_split_lengths", "random_split_indices" + + +class Partition(Enum): + """Enumeration of dataset partitions / splits.""" + + NONE = "none" + TEST = "test" + TRAIN = "train" + VALID = "valid" + + @classmethod + def from_arg(cls, arg: Union[Partition, str, None]) -> Partition: + """Create enumeration value from function argument.""" + if arg is None: + return cls.NONE + return cls(arg) + + +def dataset_split_lengths( + total: int, ratios: Union[float, Sequence[float]] +) -> Tuple[int, int, int]: + """Split dataset in training, validation, and test subset. + + The output ``lengths`` of this function can be passed to ``paddle.utils.data.random_split`` to obtain + the ``paddle.utils.data.dataset.Subset`` for each split. + + Args: + total: Total number of samples in dataset. + ratios: Fraction of samples in each split. When a float or 1-tuple is given, + the specified fraction of samples is used for training and all remaining + samples for validation during training. When a 2-tuple is given, the test + set is assigned no samples. Otherwise, a 3-tuple consisting of ratios + for training, validation, and test set, respectively should be given. + The ratios must sum to one. + + Returns: + lengths: Number of dataset samples in each subset. + + """ + if not isinstance(ratios, float) and len(ratios) == 1: + ratios = ratios[0] + if isinstance(ratios, float): + ratios = ratios, 1.0 - ratios + if len(ratios) == 2: + ratios += (0.0,) + elif len(ratios) != 3: + raise ValueError( + "dataset_split_lengths() 'ratios' must be float or tuple of length 1, 2, or 3" + ) + if ratios[0] <= 0 or ratios[0] > 1: + raise ValueError( + "dataset_split_lengths() training split ratio must be in (0, 1]" + ) + if any([(ratio < 0 or ratio > 1) for ratio in ratios]): + raise ValueError("dataset_split_lengths() ratios must be in [0, 1]") + if sum(ratios) != 1: + raise ValueError("dataset_split_lengths() 'ratios' must sum to one") + lengths = [int(round(ratio * total)) for ratio in ratios] + lengths[2] = max(0, lengths[2] + (total - sum(lengths))) + lengths[1] = max(0, lengths[1] + (total - sum(lengths))) + assert sum(lengths) == total + return tuple(lengths) + + +def random_split_indices( + lengths: Sequence[int], generator: Optional[paddle.Generator] = None +) -> List[List[int]]: + """Randomly split dataset indices into non-overlapping sets of given lengths. + + Args: + lengths: Lengths of splits to be produced. + generator: Generator used for the random permutation. + + Returns: + Lists of specified ``lengths`` with randomly selected indices in ``[0, sum(lengths))``. + + """ + subsets = [] + indices = paddle.randperm(n=sum(lengths)).tolist() + for offset, length in zip(accumulate(lengths), lengths): + subsets.append(indices[offset - length : offset]) + return subsets diff --git a/jointContribution/HighResolution/deepali/data/prepare.py b/jointContribution/HighResolution/deepali/data/prepare.py new file mode 100644 index 0000000000..c73befd46d --- /dev/null +++ b/jointContribution/HighResolution/deepali/data/prepare.py @@ -0,0 +1,71 @@ +from collections import abc +from dataclasses import is_dataclass +from typing import Any +from typing import Mapping +from typing import NamedTuple +from typing import Optional +from typing import Sequence +from typing import Union +from typing import overload + +import paddle + +from ..core.types import Batch +from ..core.types import Dataclass +from ..core.types import Device +from ..core.types import is_namedtuple +from .sample import replace_all_sample_field_values +from .sample import sample_field_names +from .sample import sample_field_value + +__all__ = ("prepare_batch",) + + +@overload +def prepare_batch(batch: Sequence[Mapping[str, Any]]) -> Mapping[str, Any]: + ... + + +@overload +def prepare_batch(batch: Sequence[Dataclass]) -> Dataclass: + ... + + +@overload +def prepare_batch(batch: Sequence[NamedTuple]) -> NamedTuple: + ... + + +def prepare_batch( + batch: Batch, + device: Optional[Union[Device, str]] = None, + non_blocking: bool = False, + memory_format=None, +) -> Batch: + """Move batch data to execution device.""" + names = sample_field_names(batch) + values = [] + for name in names: + value = sample_field_value(batch, name) + value = prepare_item( + value, device=device, non_blocking=non_blocking, memory_format=memory_format + ) + values.append(value) + return replace_all_sample_field_values(batch, values) + + +def prepare_item( + value: Any, + device: Optional[Union[Device, str]] = None, + non_blocking: bool = False, + memory_format=None, +) -> Any: + """Move batch item data to execution device.""" + kwargs = dict(device=device, non_blocking=non_blocking, memory_format=memory_format) + if isinstance(value, paddle.Tensor): + value = value.astype(**kwargs) + elif isinstance(value, abc.Mapping) or is_dataclass(value) or is_namedtuple(value): + value = prepare_batch(value, **kwargs) + elif isinstance(value, Sequence) and not isinstance(value, str): + value = [prepare_item(item, **kwargs) for item in value] + return value diff --git a/jointContribution/HighResolution/deepali/data/sample.py b/jointContribution/HighResolution/deepali/data/sample.py new file mode 100644 index 0000000000..e950b5ecb9 --- /dev/null +++ b/jointContribution/HighResolution/deepali/data/sample.py @@ -0,0 +1,56 @@ +"""Functions for dealing with dataset sample collections.""" +from collections import OrderedDict +from copy import copy as shallowcopy +from dataclasses import fields +from dataclasses import is_dataclass +from typing import Any +from typing import Mapping +from typing import Sequence +from typing import Tuple + +from ..core.types import Sample +from ..core.types import is_namedtuple + +__all__ = ( + "Sample", + "sample_field_names", + "sample_field_value", + "replace_all_sample_field_values", +) + + +def sample_field_names(sample: Sample) -> Tuple[str]: + """Get names of fields in data sample.""" + if is_dataclass(sample): + return tuple(field.name for field in fields(sample)) + if is_namedtuple(sample): + return sample._fields + if not isinstance(sample, Mapping): + raise TypeError("Dataset 'sample' must be dataclass, Mapping, or NamedTuple") + return tuple(sample.keys()) + + +def sample_field_value(sample: Sample, name: str) -> Any: + """Get sample value of named data field.""" + if isinstance(sample, Mapping): + return sample[name] + return getattr(sample, name) + + +def replace_all_sample_field_values(sample: Sample, values: Sequence[Any]) -> Sample: + """Replace all sample field values.""" + names = sample_field_names(sample) + if len(names) != len(values): + raise ValueError( + "replace_all_values() 'values' must contain an entry for every field" + ) + if is_dataclass(sample): + result = shallowcopy(sample) + for name, value in zip(names, values): + setattr(result, name, value) + return result + if is_namedtuple(sample): + return sample._replace(**{name: value for name, value in zip(names, values)}) + if isinstance(sample, OrderedDict): + return OrderedDict([(name, value) for name, value in zip(names, values)]) + return {name: value for name, value in zip(names, values)} diff --git a/jointContribution/HighResolution/deepali/data/sampler.py b/jointContribution/HighResolution/deepali/data/sampler.py new file mode 100644 index 0000000000..823c1d7a9a --- /dev/null +++ b/jointContribution/HighResolution/deepali/data/sampler.py @@ -0,0 +1,131 @@ +import math +import multiprocessing as mp +from typing import Iterator +from typing import Optional +from typing import Sequence + +import paddle +from paddle.io import _T +from paddle.io import Sampler + + +class DistributedWeightedRandomSampler(Sampler[_T]): + """A version of WeightedRandomSampler that can be used with DistributedDataParallel training.""" + + def __init__( + self, + weights: Sequence[float], + num_samples: int = -1, + replacement: bool = True, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + seed: int = 0, + drop_last: bool = False, + ) -> None: + """Initialize map-style dataset index sampler. + + Sampler that restricts data loading to a subset of the random samples. + It is especially useful in conjunction with ``paddle.nn.parallel.DistributedDataParallel``. + In such a case, each process can pass a ``paddle.utils.data.DistributedWeightedRandomSampler`` + instance as a ``paddle.utils.data.DataLoader`` sampler, and load a subset of the + original dataset that is exclusive to it. + + .. note:: + Dataset is assumed to be a map-style dataset of constant size. + + Args: + weights: Sampling weights for each of the ``len(weights)`` indices. + These are passed on to ``paddle.multinomial()`` to sample indices. + num_samples: Number of dataset indices to sample. If negative, use ``len(weights)``. + Note that using all samples would only result in shuffling the dataset if + ``replacement=False``, but no weighted sampling of a subset of it. + replacement: Whether to sample dataset indices with replacement. + num_replicas: Number of processes participating in distributed training. + By default, ``world_size`` is retrieved from the current distributed group. + rank: Rank of the current process within ``num_replicas``. By default, ``rank`` is + retrieved from the current distributed group. + shuffle: If ``True``, the current ``epoch`` is added to the ``seed`` value to draw + different samples at each epoch. Moreover, when sampling without replacement, + the sampler will further shuffle the randomly drawn indices. + seed: Random seed used to draw and shuffle random samples. This number should be identical + across all processes in the distributed group. + drop_last: If ``True``, then the sampler will drop the tail of the data to make it evenly + divisible across the number of replicas. If ``False``, the sampler will add extra indices + to make the data evenly divisible across the replicas. + + .. warning:: + In distributed mode, calling the :meth:``set_epoch`` method at the beginning of each epoch + **before** creating the ``DataLoader`` iterator is necessary to make shuffling work properly + across multiple epochs. Otherwise, the same ordering will always be used. It should further + be noted, that the ``epoch`` is stored in shared memory such that persistent worker processes + all receive an update of the epoch number when ``set_epoch`` is called. + + """ + if num_replicas is None: + if not paddle.distributed.is_available(): + raise RuntimeError( + f"{type(self).__name__}() requires distributed package to be available" + ) + num_replicas = paddle.distributed.get_world_size() + if rank is None: + if not paddle.distributed.is_available(): + raise RuntimeError( + f"{type(self).__name__}() requires distributed package to be available" + ) + rank = paddle.distributed.get_rank() + if rank >= num_replicas or rank < 0: + raise ValueError( + f"{type(self).__name__}() invalid rank {rank}, rank should be in the interval [0, {num_replicas - 1}]" + ) + if num_samples < 1: + num_samples = len(weights) + self.weights = paddle.to_tensor(data=weights, dtype="float32") + self.replacement = replacement + self.num_replicas = num_replicas + self.rank = rank + self.drop_last = drop_last + if self.drop_last and num_samples % self.num_replicas != 0: + self.num_samples = math.ceil( + (num_samples - self.num_replicas) / self.num_replicas + ) + else: + self.num_samples = math.ceil(num_samples / self.num_replicas) + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + self.seed = seed + self._epoch = mp.Value("Q", 0) + if not replacement and self.total_size > len(self.weights): + raise ValueError( + f"{type(self).__name__}() total number of samples is greater than number of available samples to draw without replacement; reduce 'num_samples' and/or enable 'drop_last'" + ) + + @property + def epoch(self) -> int: + return self._epoch.value + + def set_epoch(self, epoch: int) -> None: + """Set training epoch used to adjust seed of number generator when shuffling is enabled.""" + if epoch < 0: + raise ValueError( + f"{type(self).__name__}.set_epoch() 'epoch' must be non-negative" + ) + self._epoch.value = epoch + + def __len__(self) -> int: + return self.num_samples + + def __iter__(self) -> Iterator[paddle.utils.data.sampler.T_co]: + g = paddle.framework.core.default_cpu_generator() + g = g.manual_seed(self.seed + (self.epoch if self.shuffle else 0)) + indices = paddle.multinomial( + x=self.weights, num_samples=self.total_size, replacement=self.replacement + ) + indices = indices.tolist() + if self.shuffle and not self.replacement: + perm = paddle.randperm(n=len(indices)).tolist() + indices = [indices[j] for j in perm] + assert len(indices) == self.total_size + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + return iter(indices) diff --git a/jointContribution/HighResolution/deepali/data/tensor.py b/jointContribution/HighResolution/deepali/data/tensor.py new file mode 100644 index 0000000000..f75da4c1f7 --- /dev/null +++ b/jointContribution/HighResolution/deepali/data/tensor.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from collections import OrderedDict +from typing import Optional +from typing import Type +from typing import TypeVar + +import paddle + +from ..core.tensor import as_tensor +from ..core.types import Array +from ..core.types import Device +from ..core.types import DType + +T = TypeVar("T", bound="DataTensor") +__all__ = ("DataTensor",) + + +class DataTensor(paddle.Tensor): + """Data tensor base class.""" + + def __new__( + cls: Type[T], + data: Array, + *args, + dtype: Optional[DType] = None, + device: Optional[Device] = None, + requires_grad: Optional[bool] = None, + pin_memory: bool = False, + **kwargs, + ) -> T: + data = as_tensor(data, dtype=dtype, device=device) + if requires_grad is None: + requires_grad = not data.stop_gradient + if pin_memory: + data = data.pin_memory() + + instance = super().__new__(cls) + instance = paddle.assign(data, instance) + + if requires_grad: + instance.stop_gradient = False + else: + instance.stop_gradient = True + + return instance + + def __init__( + self: T, + data: Array, + dtype: Optional[DType] = None, + device: Optional[Device] = None, + requires_grad: Optional[bool] = None, + pin_memory: bool = False, + ) -> None: + """Initialize data tensor. + + Args: + data: paddle.Tensor data. + dtype: Data type. A copy of the data is only made when the desired + ``dtype`` is not ``None`` and not the same as ``data.dtype``. + device: Device on which to store the data. A copy of the data is only made when + the data has to be copied to a different device. + requires_grad: If autograd should record operations on the returned data tensor. + pin_memory: If set, returned data tensor would be allocated in the pinned memory. + Works only for CPU tensors. + + """ + ... + + def _make_instance(self: T, data: Optional[paddle.Tensor] = None, **kwargs) -> T: + """Create a new instance while preserving subclass (meta-)data.""" + if data is None: + data = self + if type(data) is not paddle.Tensor: + data = data.as_subclass(paddle.Tensor) + return type(self)(data, **kwargs) + + def __copy__(self: T) -> T: + return self._make_instance() + + def __deepcopy__(self: T, memo) -> T: + if id(self) in memo: + return memo[id(self)] + result = self._make_instance( + self.data.clone(), + requires_grad=self.requires_grad, + pin_memory=self.is_pinned(), + ) + memo[id(self)] = result + return result + + def __reduce_ex__(self, proto): + paddle.utils.hooks.warn_if_has_hooks(self) + args = self.storage(), self.storage_offset(), tuple(self.size()), self.stride() + if self.is_quantized: + args = args + (self.q_scale(), self.q_zero_point()) + args = args + (self.requires_grad, OrderedDict()) + f = ( + paddle._utils._rebuild_qtensor + if self.is_quantized + else paddle._utils._rebuild_tensor_v2 + ) + return _rebuild_from_type, (f, type(self), args, self.__dict__) + + def tensor(self: T) -> paddle.Tensor: + """Convert to plain paddle.Tensor.""" + return self.detach() + + +def _rebuild_from_type(func, type, args, dict): + """Function used by DataTensor.__reduce_ex__ to support unpickling of subclass type.""" + ret: paddle.Tensor = func(*args) + if type is not paddle.Tensor: + ret = ret.as_subclass(type) + ret.__dict__ = dict + return ret diff --git a/jointContribution/HighResolution/deepali/data/transforms/__init__.py b/jointContribution/HighResolution/deepali/data/transforms/__init__.py new file mode 100644 index 0000000000..e550f85493 --- /dev/null +++ b/jointContribution/HighResolution/deepali/data/transforms/__init__.py @@ -0,0 +1,51 @@ +"""The transforms in this Python package generally built on the :mod:`.core` library. +The classes defined by these modules can be used, for example, in a data input pipeline +which is attached to a data loader. The spatial transforms defined in the :mod:`.spatial` +library, on the other hand, can be used to implement either a traditional or machine +learning based co-registration approach. + +Note that data transforms are included in the :mod:`.data` library to avoid cyclical +imports between modules defining specialized tensor types such as :mod:`.data.image` +and datasets defined in :mod:`.data.dataset`, which also use these transforms to read +and preprocess the loaded data. + +Following torchvision's lead, data transform classes which operate on tensors and do not require +lambda functions are derived from ``paddle.nn.Layer``. Use ``paddle.nn.Sequential`` to compose +transforms instead of ``torchvision.transforms.Compose``. This is to support ``paddle.jit.script``. + +""" +from typing import Callable + +from .image import AvgPoolImage +from .image import CastImage +from .image import CenterCropImage +from .image import CenterPadImage +from .image import ClampImage +from .image import ImageToTensor +from .image import NarrowImage +from .image import NormalizeImage +from .image import ReadImage +from .image import ResampleImage +from .image import RescaleImage +from .image import ResizeImage +from .item import ItemTransform +from .item import ItemwiseTransform + +Transform = Callable +__all__ = ( + "Transform", + "ItemTransform", + "ItemwiseTransform", + "AvgPoolImage", + "CastImage", + "CenterCropImage", + "CenterPadImage", + "ClampImage", + "ImageToTensor", + "NarrowImage", + "NormalizeImage", + "ReadImage", + "ResampleImage", + "RescaleImage", + "ResizeImage", +) diff --git a/jointContribution/HighResolution/deepali/data/transforms/image.py b/jointContribution/HighResolution/deepali/data/transforms/image.py new file mode 100644 index 0000000000..781fc9af73 --- /dev/null +++ b/jointContribution/HighResolution/deepali/data/transforms/image.py @@ -0,0 +1,506 @@ +from pathlib import Path +from typing import List +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Union + +import paddle + +from ...core.enum import PaddingMode +from ...core.enum import Sampling +from ...core.types import PathStr +from ...core.types import ScalarOrTuple +from ..image import Image +from .item import ItemTransform +from .item import ItemwiseTransform + +__all__ = ( + "AvgPoolImage", + "CastImage", + "CenterCropImage", + "CenterPadImage", + "ClampImage", + "ImageToTensor", + "NarrowImage", + "NormalizeImage", + "ReadImage", + "ResampleImage", + "RescaleImage", + "ResizeImage", + "ImageTransformConfig", + "config_has_read_image_transform", + "prepend_read_image_transform", + "image_transform", + "image_transforms", +) + + +class AvgPoolImage(ItemwiseTransform, paddle.nn.Layer): + """Downsample image using average pooling.""" + + def __init__( + self, + kernel_size: ScalarOrTuple[int], + stride: Optional[ScalarOrTuple[int]] = None, + padding: ScalarOrTuple[int] = 0, + ceil_mode: bool = False, + count_include_pad: bool = True, + ) -> None: + super().__init__() + self.kernel_size = kernel_size + self.stride = stride + self.padding = padding + self.ceil_mode = ceil_mode + self.count_include_pad = count_include_pad + + def forward(self, image: Image) -> Image: + if not isinstance(image, Image): + raise TypeError(f"{type(self).__name__}.forward() argument must be Image") + return image.avg_pool( + self.kernel_size, + stride=self.stride, + padding=self.padding, + ceil_mode=self.ceil_mode, + count_include_pad=self.count_include_pad, + ) + + def __repr__(self) -> str: + return ( + type(self).__name__ + + f"(kernel_size={self.kernel_size!r}," + + f" stride={self.stride!r}," + + f" padding={self.padding!r}," + + f" ceil_mode={self.ceil_mode}," + + f" count_include_pad={self.count_include_pad})" + ) + + +class CastImage(ItemwiseTransform, paddle.nn.Layer): + """Cast image data to specified type.""" + + def __init__(self, dtype: Union[paddle.dtype, str]) -> None: + super().__init__() + if isinstance(dtype, str): + attr = dtype + dtype = getattr(paddle, attr, None) + if dtype is None: + raise ValueError( + f"{type(self).__name__}() module paddle has no 'dtype' named paddle.{attr}" + ) + if not isinstance(dtype, paddle.dtype): + raise ValueError( + f"{type(self).__name__}() paddle.{attr} is not a paddle.dtype name" + ) + elif not isinstance(dtype, paddle.dtype): + raise TypeError( + f"{type(self).__name__}() 'dtype' must by dtype name or paddle.dtype" + ) + self.dtype = dtype + + def forward(self, image: Image) -> Image: + if not isinstance(image, Image): + raise TypeError(f"{type(self).__name__}.forward() argument must be Image") + return image.astype(self.dtype) + + def __repr__(self) -> str: + return type(self).__name__ + f"(dtype={self.dtype!r})" + + +class CenterCropImage(ItemwiseTransform, paddle.nn.Layer): + """Crop image to specified maximum output size.""" + + def __init__(self, size: Union[int, Sequence[int]]) -> None: + super().__init__() + self.size = size + + def forward(self, image: Image) -> Image: + if not isinstance(image, Image): + raise TypeError(f"{type(self).__name__}.forward() argument must be Image") + return image.center_crop(self.size) + + def __repr__(self) -> str: + return type(self).__name__ + f"(size={self.size!r})" + + +class CenterPadImage(ItemwiseTransform, paddle.nn.Layer): + """Pad image to specified minimum output size.""" + + def __init__( + self, + size: Union[int, Sequence[int]], + mode: Union[PaddingMode, str] = PaddingMode.CONSTANT, + value: float = 0, + ) -> None: + super().__init__() + self.size = size + self.mode = PaddingMode(mode) + self.value = float(value) + + def forward(self, image: Image) -> Image: + if not isinstance(image, Image): + raise TypeError(f"{type(self).__name__}.forward() argument must be Image") + return image.center_pad(self.size, mode=self.mode, value=self.value) + + def __repr__(self) -> str: + return ( + type(self).__name__ + + f"(size={self.size!r}, mode={self.mode.value!r}, value={self.value!r})" + ) + + +class ClampImage(ItemwiseTransform, paddle.nn.Layer): + """Clamp image intensities to specified minimum and/or maximum value.""" + + def __init__( + self, + min: Optional[float] = None, + max: Optional[float] = None, + inplace: bool = False, + ) -> None: + super().__init__() + self.min = min + self.max = max + self.inplace = bool(inplace) + + def forward(self, image: Image) -> Image: + if not isinstance(image, Image): + raise TypeError(f"{type(self).__name__}.forward() argument must be Image") + clamp_fn = image.clamp_ if self.inplace else image.clamp + image = clamp_fn(self.min, self.max) + return image + + def __repr__(self) -> str: + return ( + type(self).__name__ + + f"(min={self.min!r}, max={self.max!r}, inplace={self.inplace!r})" + ) + + +class ImageToTensor(ItemwiseTransform, paddle.nn.Layer): + """Convert image to data tensor.""" + + def forward(self, image: Image) -> paddle.Tensor: + if not isinstance(image, Image): + raise TypeError(f"{type(self).__name__}.forward() argument must be Image") + return image.tensor() + + def __repr__(self) -> str: + return type(self).__name__ + "()" + + +class NarrowImage(ItemwiseTransform, paddle.nn.Layer): + """Return image with data tensor narrowed along specified dimension.""" + + def __init__(self, dim: int, start: int, length: int = 1) -> None: + super().__init__() + if dim != 0: + raise NotImplementedError( + "NarrowImage() 'dim' must be zero at the moment.Extend implementation to adjust image.grid() for other image dimensions." + ) + self.dim = dim + self.start = start + self.length = length + + def forward(self, image: Image) -> paddle.Tensor: + if not isinstance(image, Image): + raise TypeError(f"{type(self).__name__}.forward() argument must be Image") + start_26 = image.shape[self.dim] + self.start if self.start < 0 else self.start + image = paddle.slice(image, [self.dim], [start_26], [start_26 + self.length]) + return image + + def __repr__(self) -> str: + return ( + type(self).__name__ + + f"(dim={self.dim}, start={self.start}, length={self.length})" + ) + + +class NormalizeImage(ItemwiseTransform, paddle.nn.Layer): + """Normalize and clamp image intensities in [min, max].""" + + def __init__( + self, + min: Optional[float] = None, + max: Optional[float] = None, + mode: str = "unit", + inplace: bool = False, + ) -> None: + super().__init__() + if mode not in ("center", "unit", "zscore", "z-score"): + raise ValueError( + "NormalizeImage() 'mode' must be 'center', 'unit', or 'zscore'" + ) + self.min = min + self.max = max + self.mode = mode + self.inplace = inplace + + def forward(self, image: Image) -> Image: + if not isinstance(image, Image): + raise TypeError(f"{type(self).__name__}.forward() argument must be Image") + normalize_fn = image.normalize_ if self.inplace else image.normalize + return normalize_fn(mode=self.mode, min=self.min, max=self.max) + + def __repr__(self) -> str: + return ( + type(self).__name__ + + f"(mode={self.mode!r}, min={self.min!r}, max={self.max!r}, inplace={self.inplace!r})" + ) + + +class ReadImage(ItemwiseTransform, paddle.nn.Layer): + """Read image data from file path.""" + + def __init__( + self, + dtype: Optional[Union[paddle.dtype, str]] = None, + device: Optional[Union[str, str]] = None, + ) -> None: + super().__init__() + if isinstance(dtype, str): + attr = dtype + dtype = getattr(paddle, attr, None) + if dtype is None: + raise ValueError( + f"ReadImage() module paddle has no 'dtype' named paddle.{attr}" + ) + if dtype is not None and not isinstance(dtype, paddle.dtype): + raise TypeError("ReadImage() 'dtype' must by None or paddle.dtype") + self.dtype = dtype + self.device = str(device or "cpu").replace("cuda", "gpu") + + def forward(self, path: PathStr) -> Image: + if not isinstance(path, (str, Path)): + raise TypeError(f"{type(self).__name__}() 'path' must be Path or str") + image = Image.read(path, dtype=self.dtype, device=self.device) + return image + + def __repr__(self) -> str: + return type(self).__name__ + f"(dtype={self.dtype}, device='{self.device!s}')" + + +class ResampleImage(ItemwiseTransform, paddle.nn.Layer): + """Resample image to specified voxel size.""" + + def __init__( + self, + spacing: Union[float, Sequence[float], str], + mode: Union[Sampling, str] = Sampling.LINEAR, + ) -> None: + super().__init__() + self.spacing = spacing + self.mode = Sampling(mode) + + def forward(self, image: Image) -> Image: + if not isinstance(image, Image): + raise TypeError(f"{type(self).__name__}.forward() argument must be Image") + image = image.resample(self.spacing, mode=self.mode) + return image + + def __repr__(self) -> str: + return ( + type(self).__name__ + + f"(spacing={self.spacing!r}, mode={self.mode.value!r})" + ) + + +class RescaleImage(ItemwiseTransform, paddle.nn.Layer): + """Linearly rescale image data.""" + + def __init__( + self, + min: Optional[float] = None, + max: Optional[float] = None, + mul: Optional[float] = None, + add: Optional[float] = None, + ) -> None: + super().__init__() + if mul is not None or add is not None: + if min is not None or max is not None: + raise ValueError( + "RescaleImage() 'min'/'max' and 'add'/'mul' are mutually exclusive" + ) + self.min = None + self.max = None + self.mul = 1 if mul is None else float(mul) + self.add = 0 if add is None else float(add) + else: + self.min = min + self.max = max + self.mul = None + self.add = None + + def forward(self, image: Image) -> Image: + if not isinstance(image, Image): + raise TypeError(f"{type(self).__name__}.forward() argument must be Image") + if self.mul is not None or self.add is not None: + assert self.min is None and self.max is None + if self.mul != 1: + image = image.mul(self.mul) + if self.add != 0: + image = image.add(self.add) + else: + image = image.rescale(min=self.min, max=self.max) + return image + + def __repr__(self) -> str: + s = type(self).__name__ + "(" + if self.mul is not None or self.add is not None: + s += f"mul={self.mul!r}, add={self.add!r}" + else: + s += f"min={self.min!r}, max={self.max!r}" + return s + + +class ResizeImage(ItemwiseTransform, paddle.nn.Layer): + """Resample image to specified image size.""" + + def __init__( + self, + size: Union[int, Sequence[int]], + mode: Union[Sampling, str] = Sampling.LINEAR, + ) -> None: + super().__init__() + self.size = size + self.mode = Sampling(mode) + + def forward(self, image: Image) -> Image: + if not isinstance(image, Image): + raise TypeError(f"{type(self).__name__}.forward() argument must be Image") + image = image.resize(self.size, mode=self.mode) + return image + + def __repr__(self) -> str: + return type(self).__name__ + f"(size={self.size!r}, mode={self.mode.value!r})" + + +ImageTransformMapping = Mapping[str, Union[Sequence, Mapping]] +ImageTransformConfig = Union[ + str, ImageTransformMapping, Sequence[Union[str, ImageTransformMapping]] +] +IMAGE_TRANSFORM_TYPES = { + "avgpool": AvgPoolImage, + "cast": CastImage, + "centercrop": CenterCropImage, + "centerpad": CenterPadImage, + "clamp": ClampImage, + "narrow": NarrowImage, + "normalize": NormalizeImage, + "read": ReadImage, + "rescale": RescaleImage, + "resample": ResampleImage, + "resize": ResizeImage, +} +INPLACE_IMAGE_TRANSFORMS = {"clamp"} + + +def config_has_read_image_transform(config: ImageTransformConfig) -> bool: + """Whether image data transformation configuration contains a "read" image transform.""" + if isinstance(config, str): + return config.lower() == "read" + if isinstance(config, Mapping): + for name in config: + if name.lower() == "read": + return True + return False + if isinstance(config, Sequence): + for item in config: + if isinstance(item, Sequence) and not isinstance(item, str): + raise ValueError( + "config_has_read_image_transform() 'config' Sequence cannot be nested" + ) + if config_has_read_image_transform(item): + return True + return False + raise TypeError( + "config_has_read_image_transform() 'config' must be str, Mapping, or Sequence" + ) + + +def prepend_read_image_transform( + config: ImageTransformConfig, + dtype: Optional[str] = None, + device: Optional[str] = None, +) -> ImageTransformConfig: + """Insert a "read" image transform before any other image data transform.""" + if config_has_read_image_transform(config): + return config + read_transform_config = {"read": dict(dtype=dtype, device=device)} + if isinstance(config, str): + return [read_transform_config, config] + if isinstance(config, Mapping): + return {**read_transform_config, **config} + if isinstance(config, Sequence): + return [read_transform_config] + list(config) + raise TypeError( + "prepend_read_image_transform() 'config' must be str, Mapping, or Sequence" + ) + + +def image_transform( + name: str, *args, key: Optional[str] = None, inplace: bool = True, **kwargs +) -> ItemwiseTransform: + """Create image data transform given its name.""" + cls = IMAGE_TRANSFORM_TYPES.get(name.lower()) + if cls is None: + raise ValueError(f"image_transform() unknown image data transform '{name}'") + if name in INPLACE_IMAGE_TRANSFORMS: + kwargs["inplace"] = inplace + transform = cls(*args, **kwargs) + if key: + transform = ItemTransform(transform, key=key) + return transform + + +def image_transforms( + config: ImageTransformConfig, key: Optional[str] = None +) -> List[paddle.nn.Layer]: + """Create image data transforms from configuration. + + A sequence of image transforms can be configured using names: + + - ``avgpool``: :class:`.AvgPoolImage` + - ``cast``: :class:`.CastImage` + - ``centercrop``: :class:`.CenterCropImage` + - ``centerpad``: :class:`.CenterPadImage` + - ``clamp``: :class:`.ClampImage` + - ``narrow``: :class:`.NarrowImage` + - ``normalize``: :class:`.NormalizeImage` + - ``read``: :class:`.ReadImage` + - ``rescale``: :class:`.RescaleImage` + - ``resample``: :class:`.ResampleImage` + - ``resize``: :class:`.ResizeImage` + + Parameters can be passed to the data transform as keyword arguments, e.g., + + .. code-block:: yaml + + transforms: + - read: {dtype: float32} + - avgpool: {kernel_size: 2} + - normalize: {min: 0, max: 255, mode: unit} + + """ + transforms = [] + if isinstance(config, str): + transforms.append(image_transform(config, key=key)) + elif isinstance(config, Mapping): + for name, value in config.items(): + if value is None: + transform = image_transform(name, key=key) + elif isinstance(value, (list, tuple)): + transform = image_transform(name, *value, key=key) + elif isinstance(value, Mapping): + transform = image_transform(name, key=key, **value) + else: + transform = image_transform(name, value, key=key) + transforms.append(transform) + elif isinstance(config, Sequence): + for item in config: + if isinstance(item, Sequence) and not isinstance(item, str): + raise ValueError("image_transform() 'config' Sequence cannot be nested") + transforms.extend(image_transforms(item, key=key)) + else: + raise TypeError("image_transforms() 'config' must be str, Mapping, or Sequence") + return transforms diff --git a/jointContribution/HighResolution/deepali/data/transforms/item.py b/jointContribution/HighResolution/deepali/data/transforms/item.py new file mode 100644 index 0000000000..10629b09af --- /dev/null +++ b/jointContribution/HighResolution/deepali/data/transforms/item.py @@ -0,0 +1,275 @@ +from collections.abc import KeysView +from copy import copy as shallowcopy +from copy import deepcopy +from dataclasses import fields +from dataclasses import is_dataclass +from typing import Any +from typing import Callable +from typing import Iterable +from typing import Mapping +from typing import Optional +from typing import Union + +import paddle + +from ...core.types import RE_OUTPUT_KEY_INDEX +from ...core.types import is_namedtuple + +__all__ = "ItemTransform", "ItemwiseTransform" + + +class ItemTransform(paddle.nn.Layer): + """Transform only specified item/field of dict, named tuple, tuple, list, or dataclass.""" + + def __init__( + self, + transform: Callable, + key: Optional[Union[int, str, KeysView, Iterable[Union[int, str]]]] = None, + copy: bool = False, + ignore_meta: bool = True, + ignore_missing: bool = False, + ): + """Initialize item transformation. + + Args: + transform: Item value transformation. + key: Index, key, or field name of item to transform. If ``None``, empty string, or 'all', + apply ``transform`` to all items in the input ``data`` object. Can be a nested + key with dot ('.') character as subkey delimiters, and contain indices to access + list or tuple items. For example, "a.b.c", "a.b.2", "a.b.c[1]", "a[1][0]", + "a[0].c", and "a.0.c", are all valid keys as long as the data to transform has + the appropriate entries and (nested) container types. + copy: Whether to copy all input items. By default, a shallow copy of the input ``data`` + is made and only the item specified by ``key`` is replaced by its transformed value. + If ``True``, a deep copy of all input ``data`` items is made. + ignore_missing: Whether to skip processing of items with value ``None``. + ignore_meta: Preserve 'meta' dictionary key value as is (cf. ``MetaDataset``). + + """ + super().__init__() + self.transform = transform + if not isinstance(key, str) and isinstance(key, (KeysView, Iterable)): + key = set(key) + self.key = key + self.copy = copy + self.ignore_meta = ignore_meta + self.ignore_missing = ignore_missing + + def forward(self, data: Any) -> Any: + """Apply transformation. + + Args: + data: Input value, dict, tuple, list, or dataclass. + + Returns: + Copy of ``data`` with value of ``self.key`` replaced by its transformed value. + If ``self.key is None``, all values in the input ``data`` are transformed. + By default, a shallow copy of ``data`` is made unless ``self.copy == True``. + + """ + if self.key in (None, "", "all", "ALL"): + return self._apply_all(data) + if isinstance(self.key, set): + keys = self.key + elif not isinstance(self.key, str) and isinstance( + self.key, (KeysView, Iterable) + ): + keys = set(self.key) + else: + keys = [self.key] + for key in keys: + if isinstance(key, int): + data = self._apply_index(data, key) + elif isinstance(key, str): + key = RE_OUTPUT_KEY_INDEX.sub(".\\1", key) + data = self._apply_key(data, key) + else: + raise TypeError( + f"{type(self).__name__}() 'key' must be None, int, or str" + ) + return data + + def _apply_all(self, data: Any) -> Any: + """Transform all leaf items.""" + if is_dataclass(data): + output = shallowcopy(data) + for field in fields(data): + value = getattr(data, field.name) + if not self.ignore_meta or field.name != "meta": + value = self._apply_all(value) + setattr(output, field.name, value) + elif isinstance(data, Mapping): + output = {} + for k, v in data.items(): + if not self.ignore_meta or k != "meta": + v = self._apply_all(v) + output[k] = v + elif isinstance(data, tuple): + output = tuple(self._apply_all(d) for d in data) + if is_namedtuple(data): + output = type(data)(*output) + elif isinstance(data, list): + output = list(self._apply_all(d) for d in data) + elif data is None: + if self.ignore_missing: + output = None + else: + raise ValueError( + f"{type(self).__name__}() value is None (ignore_missing=False)" + ) + else: + output = self.transform(data) + return output + + def _apply_index(self, data: Any, index: int) -> Any: + """Transform item at specified index.""" + if not isinstance(data, (list, tuple)): + raise TypeError( + f"{type(self).__name__}() 'data' must be list or tuple when key is int" + ) + try: + item = data[index] + except IndexError: + raise IndexError( + f"{type(self).__name__}() 'data' sequence must have item at index {index}" + ) + item = self.transform(item) + args = (item if i == index else self._maybe_copy(v) for i, v in enumerate(data)) + if is_namedtuple(data): + return type(data)(*args) + return type(data)(args) + + def _apply_key(self, data: Any, key: str, prefix: str = "") -> Any: + """Transform specified item in data map.""" + parts = key.split(".", 1) + if len(parts) == 1: + parts = [parts[0], ""] + key, subkey = parts + if not key: + raise KeyError( + f"{type(self).__name__}() 'key' must not be empty (prefix={prefix!r}, key={key!r}, subkey={subkey!r})" + ) + index = None + item = None + if is_dataclass(data) or is_namedtuple(data): + try: + item = getattr(data, key) + except AttributeError: + if prefix: + msg = f"{type(self).__name__}() 'data' entry {prefix!r} has no attribute named {key}" + else: + msg = f"{type(self).__name__}() 'data' has no attribute named {key}" + raise AttributeError(msg) + elif isinstance(data, (list, tuple)): + try: + index = int(key) + except (TypeError, ValueError): + if prefix: + msg = f"{type(self).__name__}() 'data' entry {prefix!r} is list or tuple, but key is no index" + else: + msg = f"{type(self).__name__}() 'data' is list or tuple, but key is no index" + raise AttributeError(msg) + try: + item = data[index] + except IndexError: + if prefix: + msg = f"{type(self).__name__}() 'data' entry {prefix!r} index {index} is out of bounds" + else: + msg = ( + f"{type(self).__name__}() 'data' index {index} is out of bounds" + ) + raise IndexError(msg) + elif isinstance(data, Mapping): + try: + item = data[key] + except KeyError: + try: + index = int(key) + item = data[index] + key = index + except (IndexError, KeyError, TypeError, ValueError): + if prefix: + msg = f"{type(self).__name__}() 'data' dict {prefix!r} must have key {key!r}" + else: + msg = ( + f"{type(self).__name__}() 'data' dict must have key {key!r}" + ) + raise KeyError(msg) + else: + if prefix: + msg = f"{type(self).__name__}() 'data' entry {prefix!r} must be list, tuple, dict, dataclass, or namedtuple" + else: + msg = f"{type(self).__name__}() 'data' must be list, tuple, dict, dataclass, or namedtuple" + raise TypeError(msg) + if subkey: + item = self._apply_key( + item, subkey, prefix=prefix + "." + key if prefix else key + ) + else: + item = self._apply_all(item) + if is_dataclass(data): + if self.copy: + args = ( + item if field.name == key else deepcopy(getattr(data, field.name)) + for field in fields(data) + ) + data = type(data)(*args) + else: + setattr(data, key, item) + elif is_namedtuple(data): + if self.copy: + args = ( + item if k == key else deepcopy(v) + for k, v in zip(data._fields, data) + ) + data = type(data)(*args) + else: + data = data._replace(**{key: item}) + elif isinstance(data, tuple): + assert index is not None + data = tuple( + item if i == index else self._maybe_copy(v) for i, v in enumerate(data) + ) + elif isinstance(data, (list, tuple)): + assert index is not None + data = list( + item if i == index else self._maybe_copy(v) for i, v in enumerate(data) + ) + else: + assert isinstance(data, Mapping) + data = { + k: (item if k == key else self._maybe_copy(v)) for k, v in data.items() + } + return data + + def _maybe_copy(self, data: Any) -> Any: + if self.copy: + return deepcopy(data) + return data + + def __repr__(self) -> str: + return ( + type(self).__name__ + + f"({self.transform!r}, key={self.key!r}, copy={self.copy!r})" + ) + + +class ItemwiseTransform: + """Mix-in for data preprocessing and augmentation transforms.""" + + @classmethod + def item( + cls, + key: Union[int, str, KeysView, Iterable[Union[int, str]]], + *args, + ignore_meta: bool = True, + ignore_missing: bool = False, + **kwargs, + ) -> ItemTransform: + """Apply transform to specified item only.""" + return ItemTransform( + cls(*args, **kwargs), + key=key, + ignore_meta=ignore_meta, + ignore_missing=ignore_missing, + ) diff --git a/jointContribution/HighResolution/deepali/losses/__init__.py b/jointContribution/HighResolution/deepali/losses/__init__.py new file mode 100644 index 0000000000..459352071b --- /dev/null +++ b/jointContribution/HighResolution/deepali/losses/__init__.py @@ -0,0 +1,162 @@ +import sys +from typing import Any + +import paddle + +from .base import BSplineLoss +from .base import DisplacementLoss +from .base import NormalizedPairwiseImageLoss +from .base import PairwiseImageLoss +from .base import ParamsLoss +from .base import PointSetDistance +from .base import RegistrationLoss +from .base import RegistrationLosses +from .base import RegistrationResult +from .bspline import BSplineBending +from .bspline import BSplineBendingEnergy +from .flow import BE +from .flow import TV +from .flow import Bending +from .flow import BendingEnergy +from .flow import Curvature +from .flow import Diffusion +from .flow import Divergence +from .flow import Elasticity +from .flow import TotalVariation +from .image import DSC +from .image import LCC +from .image import LNCC +from .image import MAE +from .image import MI +from .image import MSE +from .image import NMI +from .image import SLCC +from .image import SSD +from .image import WLCC +from .image import Dice +from .image import HuberImageLoss +from .image import L1ImageLoss +from .image import L2ImageLoss +from .image import PatchLoss +from .image import PatchwiseImageLoss +from .image import SmoothL1ImageLoss +from .params import L1_Norm +from .params import L1Norm +from .params import L2_Norm +from .params import L2Norm +from .params import Sparsity +from .pointset import CPD +from .pointset import LPD +from .pointset import ClosestPointDistance +from .pointset import LandmarkPointDistance + +__all__ = ( + "BSplineLoss", + "DisplacementLoss", + "NormalizedPairwiseImageLoss", + "PairwiseImageLoss", + "PatchwiseImageLoss", + "PatchLoss", + "ParamsLoss", + "PointSetDistance", + "RegistrationLoss", + "RegistrationLosses", + "RegistrationResult", + "is_pairwise_image_loss", + "is_displacement_loss", + "is_pointset_distance", + "BE", + "Bending", + "BendingEnergy", + "BSplineBending", + "BSplineBendingEnergy", + "ClosestPointDistance", + "CPD", + "Curvature", + "Dice", + "Diffusion", + "Divergence", + "DSC", + "Elasticity", + "HuberImageLoss", + "L1ImageLoss", + "L2ImageLoss", + "L1Norm", + "L1_Norm", + "L2Norm", + "L2_Norm", + "LandmarkPointDistance", + "LPD", + "LCC", + "LNCC", + "MAE", + "MI", + "MSE", + "NMI", + "SLCC", + "SmoothL1ImageLoss", + "Sparsity", + "SSD", + "TotalVariation", + "TV", + "WLCC", + "create_loss", + "new_loss", +) + + +def is_pairwise_image_loss(arg: Any) -> bool: + """Check if given argument is name or instance of pairwise image loss.""" + return is_loss_of_type(PairwiseImageLoss, arg) + + +def is_displacement_loss(arg: Any) -> bool: + """Check if given argument is name or instance of displacement field loss.""" + return is_loss_of_type(DisplacementLoss, arg) + + +def is_pointset_distance(arg: Any) -> bool: + """Check if given argument is name or instance of point set distance.""" + return is_loss_of_type(PointSetDistance, arg) + + +def is_loss_of_type(base, arg: Any) -> bool: + """Check if given argument is name or instance of pairwise image loss.""" + cls = None + if isinstance(arg, str): + cls = getattr(sys.modules[__name__], arg, None) + elif type(arg) is type: + cls = arg + elif arg is not None: + cls = type(arg) + if cls is not None: + bases = list(cls.__bases__) + while bases: + b = bases.pop() + if b is base: + return True + bases.extend(b.__bases__) + return False + + +def new_loss(name: str, *args, **kwargs) -> paddle.nn.Layer: + """Initialize new loss module. + + Args: + name: Name of loss type. + args: Loss arguments. + kwargs: Loss keyword arguments. + + Returns: + New loss module. + + """ + cls = getattr(sys.modules[__name__], name, None) + if cls is None: + raise ValueError(f"new_loss() unknown loss {name}") + if cls is paddle.nn.Layer or not issubclass(cls, paddle.nn.Layer): + raise TypeError(f"new_loss() '{name}' is not a subclass of paddle.nn.Layer") + return cls(*args, **kwargs) + + +create_loss = new_loss diff --git a/jointContribution/HighResolution/deepali/losses/base.py b/jointContribution/HighResolution/deepali/losses/base.py new file mode 100644 index 0000000000..33d44ae4f0 --- /dev/null +++ b/jointContribution/HighResolution/deepali/losses/base.py @@ -0,0 +1,218 @@ +from abc import ABCMeta +from abc import abstractmethod +from collections import OrderedDict +from typing import Any +from typing import Dict +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Union + +import paddle + +from ..core.math import max_difference +from ..core.types import ScalarOrTuple + +RegistrationResult = Dict[str, Any] +RegistrationLosses = Union[ + paddle.nn.Layer, + paddle.nn.LayerDict, + paddle.nn.LayerList, + Mapping[str, paddle.nn.Layer], + Sequence[paddle.nn.Layer], +] + + +class RegistrationLoss(paddle.nn.Layer, metaclass=ABCMeta): + """Base class of registration loss functions. + + A registration loss function, also referred to as energy function, is an objective function + to be minimized by an optimization routine. In particular, these energy functions are used inside + the main loop which performs individual gradient steps using an instance of ``paddle.optim.Optimizer``. + + Registration loss consists of one or more loss terms, which may be either one of: + - A pairwise data term measuring the alignment of a single input pair (e.g., images, point sets, surfaces). + - A groupwise data term measuring the alignment of two or more inputs. + - A regularization term penalizing certain dense spatial deformations. + - Other regularization terms based on spatial transformation parameters. + + """ + + @staticmethod + def as_module_dict( + arg: Optional[RegistrationLosses], start: int = 0 + ) -> paddle.nn.LayerDict: + """Convert argument to ``ModuleDict``.""" + if arg is None: + return paddle.nn.LayerDict() + if isinstance(arg, paddle.nn.LayerDict): + return arg + if isinstance(arg, paddle.nn.Layer): + arg = [arg] + if isinstance(arg, dict): + modules = arg + else: + modules = OrderedDict( + [ + (f"loss_{i + start}", m) + for i, m in enumerate(arg) + if isinstance(m, paddle.nn.Layer) + ] + ) + return paddle.nn.LayerDict(sublayers=modules) + + @abstractmethod + def eval(self) -> RegistrationResult: + """Evaluate registration loss. + + Returns: + Dictionary of current registration result. The entries in the dictionary depend on the + respective registration loss function used, but must include at a minimum the total + scalar "loss" value. + + """ + raise NotImplementedError(f"{type(self).__name__}.eval()") + + def forward(self) -> paddle.Tensor: + """Evaluate registration loss.""" + result = self.eval() + if not isinstance(result, dict): + raise TypeError(f"{type(self).__name__}.eval() must return a dictionary") + if "loss" not in result: + raise ValueError( + f"{type(self).__name__}.eval() result must contain key 'loss'" + ) + loss = result["loss"] + if not isinstance(loss, paddle.Tensor): + raise TypeError( + f"{type(self).__name__}.eval() result 'loss' must be tensor" + ) + return loss + + +class PairwiseImageLoss(paddle.nn.Layer, metaclass=ABCMeta): + """Base class of pairwise image dissimilarity criteria.""" + + @abstractmethod + def forward( + self, + source: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + """Evaluate image dissimilarity loss.""" + raise NotImplementedError(f"{type(self).__name__}.forward()") + + +class NormalizedPairwiseImageLoss(PairwiseImageLoss): + """Base class of pairwise image dissimilarity criteria with implicit input normalization.""" + + def __init__( + self, + source: Optional[paddle.Tensor] = None, + target: Optional[paddle.Tensor] = None, + norm: Optional[Union[bool, paddle.Tensor]] = None, + ): + """Initialize similarity metric. + + Args: + source: Source image from which to compute ``norm``. If ``None``, only use ``target`` if specified. + target: Target image from which to compute ``norm``. If ``None``, only use ``source`` if specified. + norm: Positive factor by which to divide loss. If ``None`` or ``True``, use ``source`` and/or ``target``. + If ``False`` or both ``source`` and ``target`` are ``None``, a normalization factor of one is used. + + """ + super().__init__() + if norm is True: + norm = None + if norm is None: + if target is None: + target = source + elif source is None: + source = target + if source is not None and target is not None: + norm = max_difference(source, target).square() + elif norm is False: + norm = None + assert norm is None or isinstance(norm, (float, int, paddle.Tensor)) + self.norm = norm + + def extra_repr(self) -> str: + s = "" + norm = self.norm + if isinstance(norm, paddle.Tensor) and norm.size != 1: + s += f"norm={self.norm!r}" + elif norm is not None: + s += f"norm={float(norm):.5f}" + return s + + +class DisplacementLoss(paddle.nn.Layer, metaclass=ABCMeta): + """Base class of regularization terms based on dense displacements.""" + + @abstractmethod + def forward(self, u: paddle.Tensor) -> paddle.Tensor: + """Evaluate regularization loss for given transformation.""" + raise NotImplementedError(f"{type(self).__name__}.forward()") + + +class BSplineLoss(paddle.nn.Layer, metaclass=ABCMeta): + """Base class of loss terms based on cubic B-spline deformation coefficients.""" + + def __init__(self, stride: ScalarOrTuple[int] = 1, reduction: str = "mean"): + """Initialize regularization term. + + Args: + stride: Number of points between control points at which to evaluate bending energy, plus one. + If a sequence of values is given, these must be the strides for the different spatial + dimensions in the order ``(sx, ...)``. A stride of 1 is equivalent to evaluating bending + energy only at the usually coarser resolution of the control point grid. It should be noted + that the stride need not match the stride used to densely sample the spline deformation field + at a given fixed target image resolution. + reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. + + """ + super().__init__() + self.stride = stride + self.reduction = reduction + + @abstractmethod + def forward( + self, params: paddle.Tensor, stride: ScalarOrTuple[int] = 1 + ) -> paddle.Tensor: + """Evaluate loss term for given free-form deformation parameters.""" + raise NotImplementedError(f"{type(self).__name__}.forward()") + + def extra_repr(self) -> str: + return f"stride={self.stride!r}, reduction={self.reduction!r}" + + +class PointSetDistance(paddle.nn.Layer, metaclass=ABCMeta): + """Base class of point set distance terms.""" + + @abstractmethod + def forward(self, x: paddle.Tensor, *ys: paddle.Tensor) -> paddle.Tensor: + """Evaluate point set distance. + + Note that some point set distance measures require a 1-to-1 correspondence + between the two input point sets, and thus ``M == N``. Other distance losses + may compute correspondences themselves, e.g., based on closest points. + + Args: + x: paddle.Tensor of shape ``(M, X, D)`` with points of (transformed) target point set. + ys: Tensors of shape ``(N, Y, D)`` with points of other point sets. + + Returns: + Point set distance. + + """ + raise NotImplementedError(f"{type(self).__name__}.forward()") + + +class ParamsLoss(paddle.nn.Layer, metaclass=ABCMeta): + """Regularization loss based on model parameters.""" + + @abstractmethod + def forward(self, params: paddle.Tensor) -> paddle.Tensor: + """Evaluate loss term for given model parameters.""" + raise NotImplementedError(f"{type(self).__name__}.forward()") diff --git a/jointContribution/HighResolution/deepali/losses/bspline.py b/jointContribution/HighResolution/deepali/losses/bspline.py new file mode 100644 index 0000000000..9d91a4c95c --- /dev/null +++ b/jointContribution/HighResolution/deepali/losses/bspline.py @@ -0,0 +1,17 @@ +import paddle + +from . import functional as L +from .base import BSplineLoss + + +class BSplineBending(BSplineLoss): + """Bending energy of cubic B-spline free form deformation.""" + + def forward(self, params: paddle.Tensor) -> paddle.Tensor: + """Evaluate loss term for given free form deformation parameters.""" + return L.bspline_bending_loss( + params, stride=self.stride, reduction=self.reduction + ) + + +BSplineBendingEnergy = BSplineBending diff --git a/jointContribution/HighResolution/deepali/losses/flow.py b/jointContribution/HighResolution/deepali/losses/flow.py new file mode 100644 index 0000000000..bbcf35f723 --- /dev/null +++ b/jointContribution/HighResolution/deepali/losses/flow.py @@ -0,0 +1,214 @@ +from __future__ import annotations + +from typing import Optional +from typing import Union + +import paddle + +from ..core.types import Shape +from . import functional as L +from .base import DisplacementLoss + + +class _SpatialDerivativesLoss(DisplacementLoss): + """Base class of regularization terms based on spatial derivatives of dense displacements.""" + + def __init__( + self, + mode: str = "central", + sigma: Optional[float] = None, + reduction: str = "mean", + ): + """Initialize regularization term. + + Args: + mode: Method used to approximate spatial derivatives. See ``spatial_derivatives()``. + sigma: Standard deviation of Gaussian in grid units. See ``spatial_derivatives()``. + reduction: Operation to use for reducing spatially distributed loss values. + + """ + super().__init__() + self.mode = mode + self.sigma = float(0 if sigma is None else sigma) + self.reduction = reduction + + def _spacing(self, u_shape: Shape) -> Optional[paddle.Tensor]: + ndim = len(u_shape) + if ndim < 3: + raise ValueError( + f"{type(self).__name__}.forward() 'u' must be at least 3-dimensional" + ) + if ndim == 3: + return None + size = paddle.to_tensor( + data=u_shape[-1:1:-1], + dtype="float32", + place=str("cpu").replace("cuda", "gpu"), + ) + return 2 / (size - 1) + + def extra_repr(self) -> str: + return f"mode={self.mode!r}, sigma={self.sigma!r}, reduction={self.reduction!r}" + + +class GradLoss(_SpatialDerivativesLoss): + """Displacement field gradient loss.""" + + def __init__( + self, + p: Union[int, float] = 2, + q: Optional[Union[int, float]] = 1, + mode: str = "central", + sigma: Optional[float] = None, + reduction: str = "mean", + ): + """Initialize regularization term. + + Args: + mode: Method used to approximate spatial derivatives. See ``spatial_derivatives()``. + sigma: Standard deviation of Gaussian in grid units. See ``spatial_derivatives()``. + reduction: Operation to use for reducing spatially distributed loss values. + + """ + super().__init__(mode=mode, sigma=sigma, reduction=reduction) + self.p = p + self.q = q + + def forward(self, u: paddle.Tensor) -> paddle.Tensor: + """Evaluate regularization loss for given transformation.""" + spacing = self._spacing(tuple(u.shape)) + return L.grad_loss( + u, + p=self.p, + q=self.q, + spacing=spacing, + mode=self.mode, + sigma=self.sigma, + reduction=self.reduction, + ) + + def extra_repr(self) -> str: + return f"p={self.p}, q={self.q}, " + super().extra_repr() + + +class Bending(_SpatialDerivativesLoss): + """Bending energy of displacement field.""" + + def forward(self, u: paddle.Tensor) -> paddle.Tensor: + """Evaluate regularization loss for given transformation.""" + spacing = self._spacing(tuple(u.shape)) + return L.bending_loss( + u, + spacing=spacing, + mode=self.mode, + sigma=self.sigma, + reduction=self.reduction, + ) + + +BendingEnergy = Bending +BE = Bending + + +class Curvature(_SpatialDerivativesLoss): + """Curvature of displacement field.""" + + def forward(self, u: paddle.Tensor) -> paddle.Tensor: + """Evaluate regularization loss for given transformation.""" + spacing = self._spacing(tuple(u.shape)) + return L.curvature_loss( + u, + spacing=spacing, + mode=self.mode, + sigma=self.sigma, + reduction=self.reduction, + ) + + +class Diffusion(_SpatialDerivativesLoss): + """Diffusion of displacement field.""" + + def forward(self, u: paddle.Tensor) -> paddle.Tensor: + """Evaluate regularization loss for given transformation.""" + spacing = self._spacing(tuple(u.shape)) + return L.diffusion_loss( + u, + spacing=spacing, + mode=self.mode, + sigma=self.sigma, + reduction=self.reduction, + ) + + +class Divergence(_SpatialDerivativesLoss): + """Divergence of displacement field.""" + + def forward(self, u: paddle.Tensor) -> paddle.Tensor: + """Evaluate regularization loss for given transformation.""" + spacing = self._spacing(tuple(u.shape)) + return L.divergence_loss( + u, + spacing=spacing, + mode=self.mode, + sigma=self.sigma, + reduction=self.reduction, + ) + + +class Elasticity(_SpatialDerivativesLoss): + """Linear elasticity of displacement field.""" + + def __init__( + self, + material_name: Optional[str] = None, + first_parameter: Optional[float] = None, + second_parameter: Optional[float] = None, + poissons_ratio: Optional[float] = None, + youngs_modulus: Optional[float] = None, + shear_modulus: Optional[float] = None, + mode: str = "central", + sigma: Optional[float] = None, + reduction: str = "mean", + ): + super().__init__(mode=mode, sigma=sigma, reduction=reduction) + self.material_name = material_name + self.first_parameter = first_parameter + self.second_parameter = second_parameter + self.poissons_ratio = poissons_ratio + self.youngs_modulus = youngs_modulus + self.shear_modulus = shear_modulus + + def forward(self, u: paddle.Tensor) -> paddle.Tensor: + """Evaluate regularization loss for given transformation.""" + spacing = self._spacing(tuple(u.shape)) + return L.elasticity_loss( + u, + spacing=spacing, + mode=self.mode, + sigma=self.sigma, + reduction=self.reduction, + material_name=self.material_name, + first_parameter=self.first_parameter, + second_parameter=self.second_parameter, + poissons_ratio=self.poissons_ratio, + youngs_modulus=self.youngs_modulus, + shear_modulus=self.shear_modulus, + ) + + +class TotalVariation(_SpatialDerivativesLoss): + """Total variation of displacement field.""" + + def forward(self, u: paddle.Tensor) -> paddle.Tensor: + """Evaluate regularization loss for given transformation.""" + spacing = self._spacing(tuple(u.shape)) + return L.total_variation_loss( + u, + spacing=spacing, + mode=self.mode, + sigma=self.sigma, + reduction=self.reduction, + ) + + +TV = TotalVariation diff --git a/jointContribution/HighResolution/deepali/losses/functional.py b/jointContribution/HighResolution/deepali/losses/functional.py new file mode 100644 index 0000000000..74c7674280 --- /dev/null +++ b/jointContribution/HighResolution/deepali/losses/functional.py @@ -0,0 +1,1791 @@ +import itertools +import math +from typing import Optional +from typing import Protocol +from typing import Sequence +from typing import Tuple +from typing import Union + +import paddle +from paddle.nn.functional import binary_cross_entropy_with_logits + +from ..core.bspline import evaluate_cubic_bspline +from ..core.enum import SpatialDerivativeKeys +from ..core.enum import SpatialDim +from ..core.flow import denormalize_flow +from ..core.grid import Grid +from ..core.image import avg_pool +from ..core.image import dot_channels +from ..core.image import rand_sample +from ..core.image import spatial_derivatives +from ..core.pointset import transform_grid +from ..core.pointset import transform_points +from ..core.tensor import as_one_hot_tensor +from ..core.tensor import move_dim +from ..core.types import Array +from ..core.types import ScalarOrTuple +from ..utils import paddle_aux + +__all__ = ( + "balanced_binary_cross_entropy_with_logits", + "binary_cross_entropy_with_logits", + "label_smoothing", + "dice_score", + "dice_loss", + "kld_loss", + "lcc_loss", + "mae_loss", + "mse_loss", + "ssd_loss", + "mi_loss", + "grad_loss", + "bending_loss", + "bending_energy", + "be_loss", + "bspline_bending_loss", + "bspline_bending_energy", + "bspline_be_loss", + "curvature_loss", + "diffusion_loss", + "divergence_loss", + "elasticity_loss", + "focal_loss_with_logits", + "total_variation_loss", + "tv_loss", + "tversky_index", + "tversky_index_with_logits", + "tversky_loss", + "tversky_loss_with_logits", + "inverse_consistency_loss", + "masked_loss", + "reduce_loss", + "wlcc_loss", +) + + +class ElementwiseLoss(Protocol): + """Type annotation of a eleemntwise loss function.""" + + def __call__( + self, input: paddle.Tensor, target: paddle.Tensor, reduction: str = "mean" + ) -> paddle.Tensor: + ... + + +def label_smoothing( + labels: paddle.Tensor, + num_classes: Optional[int] = None, + ignore_index: Optional[int] = None, + alpha: float = 0.1, +) -> paddle.Tensor: + """Apply label smoothing to target labels. + + Implements label smoothing as proposed by Muller et al., (2019) in https://arxiv.org/abs/1906.02629v2. + + Args: + labels: Scalar target labels or one-hot encoded class probabilities. + num_classes: Number of target labels. If ``None``, use maximum value in ``target`` plus 1 + when a scalar label map is given. + ignore_index: Ignore index to be kept during the expansion. The locations of the index + value in the labels image is stored in the corresponding locations across all channels + so that this location can be ignored across all channels later, e.g. in Dice computation. + This argument must be ``None`` if ``labels`` has ``C == num_channels``. + alpha: Label smoothing factor in [0, 1]. If zero, no label smoothing is done. + + Returns: + Multi-channel tensor of smoothed target class probabilities. + + """ + if not isinstance(labels, paddle.Tensor): + raise TypeError("label_smoothing() 'labels' must be paddle.Tensor") + if labels.ndim < 4: + raise ValueError( + "label_smoothing() 'labels' must be tensor of shape (N, C, ..., X)" + ) + if tuple(labels.shape)[1] == 1: + target = as_one_hot_tensor( + labels, num_classes, ignore_index=ignore_index, dtype="float32" + ) + else: + target = labels.astype(dtype="float32") + if alpha > 0: + target = (1 - alpha) * target + alpha * (1 - target) / (target.shape[1] - 1) + return target + + +def balanced_binary_cross_entropy_with_logits( + logits: paddle.Tensor, + target: paddle.Tensor, + weight: Optional[paddle.Tensor] = None, + reduction: str = "mean", +) -> paddle.Tensor: + """Balanced binary cross entropy. + + Shruti Jadon (2020) A survey of loss functions for semantic segmentation. + https://arxiv.org/abs/2006.14822 + + Args: + logits: Logits of binary predictions as tensor of shape ``(N, 1, ..., X)``. + target: Target label probabilities as tensor of shape ``(N, 1, ..., X)``. + weight: Voxelwise label weight tensor of shape ``(N, 1, ..., X)``. + reduction: Either ``none``, ``mean``, or ``sum``. + + Returns: + Balanced binary cross entropy (bBCE). If ``reduction="none"``, the returned tensor has shape + ``(N, 1, ..., X)`` with bBCE for each element. Otherwise, it is reduced into a scalar. + + """ + if logits.ndim < 3 or tuple(logits.shape)[1] != 1: + raise ValueError( + "balanced_binary_cross_entropy_with_logits() 'logits' must have shape (N, 1, ..., X)" + ) + if target.ndim < 3 or tuple(target.shape)[1] != 1: + raise ValueError( + "balanced_binary_cross_entropy_with_logits() 'target' must have shape (N, 1, ..., X)" + ) + if tuple(logits.shape)[0] != tuple(target.shape)[0]: + raise ValueError( + "balanced_binary_cross_entropy_with_logits() 'logits' and 'target' must have matching batch size" + ) + neg_weight = ( + target.flatten(start_axis=1) + .mean(axis=-1) + .reshape((-1,) + (1,) * (target.ndim - 1)) + ) + pos_weight = 1 - neg_weight + log_y_pred: paddle.Tensor = paddle.nn.functional.log_sigmoid(x=logits) + loss_pos = -log_y_pred.mul(target) + loss_neg = logits.sub(log_y_pred).mul(1 - target) + loss = loss_pos.mul(pos_weight).add(loss_neg.mul(neg_weight)) + loss = masked_loss( + loss, weight, "balanced_binary_cross_entropy_with_logits", inplace=True + ) + loss = reduce_loss(loss, reduction=reduction) + return loss + + +def dice_score( + input: paddle.Tensor, + target: paddle.Tensor, + weight: Optional[paddle.Tensor] = None, + epsilon: float = 1e-15, + reduction: str = "mean", +) -> paddle.Tensor: + """Soft Dice similarity of multi-channel classification result. + + Args: + input: Normalized logits of binary predictions as tensor of shape ``(N, C, ..., X)``. + target: Target label probabilities as tensor of shape ``(N, C, ..., X)``. + weight: Voxelwise label weight tensor of shape ``(N, C, ..., X)``. + epsilon: Small constant used to avoid division by zero. + reduction: Either ``none``, ``mean``, or ``sum``. + + Returns: + Dice similarity coefficient (DSC). If ``reduction="none"``, the returned tensor has shape + ``(N, C)`` with DSC for each batch. Otherwise, the DSC scores are reduced into a scalar. + + """ + if not isinstance(input, paddle.Tensor): + raise TypeError("dice_score() 'input' must be paddle.Tensor") + if not isinstance(target, paddle.Tensor): + raise TypeError("dice_score() 'target' must be paddle.Tensor") + if input.dim() < 3: + raise ValueError("dice_score() 'input' must be tensor of shape (N, C, ..., X)") + if tuple(input.shape) != tuple(target.shape): + raise ValueError("dice_score() 'input' and 'target' must have identical shape") + y_pred = input.astype(dtype="float32") + y = target.astype(dtype="float32") + intersection = dot_channels(y_pred, y, weight=weight) + denominator = dot_channels(y_pred, y_pred, weight=weight) + dot_channels( + y, y, weight=weight + ) + loss = ( + intersection.multiply_(y=paddle.to_tensor(2)) + .add_(y=paddle.to_tensor(epsilon)) + .div(denominator.add_(y=paddle.to_tensor(epsilon))) + ) + loss = reduce_loss(loss, reduction) + return loss + + +def dice_loss( + input: paddle.Tensor, + target: paddle.Tensor, + weight: Optional[paddle.Tensor] = None, + epsilon: float = 1e-15, + reduction: str = "mean", +) -> paddle.Tensor: + """One minus soft Dice similarity of multi-channel classification result. + + Args: + input: Normalized logits of binary predictions as tensor of shape ``(N, C, ..., X)``. + target: Target label probabilities as tensor of shape ``(N, C, ..., X)``. + weight: Voxelwise label weight tensor of shape ``(N, C, ..., X)``. + epsilon: Small constant used to avoid division by zero. + reduction: Either ``none``, ``mean``, or ``sum``. + + Returns: + One minus Dice similarity coefficient (DSC). If ``reduction="none"``, the returned tensor has shape + ``(N, C)`` with DSC for each batch. Otherwise, the DSC scores are reduced into a scalar. + + """ + dsc = dice_score(input, target, weight=weight, epsilon=epsilon, reduction="none") + loss = reduce_loss(1 - dsc, reduction) + return loss + + +def tversky_index( + input: paddle.Tensor, + target: paddle.Tensor, + weight: Optional[paddle.Tensor] = None, + alpha: Optional[float] = None, + beta: Optional[float] = None, + epsilon: float = 1e-15, + normalize: bool = False, + binarize: bool = False, + reduction: str = "mean", +) -> paddle.Tensor: + """Tversky index as described in https://arxiv.org/abs/1706.05721. + + Args: + input: Binary predictions as tensor of shape ``(N, 1, ..., X)`` + or multi-class prediction tensor of shape ``(N, C, ..., X)``. + target: Target labels as tensor of shape ``(N, ..., X)``, binary classification target + of shape ``(N, 1, ..., X)``, or one-hot encoded tensor of shape ``(N, C, ..., X)``. + weight: Voxelwise label weight tensor of shape ``(N, ..., X)`` or ``(N, 1|C, ..., X)``.. + alpha: False positives multiplier. Set to ``1 - beta`` if ``None``. + beta: False negatives multiplier. + epsilon: Constant used to avoid division by zero. + normalize: Whether to normalize ``input`` predictions using ``sigmoid`` or ``softmax``. + binarize: Whether to round normalized predictions to 0 or 1, respectively. If ``False``, + soft normalized predictions (and target scores) are used. In order for the Tversky + index to be identical to Dice, this option must be set to ``True`` and ``alpha=beta=0.5``. + reduction: Either ``none``, ``mean``, or ``sum``. + + Returns: + Tversky index (TI). If ``reduction="none"``, the returned tensor has shape ``(N, C)`` + with TI for each batch. Otherwise, the TI values are reduced into a scalar. + + """ + if alpha is None and beta is None: + alpha = beta = 0.5 + elif alpha is None: + alpha = 1 - beta + elif beta is None: + beta = 1 - alpha + if not isinstance(input, paddle.Tensor): + raise TypeError("tversky_index() 'input' must be paddle.Tensor") + if not isinstance(target, paddle.Tensor): + raise TypeError("tversky_index() 'target' must be paddle.Tensor") + if input.ndim < 3 or tuple(input.shape)[1] < 1: + raise ValueError( + "tversky_index() 'input' must be have shape (N, 1, ..., X) or (N, C>1, ..., X)" + ) + if target.ndim < 2 or tuple(target.shape)[1] < 1: + raise ValueError( + "tversky_index() 'target' must be have shape (N, ..., X), (N, 1, ..., X), or (N, C>1, ..., X)" + ) + if tuple(target.shape)[0] != tuple(input.shape)[0]: + raise ValueError( + f"tversky_index() 'input' and 'target' batch size must be identical, got {tuple(input.shape)[0]} != {tuple(target.shape)[0]}" + ) + input = input.astype(dtype="float32") + if tuple(input.shape)[1] == 1: + y_pred = input.sigmoid() if normalize else input + else: + y_pred = paddle.nn.functional.softmax(input, axis=1) if normalize else input + if binarize: + y_pred = y_pred.round() + num_classes = max(2, tuple(y_pred.shape)[1]) + if target.ndim == input.ndim: + y = target.astype(y_pred.dtype) + if tuple(target.shape)[1] == 1: + if num_classes > 2: + raise ValueError( + f"tversky_index() 'target' has shape (N, 1, ..., X), but 'input' is multi-class prediction (C={num_classes})" + ) + if tuple(y_pred.shape)[1] == 2: + start_20 = y_pred.shape[1] + 1 if 1 < 0 else 1 + y_pred = paddle.slice(y_pred, [1], [start_20], [start_20 + 1]) + else: + if tuple(target.shape)[1] != num_classes: + raise ValueError( + f"tversky_index() 'target' has shape (N, C, ..., X), but C does not match 'input' with C={num_classes}" + ) + if tuple(y_pred.shape)[1] == 1: + start_21 = y.shape[1] + 1 if 1 < 0 else 1 + y = paddle.slice(y, [1], [start_21], [start_21 + 1]) + if binarize: + y = y.round() + elif target.ndim + 1 == y_pred.ndim: + if num_classes == 2 and tuple(y_pred.shape)[1] == 1: + y = ( + target.unsqueeze(axis=1) + .greater_equal(y=paddle.to_tensor(0.5)) + .astype(y_pred.dtype) + ) + if binarize: + y = y.round() + else: + y = as_one_hot_tensor(target, num_classes, dtype=y_pred.dtype) + else: + raise ValueError( + "tversky_index() 'target' must be tensor of shape (N, ..., X) or (N, C, ... X)" + ) + if tuple(y.shape) != tuple(y_pred.shape): + raise ValueError( + "tversky_index() 'input' and 'target' shapes must be compatible" + ) + if weight is not None: + if weight.ndim + 1 == y.ndim: + weight = weight.unsqueeze(axis=1) + if weight.ndim != y.ndim: + raise ValueError( + "tversky_index() 'weight' shape must be (N, ..., X) or (N, C, ..., X)" + ) + if tuple(weight.shape)[0] != tuple(target.shape)[0]: + raise ValueError( + "tversky_index() 'weight' batch size does not match 'input' and 'target'" + ) + if tuple(weight.shape)[1] == 1: + weight = weight.repeat((1,) + (num_classes,) + (1,) * (weight.ndim - 2)) + if tuple(weight.shape) != tuple(y.shape): + raise ValueError( + "tversky_index() 'weight' shape must be compatible with 'input' and 'target'" + ) + intersection = dot_channels(y_pred, y, weight=weight) + fps = dot_channels(y_pred, 1 - y, weight=weight).mul_(alpha) + fns = dot_channels(1 - y_pred, y, weight=weight).mul_(beta) + numerator = intersection.add_(y=paddle.to_tensor(epsilon)) + denominator = numerator.add(fps).add(fns) + loss = numerator.div(denominator) + loss = reduce_loss(loss, reduction) + return loss + + +def tversky_index_with_logits( + logits: paddle.Tensor, + target: paddle.Tensor, + weight: Optional[paddle.Tensor] = None, + alpha: Optional[float] = None, + beta: Optional[float] = None, + epsilon: float = 1e-15, + binarize: bool = False, + reduction: str = "mean", +) -> paddle.Tensor: + """Tversky index as described in https://arxiv.org/abs/1706.05721. + + Args: + logits: Binary prediction logits as tensor of shape ``(N, 1, ..., X)``. + target: Target labels as tensor of shape ``(N, ..., X)`` or ``(N, 1, ..., X)``. + weight: Voxelwise label weight tensor of shape ``(N, ..., X)`` or ``(N, 1, ..., X)``. + alpha: False positives multiplier. Set to ``1 - beta`` if ``None``. + beta: False negatives multiplier. + epsilon: Constant used to avoid division by zero. + normalize: Whether to normalize ``input`` predictions using ``sigmoid`` or ``softmax``. + binarize: Whether to round normalized predictions to 0 or 1, respectively. If ``False``, + soft normalized predictions (and target scores) are used. In order for the Tversky + index to be identical to Dice, this option must be set to ``True`` and ``alpha=beta=0.5``. + reduction: Either ``none``, ``mean``, or ``sum``. + + Returns: + Tversky index (TI). If ``reduction="none"``, the returned tensor has shape ``(N, 1)`` + with TI for each batch. Otherwise, the TI values are reduced into a scalar. + + """ + return tversky_index( + logits, + target, + weight=weight, + alpha=alpha, + beta=beta, + epsilon=epsilon, + normalize=True, + binarize=binarize, + reduction=reduction, + ) + + +def tversky_loss( + input: paddle.Tensor, + target: paddle.Tensor, + weight: Optional[paddle.Tensor] = None, + alpha: Optional[float] = None, + beta: Optional[float] = None, + gamma: Optional[float] = None, + epsilon: float = 1e-15, + normalize: bool = False, + binarize: bool = False, + reduction: str = "mean", +) -> paddle.Tensor: + """Tversky loss as described in https://arxiv.org/abs/1706.05721. + + Args: + input: Normalized logits of binary predictions as tensor of shape ``(N, C, ..., X)``. + target: Target label probabilities as tensor of shape ``(N, C, ..., X)``. + weight: Voxelwise label weight tensor of shape ``(N, ..., X)`` or ``(N, 1, ..., X)``. + alpha: False positives multiplier. Set to ``1 - beta`` if ``None``. + beta: False negatives multiplier. + gamma: Exponent used for focal Tverksy loss. + epsilon: Constant used to avoid division by zero. + normalize: Whether to normalize ``input`` predictions using ``sigmoid`` or ``softmax``. + binarize: Whether to round normalized predictions to 0 or 1, respectively. If ``False``, + soft normalized predictions (and target scores) are used. In order for the Tversky + index to be identical to Dice, this option must be set to ``True`` and ``alpha=beta=0.5``. + reduction: Either ``none``, ``mean``, or ``sum``. + + Returns: + One minus Tversky index (TI) to the power of gamma. If ``reduction="none"``, the returned + tensor has shape ``(N, C)`` with the loss for each batch. Otherwise, a scalar is returned. + + """ + ti = tversky_index( + input, + target, + weight=weight, + alpha=alpha, + beta=beta, + gamma=gamma, + epsilon=epsilon, + normalize=normalize, + binarize=binarize, + reduction="none", + ) + one = paddle.to_tensor(data=1, dtype=ti.dtype, place=ti.place) + loss = one.sub(ti) + if gamma: + if gamma > 1: + loss = loss.pow_(y=gamma) + elif gamma < 1: + raise ValueError( + "tversky_loss() 'gamma' must be greater than or equal to 1" + ) + loss = reduce_loss(loss, reduction) + return loss + + +def tversky_loss_with_logits( + logits: paddle.Tensor, + target: paddle.Tensor, + weight: Optional[paddle.Tensor] = None, + alpha: Optional[float] = None, + beta: Optional[float] = None, + gamma: Optional[float] = None, + epsilon: float = 1e-15, + binarize: bool = False, + reduction: str = "mean", +) -> paddle.Tensor: + """Tversky loss as described in https://arxiv.org/abs/1706.05721. + + Args: + logits: Binary prediction logits as tensor of shape ``(N, 1, ..., X)``. + target: Target labels as tensor of shape ``(N, ..., X)`` or ``(N, 1, ..., X)``. + weight: Voxelwise label weight tensor of shape ``(N, ..., X)`` or ``(N, 1, ..., X)``. + alpha: False positives multiplier. Set to ``1 - beta`` if ``None``. + beta: False negatives multiplier. + gamma: Exponent used for focal Tverksy loss. + epsilon: Constant used to avoid division by zero. + normalize: Whether to normalize ``input`` predictions using ``sigmoid`` or ``softmax``. + binarize: Whether to round normalized predictions to 0 or 1, respectively. If ``False``, + soft normalized predictions (and target scores) are used. In order for the Tversky + index to be identical to Dice, this option must be set to ``True`` and ``alpha=beta=0.5``. + reduction: Either ``none``, ``mean``, or ``sum``. + + Returns: + One minus Tversky index (TI) to the power of gamma. If ``reduction="none"``, the returned + tensor has shape ``(N, C)`` with the loss for each batch. Otherwise, a scalar is returned. + + """ + return tversky_loss( + logits, + target, + weight=weight, + alpha=alpha, + beta=beta, + gamma=gamma, + epsilon=epsilon, + normalize=True, + binarize=binarize, + reduction=reduction, + ) + + +def focal_loss_with_logits( + input: paddle.Tensor, + target: paddle.Tensor, + weight: Optional[paddle.Tensor] = None, + alpha: float = 0.25, + gamma: float = 2, + reduction: str = "mean", +) -> paddle.Tensor: + """Loss used in RetinaNet for dense detection: https://arxiv.org/abs/1708.02002. + + Args: + input: Logits of the predictions for each example. + target: A tensor with the same shape as ``input``. Stores the binary classification + label for each element in inputs (0 for the negative class and 1 for the positive class). + weight: Multiplicative mask tensor with same shape as ``input``. + alpha: Weighting factor in [0, 1] to balance positive vs negative examples or -1 for ignore. + gamma: Exponent of the modulating factor (1 - p_t) to balance easy vs hard examples. + reduction: Either ``none``, ``mean``, or ``sum``. + + Returns: + Loss tensor with the reduction option applied. + + """ + bce = binary_cross_entropy_with_logits(logit=input, label=target, reduction="none") + one = paddle.to_tensor(data=1, dtype=bce.dtype, place=bce.place) + loss = one.sub(paddle.exp(x=-bce)).pow(y=gamma).mul(bce) + if alpha >= 0: + if alpha > 1: + raise ValueError("focal_loss() 'alpha' must be in [0, 1]") + loss = target.mul(alpha).add(one.sub(target).mul(1 - alpha)).mul(loss) + loss = masked_loss(loss, weight, "focal_loss_with_logits", inplace=True) + loss = reduce_loss(loss, reduction) + return loss + + +def kld_loss( + mean: paddle.Tensor, logvar: paddle.Tensor, reduction: str = "mean" +) -> paddle.Tensor: + """Kullback-Leibler divergence in case of zero-mean and isotropic unit variance normal prior distribution. + + Kingma and Welling, Auto-Encoding Variational Bayes, ICLR 2014, https://arxiv.org/abs/1312.6114 (Appendix B). + + """ + loss = ( + mean.square() + .add_(y=paddle.to_tensor(logvar.exp())) + .subtract_(y=paddle.to_tensor(1)) + .subtract_(y=paddle.to_tensor(logvar)) + ) + loss = reduce_loss(loss, reduction) + loss = loss.multiply_(y=paddle.to_tensor(0.5)) + return loss + + +def lcc_loss( + source: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + kernel_size: ScalarOrTuple[int] = 7, + epsilon: float = 1e-15, + reduction: str = "mean", +) -> paddle.Tensor: + """Local normalized cross correlation. + + References: + Avants et al., 2008, Symmetric Diffeomorphic Image Registration with Cross Correlation: + Evaluating Automated Labeling of Elderly and Neurodegenerative Brain, + doi:10.1016/j.media.2007.06.004. + + Args: + source: Source image sampled on ``target`` grid. + target: Target image with same shape as ``source``. + mask: Multiplicative mask tensor with same shape as ``source``. + kernel_size: Local rectangular window size in number of grid points. + epsilon: Small constant added to denominator to avoid division by zero. + reduction: Whether to compute "mean" or "sum" over all grid points. If "none", output + tensor shape is equal to the shape of the input tensors given an odd kernel size. + + Returns: + Negative local normalized cross correlation plus one. + + """ + if not isinstance(source, paddle.Tensor): + raise TypeError("lcc_loss() 'source' must be tensor") + if not isinstance(target, paddle.Tensor): + raise TypeError("lcc_loss() 'target' must be tensor") + if tuple(source.shape) != tuple(target.shape): + raise ValueError("lcc_loss() 'source' must have same shape as 'target'") + + def local_sum(data: paddle.Tensor) -> paddle.Tensor: + return avg_pool( + data, kernel_size=kernel_size, stride=1, padding=None, divisor_override=1 + ) + + def local_mean(data: paddle.Tensor) -> paddle.Tensor: + return avg_pool( + data, + kernel_size=kernel_size, + stride=1, + padding=None, + count_include_pad=False, + ) + + source = source.astype(dtype="float32") + target = target.astype(dtype="float32") + source_mean = local_mean(source) + target_mean = local_mean(target) + x = source.sub(source_mean) + y = target.sub(target_mean) + a = local_sum(x.mul(y)) + b = local_sum(x.square()) + c = local_sum(y.square()) + loss = ( + paddle.square_(a) + .divide_( + y=paddle.to_tensor( + b.multiply_(y=paddle.to_tensor(c)).add_(y=paddle.to_tensor(epsilon)) + ) + ) + .neg_() + .add_(y=paddle.to_tensor(1)) + ) + loss = masked_loss(loss, mask, "lcc_loss") + loss = reduce_loss(loss, reduction, mask) + return loss + + +def wlcc_loss( + source: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + source_mask: Optional[paddle.Tensor] = None, + target_mask: Optional[paddle.Tensor] = None, + kernel_size: ScalarOrTuple[int] = 7, + epsilon: float = 1e-15, + reduction: str = "mean", +) -> paddle.Tensor: + """Weighted local normalized cross correlation. + + References: + Lewis et al., 2020, Fast Learning-based Registration of Sparse 3D Clinical Images, arXiv:1812.06932. + + Args: + source: Source image sampled on ``target`` grid. + target: Target image with same shape as ``source``. + mask: Multiplicative mask tensor ``w_c`` with same shape as ``target`` and ``source``. + This tensor is used for computing the weighted local correlation. If ``None`` and + both ``source_mask`` and ``target_mask`` are given, it is set to the product of these. + Otherwise, no mask is used to aggregate the local cross correlation values. When both + ``source_mask`` and ``target_mask`` are ``None``, but ``mask`` is not, then the specified + ``mask`` is used both as ``source_mask`` and ``target_mask``. + source_mask: Multiplicative mask tensor ``w_m`` with same shape as ``source``. + This tensor is used for computing the weighted local ``source`` mean. If ``None``, + the local mean is computed over all ``source`` elements within each local neighborhood. + target_mask: Multiplicative mask tensor ``w_f`` with same shape as ``source``. + This tensor is used for computing the weighted local ``target`` mean. If ``None``, + the local mean is computed over all ``target`` elements within each local neighborhood. + kernel_size: Local rectangular window size in number of grid points. + epsilon: Small constant added to denominator to avoid division by zero. + reduction: Whether to compute "mean" or "sum" over all grid points. If "none", output + tensor shape is equal to the shape of the input tensors given an odd kernel size. + + Returns: + Negative local normalized cross correlation plus one. + + """ + if not isinstance(source, paddle.Tensor): + raise TypeError("wlcc_loss() 'source' must be tensor") + if not isinstance(target, paddle.Tensor): + raise TypeError("wlcc_loss() 'target' must be tensor") + if tuple(source.shape) != tuple(target.shape): + raise ValueError("wlcc_loss() 'source' must have same shape as 'target'") + for t, t_name, w, w_name in zip( + [target, source, target], + ["target", "source", "target"], + [mask, source_mask, target_mask], + ["mask", "source_mask", "target_mask"], + ): + if w is None: + continue + if not isinstance(w, paddle.Tensor): + raise TypeError(f"wlcc_loss() '{w_name}' must be tensor") + if tuple(w.shape)[0] not in (1, tuple(t.shape)[0]): + raise ValueError( + f"wlcc_loss() '{w_name}' batch size ({tuple(w.shape)[0]}) must be 1 or match '{t_name}' ({tuple(t.shape)[0]})" + ) + if tuple(w.shape)[1] not in (1, tuple(t.shape)[1]): + raise ValueError( + f"wlcc_loss() '{w_name}' number of channels ({tuple(w.shape)[1]}) must be 1 or match '{t_name}' ({tuple(t.shape)[1]})" + ) + if tuple(w.shape)[2:] != tuple(t.shape)[2:]: + raise ValueError( + f"wlcc_loss() '{w_name}' grid shape ({tuple(w.shape)[2:]}) must match '{t_name}' ({tuple(t.shape)[2:]})" + ) + + def local_sum(data: paddle.Tensor) -> paddle.Tensor: + return avg_pool( + data, kernel_size=kernel_size, stride=1, padding=None, divisor_override=1 + ) + + def local_mean( + data: paddle.Tensor, weight: Optional[paddle.Tensor] = None + ) -> paddle.Tensor: + if weight is None: + return avg_pool( + data, + kernel_size=kernel_size, + stride=1, + padding=None, + count_include_pad=False, + ) + a = local_sum(data.mul(weight)) + b = local_sum(weight).add_(y=paddle.to_tensor(epsilon)) + return a.divide_(y=paddle.to_tensor(b)) + + if mask is not None and source_mask is None and target_mask is None: + source_mask = mask.astype(dtype="float32") + target_mask = source_mask + else: + if source_mask is not None: + source_mask = source_mask.astype(dtype="float32") + if target_mask is not None: + target_mask = target_mask.astype(dtype="float32") + source = source.astype(dtype="float32") + target = target.astype(dtype="float32") + source_mean = local_mean(source, source_mask) + target_mean = local_mean(target, target_mask) + x = source.sub(source_mean) + y = target.sub(target_mean) + if mask is None and source_mask is not None and target_mask is not None: + mask = source_mask.mul(target_mask) + if mask is not None: + x = x.multiply_(y=paddle.to_tensor(mask)) + y = y.multiply_(y=paddle.to_tensor(mask)) + a = local_sum(x.mul(y)) + b = local_sum(x.square()) + c = local_sum(y.square()) + loss = ( + paddle.square_(a) + .divide_( + y=paddle.to_tensor( + b.multiply_(y=paddle.to_tensor(c)).add_(y=paddle.to_tensor(epsilon)) + ) + ) + .neg_() + .add_(y=paddle.to_tensor(1)) + ) + loss = masked_loss(loss, mask, name="wlcc_loss") + loss = reduce_loss(loss, reduction, mask) + return loss + + +def huber_loss( + input: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + norm: Optional[Union[float, paddle.Tensor]] = None, + reduction: str = "mean", + delta: float = 1.0, +) -> paddle.Tensor: + """Normalized masked Huber loss. + + Args: + input: Source image sampled on ``target`` grid. + target: Target image with same shape as ``input``. + mask: Multiplicative mask with same shape as ``input``. + norm: Positive factor by which to divide loss value. + reduction: Whether to compute "mean" or "sum" over all grid points. + If "none", output tensor shape is equal to the shape of the input tensors. + delta: Specifies the threshold at which to change between delta-scaled L1 and L2 loss. + + Returns: + Masked, aggregated, and normalized Huber loss. + + """ + + def loss_fn( + input: paddle.Tensor, target: paddle.Tensor, reduction: str = "mean" + ) -> paddle.Tensor: + return paddle.nn.functional.smooth_l1_loss( + input=input, label=target, reduction=reduction, delta=delta + ) + + return elementwise_loss( + "huber_loss", loss_fn, input, target, mask=mask, norm=norm, reduction=reduction + ) + + +def smooth_l1_loss( + input: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + norm: Optional[Union[float, paddle.Tensor]] = None, + reduction: str = "mean", + beta: float = 1.0, +) -> paddle.Tensor: + """Normalized masked smooth L1 loss. + + Args: + input: Source image sampled on ``target`` grid. + target: Target image with same shape as ``input``. + mask: Multiplicative mask with same shape as ``input``. + norm: Positive factor by which to divide loss value. + reduction: Whether to compute "mean" or "sum" over all grid points. + If "none", output tensor shape is equal to the shape of the input tensors. + delta: Specifies the threshold at which to change between delta-scaled L1 and L2 loss. + + Returns: + Masked, aggregated, and normalized smooth L1 loss. + + """ + + def loss_fn( + input: paddle.Tensor, target: paddle.Tensor, reduction: str = "mean" + ) -> paddle.Tensor: + return ( + paddle.nn.functional.smooth_l1_loss( + input=input, reduction=reduction, label=target, delta=beta + ) + / beta + ) + + return elementwise_loss( + "smooth_l1_loss", + loss_fn, + input, + target, + mask=mask, + norm=norm, + reduction=reduction, + ) + + +def l1_loss( + input: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + norm: Optional[Union[float, paddle.Tensor]] = None, + reduction: str = "mean", +) -> paddle.Tensor: + """Normalized mean absolute error. + + Args: + input: Source image sampled on ``target`` grid. + target: Target image with same shape as ``input``. + mask: Multiplicative mask with same shape as ``input``. + norm: Positive factor by which to divide loss value. + reduction: Whether to compute "mean" or "sum" over all grid points. + If "none", output tensor shape is equal to the shape of the input tensors. + + Returns: + Normalized mean absolute error. + + """ + + def loss_fn( + input: paddle.Tensor, target: paddle.Tensor, reduction: str = "mean" + ) -> paddle.Tensor: + return paddle.nn.functional.l1_loss( + input=input, label=target, reduction=reduction + ) + + return elementwise_loss( + "l1_loss", loss_fn, input, target, mask=mask, norm=norm, reduction=reduction + ) + + +def mae_loss( + input: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + norm: Optional[Union[float, paddle.Tensor]] = None, + reduction: str = "mean", +) -> paddle.Tensor: + """Normalized mean absolute error. + + Args: + input: Source image sampled on ``target`` grid. + target: Target image with same shape as ``input``. + mask: Multiplicative mask with same shape as ``input``. + norm: Positive factor by which to divide loss value. + reduction: Whether to compute "mean" or "sum" over all grid points. + If "none", output tensor shape is equal to the shape of the input tensors. + + Returns: + Normalized mean absolute error. + + """ + + def loss_fn( + input: paddle.Tensor, target: paddle.Tensor, reduction: str = "mean" + ) -> paddle.Tensor: + return paddle.nn.functional.l1_loss( + input=input, label=target, reduction=reduction + ) + + return elementwise_loss( + "mae_loss", loss_fn, input, target, mask=mask, norm=norm, reduction=reduction + ) + + +def mse_loss( + input: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + norm: Optional[Union[float, paddle.Tensor]] = None, + reduction: str = "mean", +) -> paddle.Tensor: + """Average normalized squared differences. + + This loss is equivalent to `ssd_loss`, except that the default `reduction` is "mean". + + Args: + input: Source image sampled on ``target`` grid. + target: Target image with same shape as ``input``. + mask: Multiplicative mask with same shape as ``input``. + norm: Positive factor by which to divide loss value. + reduction: Whether to compute "mean" or "sum" over all grid points. + If "none", output tensor shape is equal to the shape of the input tensors. + + Returns: + Average normalized squared differences. + + """ + return ssd_loss(input, target, mask=mask, norm=norm, reduction=reduction) + + +def ssd_loss( + input: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + norm: Optional[Union[float, paddle.Tensor]] = None, + reduction: str = "sum", +) -> paddle.Tensor: + """Sum of normalized squared differences. + + The SSD loss is equivalent to MSE, except that an optional overlap mask is supported and + that the loss value is optionally multiplied by a normalization constant. Moreover, by default + the sum instead of the mean of per-element loss values is returned (cf. ``reduction``). + The value returned by ``max_difference(source, target).square()`` can be used as normalization + factor, which is equvalent to first normalizing the images to [0, 1]. + + Args: + input: Source image sampled on ``target`` grid. + target: Target image with same shape as ``input``. + mask: Multiplicative mask with same shape as ``input``. + norm: Positive factor by which to divide loss value. + reduction: Whether to compute "mean" or "sum" over all grid points. + If "none", output tensor shape is equal to the shape of the input tensors. + + Returns: + Sum of normalized squared differences. + + """ + if not isinstance(input, paddle.Tensor): + raise TypeError("ssd_loss() 'input' must be tensor") + if not isinstance(target, paddle.Tensor): + raise TypeError("ssd_loss() 'target' must be tensor") + if tuple(input.shape) != tuple(target.shape): + raise ValueError("ssd_loss() 'input' must have same shape as 'target'") + loss = input.sub(target).square() + loss = masked_loss(loss, mask, "ssd_loss") + loss = reduce_loss(loss, reduction, mask) + if norm is not None: + norm = paddle.to_tensor(data=norm, dtype=loss.dtype, place=loss.place).squeeze() + if not norm.ndim == 0: + raise ValueError("ssd_loss() 'norm' must be scalar") + if norm > 0: + loss = loss.divide_(y=paddle.to_tensor(norm)) + return loss + + +def mi_loss( + input: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + vmin: Optional[float] = None, + vmax: Optional[float] = None, + num_bins: Optional[int] = None, + num_samples: Optional[int] = None, + sample_ratio: Optional[float] = None, + normalized: bool = False, +) -> paddle.Tensor: + """Calculate mutual information loss using Parzen window density and entropy estimations. + + References: + Qiu, H., Qin, C., Schuh, A., Hammernik, K.: Learning Diffeomorphic and Modality-invariant + Registration using B-splines. Medical Imaging with Deep Learning. (2021). + Thévenaz, P., Unser, M.: Optimization of mutual information for multiresolution image registration. + IEEE Trans. Image Process. 9, 2083–2099 (2000). + + Args: + input: Source image sampled on ``target`` grid. + target: Target image with same shape as ``input``. + mask: Region of interest mask with same shape as ``input``. + vmin: Minimal intensity value the joint and marginal density is estimated. + vmax: Maximal intensity value the joint and marginal density is estimated. + num_bins: Number of bin edges to discretize the density estimation. + num_samples: Number of voxels in the image domain randomly sampled to compute the loss, + ignored if `sample_ratio` is also set. + sample_ratio: Ratio of voxels in the image domain randomly sampled to compute the loss. + normalized: Calculate Normalized Mutual Information instead of Mutual Information if True. + + Returns: + Negative mutual information. If ``normalized=True``, 2 is added such that the loss value is in [0, 1]. + + """ + if target.ndim < 3: + raise ValueError("mi_loss() 'target' must be tensor of shape (N, C, ..., X)") + if tuple(input.shape) != tuple(target.shape): + raise ValueError("ssd_loss() 'input' must have same shape as 'target'") + if vmin is None: + vmin = paddle_aux.min(input.min(), target.min()).item() + if vmax is None: + vmax = paddle_aux.max(input.max(), target.max()).item() + if num_bins is None: + num_bins = 64 + elif num_bins == "auto": + raise NotImplementedError( + "mi_loss() automatically setting num_bins based on dynamic range of input" + ) + shape = tuple(target.shape) + input = input.flatten(start_axis=2) + target = target.flatten(start_axis=2) + if mask is not None: + if ( + mask.ndim < 3 + or tuple(mask.shape)[2:] != shape[2:] + or tuple(mask.shape)[1] != 1 + ): + raise ValueError( + "mi_loss() 'mask' must be tensor of shape (1|N, 1, ..., X) with spatial dimensions matching 'target'" + ) + mask = mask.flatten(start_axis=2) + if sample_ratio is not None: + if num_samples is not None: + raise ValueError( + "mi_loss() 'num_samples' and 'sample_ratio' are mutually exclusive" + ) + if sample_ratio <= 0 or sample_ratio > 1: + raise ValueError( + "mi_loss() 'sample_ratio' must be in open-closed interval (0, 1]" + ) + num_samples = max(1, int(sample_ratio * tuple(target.shape)[2:].size)) + if num_samples is not None: + input, target = rand_sample( + [input, target], num_samples, mask=mask, replacement=True + ) + elif mask is not None: + input = input.mul(mask) + target = target.mul(mask) + bin_width = (vmax - vmin) / num_bins + out_3 = paddle.linspace(start=vmin, stop=vmax, num=num_bins) + out_3.stop_gradient = not False + bin_center = out_3 + bin_center = bin_center.unsqueeze(axis=1).astype(dtype=input.dtype) + pw_sdev = bin_width * (1 / (2 * math.sqrt(2 * math.log(2)))) + pw_norm = 1 / math.sqrt(2 * math.pi) * pw_sdev + + def parzen_window_fn(x: paddle.Tensor) -> paddle.Tensor: + return x.sub(bin_center).square().div(2 * pw_sdev**2).neg().exp().mul(pw_norm) + + pw_input = parzen_window_fn(input) + pw_target = parzen_window_fn(target) + x = pw_target + perm_11 = list(range(x.ndim)) + perm_11[1] = 2 + perm_11[2] = 1 + hist_joint = pw_input.bmm(y=x.transpose(perm=perm_11)) + hist_norm = hist_joint.flatten(start_axis=1, stop_axis=-1).sum(axis=1) + 1e-05 + p_joint = hist_joint / hist_norm.view(-1, 1, 1) + p_input = paddle.sum(x=p_joint, axis=2) + p_target = paddle.sum(x=p_joint, axis=1) + ent_input = -paddle.sum(x=p_input * paddle.log(x=p_input + 1e-05), axis=1) + ent_target = -paddle.sum(x=p_target * paddle.log(x=p_target + 1e-05), axis=1) + ent_joint = -paddle.sum(x=p_joint * paddle.log(x=p_joint + 1e-05), axis=(1, 2)) + if normalized: + loss = 2 - paddle.mean(x=(ent_input + ent_target) / ent_joint) + else: + loss = paddle.mean(x=ent_input + ent_target - ent_joint).neg() + return loss + + +def nmi_loss( + input: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + vmin: Optional[float] = None, + vmax: Optional[float] = None, + num_bins: Optional[int] = None, + num_samples: Optional[int] = None, + sample_ratio: Optional[float] = None, +) -> paddle.Tensor: + return mi_loss( + input, + target, + mask=mask, + vmin=vmin, + vmax=vmax, + num_bins=num_bins, + num_samples=num_samples, + sample_ratio=sample_ratio, + normalized=True, + ) + + +def grad_loss( + u: paddle.Tensor, + p: Union[int, float] = 2, + q: Optional[Union[int, float]] = 1, + spacing: Optional[Array] = None, + sigma: Optional[float] = None, + mode: str = "central", + which: Optional[Union[str, Sequence[str]]] = None, + reduction: str = "mean", +) -> paddle.Tensor: + """Loss term based on p-norm of spatial gradient of vector fields. + + The ``p`` and ``q`` parameters can be used to specify which norm to compute, i.e., ``sum(abs(du)**p)**q``, + where ``du`` are the 1st order spatial derivative of the input vector fields ``u`` computed using a finite + difference scheme and optionally normalized using a specified grid ``spacing``. + + This regularization loss is the basis, for example, for total variation and diffusion penalties. + + Args: + u: Batch of vector fields as tensor of shape ``(N, D, ..., X)``. When a tensor with less than + four dimensions is given, it is assumed to be a linear transformation and zero is returned. + p: The order of the gradient norm. When ``p = 0``, the partial derivatives are summed up. + q: Power parameter of gradient norm. If ``None``, then ``q = 1 / p``. If ``q == 0``, the + absolute value of the sum of partial derivatives is computed at each grid point. + spacing: Sampling grid spacing. + sigma: Standard deviation of Gaussian in grid units. + mode: Method used to approximate spatial derivatives. See ``spatial_derivatives()``. + which: String codes of spatial deriviatives to compute. See ``SpatialDerivativeKeys``. + reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. + + Returns: + Spatial gradient loss of vector fields. + + """ + if u.ndim < 4: + if reduction == "none": + raise NotImplementedError( + "grad_loss() not implemented for linear transformation and 'reduction'='none'" + ) + return paddle.to_tensor(data=0, dtype=u.dtype, place=u.place) + D = tuple(u.shape)[1] + if u.ndim - 2 != D: + raise ValueError( + f"grad_loss() 'u' must be tensor of shape (N, {u.ndim - 2}, ..., X)" + ) + if q is None: + q = 1.0 / p + derivs = spatial_derivatives( + u, mode=mode, which=which, order=1, sigma=sigma, spacing=spacing + ) + loss = paddle.concat( + x=[deriv.unsqueeze(axis=-1) for deriv in derivs.values()], axis=-1 + ) + if p == 1: + loss = loss.abs() + elif p != 0: + if p % 2 == 0: + loss = loss.pow(y=p) + else: + loss = loss.abs().pow_(y=p) + loss = loss.sum(axis=-1) + if q == 0: + loss.abs_() + elif q != 1: + loss.pow_(y=q) + loss = reduce_loss(loss, reduction) + return loss + + +def bending_loss( + u: paddle.Tensor, + spacing: Optional[Array] = None, + sigma: Optional[float] = None, + mode: str = "sobel", + reduction: str = "mean", +) -> paddle.Tensor: + """Bending energy of vector fields. + + Args: + u: Batch of vector fields as tensor of shape ``(N, D, ..., X)``. When a tensor with less than + four dimensions is given, it is assumed to be a linear transformation and zero is returned. + spacing: Sampling grid spacing. + sigma: Standard deviation of Gaussian in grid units (cf. ``spatial_derivatives()``). + mode: Method used to approximate spatial derivatives (cf. ``spatial_derivatives()``). + reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. + + Returns: + Bending energy. + + """ + if u.ndim < 4: + if reduction == "none": + raise NotImplementedError( + "bending_energy() not implemented for linear transformation and 'reduction'='none'" + ) + return paddle.to_tensor(data=0, dtype=u.dtype, place=u.place) + D = tuple(u.shape)[1] + if u.ndim - 2 != D: + raise ValueError( + f"bending_energy() 'u' must be tensor of shape (N, {u.ndim - 2}, ..., X)" + ) + which = SpatialDerivativeKeys.unique(SpatialDerivativeKeys.all(ndim=D, order=2)) + derivs = spatial_derivatives( + u, mode=mode, which=which, sigma=sigma, spacing=spacing + ) + derivs = paddle.concat( + x=[deriv.unsqueeze(axis=-1) for deriv in derivs.values()], axis=-1 + ) + derivs *= paddle.to_tensor( + data=[(2 if SpatialDerivativeKeys.is_mixed(key) else 1) for key in which], + place=u.place, + ) + loss = derivs.pow(y=2).sum(axis=-1) + loss = reduce_loss(loss, reduction) + return loss + + +be_loss = bending_loss +bending_energy = bending_loss + + +def bspline_bending_loss( + data: paddle.Tensor, stride: ScalarOrTuple[int] = 1, reduction: str = "mean" +) -> paddle.Tensor: + """Evaluate bending energy of cubic B-spline function, e.g., spatial free-form deformation. + + Args: + data: Cubic B-spline interpolation coefficients as tensor of shape ``(N, C, ..., X)``. + stride: Number of points between control points at which to evaluate bending energy, plus one. + If a sequence of values is given, these must be the strides for the different spatial + dimensions in the order ``(sx, ...)``. A stride of 1 is equivalent to evaluating bending + energy only at the usually coarser resolution of the control point grid. It should be noted + that the stride need not match the stride used to densely sample the spline deformation field + at a given fixed target image resolution. + reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. + + Returns: + Bending energy of cubic B-spline. + + """ + if not isinstance(data, paddle.Tensor): + raise TypeError("bspline_bending_loss() 'data' must be paddle.Tensor") + if not paddle.is_floating_point(x=data): + raise TypeError("bspline_bending_loss() 'data' must have floating point dtype") + if data.ndim < 3: + raise ValueError("bspline_bending_loss() 'data' must have shape (N, C, ..., X)") + D = data.ndim - 2 + C = tuple(data.shape)[1] + if C != D: + raise ValueError( + f"bspline_bending_loss() 'data' number of channels ({C}) does not match number of spatial dimensions ({D})" + ) + energy: Optional[paddle.Tensor] = None + derivs = SpatialDerivativeKeys.all(D, order=2) + derivs = SpatialDerivativeKeys.unique(derivs) + npoints = 0 + for deriv in derivs: + derivative = [0] * D + for sdim in SpatialDerivativeKeys.split(deriv): + derivative[sdim] += 1 + assert sum(derivative) == 2 + values = evaluate_cubic_bspline( + data, stride=stride, derivative=derivative + ).square() + if reduction != "none": + npoints = len(tuple(values.shape)[2:]) + values = values.sum() + if not SpatialDerivativeKeys.is_mixed(deriv): + values = values.multiply_(y=paddle.to_tensor(2)) + energy = values if energy is None else energy.add_(y=paddle.to_tensor(values)) + assert energy is not None + assert npoints > 0 + if reduction == "mean" and npoints > 1: + energy = energy.divide_(y=paddle.to_tensor(npoints)) + return energy + + +bspline_be_loss = bspline_bending_loss +bspline_bending_energy = bspline_bending_loss + + +def curvature_loss( + u: paddle.Tensor, + spacing: Optional[Array] = None, + sigma: Optional[float] = None, + mode: str = "sobel", + reduction: str = "mean", +) -> paddle.Tensor: + """Loss term based on unmixed 2nd order spatial derivatives of vector fields. + + References: + Fischer & Modersitzki (2003). Curvature based image registration. + Journal Mathematical Imaging and Vision, 18(1), 81–85. + + Args: + u: Batch of vector fields as tensor of shape ``(N, D, ..., X)``. When a tensor with less than + four dimensions is given, it is assumed to be a linear transformation and zero is returned. + spacing: Sampling grid spacing. + sigma: Standard deviation of Gaussian in grid units (cf. ``spatial_derivatives()``). + mode: Method used to approximate spatial derivatives (cf. ``spatial_derivatives()``). + reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. + + Returns: + Curvature loss of vector fields. + + """ + if u.ndim < 4: + if reduction == "none": + raise NotImplementedError( + "curvature_loss() not implemented for linear transformation and reduction='none'" + ) + return paddle.to_tensor(data=0, dtype=u.dtype, place=u.place) + D = tuple(u.shape)[1] + if u.ndim - 2 != D: + raise ValueError( + f"curvature_loss() 'u' must be tensor of shape (N, {u.ndim - 2}, ..., X)" + ) + which = SpatialDerivativeKeys.unmixed(ndim=D, order=2) + derivs = spatial_derivatives( + u, mode=mode, which=which, sigma=sigma, spacing=spacing + ) + derivs = paddle.concat( + x=[deriv.unsqueeze(axis=-1) for deriv in derivs.values()], axis=-1 + ) + loss = 0.5 * derivs.sum(axis=-1).pow(y=2) + loss = reduce_loss(loss, reduction) + return loss + + +def diffusion_loss( + u: paddle.Tensor, + spacing: Optional[paddle.Tensor] = None, + sigma: Optional[float] = None, + mode: str = "central", + reduction: str = "mean", +) -> paddle.Tensor: + """Diffusion regularization loss.""" + loss = grad_loss( + u, p=2, q=1, spacing=spacing, sigma=sigma, mode=mode, reduction=reduction + ) + return loss.multiply_(y=paddle.to_tensor(0.5)) + + +def divergence_loss( + u: paddle.Tensor, + q: Optional[Union[int, float]] = 1, + spacing: Optional[Array] = None, + sigma: Optional[float] = None, + mode: str = "central", + reduction: str = "mean", +) -> paddle.Tensor: + """Loss term encouraging divergence-free vector fields.""" + if u.ndim < 4: + if reduction == "none": + raise NotImplementedError( + "divergence_loss() not implemented for linear transformation and reduction='none'" + ) + return paddle.to_tensor(data=0, dtype=u.dtype, place=u.place) + D = tuple(u.shape)[1] + if u.ndim - 2 != D: + raise ValueError( + f"divergence_loss() 'u' must be tensor of shape (N, {u.ndim - 2}, ..., X)" + ) + derivs = spatial_derivatives(u, mode=mode, order=1, sigma=sigma, spacing=spacing) + derivs = paddle.concat( + x=[deriv.unsqueeze(axis=-1) for deriv in derivs.values()], axis=-1 + ) + loss = derivs.sum(axis=-1) + loss = loss.abs_() if q < 2 else loss.pow_(y=q) + loss = reduce_loss(loss, reduction) + return loss + + +def lame_parameters( + material_name: Optional[str] = None, + first_parameter: Optional[float] = None, + second_parameter: Optional[float] = None, + shear_modulus: Optional[float] = None, + poissons_ratio: Optional[float] = None, + youngs_modulus: Optional[float] = None, +) -> Tuple[float, float]: + """Get Lame parameters of linear elasticity given different quantities. + + Args: + material_name: Name of material preset. Cannot be used in conjunction with other arguments. + first_parameter: Lame's first parameter. + second_parameter: Lame's second parameter, i.e., shear modulus. + shear_modulus: Shear modulus, i.e., Lame's second parameter. + poissons_ratio: Poisson's ratio. + youngs_modulus: Young's modulus. + + Returns: + lambda: Lame's first parameter. + mu: Lame's second parameter. + + """ + RUBBER_POISSONS_RATIO = 0.4999 + RUBBER_SHEAR_MODULUS = 0.0006 + kwargs = { + name: value + for name, value in zip( + [ + "first_parameter", + "second_parameter", + "shear_modulus", + "poissons_ratio", + "youngs_modulus", + ], + [ + first_parameter, + second_parameter, + poissons_ratio, + youngs_modulus, + shear_modulus, + ], + ) + if value is not None + } + if material_name: + if kwargs: + raise ValueError( + "lame_parameters() 'material_name' cannot be specified in combination with other quantities" + ) + if material_name == "rubber": + poissons_ratio = RUBBER_POISSONS_RATIO + shear_modulus = RUBBER_SHEAR_MODULUS + else: + raise ValueError( + f"lame_parameters() unknown 'material_name': {material_name}" + ) + elif len(kwargs) != 2: + raise ValueError( + "lame_parameters() specify 'material_name' or exactly two parameters, got: " + + ", ".join(f"{k}={v}" for k, v in kwargs.items()) + ) + if second_parameter is None: + second_parameter = shear_modulus + elif shear_modulus is None: + shear_modulus = second_parameter + else: + raise ValueError( + "lame_parameters() 'second_parameter' and 'shear_modulus' are mutually exclusive" + ) + if first_parameter is None: + if shear_modulus is None: + if poissons_ratio is not None and youngs_modulus is not None: + first_parameter = ( + poissons_ratio + * youngs_modulus + / (1 + poissons_ratio)(1 - 2 * poissons_ratio) + ) + second_parameter = youngs_modulus / (2 * (1 + poissons_ratio)) + elif youngs_modulus is None: + if poissons_ratio is None: + poissons_ratio = RUBBER_POISSONS_RATIO + first_parameter = ( + 2 * shear_modulus * poissons_ratio / (1 - 2 * poissons_ratio) + ) + else: + first_parameter = ( + shear_modulus + * (youngs_modulus - 2 * shear_modulus) + / (3 * shear_modulus - youngs_modulus) + ) + elif second_parameter is None: + if youngs_modulus is None: + if poissons_ratio is None: + poissons_ratio = RUBBER_POISSONS_RATIO + second_parameter = ( + first_parameter * (1 - 2 * poissons_ratio) / (2 * poissons_ratio) + ) + else: + r = math.sqrt( + youngs_modulus**2 + + 9 * first_parameter**2 + + 2 * youngs_modulus * first_parameter + ) + second_parameter = youngs_modulus - 3 * first_parameter + r / 4 + if first_parameter is None or second_parameter is None: + raise NotImplementedError( + "lame_parameters() deriving Lame parameters from: " + + ", ".join(f"'{name}'" for name in kwargs.keys()) + ) + if first_parameter < 0: + raise ValueError("lame_parameter() 'first_parameter' is negative") + if second_parameter < 0: + raise ValueError("lame_parameter() 'second_parameter' is negative") + return first_parameter, second_parameter + + +def elasticity_loss( + u: paddle.Tensor, + material_name: Optional[str] = None, + first_parameter: Optional[float] = None, + second_parameter: Optional[float] = None, + shear_modulus: Optional[float] = None, + poissons_ratio: Optional[float] = None, + youngs_modulus: Optional[float] = None, + spacing: Optional[Array] = None, + sigma: Optional[float] = None, + mode: str = "sobel", + reduction: str = "mean", +) -> paddle.Tensor: + """Loss term based on Navier-Cauchy PDE of linear elasticity. + + References: + Fischer & Modersitzki, 2004, A unified approach to fast image registration and a new + curvature based registration technique. + + Args: + u: Batch of vector fields as tensor of shape ``(N, D, ..., X)``. When a tensor with less than + four dimensions is given, it is assumed to be a linear transformation and zero is returned. + material_name: Name of material preset. Cannot be used in conjunction with other arguments. + first_parameter: Lame's first parameter. + second_parameter: Lame's second parameter, i.e., shear modulus. + shear_modulus: Shear modulus, i.e., Lame's second parameter. + poissons_ratio: Poisson's ratio. + youngs_modulus: Young's modulus. + spacing: Sampling grid spacing. + sigma: Standard deviation of Gaussian in grid units (cf. ``spatial_derivatives()``). + mode: Method used to approximate spatial derivatives (cf. ``spatial_derivatives()``). + reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. + + Returns: + Linear elasticity loss of vector field. + + """ + lambd, mu = lame_parameters( + material_name=material_name, + first_parameter=first_parameter, + second_parameter=second_parameter, + shear_modulus=shear_modulus, + poissons_ratio=poissons_ratio, + youngs_modulus=youngs_modulus, + ) + if u.ndim < 4: + if reduction == "none": + raise NotImplementedError( + "elasticity_loss() not implemented for linear transformation and reduction='none'" + ) + return paddle.to_tensor(data=0, dtype=u.dtype, place=u.place) + D = tuple(u.shape)[1] + if u.ndim - 2 != D: + raise ValueError( + f"elasticity_loss() 'u' must be tensor of shape (N, {u.ndim - 2}, ..., X)" + ) + derivs = spatial_derivatives(u, mode=mode, order=1, sigma=sigma, spacing=spacing) + derivs = [derivs[str(SpatialDim(i))] for i in range(D)] + start_22 = derivs[0].shape[1] + 0 if 0 < 0 else 0 + loss = paddle.slice(derivs[0], [1], [start_22], [start_22 + 1]).clone() + for i in range(1, D): + start_23 = derivs[i].shape[1] + i if i < 0 else i + loss = loss.add_( + y=paddle.to_tensor(paddle.slice(derivs[i], [1], [start_23], [start_23 + 1])) + ) + loss = paddle.square_(loss).multiply_(y=paddle.to_tensor(lambd / 2)) + for j, k in itertools.product(range(D), repeat=2): + start_24 = derivs[j].shape[1] + k if k < 0 else k + start_25 = derivs[k].shape[1] + j if j < 0 else j + temp = paddle.slice(derivs[j], [1], [start_24], [start_24 + 1]).add( + paddle.slice(derivs[k], [1], [start_25], [start_25 + 1]) + ) + loss = loss.add_( + y=paddle.to_tensor( + paddle.square_(temp).multiply_(y=paddle.to_tensor(mu / 4)) + ) + ) + loss = reduce_loss(loss, reduction) + return loss + + +def total_variation_loss( + u: paddle.Tensor, + spacing: Optional[paddle.Tensor] = None, + sigma: Optional[float] = None, + mode: str = "central", + reduction: str = "mean", +) -> paddle.Tensor: + """Total variation regularization loss.""" + return grad_loss( + u, p=1, q=1, spacing=spacing, sigma=sigma, mode=mode, reduction=reduction + ) + + +tv_loss = total_variation_loss + + +def inverse_consistency_loss( + forward: paddle.Tensor, + inverse: paddle.Tensor, + grid: Optional[Grid] = None, + margin: Union[int, float] = 0, + mask: Optional[paddle.Tensor] = None, + units: str = "cube", + reduction: str = "mean", +) -> paddle.Tensor: + """Evaluate inverse consistency error. + + This function expects forward and inverse coordinate maps to be with respect to the unit cube + of side length 2 as defined by the domain and codomain ``grid`` (see also ``Grid.axes()``). + + Args: + forward: paddle.Tensor representation of spatial transformation. + inverse: paddle.Tensor representation of inverse transformation. + grid: Coordinate domain and codomain of forward transformation. + margin: Number of ``grid`` points to ignore when computing mean error. If type of the + argument is ``int``, this number of points are dropped at each boundary in each dimension. + If a ``float`` value is given, it must be in [0, 1) and denote the percentage of sampling + points to drop at each border. Inverse consistency of points near the domain boundary is + affected by extrapolation and excluding these may be preferrable. See also ``mask``. + mask: Foreground mask as tensor of shape ``(N, 1, ..., X)`` with size matching ``forward``. + Inverse consistency errors at target grid points with a zero mask value are ignored. + units: Compute mean inverse consistency error in specified units: ``cube`` with respect to + normalized grid cube coordinates, ``voxel`` in voxel units, or in ``world`` units (mm). + reduction: Specifies the reduction to apply to the output: 'none' | 'mean' | 'sum'. + + Returns: + Inverse consistency error. + + """ + if not isinstance(forward, paddle.Tensor): + raise TypeError("inverse_consistency_loss() 'forward' must be tensor") + if not isinstance(inverse, paddle.Tensor): + raise TypeError("inverse_consistency_loss() 'inverse' must be tensor") + if not isinstance(margin, (int, float)): + raise TypeError("inverse_consistency_loss() 'margin' must be int or float") + if grid is None: + if forward.ndim < 4: + if inverse.ndim < 4: + raise ValueError( + "inverse_consistency_loss() 'grid' required when both transforms are affine" + ) + grid = Grid(shape=tuple(inverse.shape)[2:]) + else: + grid = Grid(shape=tuple(forward.shape)[2:]) + x = grid.coords(dtype=forward.dtype, device=forward.place).unsqueeze(axis=0) + y = transform_grid(forward, x, align_corners=grid.align_corners()) + y = transform_points(inverse, y, align_corners=grid.align_corners()) + error = y - x + if mask is not None: + if not isinstance(mask, paddle.Tensor): + raise TypeError("inverse_consistency_loss() 'mask' must be tensor") + if mask.ndim != grid.ndim + 2: + raise ValueError( + f"inverse_consistency_loss() 'mask' must be {grid.ndim + 2}-dimensional" + ) + if tuple(mask.shape)[1] != 1: + raise ValueError( + "inverse_consistency_loss() 'mask' must have shape (N, 1, ..., X)" + ) + if tuple(mask.shape)[0] != 1 and tuple(mask.shape)[0] != tuple(error.shape)[0]: + raise ValueError( + f"inverse_consistency_loss() 'mask' batch size must be 1 or {tuple(error.shape)[0]}" + ) + error[move_dim(mask == 0, 1, -1).expand_as(error)] = 0 + if margin > 0: + if isinstance(margin, float): + if margin < 0 or margin >= 1: + raise ValueError( + f"inverse_consistency_loss() 'margin' must be in [0, 1), got {margin}" + ) + m = [int(margin * n) for n in tuple(grid.shape)] + else: + m = [max(0, int(margin))] * grid.ndim + subgrid = tuple( + reversed([slice(i, n - i) for i, n in zip(m, tuple(grid.shape))]) + ) + error = error[ + (slice(0, tuple(error.shape)[0]),) + subgrid + (slice(0, grid.ndim),) + ] + if units in ("voxel", "world"): + error = denormalize_flow(error, size=tuple(grid.shape), channels_last=True) + if units == "world": + error *= grid.spacing().to(error) + error: paddle.Tensor = error.norm(p=2, axis=-1) + if reduction != "none": + count = error.size + error = error.sum() + if reduction == "mean" and mask is not None: + count = (mask != 0).sum() + error /= count + return error + + +def elementwise_loss( + name: str, + loss_fn: ElementwiseLoss, + input: paddle.paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + norm: Optional[Union[float, paddle.Tensor]] = None, + reduction: str = "mean", +) -> paddle.Tensor: + """Evaluate, aggregate, and normalize elementwise loss, optionally within masked region. + + Args: + input: Source image sampled on ``target`` grid. + target: Target image with same shape as ``input``. + mask: Multiplicative mask with same shape as ``input``. + norm: Positive factor by which to divide loss value. + reduction: Whether to compute "mean" or "sum" over all grid points. + If "none", output tensor shape is equal to the shape of the input tensors. + + Returns: + Aggregated normalized loss value. + + """ + if not isinstance(input, paddle.Tensor): + raise TypeError(f"{name}() 'input' must be tensor") + if not isinstance(target, paddle.Tensor): + raise TypeError(f"{name}() 'target' must be tensor") + if tuple(input.shape) != tuple(target.shape): + raise ValueError(f"{name}() 'input' must have same shape as 'target'") + if mask is None: + loss = loss_fn(input, target, reduction=reduction) + else: + loss = loss_fn(input, target, reduction="none") + loss = masked_loss(loss, mask, name) + loss = reduce_loss(loss, reduction, mask) + if norm is not None: + norm = paddle.to_tensor(data=norm, dtype=loss.dtype, place=loss.place).squeeze() + if not norm.ndim == 0: + raise ValueError(f"{name}() 'norm' must be scalar") + if norm > 0: + loss = loss.divide_(y=paddle.to_tensor(norm)) + return loss + + +def masked_loss( + loss: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + name: Optional[str] = None, + inplace: bool = False, +) -> paddle.Tensor: + """Multiply loss with an optionally specified spatial mask.""" + if mask is None: + return loss + if not name: + name = "masked_loss" + if not isinstance(mask, paddle.Tensor): + raise TypeError(f"{name}() 'mask' must be tensor") + if tuple(mask.shape)[0] != 1 and tuple(mask.shape)[0] != tuple(loss.shape)[0]: + raise ValueError( + f"{name}() 'mask' must have same batch size as 'target' or batch size 1" + ) + if tuple(mask.shape)[1] != 1 and tuple(mask.shape)[1] != tuple(loss.shape)[1]: + raise ValueError( + f"{name}() 'mask' must have same number of channels as 'target' or only 1" + ) + if tuple(mask.shape)[2:] != tuple(loss.shape)[2:]: + raise ValueError(f"{name}() 'mask' must have same spatial shape as 'target'") + if inplace: + loss = loss.multiply_(y=paddle.to_tensor(mask)) + else: + loss = loss.mul(mask.astype(loss.dtype)) + return loss + + +def reduce_loss( + loss: paddle.Tensor, reduction: str = "mean", mask: Optional[paddle.Tensor] = None +) -> paddle.Tensor: + """Reduce loss computed at each grid point.""" + if reduction not in ("mean", "sum", "none"): + raise ValueError("reduce_loss() 'reduction' must be 'mean', 'sum' or 'none'") + if reduction == "none": + return loss + if mask is None: + return loss.mean() if reduction == "mean" else loss.sum() + value = loss.sum() + if reduction == "mean": + numel = mask.expand_as(y=loss).sum() + value = value.divide_(y=paddle.to_tensor(numel, dtype=value.dtype)) + return value diff --git a/jointContribution/HighResolution/deepali/losses/image.py b/jointContribution/HighResolution/deepali/losses/image.py new file mode 100644 index 0000000000..0f8c664bbe --- /dev/null +++ b/jointContribution/HighResolution/deepali/losses/image.py @@ -0,0 +1,422 @@ +from typing import Optional +from typing import Union + +import paddle + +from ..core import functional as U +from ..core.types import ScalarOrTuple +from . import functional as L +from .base import NormalizedPairwiseImageLoss +from .base import PairwiseImageLoss + + +class Dice(PairwiseImageLoss): + """Generalized Sorensen-Dice similarity coefficient.""" + + def __init__(self, epsilon: float = 1e-15) -> None: + super().__init__() + self.epsilon = epsilon + + def forward( + self, + source: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + """Evaluate image dissimilarity loss.""" + return L.dice_loss(source, target, weight=mask, epsilon=self.epsilon) + + def extra_repr(self) -> str: + return f"epsilon={self.epsilon:.2e}" + + +DSC = Dice + + +class LCC(PairwiseImageLoss): + """Local normalized cross correlation.""" + + def __init__( + self, kernel_size: ScalarOrTuple[int] = 7, epsilon: float = 1e-15 + ) -> None: + super().__init__() + self.kernel_size = kernel_size + self.epsilon = epsilon + + def forward( + self, + source: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + """Evaluate image dissimilarity loss.""" + return L.lcc_loss( + source, + target, + mask=mask, + kernel_size=self.kernel_size, + epsilon=self.epsilon, + ) + + def extra_repr(self) -> str: + return f"kernel_size={self.kernel_size}, epsilon={self.epsilon:.2e}" + + +LNCC = LCC + + +class WLCC(PairwiseImageLoss): + """Weighted local normalized cross correlation.""" + + def __init__( + self, kernel_size: ScalarOrTuple[int] = 7, epsilon: float = 1e-15 + ) -> None: + super().__init__() + self.kernel_size = kernel_size + self.epsilon = epsilon + + def forward( + self, + source: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + source_mask: Optional[paddle.Tensor] = None, + target_mask: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + """Evaluate image dissimilarity loss.""" + return L.wlcc_loss( + source, + target, + mask=mask, + source_mask=source_mask, + target_mask=target_mask, + kernel_size=self.kernel_size, + epsilon=self.epsilon, + ) + + def extra_repr(self) -> str: + return f"kernel_size={self.kernel_size}, epsilon={self.epsilon:.2e}" + + +SLCC = WLCC + + +class L1ImageLoss(NormalizedPairwiseImageLoss): + """Average absolute intensity differences.""" + + def forward( + self, + source: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + """Evaluate image dissimilarity loss.""" + return L.mae_loss(source, target, mask=mask, norm=self.norm) + + +MAE = L1ImageLoss + + +class HuberImageLoss(NormalizedPairwiseImageLoss): + """Average Huber loss.""" + + def __init__( + self, + source: Optional[paddle.Tensor] = None, + target: Optional[paddle.Tensor] = None, + norm: Optional[Union[bool, paddle.Tensor]] = None, + delta: Optional[float] = None, + beta: Optional[float] = None, + ): + """Initialize similarity metric. + + Args: + source: Source image from which to compute ``norm``. If ``None``, only use ``target`` if specified. + target: Target image from which to compute ``norm``. If ``None``, only use ``source`` if specified. + norm: Positive factor by which to divide loss. If ``None`` or ``True``, use ``source`` and/or ``target``. + If ``False`` or both ``source`` and ``target`` are ``None``, a normalization factor of one is used. + delta: Specifies the threshold at which to change between delta-scaled L1 and L2 loss. + beta: Alternative name for ``delta`` to be compatible with ``SmoothL1ImageLoss``. + + """ + if beta is not None: + if delta is not None: + raise ValueError( + f"{type(self).__name__}() 'delta' and 'beta' are mutually exclusive" + ) + delta = beta + elif delta is None: + delta = 1.0 + super().__init__(source, target, norm) + self.delta = delta + + @property + def beta(self) -> float: + return self.delta + + def forward( + self, + source: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + """Evaluate image dissimilarity loss.""" + return L.huber_loss(source, target, mask=mask, norm=self.norm, delta=self.delta) + + +class SmoothL1ImageLoss(NormalizedPairwiseImageLoss): + """Average smooth L1 loss.""" + + def __init__( + self, + source: Optional[paddle.Tensor] = None, + target: Optional[paddle.Tensor] = None, + norm: Optional[Union[bool, paddle.Tensor]] = None, + beta: Optional[float] = None, + delta: Optional[float] = None, + ): + """Initialize similarity metric. + + Args: + source: Source image from which to compute ``norm``. If ``None``, only use ``target`` if specified. + target: Target image from which to compute ``norm``. If ``None``, only use ``source`` if specified. + norm: Positive factor by which to divide loss. If ``None`` or ``True``, use ``source`` and/or ``target``. + If ``False`` or both ``source`` and ``target`` are ``None``, a normalization factor of one is used. + beta: Specifies the threshold at which to change between delta-scaled L1 and L2 loss. + delta: Alternative name for ``beta`` to be compatible with ``HuberImageLoss``. + + """ + if delta is not None: + if beta is not None: + raise ValueError( + f"{type(self).__name__}() 'beta' and 'detla' are mutually exclusive" + ) + beta = delta + elif beta is None: + beta = 1.0 + super().__init__(source, target, norm) + self.beta = beta + + @property + def delta(self) -> float: + return self.beta + + def forward( + self, + source: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + """Evaluate image dissimilarity loss.""" + return L.smooth_l1_loss(source, target, mask=mask, norm=self.norm) + + +class L2ImageLoss(NormalizedPairwiseImageLoss): + """Average squared intensity differences.""" + + def forward( + self, + source: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + """Evaluate image dissimilarity loss.""" + return L.mse_loss(source, target, mask=mask, norm=self.norm) + + +MSE = L2ImageLoss + + +class SSD(NormalizedPairwiseImageLoss): + """Sum of squared intensity differences.""" + + def forward( + self, + source: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + """Evaluate image dissimilarity loss.""" + return L.ssd_loss(source, target, mask=mask, norm=self.norm) + + +class MI(PairwiseImageLoss): + """Mutual information loss using Parzen window estimate with Gaussian kernel.""" + + def __init__( + self, + vmin: Optional[float] = None, + vmax: Optional[float] = None, + bins: Optional[int] = None, + sample: Optional[float] = None, + num_bins: Optional[int] = None, + num_samples: Optional[int] = None, + sample_ratio: Optional[float] = None, + normalized: bool = False, + ): + """Initialize mutual information loss term. + + See :func:`deepali.losses.functional.mi_loss`. + + """ + if bins is not None: + if num_bins is not None: + raise ValueError( + f"{type(self).__name__}() 'bins' and 'num_bins' are mutually exclusive" + ) + num_bins = bins + if sample is not None: + if sample_ratio is not None or num_samples is not None: + raise ValueError( + f"{type(self).__name__}() 'sample', 'sample_ratio', and 'num_samples' are mutually exclusive" + ) + if 0 < sample < 1: + sample_ratio = float(sample) + else: + try: + num_samples = int(sample) + except TypeError: + pass + if num_samples is None or float(num_samples) != sample: + raise ValueError( + f"{type(self).__name__}() 'sample' must be float in (0, 1) or positive int" + ) + if num_samples == -1: + num_samples = None + if num_samples is not None and ( + not isinstance(num_samples, int) or num_samples <= 0 + ): + raise ValueError( + f"{type(self).__name__}() 'num_samples' must be positive integral value" + ) + if sample_ratio is not None and (sample_ratio <= 0 or sample_ratio >= 1): + raise ValueError( + f"{type(self).__name__}() 'sample_ratio' must be in closed interval [0, 1]" + ) + super().__init__() + self.vmin = vmin + self.vmax = vmax + self.num_bins = num_bins + self.num_samples = num_samples + self.sample_ratio = sample_ratio + self._normalized = normalized + + @property + def bins(self) -> int: + return self.num_bins + + @property + def normalized(self) -> bool: + return self._normalized + + def forward( + self, + source: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + """Evaluate patch dissimilarity loss.""" + return L.mi_loss( + source, + target, + mask=mask, + vmin=self.vmin, + vmax=self.vmax, + num_bins=self.num_bins, + num_samples=self.num_samples, + sample_ratio=self.sample_ratio, + normalized=self.normalized, + ) + + def extra_repr(self) -> str: + return f"vmin={self.vmin!r}, vmax={self.vmax!r}, num_bins={self.num_bins!r}, num_samples={self.num_samples!r}, sampling_ratio={self.sample_ratio!r}, normalized={self.normalized!r}" + + +class NMI(MI): + """Normalized mutual information loss using Parzen window estimate with Gaussian kernel.""" + + def __init__( + self, + vmin: Optional[float] = None, + vmax: Optional[float] = None, + bins: Optional[int] = None, + sample: Optional[float] = None, + num_bins: Optional[int] = None, + num_samples: Optional[int] = None, + sample_ratio: Optional[float] = None, + ): + """Initialize normalized mutual information loss term. + + See :func:`deepali.losses.functional.nmi_loss`. + + """ + super().__init__( + vmin=vmin, + vmax=vmax, + bins=bins, + sample=sample, + num_bins=num_bins, + num_samples=num_samples, + sample_ratio=sample_ratio, + ) + + def extra_repr(self) -> str: + return f"vmin={self.vmin!r}, vmax={self.vmax!r}, num_bins={self.num_bins!r}, num_samples={self.num_samples!r}, sampling_ratio={self.sample_ratio!r}" + + +class PatchwiseImageLoss(PairwiseImageLoss): + """Pairwise similarity of 2D image patches defined within a 3D volume.""" + + def __init__(self, patches: paddle.Tensor, loss_fn: PairwiseImageLoss = SSD()): + """Initialize loss term. + + Args: + patches: Patch sampling points as tensor of shape ``(N, Z, Y, X, 3)``. + loss_fn: Pairwise image similarity loss term used to evaluate similarity of patches. + + """ + super().__init__() + if not isinstance(patches, paddle.Tensor): + raise TypeError("PatchwiseImageLoss() 'patches' must be paddle.Tensor") + if not patches.ndim == 5 or tuple(patches.shape)[-1] != 3: + raise ValueError( + "PatchwiseImageLoss() 'patches' must have shape (N, Z, Y, X, 3)" + ) + self.patches = patches + self.loss_fn = loss_fn + + def forward( + self, + source: paddle.Tensor, + target: paddle.Tensor, + mask: Optional[paddle.Tensor] = None, + ) -> paddle.Tensor: + """Evaluate patch dissimilarity loss.""" + if target.ndim != 5: + raise ValueError( + "PatchwiseImageLoss.forward() 'target' must have shape (N, C, Z, Y, X)" + ) + if tuple(source.shape) != tuple(target.shape): + raise ValueError( + "PatchwiseImageLoss.forward() 'source' must have same shape as 'target'" + ) + if mask is not None: + if tuple(mask.shape) != tuple(target.shape): + raise ValueError( + "PatchwiseImageLoss.forward() 'mask' must have same shape as 'target'" + ) + mask = self._reshape(U.grid_sample_mask(mask, self.patches)) + source = self._reshape(U.grid_sample(source, self.patches)) + target = self._reshape(U.grid_sample(target, self.patches)) + return self.loss_fn(source, target, mask=mask) + + @staticmethod + def _reshape(x: paddle.Tensor) -> paddle.Tensor: + """Reshape tensor to (N * Z, C, 1, Y, X) such that each patch is a separate image in the batch.""" + N, C, Z, Y, X = tuple(x.shape) + x = x.transpose(perm=[0, 2, 1, 3, 4]) + x = x.reshape(N * Z, C, 1, Y, X) + return x + + +PatchLoss = PatchwiseImageLoss diff --git a/jointContribution/HighResolution/deepali/losses/params.py b/jointContribution/HighResolution/deepali/losses/params.py new file mode 100644 index 0000000000..8f3587d663 --- /dev/null +++ b/jointContribution/HighResolution/deepali/losses/params.py @@ -0,0 +1,55 @@ +import paddle + +from .base import ParamsLoss + + +class L1Norm(ParamsLoss): + """Regularization loss term based on L1-norm of model parameters.""" + + def __init__(self, scale: float = 1000.0) -> None: + """Initialize loss term. + + Args: + scale: Constant factor by which to scale loss value such that magnitude is in + similar range to other registration loss terms, i.e., image matching terms. + + """ + super().__init__() + self.scale = scale + + def forward(self, params: paddle.Tensor) -> paddle.Tensor: + """Evaluate loss term for given model parameters.""" + return self.scale * params.abs().mean() + + +L1_Norm = L1Norm + + +class L2Norm(ParamsLoss): + """Regularization loss term based on L2-norm of model parameters.""" + + def __init__(self, scale: float = 1000000.0) -> None: + """Initialize loss term. + + Args: + scale: Constant factor by which to scale loss value such that magnitude is in + similar range to other registration loss terms, i.e., image matching terms. + + """ + super().__init__() + self.scale = scale + + def forward(self, params: paddle.Tensor) -> paddle.Tensor: + """Evaluate loss term for given model parameters.""" + return self.scale * params.square().mean() + + +L2_Norm = L2Norm + + +class Sparsity(ParamsLoss): + """Regularization loss term encouraging sparsity of non-zero model parameters.""" + + def forward(self, params: paddle.Tensor) -> paddle.Tensor: + """Evaluate loss term for given model parameters.""" + return params.abs().mean() diff --git a/jointContribution/HighResolution/deepali/losses/pointset.py b/jointContribution/HighResolution/deepali/losses/pointset.py new file mode 100644 index 0000000000..2d819a61a5 --- /dev/null +++ b/jointContribution/HighResolution/deepali/losses/pointset.py @@ -0,0 +1,83 @@ +import paddle + +from ..core import functional as U +from .base import PointSetDistance + + +class ClosestPointDistance(PointSetDistance): + """Average closest point distance.""" + + def __init__(self, scale: float = 10, split_size: int = 100000.0): + """Initialize closest point distance loss. + + Args: + scale: Constant factor by which to scale average closest point + distance value such that magnitude is in similar range to + other registration loss terms, i.e., image similarity losses. + split_size: Number of points by which to split point sets during + distance calculation to avoid running out of memory. For each + split, all pairwise point distances are calculated, followed + by a reduction of the results by selecting the minimum across + all splits. + + """ + super().__init__() + self.scale = float(scale) + self.split_size = int(split_size) + + def forward(self, x: paddle.Tensor, *ys: paddle.Tensor) -> paddle.Tensor: + """Evaluate point set distance.""" + if not ys: + raise ValueError( + f"{type(self).__name__}.forward() requires at least two point sets" + ) + x = x.astype(dtype="float32") + loss = paddle.to_tensor(data=0, dtype=x.dtype, place=x.place) + for y in ys: + with paddle.no_grad(): + indices = U.closest_point_indices(x, y, split_size=self.split_size) + y = U.batched_index_select(y, 1, indices) + dists: paddle.Tensor = paddle.linalg.norm(x=x - y, p=2, axis=2) + loss += dists.mean() + return self.scale * loss / len(ys) + + def extra_repr(self) -> str: + return f"scale={self.scale}, split_size={self.split_size}" + + +CPD = ClosestPointDistance + + +class LandmarkPointDistance(PointSetDistance): + """Average distance between corresponding landmarks.""" + + def __init__(self, scale: float = 10): + """Initialize point distance loss. + + Args: + scale: Constant factor by which to scale average point distance value + such that magnitude is in similar range to other registration loss + terms, i.e., image similarity losses. + + """ + super().__init__() + self.scale = float(scale) + + def forward(self, x: paddle.Tensor, *ys: paddle.Tensor) -> paddle.Tensor: + """Evaluate point set distance.""" + if not ys: + raise ValueError( + f"{type(self).__name__}.forward() requires at least two point sets" + ) + x = x.astype(dtype="float32") + loss = paddle.to_tensor(data=0, dtype=x.dtype, place=x.place) + for y in ys: + dists: paddle.Tensor = paddle.linalg.norm(x=x - y, p=2, axis=2) + loss += dists.mean() + return self.scale * loss / len(ys) + + def extra_repr(self) -> str: + return f"scale={self.scale}" + + +LPD = LandmarkPointDistance diff --git a/jointContribution/HighResolution/deepali/modules/__init__.py b/jointContribution/HighResolution/deepali/modules/__init__.py new file mode 100644 index 0000000000..1e8dc0eac1 --- /dev/null +++ b/jointContribution/HighResolution/deepali/modules/__init__.py @@ -0,0 +1,48 @@ +"""Modules without learnable parameters. + +This library defines subclasses of ``paddle.nn.Layer`` which expose the tensor operations +available in the :mod:`.core` library via a stateful functor object that can be used +in models to perform predefined operations with in general no optimizable parameters. + +""" +from .basic import GetItem +from .basic import Narrow +from .basic import Pad +from .basic import Reshape +from .basic import View +from .flow import ExpFlow +from .image import BlurImage +from .image import FilterImage +from .image import GaussianConv +from .lambd import LambdaFunc +from .lambd import LambdaLayer +from .mixins import DeviceProperty +from .mixins import ReprWithCrossReferences +from .output import ToImmutableOutput +from .sample import AlignImage +from .sample import SampleImage +from .sample import TransformImage +from .utilities import remove_layers_in_state_dict +from .utilities import rename_layers_in_state_dict + +__all__ = ( + "AlignImage", + "BlurImage", + "DeviceProperty", + "ExpFlow", + "FilterImage", + "GaussianConv", + "GetItem", + "LambdaFunc", + "LambdaLayer", + "Narrow", + "Pad", + "ReprWithCrossReferences", + "Reshape", + "SampleImage", + "ToImmutableOutput", + "TransformImage", + "View", + "remove_layers_in_state_dict", + "rename_layers_in_state_dict", +) diff --git a/jointContribution/HighResolution/deepali/modules/basic.py b/jointContribution/HighResolution/deepali/modules/basic.py new file mode 100644 index 0000000000..a1006eff4d --- /dev/null +++ b/jointContribution/HighResolution/deepali/modules/basic.py @@ -0,0 +1,137 @@ +from typing import Any +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Union + +import paddle + +from ..core import functional as U +from ..core.enum import PaddingMode +from ..core.types import ScalarOrTuple + + +class GetItem(paddle.nn.Layer): + """Get item at specified input tensor sequence index or with given dictionary key.""" + + def __init__(self, key: Any) -> None: + """Set item index or key. + + Args: + key: Index of item in sequence of input tensors or key into input map. + + """ + super().__init__() + self.key = key + + def forward( + self, input: Union[Sequence[paddle.Tensor], Mapping[Any, paddle.Tensor]] + ) -> paddle.Tensor: + return input[self.key] + + def extra_repr(self) -> str: + return repr(self.key) + + +class Narrow(paddle.nn.Layer): + """Narrowed version of input tensor.""" + + def __init__(self, dim: int, start: int, length: int) -> None: + super().__init__() + self.dim = dim + self.start = start + self.length = length + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + start_6 = x.shape[self.dim] + self.start if self.start < 0 else self.start + return paddle.slice(x, [self.dim], [start_6], [start_6 + self.length]) + + def extra_repr(self) -> str: + return f"dim={self.dim}, start={self.start}, length={self.length}" + + +class Pad(paddle.nn.Layer): + """Pad tensor.""" + + def __init__( + self, + margin: Optional[ScalarOrTuple[int]] = None, + padding: Optional[ScalarOrTuple[int]] = None, + mode: Union[PaddingMode, str] = PaddingMode.ZEROS, + value: float = 0, + ) -> None: + if margin is None and padding is None: + raise AssertionError("Pad() either 'margin' or 'padding' is required") + if margin is not None and padding is not None: + raise AssertionError("Pad() 'margin' and 'padding' are mutually exclusive") + super().__init__() + self.margin = margin + self.padding = padding + self.mode = PaddingMode(mode) + self.value = value + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + return U.pad( + x, margin=self.margin, num=self.padding, mode=self.mode, value=self.value + ) + + def extra_repr(self) -> str: + if self.margin is None: + s = f"padding={self.padding}" + else: + s = f"margin={self.margin}" + s += f", mode={self.mode.value!r}" + if self.mode is PaddingMode.CONSTANT: + s += f", value={self.value}" + return s + + +class Reshape(paddle.nn.Layer): + """Reshape input tensor. + + This module provides a view of the input tensor without making a copy if possible. + Otherwise, a copy is made of the input data. See ``paddle.reshape()`` for details. + + """ + + def __init__(self, shape: Sequence[int]) -> None: + """Set output tensor shape. + + Args: + shape: Output tensor shape, optionally excluding first batch dimension. + + """ + super().__init__() + self.shape = shape + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + shape = self.shape + if len(shape) == x.ndim - 1: + shape = (-1,) + shape + return x.reshape(shape) + + def extra_repr(self) -> str: + return repr(self.shape) + + +class View(paddle.nn.Layer): + """View input tensor with specified shape.""" + + def __init__(self, shape: Sequence[int]) -> None: + """Set output tensor shape. + + Args: + shape: Output tensor shape, optionally excluding first batch dimension. + + """ + super().__init__() + self.shape = shape + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + shape = self.shape + if len(shape) == x.ndim - 1: + shape = (-1,) + shape + return x.view(*shape) + + def extra_repr(self) -> str: + return repr(self.shape) diff --git a/jointContribution/HighResolution/deepali/modules/flow.py b/jointContribution/HighResolution/deepali/modules/flow.py new file mode 100644 index 0000000000..d98d9bf1f0 --- /dev/null +++ b/jointContribution/HighResolution/deepali/modules/flow.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from copy import copy as shallow_copy +from typing import Optional + +import paddle + +from ..core import ALIGN_CORNERS +from ..core import functional as U + + +class ExpFlow(paddle.nn.Layer): + """Layer that computes exponential map of flow field.""" + + def __init__( + self, + scale: Optional[float] = None, + steps: Optional[int] = None, + align_corners: bool = ALIGN_CORNERS, + ): + """Initialize parameters. + + Args: + scale: Constant scaling factor of input velocities (e.g., -1 for inverse). Default is 1. + steps: Number of squaring steps. + align_corners: Whether input vectors are with respect to ``Axes.CUBE`` (False) + or ``Axes.CUBE_CORNERS`` (True). This flag is passed on to ``grid_sample()``. + + """ + super().__init__() + self.scale = float(1 if scale is None else scale) + self.steps = int(5 if steps is None else steps) + self.align_corners = bool(align_corners) + + def forward(self, x: paddle.Tensor, inverse: bool = False) -> paddle.Tensor: + """Compute exponential map of vector field.""" + scale = self.scale + if inverse: + scale *= -1 + return U.expv( + x, scale=scale, steps=self.steps, align_corners=self.align_corners + ) + + @property + def inv(self) -> ExpFlow: + """Get inverse exponential map. + + .. code-block:: python + + u = exp(v) + w = exp.inv(v) + + """ + return self.inverse() + + def inverse(self) -> ExpFlow: + """Get inverse exponential map.""" + copy = shallow_copy(self) + copy.scale *= -1 + return copy + + def extra_repr(self) -> str: + return f"scale={repr(self.scale)}, steps={repr(self.steps)}" diff --git a/jointContribution/HighResolution/deepali/modules/image.py b/jointContribution/HighResolution/deepali/modules/image.py new file mode 100644 index 0000000000..6cf14593ea --- /dev/null +++ b/jointContribution/HighResolution/deepali/modules/image.py @@ -0,0 +1,117 @@ +import math +from numbers import Number +from typing import Optional +from typing import Union + +import paddle +from pkg_resources import parse_version + +from ..core import functional as U +from ..core.enum import PaddingMode +from ..core.kernels import gaussian1d +from ..core.types import ScalarOrTuple +from ..utils import paddle_aux + + +class FilterImage(paddle.nn.Layer): + """Convoles an image with a predefined filter kernel.""" + + def __init__( + self, + kernel: Optional[paddle.Tensor], + padding: Optional[Union[PaddingMode, str]] = None, + ): + """Initialize parameters. + + Args: + kernel: Predefined convolution kernel. + padding: Image extrapolation mode. + + """ + super().__init__() + self.padding = ( + PaddingMode.CONSTANT if padding is None else PaddingMode.from_arg(padding) + ) + self.register_buffer(name="kernel", tensor=kernel) + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + """Convolve input image with predefined filter kernel.""" + kernel: Optional[paddle.Tensor] = self.kernel + if self.kernel is None or kernel.size < 2: + return x + return U.conv(x, kernel, padding=self.padding) + + def extra_repr(self) -> str: + return f"padding={repr(self.padding.value)}" + + +class BlurImage(FilterImage): + """Blurs an image by a predefined Gaussian low-pass filter.""" + + def __init__(self, sigma: float, padding: Optional[Union[PaddingMode, str]] = None): + """Initialize parameters. + + Args: + sigma: Standard deviation of isotropic Gaussian kernel in grid units (pixel, voxel). + padding: Image extrapolation mode. + + """ + sigma = float(sigma) + kernel = gaussian1d(sigma) if sigma > 0 else None + super().__init__(kernel=kernel, padding=padding) + self.sigma = sigma + + def extra_repr(self) -> str: + return f"sigma={repr(self.sigma)}, " + super().extra_repr() + + +class GaussianConv(paddle.nn.Layer): + """Blurs an image by a predefined Gaussian low-pass filter.""" + + def __init__( + self, channels: int, kernel_size: ScalarOrTuple[int], sigma: float, dim: int = 3 + ) -> None: + """Initialize Gaussian convolution kernel. + + Args: + channels (int, sequence): Number of channels of the input and output tensors. + kernel_size (int, sequence): Size of the gaussian kernel. + sigma (float, sequence): Standard deviation of the gaussian kernel. + dim (int, optional): The number of dimensions of the data. + + """ + if dim < 2 or dim > 3: + raise ValueError(f"Only 2 and 3 dimensions are supported, got: {dim}") + super().__init__() + if isinstance(kernel_size, Number): + kernel_size = (kernel_size,) * dim + if isinstance(sigma, Number): + sigma = (sigma,) * dim + kernel = paddle.to_tensor(data=1, dtype="float32", place="cpu") + mgrids = [paddle.arange(dtype="float32", end=n) for n in kernel_size] + if parse_version(paddle.__version__) < parse_version("1.10"): + mgrids = paddle.meshgrid(mgrids) + else: + mgrids = paddle.meshgrid(mgrids) + norm = math.sqrt(2 * math.pi) + for size, std, mgrid in zip(kernel_size, sigma, mgrids): + mean = (size - 1) / 2 + kernel *= ( + 1 / (std * norm) * paddle.exp(x=-(((mgrid - mean) / std) ** 2) / 2) + ) + kernel = kernel.divide_(y=paddle.to_tensor(kernel.sum())) + kernel = kernel.view(1, 1, *tuple(kernel.shape)) + kernel = kernel.repeat(channels, *((1,) * (kernel.dim() - 1))) + self.register_buffer(name="kernel", tensor=kernel, persistable=True) + self.groups = channels + self.pad = (kernel_size[0] // 2,) * (2 * dim) + self.conv = ( + paddle.nn.functional.conv2d if dim == 2 else paddle.nn.functional.conv3d + ) + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + """Convolve input with Gaussian kernel.""" + kernel: paddle.Tensor = self.kernel + data = paddle_aux._FUNCTIONAL_PAD(pad=self.pad, mode="replicate", x=x) + data = self.conv(data, kernel, groups=self.groups) + return data diff --git a/jointContribution/HighResolution/deepali/modules/lambd.py b/jointContribution/HighResolution/deepali/modules/lambd.py new file mode 100644 index 0000000000..e2dad2c929 --- /dev/null +++ b/jointContribution/HighResolution/deepali/modules/lambd.py @@ -0,0 +1,25 @@ +from typing import Callable + +import paddle + +LambdaFunc = Callable[[paddle.Tensor], paddle.Tensor] + + +class LambdaLayer(paddle.nn.Layer): + """Wrap any tensor operation in a network module.""" + + def __init__(self, func: LambdaFunc) -> None: + """Set callable tensor operation. + + Args: + func: Callable tensor operation. Must be instance of ``paddle.nn.Layer`` + if it contains learnable parameters. In this case, however, the + ``LambdaLayer`` wrapper becomes redundant. Main use is to wrap + non-learnable Python functions. + + """ + super().__init__() + self.func = func + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + return self.func(x) diff --git a/jointContribution/HighResolution/deepali/modules/mixins.py b/jointContribution/HighResolution/deepali/modules/mixins.py new file mode 100644 index 0000000000..835edb7d4e --- /dev/null +++ b/jointContribution/HighResolution/deepali/modules/mixins.py @@ -0,0 +1,74 @@ +from typing import Dict +from typing import Optional + +import paddle + + +class DeviceProperty(object): + """Mixin for paddle.nn.Layer to provide 'device' property.""" + + @property + def device(self) -> str: + """Device of first found module parameter or buffer.""" + for param in self.parameters(): + if param is not None: + return param.place + for buffer in self.buffers(): + if buffer is not None: + return buffer.place + return str("cpu").replace("cuda", "gpu") + + +class ReprWithCrossReferences(object): + """Mixin of __repr__ for paddle.nn.Layer subclasses to include cross-references to reused modules.""" + + def __repr__(self) -> str: + return self._repr_impl() + + def _repr_impl( + self, + prefix: str = "", + module: Optional[paddle.nn.Layer] = None, + memo: Optional[Dict[paddle.nn.Layer, str]] = None, + ) -> str: + if module is None: + module = self + if memo is None: + memo = {} + extra_lines = [] + extra_repr = module.extra_repr() + if extra_repr: + extra_lines = extra_repr.split("\n") + child_lines = [] + for key, child in module._modules.items(): + mod_str = self._repr_impl( + prefix=prefix + key + ".", module=child, memo=memo + ) + mod_str = _addindent(mod_str, 2) + prev_key = memo.get(child) + if prev_key: + mod_str = f"{prev_key}(\n {mod_str}\n)" + mod_str = _addindent(mod_str, 2) + child_lines.append(f"({key}): {mod_str}") + memo[child] = prefix + key + lines = extra_lines + child_lines + main_str = module._get_name() + "(" + if lines: + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += "\n " + "\n ".join(lines) + "\n" + main_str += ")" + return main_str + + +def _addindent(s_: str, numSpaces: int) -> str: + """Add indentation to multi-line string.""" + s = s_.split("\n") + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(numSpaces * " " + line) for line in s] + s = "\n".join(s) + s = first + "\n" + s + return s diff --git a/jointContribution/HighResolution/deepali/modules/output.py b/jointContribution/HighResolution/deepali/modules/output.py new file mode 100644 index 0000000000..a1add0ac69 --- /dev/null +++ b/jointContribution/HighResolution/deepali/modules/output.py @@ -0,0 +1,31 @@ +from typing import Mapping +from typing import Sequence +from typing import Union + +import paddle + +from ..core.nnutils import as_immutable_container + + +class ToImmutableOutput(paddle.nn.Layer): + """Convert input to immutable output container. + + For use with ``paddle.utils.tensorboard.SummaryWriter.add_graph`` when model output is list or dict. + See error message: "Encountering a dict at the output of the tracer might cause the trace to be incorrect, + this is only valid if the container structure does not change based on the module's inputs. Consider using + a constant container instead (e.g. for `list`, use a `tuple` instead. for `dict`, use a `NamedTuple` instead). + If you absolutely need this and know the side effects, pass strict=False to trace() to allow this behavior." + + """ + + def __init__(self, recursive: bool = True) -> None: + super().__init__() + self.recursive = recursive + + def forward( + self, input: Union[paddle.Tensor, Sequence, Mapping] + ) -> Union[paddle.Tensor, Sequence, Mapping]: + return as_immutable_container(input, recursive=self.recursive) + + def extra_repr(self) -> str: + return f"recursive={self.recursive!r}" diff --git a/jointContribution/HighResolution/deepali/modules/sample.py b/jointContribution/HighResolution/deepali/modules/sample.py new file mode 100644 index 0000000000..f25d178a85 --- /dev/null +++ b/jointContribution/HighResolution/deepali/modules/sample.py @@ -0,0 +1,378 @@ +from typing import Dict +from typing import Mapping +from typing import Optional +from typing import Tuple +from typing import Union +from typing import cast +from typing import overload + +import paddle + +from ..core import functional as U +from ..core.enum import PaddingMode +from ..core.enum import Sampling +from ..core.grid import Axes +from ..core.grid import Grid +from ..core.grid import grid_points_transform +from ..core.linalg import homogeneous_matmul +from ..core.linalg import homogeneous_transform +from ..core.types import Scalar + + +class SampleImage(paddle.nn.Layer): + """Sample images at grid points.""" + + def __init__( + self, + target: Grid, + source: Optional[Grid] = None, + axes: Optional[Union[Axes, str]] = None, + sampling: Optional[Union[Sampling, str]] = None, + padding: Optional[Union[PaddingMode, str, Scalar]] = None, + align_centers: bool = False, + ): + """Initialize base class. + + Args: + target: Grid on which to sample transformed images. + source: Grid on which input source images are sampled. + axes: Axes with respect to which grid coordinates are defined. + If ``None``, use cube axes corresponding to ``target.align_corners()`` setting. + sampling: Image interpolation mode. + padding: Image extrapolation mode. + align_centers: Whether to implicitly align the ``target`` and ``source`` centers. + If ``True``, only the affine component of the target to source transformation + is applied. If ``False``, also the translation of grid center points is considered. + + """ + super().__init__() + self._target = target + self._source = source or target + if axes is None: + axes = Axes.from_grid(target) + self._axes = Axes(axes) + self._sampling = Sampling.from_arg(sampling) + if padding is None or isinstance(padding, (PaddingMode, str)): + self._padding = PaddingMode.from_arg(padding) + else: + self._padding = float(padding) + self._align_centers = bool(align_centers) + self.register_buffer(name="matrix", tensor=self._matrix()) + + def axes(self) -> Axes: + """Axes with respect to which target grid points and transformations thereof are defined.""" + return self._axes + + def target_grid(self) -> Grid: + """Target sampling grid.""" + return self._target + + def source_grid(self) -> Grid: + """Source sampling grid.""" + return self._source + + def sampling(self) -> Sampling: + """Image sampling mode.""" + return self._sampling + + def padding(self) -> Union[PaddingMode, Scalar]: + """Image padding mode or value, respectively.""" + return self._padding + + def padding_mode(self) -> PaddingMode: + """Image padding mode.""" + return ( + self._padding + if isinstance(self._padding, PaddingMode) + else PaddingMode.CONSTANT + ) + + def padding_value(self) -> float: + """Image padding value if mode is "constant".""" + return 0.0 if isinstance(self._padding, PaddingMode) else float(self._padding) + + def align_centers(self) -> bool: + """Whether grid center points are implicitly aligned.""" + return self._align_centers + + def align_corners(self) -> bool: + """Whether to sample images using ``align_corners=False`` or ``align_corners=True``.""" + return self._target.align_corners() + + def _matrix(self) -> paddle.Tensor: + """Homogeneous coordinate transformation from target grid points to source grid cube.""" + align_corners = self.align_corners() + to_axes = Axes.from_align_corners(align_corners) + matrix = grid_points_transform(self._target, self._axes, self._source, to_axes) + if self._align_centers: + offset = self._target.world_to_cube( + self._source.center(), align_corners=align_corners + ) + matrix = homogeneous_matmul(matrix, offset) + return matrix.unsqueeze(axis=0) + + def _transform_target_to_source(self, grid: paddle.Tensor) -> paddle.Tensor: + """Transform target grid points to source cube.""" + matrix = cast(paddle.Tensor, self.matrix) + return homogeneous_transform(matrix, grid) + + def _sample_source_image( + self, + grid: paddle.Tensor, + input: Optional[Union[paddle.Tensor, Mapping[str, paddle.Tensor]]] = None, + data: Optional[paddle.paddle.Tensor] = None, + mask: Optional[paddle.Tensor] = None, + ) -> Union[ + paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor], Dict[str, paddle.Tensor] + ]: + """Sample images at specified source grid points.""" + source = {} + output = {} + shape = None + align_corners = self.align_corners() + if not isinstance(grid, paddle.Tensor): + raise TypeError(f"{type(self).__name__}() 'grid' must be paddle.Tensor") + if isinstance(input, dict): + if data is not None or mask is not None: + raise ValueError( + f"{type(self).__name__}() 'input' dict and 'data'/'mask' are mutually exclusive" + ) + for data in input.values(): + if not isinstance(data, paddle.Tensor): + raise TypeError( + f"{type(self).__name__}() 'input' dict values must be paddle.Tensor" + ) + source: Dict[str, paddle.Tensor] = { + name: data for name, data in input.items() if name != "mask" + } + shape = tuple(next(iter(source.values())).shape) if source else None + mask = input.get("mask") + elif isinstance(input, paddle.Tensor): + if data is not None: + raise ValueError( + f"{type(self).__name__}() 'input' and 'data' are mutually exclusive" + ) + source = {"data": input} + shape = tuple(input.shape) + elif input is None: + if data is None and mask is None: + raise ValueError( + f"{type(self).__name__} 'input', 'data', and/or 'mask' is required" + ) + else: + raise TypeError( + f"{type(self).__name__}() 'input' must be paddle.Tensor or Mapping[str, paddle.Tensor]" + ) + for name, data in source.items(): + is_unbatched = data.ndim == tuple(grid.shape)[-1] + 1 + if is_unbatched: + data = data.unsqueeze(axis=0) + data = U.grid_sample( + data, + grid, + mode=self._sampling, + padding=self._padding, + align_corners=align_corners, + ) + if is_unbatched: + data = data.squeeze(axis=0) + output[name] = data + if mask is not None: + if not isinstance(mask, paddle.Tensor): + raise TypeError(f"{type(self).__name__}() 'mask' must be paddle.Tensor") + if shape is not None: + if mask.ndim != len(shape): + raise ValueError( + f"{type(self).__name__}() 'mask' must have same ndim as 'input' data" + ) + if tuple(mask.shape)[0] != shape[0]: + raise ValueError( + f"{type(self).__name__}() 'mask' must have same batch size as 'input' data" + ) + if tuple(mask.shape)[2:] != shape[2:]: + raise ValueError( + f"{type(self).__name__}() 'mask' must have same spatial shape as 'input' data" + ) + temp = U.grid_sample_mask(mask, grid, align_corners=align_corners) + output["mask"] = temp > 0.9 + if isinstance(input, dict): + return output + if data is None: + return output["mask"] + if mask is None: + return output["data"] + return output["data"], output["mask"] + + @overload + def forward( + self, grid: paddle.Tensor, input: paddle.Tensor, data=None, mask=None + ) -> paddle.Tensor: + """Sample batch of images at spatially transformed target grid points.""" + ... + + @overload + def forward( + self, grid: paddle.Tensor, input: Dict[str, paddle.Tensor] + ) -> Dict[str, paddle.Tensor]: + """Sample batch of optionally masked images at spatially transformed target grid points.""" + ... + + def forward( + self, + grid: paddle.Tensor, + input: Optional[Union[paddle.Tensor, Dict[str, paddle.Tensor]]] = None, + data: Optional[paddle.Tensor] = None, + mask: Optional[paddle.Tensor] = None, + ) -> Union[ + paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor], Dict[str, paddle.Tensor] + ]: + """Sample images at target grid points after mapping these to the source grid cube.""" + if grid.ndim == tuple(grid.shape)[-1] + 1: + grid = grid.unsqueeze(axis=0) + grid = self._transform_target_to_source(grid) + return self._sample_source_image(grid, input=input, data=data, mask=mask) + + def extra_repr(self) -> str: + return ( + f"target={repr(self._target)}" + + f", source={repr(self._source)}" + + f", axes={repr(self._axes.value)}" + + f", sampling={repr(self._sampling.value)}" + + f", padding={repr(self._padding.value if isinstance(self._padding, PaddingMode) else self._padding)}" + + f", align_centers={repr(self._align_centers)}" + ) + + +class TransformImage(SampleImage): + """Sample images at transformed target grid points. + + This module can be used for both linear and non-rigid transformations, where the shape of the + input ``transform`` tensor determines the type of transformation. See also ``transform_grid()``. + + """ + + def __init__( + self, + target: Grid, + source: Optional[Grid] = None, + axes: Optional[Union[Axes, str]] = None, + sampling: Union[Sampling, str] = Sampling.LINEAR, + padding: Union[PaddingMode, str, Scalar] = PaddingMode.BORDER, + align_centers: bool = False, + ): + """Initialize module. + + Args: + target: Grid on which to sample transformed images. + source: Grid on which input source images are sampled. + axes: Axes with respect to which transformations are defined. + Use ``Axes.from_grid(target)`` if ``None``. + sampling: Image interpolation mode. + padding: Image extrapolation mode. + align_centers: Whether to implicitly align the ``target`` and ``source`` centers. + If ``True``, only the affine component of the target to source transformation + is applied. If ``False``, also the translation of grid center points is considered. + + """ + super().__init__( + target, + source, + axes=axes, + sampling=sampling, + padding=padding, + align_centers=align_centers, + ) + self.register_buffer(name="grid", tensor=self._grid(), persistable=False) + + def _grid(self) -> paddle.Tensor: + """Target grid points before spatial transformation.""" + return self._target.points(self._axes).unsqueeze(axis=0) + + def forward( + self, + transform: Optional[paddle.Tensor], + input: Optional[Union[paddle.Tensor, Dict[str, paddle.Tensor]]] = None, + data: Optional[paddle.paddle.Tensor] = None, + mask: Optional[paddle.Tensor] = None, + ) -> Union[ + paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor], Dict[str, paddle.Tensor] + ]: + """Sample images at transformed target grid points after mapping these to the source grid cube.""" + grid = cast(paddle.Tensor, self.grid) + if isinstance(transform, paddle.Tensor): + if transform.ndim == tuple(grid.shape)[-1] + 1: + transform = transform.unsqueeze(axis=0) + grid = U.transform_grid(transform, grid, align_corners=self.align_corners()) + elif transform is not None: + raise TypeError("TransformImage() 'transform' must be paddle.Tensor") + grid = self._transform_target_to_source(grid) + return self._sample_source_image(grid, input=input, data=data, mask=mask) + + +class AlignImage(SampleImage): + """Sample images at linearly transformed target grid points. + + Instead of applying two separate linear transformations to the target grid points, this module first composes + the two linear transformations and then applies the composite transformation to the target grid points. + + """ + + def __init__( + self, + target: Grid, + source: Optional[Grid] = None, + axes: Optional[Union[Axes, str]] = None, + sampling: Union[Sampling, str] = Sampling.LINEAR, + padding: Union[PaddingMode, str, Scalar] = PaddingMode.BORDER, + align_centers: bool = False, + ): + """Initialize module. + + Args: + target: Grid on which to sample transformed images. + source: Grid on which input source images are sampled. + axes: Axes with respect to which transformations are defined. + Use ``Axes.from_grid(target)`` if ``None``. + sampling: Image interpolation mode. + padding: Image extrapolation mode. + align_centers: Whether to implicitly align the ``target`` and ``source`` centers. + If ``True``, only the affine component of the target to source transformation + is applied. If ``False``, also the translation of grid center points is considered. + + """ + super().__init__( + target, + source, + axes=axes, + sampling=sampling, + padding=padding, + align_centers=align_centers, + ) + self.register_buffer(name="grid", tensor=self._grid(), persistable=False) + + def _grid(self) -> paddle.Tensor: + """Target grid points before spatial transformation.""" + return self._target.points(self._axes).unsqueeze(axis=0) + + def forward( + self, + transform: Optional[paddle.Tensor], + input: Optional[Union[paddle.Tensor, Dict[str, paddle.Tensor]]] = None, + data: Optional[paddle.paddle.Tensor] = None, + mask: Optional[paddle.Tensor] = None, + ) -> Union[ + paddle.Tensor, Tuple[paddle.Tensor, paddle.Tensor], Dict[str, paddle.Tensor] + ]: + """Sample batch of optionally masked images at linearly transformed target grid points.""" + composite_transform = cast(paddle.Tensor, self.matrix) + if transform is not None: + if not isinstance(transform, paddle.Tensor): + raise TypeError("AlignImage() 'transform' must be paddle.Tensor") + if transform.ndim != 3: + raise ValueError( + "AlignImage() 'transform' must be 3-dimensional tensor" + ) + composite_transform = homogeneous_matmul(composite_transform, transform) + grid = cast(paddle.Tensor, self.grid) + grid = homogeneous_transform(composite_transform, grid) + return self._sample_source_image(grid, input=input, data=data, mask=mask) diff --git a/jointContribution/HighResolution/deepali/modules/utilities.py b/jointContribution/HighResolution/deepali/modules/utilities.py new file mode 100644 index 0000000000..18d3792192 --- /dev/null +++ b/jointContribution/HighResolution/deepali/modules/utilities.py @@ -0,0 +1,115 @@ +from collections import OrderedDict +from collections import defaultdict +from typing import Any +from typing import Dict +from typing import List +from typing import Mapping +from typing import Optional + +import paddle + + +def has_children(module: paddle.nn.Layer) -> bool: + """Check if module has other modules as children.""" + try: + next(iter(module.children())) + except StopIteration: + return False + return True + + +def module_ids_by_class_name( + model: paddle.nn.Layer, + duplicates: bool = False, + containers: bool = False, + memo: Optional[dict] = None, +) -> Dict[str, List[int]]: + """Obtain ids of module objects in model, indexed by module type name. + + Args: + model: Neural network model. + duplicates: Whether to return lists of object ids instead of sets. + containers: Whether to include object ids of modules with children. + memo: Used internally to recursively collect object ids. + + Returns: + Dictionary with module class names as keys and object ids of each type stored list. + + """ + if memo is None: + memo = defaultdict(list) + modules: dict = model._modules + for module in modules.values(): + assert isinstance(module, paddle.nn.Layer) + if containers or not has_children(module): + ids = memo[module.__class__.__name__] + ids.append(id(module)) + module_ids_by_class_name( + module, duplicates=duplicates, containers=containers, memo=memo + ) + if not duplicates: + memo = {k: list(set(v)) for k, v in memo.items()} + return memo + + +def module_counts_by_class_name( + model: paddle.nn.Layer, duplicates: bool = False, containers: bool = False +) -> Dict[str, int]: + """Count how many module objects are in a model of each module type. + + Args: + model: Neural network model. + duplicates: Whether to count duplicate objects. + containers: Whether to include modules with children. + + Returns: + Dictionary with module class names as keys and counts of object ids as values. + + """ + ids = module_ids_by_class_name(model, duplicates=duplicates, containers=containers) + return {k: len(v) for k, v in ids.items()} + + +def remove_layers_in_state_dict( + state: Mapping[str, Any], prefix: str +) -> Dict[str, Any]: + """Remove layers in loaded state dict.""" + metadata = getattr(state, "_metadata", None) + result = OrderedDict(state) + if metadata is not None: + metadata = remove_layers_in_state_dict(metadata, prefix) + setattr(result, "_metadata", metadata) + for key in state.keys(): + if key == prefix: + result.pop(key) + elif key.startswith(prefix + "."): + result.pop(key) + return result + + +def rename_layers_in_state_dict( + state: Mapping[str, Any], + rename: Optional[Mapping[str, str]] = None, + parent: Optional[str] = None, +) -> Dict[str, Any]: + """Rename layers in loaded state dict.""" + metadata = getattr(state, "_metadata", None) + state = OrderedDict(state) + if metadata is not None: + metadata = rename_layers_in_state_dict(metadata, rename=rename, parent=parent) + setattr(state, "_metadata", metadata) + if rename: + for prefix, new_name in rename.items(): + for key in list(state.keys()): + if key == prefix: + value = state.pop(key) + if new_name: + state[new_name] = value + elif key.startswith(prefix + "."): + value = state.pop(key) + new_key = new_name + key[len(prefix) + (0 if new_name else 1) :] + if new_key: + state[new_key] = value + if parent: + state = {f"{parent}.{key}": value for key, value in state.items()} + return state diff --git a/jointContribution/HighResolution/deepali/networks/__init__.py b/jointContribution/HighResolution/deepali/networks/__init__.py new file mode 100644 index 0000000000..31f2a6967d --- /dev/null +++ b/jointContribution/HighResolution/deepali/networks/__init__.py @@ -0,0 +1,13 @@ +"""Basic building blocks and (sub-)networks of learned image registration models. + +For spatial transformation models used in both classic non-learning based registration and +learned image registration, where transformation parameters are inferred by a neural network +from the input data instead of being optimized directly, see the :mod:`.spatial` library. + +Commonly, the neural network model infers the transformation parameters from the input data. +These parameters are then used to evaluate and apply the spatial transformation. For this, the +``params`` attribute of parameteric transformations can be set to a neural network instance of +type ``paddle.nn.Layer``. The input data of the so parametrized spatial transformation is then +set as :meth:`.SpatialTransform.condition`, which constitutes the input of the neural network. + +""" diff --git a/jointContribution/HighResolution/deepali/networks/blocks/__init__.py b/jointContribution/HighResolution/deepali/networks/blocks/__init__.py new file mode 100644 index 0000000000..8483977a2b --- /dev/null +++ b/jointContribution/HighResolution/deepali/networks/blocks/__init__.py @@ -0,0 +1,7 @@ +"""Building blocks to construct subnetworks, and predefined subnetworks (blocks).""" +from .residual import ResidualUnit +from .skip import DenseBlock +from .skip import Shortcut +from .skip import SkipConnection + +__all__ = "DenseBlock", "ResidualUnit", "Shortcut", "SkipConnection" diff --git a/jointContribution/HighResolution/deepali/networks/blocks/residual.py b/jointContribution/HighResolution/deepali/networks/blocks/residual.py new file mode 100644 index 0000000000..b7927c8074 --- /dev/null +++ b/jointContribution/HighResolution/deepali/networks/blocks/residual.py @@ -0,0 +1,316 @@ +from __future__ import annotations + +from typing import Any +from typing import Mapping +from typing import Optional +from typing import Union + +import paddle + +from ...core.enum import PaddingMode +from ...core.types import ScalarOrTuple +from ..layers.acti import ActivationArg +from ..layers.acti import activation +from ..layers.conv import ConvLayer +from ..layers.conv import convolution +from ..layers.conv import same_padding +from ..layers.join import JoinLayer +from ..layers.norm import NormArg +from .skip import SkipConnection +from .skip import SkipFunc + +__all__ = ("ResidualUnit",) + + +class ResidualUnit(SkipConnection): + """Sequence of convolutional layers with a residual skip connection. + + Implements a number of variants of residual blocks as described in: + - He et al., 2015, Deep Residual Learning for Image Recognition, https://arxiv.org/abs/1512.03385 + - He et al., 2016, Identity Mappings in Deep Residual Networks, https://arxiv.org/abs/1603.05027 + - Xie et al., 2017, Aggregated Residual Transformations for Deep Neural Networks, https://arxiv.org/abs/1611.05431 + + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: Optional[int] = None, + num_channels: Optional[int] = None, + kernel_size: ScalarOrTuple[int] = 3, + stride: ScalarOrTuple[int] = 1, + padding: Optional[ScalarOrTuple[int]] = None, + padding_mode: Union[PaddingMode, str] = "zeros", + dilation: ScalarOrTuple[int] = 1, + groups: int = 1, + init: str = "default", + bias: Optional[Union[bool, str]] = False, + norm: NormArg = "batch", + acti: ActivationArg = "relu", + skip: Optional[ + Union[SkipFunc, str, Mapping[str, Any]] + ] = "identity | conv1 | conv", + num_layers: Optional[int] = None, + order: str = "cna", + pre_conv: str = "conv1", + post_conv: str = "conv1", + other: Optional[ResidualUnit] = None, + ) -> None: + """Initialize residual unit. + + Args: + spatial_dims: Number of spatial input and output dimensions. + in_channels: Number of input feature maps. + out_channels: Number of output feature maps. + num_channels: Number of input feature maps to spatial convolutions. + If ``None`` or equal to ``out_channels``, the residual branch consists of ``num_layers`` + of convolutions with specified ``kernel_size`` and pre- or post-activation (and normalization). + Otherwise, the first and last convolution has kernel size 1 in order to match the + specified number of input and output channels of the shortcut connection. + When the interim number of channels is smaller than the input and output channels, + this residual unit is a so-called bottleneck block. + kernel_size: Size of residual convolutions. + pre_conv: Type of pre-convolution when residual unit is a bottleneck block. + The kernel size is set to 1 if "conv1", and ``kernel_size`` otherwise. + post_conv: Type of post-convolution when residual unit is a bottleneck block. See ``pre_conv``. + stride: Stride of first residual convolution. All subsequent convolutions have stride 1. + In case of a bottleneck block, the first convolution with kernel size 1 also has stride 1, + and the specified stride is applied on the second convolution instead. + padding: Padding used by residual convolutions with specified ``kernel_size``. + If specified, must result in a same padding such that spatial tensor shape remains unchanged. + padding_mode: Padding mode used for same padding of input to each convolutional layers. + dilation: Dilation of convolution kernel. + groups: Number of groups into which to split each convolution. Note that combined with a bottleneck, + i.e., where ``num_channels`` is smaller than ``out_channels``, this is equivalent to increasing + the number of residual paths also referred to as cardinality (cf. ResNeXt). For example, setting + ``in_channels=256``, out_channels=256``, ``num_channels=128``, and ``groups=32`` is equivalent to + a residual block consisting of 32 paths with 4 channels in each path which are concatentated prior + to the final convolution with kernel size 1 before the resulting residuals are added to the result + of the (identity) skip connection (cf. Xie et al., 2017, Fig. 3). + init: Mode used to initialize weights of convolution kernels. + bias: Whether to use bias terms of convolutional filters. + norm: Normalization layer to use in each convolutional layer. + acti: Activation function to use in each convolutional layer. + skip: Function along the shortcut connection. If specified, must ensure that the + shape of the output along this shortcut connection is identical to the shape of + the output from the sequence of convolutional layers in this block. If ``None``, + the shortcut connection is the identity map. By default, the identity map is used when the + output shape matches the input shape. Otherwise, a convolution with kernel size 1 is used if + ``in_channels != out_channels``, or a convolution with specified ``kernel_size``. The argument + can be a string specifying which of these, i.e., "identity", "conv1", or "conv", to consider. + Multiple options can be allowed by using the "|" operator. The default is "identity | conv1 | conv". + In order to always use a convolution with kernel size 1, even in case of a stride > 1, use + "identity | conv1". To force the use of the identity, set argument to ``None`` or "identity". + num_layers: Number of convolutional layers in residual block. + order: Order of convolution, normalization, and activation in each layer. + In case of post-convolution activation (post-activiation), the final activation + is applied after the addition of the layer output with the shortcut connection. + Otherwise, the addition is of input and output is the last operation of this block. + other: If specified, convolutions in convolutional layers are reused. The given residual unit + must have created with the same parameters as this residual unit. Note that the resulting + and the other residual unit will share references to the same convolution modules. + + """ + if spatial_dims < 0 or spatial_dims > 3: + raise ValueError("ResidualUnit() 'spatial_dims' must be 1, 2, or 3") + order = order.upper() + if len(set(order)) != len(order) or "C" not in order or "A" not in order: + raise ValueError( + f"ResidualUnit() 'order' must be permutation of unique characters 'a|A' (activation), 'n|N' (norm, optional), and 'c|C' (convolution), got {order!r}" + ) + if order.index("C") < order.index("A"): + post_acti = activation(acti) + else: + post_acti = None + if out_channels is None: + out_channels = in_channels + if num_channels is None: + num_channels = out_channels + is_bottleneck = num_channels != out_channels + if num_layers is None: + num_layers = 3 if is_bottleneck else 2 + if not isinstance(num_layers, int): + raise TypeError("ResidualUnit() 'num_layers' must be int or None") + if is_bottleneck and num_layers < 3: + raise ValueError( + "ResidualUnit() 'num_layers' must be at least 3 in case of a bottleneck block" + ) + elif num_layers < 1: + raise ValueError("ResidualUnit() 'num_layers' must be positive") + residual = paddle.nn.Sequential() + if is_bottleneck: + name = f"layer_{len(residual)}" + if other is None: + other_conv = None + else: + other_layer = getattr(other.residual, name) + assert isinstance(other_layer, ConvLayer) + other_conv = other_layer.conv + assert isinstance(other_conv, paddle.nn.Layer) + if pre_conv == "conv": + pre_kernel_size = kernel_size + elif pre_conv == "conv1": + pre_kernel_size = 1 + else: + raise ValueError("ResidualUnit() 'pre_conv' must be 'conv' or 'conv1'") + conv = ConvLayer( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=num_channels, + kernel_size=pre_kernel_size, + padding_mode=padding_mode, + init=init, + bias=bias, + norm=norm, + acti=acti, + order=order, + conv=other_conv, + ) + residual.add_sublayer(name=name, sublayer=conv) + for i in range( + 1 if is_bottleneck else 0, num_layers - 1 if is_bottleneck else num_layers + ): + name = f"layer_{len(residual)}" + is_first_layer = i == 0 + is_last_layer = i == num_layers - 1 + if other is None: + other_conv = None + else: + other_layer = getattr(other.residual, name) + assert isinstance(other_layer, ConvLayer) + other_conv = other_layer.conv + assert isinstance(other_conv, paddle.nn.Layer) + conv = ConvLayer( + spatial_dims=spatial_dims, + in_channels=in_channels if is_first_layer else num_channels, + out_channels=num_channels, + kernel_size=kernel_size, + stride=stride if is_first_layer else 1, + padding=padding, + padding_mode=padding_mode, + dilation=dilation, + groups=groups, + init=init, + bias=bias, + norm=norm, + acti=None if is_last_layer and post_acti is not None else acti, + order=order, + conv=other_conv, + ) + residual.add_sublayer(name=name, sublayer=conv) + if is_bottleneck: + name = f"layer_{len(residual)}" + if other is None: + other_conv = None + else: + other_layer = getattr(other.residual, name) + assert isinstance(other_layer, ConvLayer) + other_conv = other_layer.conv + assert isinstance(other_conv, paddle.nn.Layer) + if post_conv == "conv": + post_kernel_size = kernel_size + elif post_conv == "conv1": + post_kernel_size = 1 + else: + raise ValueError("ResidualUnit() 'post_conv' must be 'conv' or 'conv1'") + conv = ConvLayer( + spatial_dims=spatial_dims, + in_channels=num_channels, + out_channels=out_channels, + kernel_size=post_kernel_size, + padding_mode=padding_mode, + init=init, + bias=bias, + norm=norm, + acti=acti if post_acti is None else None, + order=order, + conv=other_conv, + ) + residual.add_sublayer(name=name, sublayer=conv) + has_strided_conv = ( + not paddle.to_tensor(data=stride, dtype="int32", place="cpu") + .prod() + .equal(y=1) + ) + if skip in (None, ""): + skip = "identity" + if isinstance(skip, str): + skip_names = [s.strip() for s in skip.lower().split("|")] + if has_strided_conv or in_channels != out_channels: + if "conv" in skip_names and ( + has_strided_conv or "conv1" not in skip_names + ): + skip = {} + elif "conv1" in skip_names: + skip = {"kernel_size": 1, "padding": 0} + elif skip == "identity": + raise ValueError( + f"ResidualUnit() cannot use 'identity' skip connection (has_strided_conv={has_strided_conv}, in_channels={in_channels}, out_channels={out_channels})" + ) + else: + raise ValueError(f"ResidualUnit() invalid 'skip' value {skip!r}") + elif "identity" in skip_names: + skip = paddle.nn.Identity() + elif "conv1" in skip_names: + skip = {"kernel_size": 1, "padding": 0} + elif "conv" in skip_names: + skip = {} + else: + raise ValueError(f"ResidualUnit() invalid 'skip' value {skip!r}") + if isinstance(skip, dict): + if "padding" in skip and "kernel_size" not in skip: + raise ValueError( + "ResidualUnit() 'skip' specifies 'padding' but not 'kernel_size'" + ) + conv_args = dict( + kernel_size=kernel_size if has_strided_conv else 1, + stride=stride, + dilation=1, + init=init, + bias=bias is not False, + ) + conv_args.update(skip) + if "padding" not in conv_args: + conv_args["padding"] = same_padding( + conv_args["kernel_size"], conv_args["dilation"] + ) + skip = convolution(spatial_dims, in_channels, out_channels, **conv_args) + elif not callable(skip): + raise TypeError( + "ResidualUnit() 'skip' must be str, dict, callable, or None" + ) + add = JoinLayer("add") + if post_acti is None: + join = add + else: + join = paddle.nn.Sequential() + join.add_sublayer(name="add", sublayer=add) + join.add_sublayer(name="acti", sublayer=post_acti) + super().__init__(residual, name="residual", skip=skip, join=join) + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.num_channels = num_channels + + def is_bottleneck(self) -> bool: + """Whether this residual unit is a "bottleneck" type ResNet block.""" + return self.out_channels != self.num_channels + + @property + def last_conv_layer(self) -> ConvLayer: + """Get last residual convolutional layer.""" + return self.func[-1] + + def zero_init_residual(self) -> None: + """Zero-initialize normalization layer in residual branch. + + This is so that the residual branch starts with zeros, and each residual block behaves like an identity. + This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677. + + """ + conv_layer = self.last_conv_layer + if conv_layer.has_norm_after_conv(): + weight = getattr(conv_layer.norm, "weight", None) + init_Constant = paddle.nn.initializer.Constant(value=0) + init_Constant(weight) diff --git a/jointContribution/HighResolution/deepali/networks/blocks/skip.py b/jointContribution/HighResolution/deepali/networks/blocks/skip.py new file mode 100644 index 0000000000..4a3adb60f0 --- /dev/null +++ b/jointContribution/HighResolution/deepali/networks/blocks/skip.py @@ -0,0 +1,130 @@ +from typing import Any +from typing import Callable +from typing import Mapping +from typing import Union +from typing import overload + +import paddle +import paddle.nn.Layer as Layer + +from ...modules.mixins import ReprWithCrossReferences +from ..layers.join import JoinFunc +from ..layers.join import join_func + +__all__ = "DenseBlock", "Shortcut", "SkipConnection", "SkipFunc" +SkipFunc = Callable[[paddle.Tensor], paddle.Tensor] + + +class DenseBlock(ReprWithCrossReferences, paddle.nn.Layer): + """Subnetwork with dense skip connections.""" + + @overload + def __init__( + self, *args: paddle.nn.Layer, join: Union[str, JoinFunc], dim: int + ) -> None: + ... + + @overload + def __init__( + self, arg: Mapping[str, Layer], join: Union[str, JoinFunc], dim: int + ) -> None: + ... + + def __init__( + self, *args: Any, join: Union[str, JoinFunc] = "cat", dim: int = 1 + ) -> None: + super().__init__() + if len(args) == 1 and isinstance(args[0], Mapping): + layers = args[0] + else: + layers = {str(i): m for i, m in enumerate(args)} + self.layers = paddle.nn.LayerDict(sublayers=layers) + self.join = join_func(join, dim=dim) + self.is_associative = join in ("add", "cat", "concat") + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + y, ys = x, [x] + join = self.join + is_associative = self.is_associative + for module in self.layers.values(): + x = join(ys) + y = module(x) + ys = [x, y] if is_associative else [*ys, y] + return y + + +class SkipConnection(ReprWithCrossReferences, paddle.nn.Layer): + """Combine input with subnetwork output along a single skip connection.""" + + @overload + def __init__( + self, + *args: paddle.nn.Layer, + name: str = "func", + skip: Union[str, SkipFunc] = "identity", + join: Union[str, JoinFunc] = "cat", + dim: int = 1 + ) -> None: + ... + + @overload + def __init__( + self, + arg: Mapping[str, Layer], + name: str = "func", + skip: Union[str, SkipFunc] = "identity", + join: Union[str, JoinFunc] = "cat", + dim: int = 1, + ) -> None: + ... + + def __init__( + self, + *args: Any, + name: str = "func", + skip: Union[str, SkipFunc] = "identity", + join: Union[str, JoinFunc] = "cat", + dim: int = 1 + ) -> None: + super().__init__() + if len(args) == 1 and isinstance(args[0], paddle.nn.Layer): + func = args[0] + else: + func = paddle.nn.Sequential(*args) + self.name = name + if skip in (None, "identity"): + skip = paddle.nn.Identity() + elif not callable(skip): + raise ValueError( + "SkipConnection() 'skip' must be 'identity', callable, or None" + ) + self.skip = skip + self.join = join_func(join, dim=dim) + self._modules[self.name] = func + + @property + def func(self) -> paddle.nn.Layer: + return self._modules[self.name] + + @property + def shortcut(self) -> SkipFunc: + return self.skip + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + a = self.skip(x) + b = self.func(x) + if not isinstance(a, paddle.Tensor): + raise TypeError( + "SkipConnection() 'skip' function must return a paddle.Tensor" + ) + if not isinstance(b, paddle.Tensor): + raise TypeError("SkipConnection() module must return a paddle.Tensor") + c = self.join([b, a]) + if not isinstance(c, paddle.Tensor): + raise TypeError( + "SkipConnection() 'join' function must return a paddle.Tensor" + ) + return c + + +Shortcut = SkipConnection diff --git a/jointContribution/HighResolution/deepali/networks/layers/__init__.py b/jointContribution/HighResolution/deepali/networks/layers/__init__.py new file mode 100644 index 0000000000..8d18ff913e --- /dev/null +++ b/jointContribution/HighResolution/deepali/networks/layers/__init__.py @@ -0,0 +1,88 @@ +"""Basic network modules, usually with learnable parameters.""" +from ...modules import Pad +from ...modules import Reshape +from ...modules import View +from .acti import Activation +from .acti import ActivationArg +from .acti import ActivationFunc +from .acti import activation +from .acti import is_activation +from .conv import Conv1d +from .conv import Conv2d +from .conv import Conv3d +from .conv import ConvLayer +from .conv import ConvTranspose1d +from .conv import ConvTranspose2d +from .conv import ConvTranspose3d +from .conv import conv_module +from .conv import convolution +from .conv import is_conv_module +from .conv import is_convolution +from .join import JoinFunc +from .join import JoinLayer +from .join import join_func +from .lambd import LambdaFunc +from .lambd import LambdaLayer +from .linear import Linear +from .norm import NormArg +from .norm import NormFunc +from .norm import NormLayer +from .norm import is_batch_norm +from .norm import is_group_norm +from .norm import is_instance_norm +from .norm import is_norm_layer +from .norm import norm_layer +from .norm import normalization +from .pool import PoolArg +from .pool import PoolFunc +from .pool import PoolLayer +from .pool import pool_layer +from .pool import pooling +from .upsample import SubpixelUpsample +from .upsample import Upsample +from .upsample import UpsampleMode + +__all__ = ( + "Pad", + "Reshape", + "View", + "Activation", + "ActivationArg", + "ActivationFunc", + "activation", + "is_activation", + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "ConvLayer", + "conv_module", + "convolution", + "is_convolution", + "is_conv_module", + "JoinLayer", + "JoinFunc", + "join_func", + "Linear", + "NormArg", + "NormFunc", + "NormLayer", + "norm_layer", + "normalization", + "is_batch_norm", + "is_group_norm", + "is_instance_norm", + "is_norm_layer", + "PoolArg", + "PoolFunc", + "PoolLayer", + "pool_layer", + "pooling", + "Upsample", + "UpsampleMode", + "SubpixelUpsample", + "LambdaLayer", + "LambdaFunc", +) diff --git a/jointContribution/HighResolution/deepali/networks/layers/acti.py b/jointContribution/HighResolution/deepali/networks/layers/acti.py new file mode 100644 index 0000000000..cd46372053 --- /dev/null +++ b/jointContribution/HighResolution/deepali/networks/layers/acti.py @@ -0,0 +1,153 @@ +from functools import partial +from typing import Any +from typing import Callable +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Type +from typing import Union + +import paddle + +from .lambd import LambdaLayer + +ActivationFunc = Callable[[paddle.Tensor], paddle.Tensor] +ActivationArg = Union[ActivationFunc, str, Mapping, Sequence, None] +ACTIVATION_TYPES = { + "celu": paddle.nn.CELU, + "elu": paddle.nn.ELU, + "hardtanh": paddle.nn.Hardtanh, + "identity": paddle.nn.Identity, + "none": paddle.nn.Identity, + "relu": paddle.nn.ReLU, + "relu6": paddle.nn.ReLU6, + "lrelu": paddle.nn.LeakyReLU, + "leakyrelu": paddle.nn.LeakyReLU, + "leaky_relu": paddle.nn.LeakyReLU, + "rrelu": paddle.nn.RReLU, + "selu": paddle.nn.SELU, + "gelu": paddle.nn.GELU, + "hardshrink": paddle.nn.Hardshrink, + "hardsigmoid": paddle.nn.Hardsigmoid, + "hardswish": paddle.nn.Hardswish, + "logsigmoid": paddle.nn.LogSigmoid, + "logsoftmax": paddle.nn.LogSoftmax, + "prelu": paddle.nn.PReLU, + "sigmoid": paddle.nn.Sigmoid, + "softmax": paddle.nn.Softmax, + "softmin": paddle.nn.Softmin, + "softplus": paddle.nn.Softplus, + "softshrink": paddle.nn.Softshrink, + "softsign": paddle.nn.Softsign, + "tanh": paddle.nn.Tanh, + "tanhshrink": paddle.nn.Tanhshrink, +} +INPLACE_ACTIVATIONS = { + "elu", + "hardtanh", + "lrelu", + "leakyrelu", + "relu", + "relu6", + "rrelu", + "selu", + "celu", +} +SOFTMINMAX_ACTIVATIONS = {"softmin", "softmax", "logsoftmax"} + + +def activation( + arg: ActivationArg, + *args: Any, + dim: Optional[int] = None, + inplace: Optional[bool] = None, + **kwargs, +) -> paddle.nn.Layer: + """Get activation function. + + Args: + arg: Custom activation function or module, or name of activation function with optional keyword arguments. + args: Arguments to pass to activation init function. + dim: Dimension along which to compute softmax activations (cf. ``ACT_SOFTMINMAX``). Unused by other activations. + inplace: Whether to compute activation output in place. Unused if unsupported by specified activation function. + kwargs: Additional keyword arguments for activation function. Overrides keyword arguments given as second + tuple item when ``arg`` is a ``(name, kwargs)`` tuple instead of a string. + + Returns: + Given activation function when ``arg`` is a ``paddle.nn.Layer``, or a new activation module otherwise. + + """ + if isinstance(arg, paddle.nn.Layer) and not args and not kwargs: + return arg + if callable(arg): + return Activation(arg, *args, **kwargs) + acti_name = "identity" + acti_args = {} + if isinstance(arg, str): + acti_name = arg + elif isinstance(arg, Mapping): + acti_name = arg.get("name") + if not acti_name: + raise ValueError("activation() 'arg' map must contain 'name'") + if not isinstance(acti_name, str): + raise TypeError("activation() 'name' must be str") + acti_args = {key: value for key, value in arg.items() if key != "name"} + elif isinstance(arg, Sequence): + if len(arg) != 2: + raise ValueError("activation() 'arg' sequence must have length two") + acti_name, acti_args = arg + if not isinstance(acti_name, str): + raise TypeError("activation() first 'arg' sequence argument must be str") + if not isinstance(acti_args, dict): + raise TypeError("activation() second 'arg' sequence argument must be dict") + acti_args = acti_args.copy() + elif arg is not None: + raise TypeError( + "activation() 'arg' must be str, mapping, 2-tuple, callable, or None" + ) + acti_name = acti_name.lower() + acti_type: Type[paddle.nn.Layer] = ACTIVATION_TYPES.get(acti_name) + if acti_type is None: + raise ValueError( + f"activation() 'arg' name {acti_name!r} is unknown. Pass a callable activation function or module instead." + ) + acti_args.update(kwargs) + if inplace is not None and acti_name in INPLACE_ACTIVATIONS: + acti_args["inplace"] = bool(inplace) + if acti_name in SOFTMINMAX_ACTIVATIONS: + if dim is None and len(args) == 1 and isinstance(args, int): + dim = args[0] + elif args or acti_args: + raise ValueError("activation() named {act_name!r} has no parameters") + if dim is None: + dim = 1 + acti = acti_type(dim) + else: + acti = acti_type(*args, **acti_args) + return acti + + +class Activation(LambdaLayer): + """Non-linear activation function.""" + + def __init__( + self, + arg: ActivationArg, + *args: Any, + dim: int = 1, + inplace: Optional[bool] = None, + **kwargs, + ) -> None: + if callable(arg): + acti = partial(arg, *args, **kwargs) if args or kwargs else arg + else: + acti = activation(arg, *args, dim=dim, inplace=inplace, **kwargs) + super().__init__(acti) + + +def is_activation(arg: Any) -> bool: + """Whether given object is an non-linear activation function module.""" + if isinstance(arg, Activation): + return True + types = tuple(ACTIVATION_TYPES.values()) + return isinstance(arg, types) diff --git a/jointContribution/HighResolution/deepali/networks/layers/conv.py b/jointContribution/HighResolution/deepali/networks/layers/conv.py new file mode 100644 index 0000000000..60a1113667 --- /dev/null +++ b/jointContribution/HighResolution/deepali/networks/layers/conv.py @@ -0,0 +1,463 @@ +import math +from collections import OrderedDict +from numbers import Integral +from typing import Any +from typing import Optional +from typing import Union + +import paddle +from paddle.nn import Conv1D as _Conv1d +from paddle.nn import Conv1DTranspose as _ConvTranspose1d +from paddle.nn import Conv2D as _Conv2d +from paddle.nn import Conv2DTranspose as _ConvTranspose2d +from paddle.nn import Conv3D as _Conv3d +from paddle.nn import Conv3DTranspose as _ConvTranspose3d + +from ...core.enum import PaddingMode +from ...core.nnutils import same_padding +from ...core.nnutils import stride_minus_kernel_padding +from ...core.types import ScalarOrTuple +from ...core.types import ScalarOrTuple1d +from ...core.types import ScalarOrTuple2d +from ...core.types import ScalarOrTuple3d +from ...modules import ReprWithCrossReferences +from .acti import ActivationArg +from .acti import activation +from .norm import NormArg +from .norm import normalization + +__all__ = ( + "Conv1d", + "Conv2d", + "Conv3d", + "ConvTranspose1d", + "ConvTranspose2d", + "ConvTranspose3d", + "ConvLayer", + "convolution", + "conv_module", + "is_convolution", + "is_conv_module", + "is_transposed_convolution", + "same_padding", + "stride_minus_kernel_padding", +) + + +class _ConvInit(object): + """Mix-in for initialization of convolutional layer parameters.""" + + def reset_parameters(self) -> None: + if self.weight_init in ("default", "uniform"): + init_KaimingUniform = paddle.nn.initializer.KaimingUniform( + negative_slope=math.sqrt(5), nonlinearity="leaky_relu" + ) + init_KaimingUniform(self.weight) + elif self.weight_init == "xavier": + init_XavierUniform = paddle.nn.initializer.XavierUniform() + init_XavierUniform(self.weight) + elif self.weight_init == "constant": + init_Constant = paddle.nn.initializer.Constant(value=0.1) + init_Constant(self.weight) + elif self.weight_init == "zeros": + init_Constant = paddle.nn.initializer.Constant(value=0.0) + init_Constant(self.weight) + else: + raise AssertionError( + f"{type(self).__name__}.reset_parameters() invalid 'init' value: {self.weight_init!r}" + ) + if self.bias is not None: + if self.bias_init in ("default", "uniform"): + fan_in, _ = paddle.nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + init_Uniform = paddle.nn.initializer.Uniform(low=-bound, high=bound) + init_Uniform(self.bias) + elif self.bias_init == "constant": + init_Constant = paddle.nn.initializer.Constant(value=0.1) + init_Constant(self.bias) + elif self.bias_init == "zeros": + init_Constant = paddle.nn.initializer.Constant(value=0.0) + init_Constant(self.bias) + else: + raise AssertionError( + f"{type(self).__name__}.reset_parameters() invalid 'bias' value: {self.bias_init!r}" + ) + + +class Conv1d(_ConvInit, paddle.nn.modules.Conv1d): + """Convolutional layer with custom initialization of learnable parameters.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: ScalarOrTuple1d[int], + stride: ScalarOrTuple1d[int] = 1, + padding: ScalarOrTuple1d[int] = 0, + dilation: ScalarOrTuple1d[int] = 1, + groups: int = 1, + bias: Union[bool, str] = True, + padding_mode: str = "zeros", + init: str = "default", + ): + self.bias_init = "uniform" if isinstance(bias, bool) else bias + self.weight_init = "uniform" if init == "default" else init + super().__init__( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bool(bias), + padding_mode=padding_mode, + ) + + +class Conv2d(_ConvInit, paddle.nn.modules.Conv2d): + """Convolutional layer with custom initialization of learnable parameters.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: ScalarOrTuple2d[int], + stride: ScalarOrTuple2d[int] = 1, + padding: ScalarOrTuple2d[int] = 0, + dilation: ScalarOrTuple2d[int] = 1, + groups: int = 1, + bias: Union[bool, str] = True, + padding_mode: str = "zeros", + init: str = "default", + ): + self.bias_init = "uniform" if isinstance(bias, bool) else bias + self.weight_init = "uniform" if init == "default" else init + super().__init__( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bool(bias), + padding_mode=padding_mode, + ) + + +class Conv3d(_ConvInit, paddle.nn.modules.Conv3d): + """Convolutional layer with custom initialization of learnable parameters.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: ScalarOrTuple3d[int], + stride: ScalarOrTuple3d[int] = 1, + padding: ScalarOrTuple3d[int] = 0, + dilation: ScalarOrTuple3d[int] = 1, + groups: int = 1, + bias: Union[bool, str] = True, + padding_mode: str = "zeros", + init: str = "uniform", + ): + self.bias_init = "uniform" if isinstance(bias, bool) else bias + self.weight_init = "uniform" if init == "default" else init + super().__init__( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bool(bias), + padding_mode=padding_mode, + ) + + +class ConvTranspose1d(_ConvInit, paddle.nn.modules.ConvTranspose1d): + """Transposed convolution in 1D with custom initialization of learnable parameters.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: ScalarOrTuple1d[int], + stride: ScalarOrTuple1d[int] = 1, + padding: ScalarOrTuple1d[int] = 0, + output_padding: ScalarOrTuple1d[int] = 0, + dilation: ScalarOrTuple1d[int] = 1, + groups: int = 1, + bias: Union[bool, str] = True, + padding_mode: str = "zeros", + init: str = "default", + ): + self.bias_init = "uniform" if isinstance(bias, bool) else bias + self.weight_init = "uniform" if init == "default" else init + super().__init__( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + dilation=dilation, + groups=groups, + bias=bool(bias), + padding_mode=padding_mode, + ) + + +class ConvTranspose2d(_ConvInit, paddle.nn.modules.ConvTranspose2d): + """Transposed convolution in 2D with custom initialization of learnable parameters.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: ScalarOrTuple2d[int], + stride: ScalarOrTuple2d[int] = 1, + padding: ScalarOrTuple2d[int] = 0, + output_padding: ScalarOrTuple2d[int] = 0, + dilation: ScalarOrTuple2d[int] = 1, + groups: int = 1, + bias: Union[bool, str] = True, + padding_mode: str = "zeros", + init: str = "default", + ): + self.bias_init = "uniform" if isinstance(bias, bool) else bias + self.weight_init = "uniform" if init == "default" else init + super().__init__( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + dilation=dilation, + groups=groups, + bias=bool(bias), + padding_mode=padding_mode, + ) + + +class ConvTranspose3d(_ConvInit, paddle.nn.modules.ConvTranspose3d): + """Transposed convolution in 3D with custom initialization of learnable parameters.""" + + def __init__( + self, + in_channels: int, + out_channels: int, + kernel_size: ScalarOrTuple3d[int], + stride: ScalarOrTuple3d[int] = 1, + padding: ScalarOrTuple3d[int] = 0, + output_padding: ScalarOrTuple3d[int] = 0, + dilation: ScalarOrTuple3d[int] = 1, + groups: int = 1, + bias: Union[bool, str] = True, + padding_mode: str = "zeros", + init: str = "default", + ): + self.bias_init = "uniform" if isinstance(bias, bool) else bias + self.weight_init = "uniform" if init == "default" else init + super().__init__( + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + dilation=dilation, + groups=groups, + bias=bool(bias), + padding_mode=padding_mode, + ) + + +class ConvLayer(ReprWithCrossReferences, paddle.nn.modules.Sequential): + """Convolutional layer with optional pre- or post-convolution normalization and/or activation.""" + + def __init__( + self, + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: ScalarOrTuple[int], + stride: ScalarOrTuple[int] = 1, + padding: Optional[ScalarOrTuple[int]] = None, + output_padding: Optional[ScalarOrTuple[int]] = None, + padding_mode: Union[PaddingMode, str] = "zeros", + dilation: ScalarOrTuple[int] = 1, + groups: int = 1, + init: str = "default", + bias: Optional[Union[bool, str]] = None, + norm: NormArg = None, + acti: ActivationArg = None, + order: str = "CNA", + transposed: bool = False, + conv: Optional[paddle.nn.Layer] = None, + ) -> None: + if spatial_dims < 0 or spatial_dims > 3: + raise ValueError("ConvLayer() 'spatial_dims' must be 1, 2, or 3") + if isinstance(kernel_size, Integral): + kernel_size = (int(kernel_size),) * spatial_dims + if padding is None: + padding = same_padding(kernel_size, dilation) + if not isinstance(order, str): + raise TypeError("ConvLayer() 'order' must be str") + order = order.upper() + if "C" not in order: + raise ValueError("ConvLayer() 'order' must contain 'C' for convolution") + if "D" in order: + raise NotImplementedError( + "ConvLayer() 'order' has 'D' for not implemented dropout" + ) + if len(set(order)) != len(order) or any( + c not in {"A", "C", "N"} for c in order + ): + raise ValueError( + f"ConvLayer() 'order' must be permutation of characters 'A' (activation), 'N' (norm), and 'C' (convolution), got {order!r}" + ) + if acti is None or "A" not in order: + acti_fn = None + else: + acti_fn = activation(acti) + norm_after_conv = False + if norm is None or "N" not in order: + norm_layer = None + else: + norm_after_conv = order.index("C") < order.index("N") + num_features = out_channels if norm_after_conv else in_channels + norm_layer = normalization( + norm, spatial_dims=spatial_dims, num_features=num_features + ) + if bias is None: + bias = True + if norm_after_conv and isinstance(norm_layer, paddle.nn.Layer): + bias = not getattr(norm_layer, "affine", False) + if output_padding is None: + output_padding = stride_minus_kernel_padding(1, stride) + if conv is None: + conv = convolution( + spatial_dims, + in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + output_padding=output_padding, + padding_mode=padding_mode, + dilation=dilation, + groups=groups, + bias=bias, + init=init, + transposed=transposed, + ) + modules = { + "A": ("acti", acti_fn), + "N": ("norm", norm_layer), + "C": ("conv", conv), + } + modules = OrderedDict(modules[k] for k in order if modules[k][1] is not None) + super().__init__(modules) + if acti_fn is None: + self.acti = None + if norm_layer is None: + self.norm = None + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = out_channels + self.order = order + + def has_norm_after_conv(self) -> bool: + """Whether this layer has a normalization layer after the convolution.""" + names = tuple( + name for name, _ in self.named_sublayers() if name in ("conv", "norm") + ) + return len(names) == 2 and names[-1] == "norm" + + +def convolution( + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: ScalarOrTuple[int], + stride: ScalarOrTuple[int] = 1, + padding: Optional[ScalarOrTuple[int]] = 0, + output_padding: Optional[ScalarOrTuple[int]] = 0, + padding_mode: Union[PaddingMode, str] = "zeros", + dilation: ScalarOrTuple[int] = 1, + groups: int = 1, + init: str = "default", + bias: Union[bool, str] = True, + transposed: bool = False, +) -> paddle.nn.Layer: + """Create convolution module for specified number of spatial input tensor dimensions.""" + if in_channels < 1: + raise ValueError( + f"convolution() 'in_channels' ({in_channels}) must be positive" + ) + if out_channels < 1: + raise ValueError( + f"convolution() 'out_channels' ({out_channels}) must be positive" + ) + padding_mode = PaddingMode.from_arg(padding_mode) + if padding_mode is PaddingMode.NONE: + padding = 0 + elif padding is None: + padding = same_padding(kernel_size, dilation) + kwargs = dict( + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + init=init, + ) + if any(n != 0 for n in ((padding,) if isinstance(padding, int) else padding)): + kwargs["padding_mode"] = padding_mode.conv_mode(spatial_dims) + if transposed: + if output_padding is None: + output_padding = stride_minus_kernel_padding(1, stride) + kwargs["output_padding"] = output_padding + if spatial_dims == 1: + conv_type = ConvTranspose1d if transposed else Conv1d + elif spatial_dims == 2: + conv_type = ConvTranspose2d if transposed else Conv2d + elif spatial_dims == 3: + conv_type = ConvTranspose3d if transposed else Conv3d + else: + raise ValueError("convolution() 'spatial_dims' must be 1, 2, or 3") + return conv_type(in_channels, out_channels, **kwargs) + + +def conv_module(*args, **kwargs) -> paddle.nn.Layer: + """Create convolution layer, see ``convolution()``.""" + return convolution(*args, **kwargs) + + +def is_convolution(arg: Any) -> bool: + """Whether given module is a learnable convolution.""" + types = ( + _Conv1d, + _Conv2d, + _Conv3d, + _ConvTranspose1d, + _ConvTranspose2d, + _ConvTranspose3d, + ) + return isinstance(arg, types) + + +def is_conv_module(arg: Any) -> bool: + """Whether given module is a learnable convolution.""" + return is_convolution(arg) + + +def is_transposed_convolution(arg: Any) -> bool: + """Whether given module is a learnable transposed convolution.""" + types = _ConvTranspose1d, _ConvTranspose2d, _ConvTranspose3d + return isinstance(arg, types) diff --git a/jointContribution/HighResolution/deepali/networks/layers/join.py b/jointContribution/HighResolution/deepali/networks/layers/join.py new file mode 100644 index 0000000000..4bd645ddea --- /dev/null +++ b/jointContribution/HighResolution/deepali/networks/layers/join.py @@ -0,0 +1,66 @@ +from typing import Callable +from typing import Sequence +from typing import Union + +import paddle + +from .lambd import LambdaLayer + +JoinFunc = Callable[[Sequence[paddle.Tensor]], paddle.Tensor] + + +def join_func(arg: Union[str, JoinFunc], dim: int = 1) -> JoinFunc: + """paddle.Tensor operation which combines features of input tensors, e.g., along skip connection. + + Args: + arg: Name of operation: 'add': Elementwise addition, 'cat' or 'concat': Concatenate along feature dimension. + dim: Dimension of input tensors containing features. + + """ + if callable(arg): + return arg + if not isinstance(arg, str): + raise TypeError("join_func() 'arg' must be str or callable") + name = arg.lower() + if name == "add": + + def add(args: Sequence[paddle.Tensor]) -> paddle.Tensor: + assert args, "join_func('add') requires at least one input tensor" + out = args[0] + for i in range(1, len(args)): + out = out + args[i] + return out + + return add + elif name in ("cat", "concat"): + + def cat(args: Sequence[paddle.Tensor]) -> paddle.Tensor: + assert args, "join_func('cat') requires at least one input tensor" + return paddle.concat(x=args, axis=dim) + + return cat + elif name == "mul": + + def mul(args: Sequence[paddle.Tensor]) -> paddle.Tensor: + assert args, "join_func('mul') requires at least one input tensor" + out = args[0] + for i in range(1, len(args)): + out = out * args[i] + return out + + return mul + raise ValueError("join_func() unknown merge function {name!r}") + + +class JoinLayer(LambdaLayer): + """Merge network branches.""" + + def __init__(self, arg: Union[str, JoinFunc], dim: int = 1) -> None: + func = join_func(arg, dim=dim) + super().__init__(func) + + def forward(self, xs: Sequence[paddle.Tensor]) -> paddle.Tensor: + return self.func(xs) + + def extra_repr(self) -> str: + return repr(self.func.__name__) diff --git a/jointContribution/HighResolution/deepali/networks/layers/lambd.py b/jointContribution/HighResolution/deepali/networks/layers/lambd.py new file mode 100644 index 0000000000..fc5b5ce7f5 --- /dev/null +++ b/jointContribution/HighResolution/deepali/networks/layers/lambd.py @@ -0,0 +1,4 @@ +from ...modules.lambd import LambdaFunc +from ...modules.lambd import LambdaLayer + +__all__ = "LambdaFunc", "LambdaLayer" diff --git a/jointContribution/HighResolution/deepali/networks/layers/linear.py b/jointContribution/HighResolution/deepali/networks/layers/linear.py new file mode 100644 index 0000000000..264637ee99 --- /dev/null +++ b/jointContribution/HighResolution/deepali/networks/layers/linear.py @@ -0,0 +1,52 @@ +import math +from typing import Optional + +import paddle + + +class Linear(paddle.nn.linear): + """Fully connected layer.""" + + def __init__( + self, + in_features: int, + out_features: int, + bias: bool = True, + init: Optional[str] = "uniform", + ) -> None: + self.bias_init = "uniform" if isinstance(bias, bool) else bias + self.weight_init = paddle.nn.init + super().__init__(in_features, out_features, bias=bool(bias)) + + def reset_parameters(self) -> None: + if self.weight_init == "uniform": + init_KaimingUniform = paddle.nn.initializer.KaimingUniform( + negative_slope=math.sqrt(5), nonlinearity="leaky_relu" + ) + init_KaimingUniform(self.weight) + elif self.weight_init == "constant": + init_Constant = paddle.nn.initializer.Constant(value=0.1) + init_Constant(self.weight) + elif self.weight_init == "zeros": + init_Constant = paddle.nn.initializer.Constant(value=0.0) + init_Constant(self.weight) + else: + raise AssertionError( + "Linear.reset_parameters() invalid 'init' value: {self.weight_init!r}" + ) + if self.bias is not None: + if self.bias_init == "uniform": + fan_in, _ = paddle.nn.init._calculate_fan_in_and_fan_out(self.weight) + bound = 1 / math.sqrt(fan_in) + init_Uniform = paddle.nn.initializer.Uniform(low=-bound, high=bound) + init_Uniform(self.bias) + elif self.bias_init == "constant": + init_Constant = paddle.nn.initializer.Constant(value=0.1) + init_Constant(self.bias) + elif self.bias_init == "zeros": + init_Constant = paddle.nn.initializer.Constant(value=0.0) + init_Constant(self.bias) + else: + raise AssertionError( + "Linear.reset_parameters() invalid 'bias' value: {self.bias_init!r}" + ) diff --git a/jointContribution/HighResolution/deepali/networks/layers/norm.py b/jointContribution/HighResolution/deepali/networks/layers/norm.py new file mode 100644 index 0000000000..5452a00583 --- /dev/null +++ b/jointContribution/HighResolution/deepali/networks/layers/norm.py @@ -0,0 +1,172 @@ +from functools import partial +from numbers import Integral +from typing import Any +from typing import Callable +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Union + +import paddle + +from .lambd import LambdaLayer + +NormFunc = Callable[[paddle.Tensor], paddle.Tensor] +NormArg = Union[NormFunc, str, Mapping[str, Any], Sequence, None] + + +def normalization( + arg: NormArg, + *args, + spatial_dims: Optional[int] = None, + num_features: Optional[int] = None, + **kwargs +) -> paddle.nn.Layer: + """Create normalization layer. + + Args: + arg: Custom normalization function or module, or name of normalization layer with optional keyword arguments. + args: Positional arguments passed to normalization layer. + num_features: Number of input features. + spatial_dims: Number of spatial dimensions of input tensors. + kwargs: Additional keyword arguments for normalization layer. Overrides keyword arguments given as second + tuple item when ``arg`` is a ``(name, kwargs)`` tuple instead of a string. + + Returns: + Given normalization function when ``arg`` is a ``paddle.nn.Layer``, or a new normalization layer otherwise. + + """ + if isinstance(arg, paddle.nn.Layer) and not args and not kwargs: + return arg + if callable(arg): + return NormLayer(arg, *args, **kwargs) + norm_name = "identity" + norm_args = {} + if isinstance(arg, str): + norm_name = arg + elif isinstance(arg, Mapping): + norm_name = arg.get("name") + if not norm_name: + raise ValueError("normalization() 'arg' map must contain 'name'") + if not isinstance(norm_name, str): + raise TypeError("normalization() 'name' must be str") + norm_args = {key: value for key, value in arg.items() if key != "name"} + elif isinstance(arg, Sequence): + if len(arg) != 2: + raise ValueError("normalization() 'arg' sequence must have length two") + norm_name, norm_args = arg + if not isinstance(norm_name, str): + raise TypeError("normalization() first 'arg' sequence argument must be str") + if not isinstance(norm_args, dict): + if norm_name == "group" and isinstance(norm_args, Integral): + norm_args = dict(num_groups=norm_args) + else: + raise TypeError( + "normalization() second 'arg' sequence argument must be dict" + ) + norm_args = norm_args.copy() + elif arg is not None: + raise TypeError( + "normalization() 'arg' must be str, mapping, 2-tuple, callable, or None" + ) + norm_name = norm_name.lower() + norm_args.update(kwargs) + if norm_name in ("none", "identity"): + norm = paddle.nn.Identity() + elif norm_name in ("batch", "batchnorm"): + if spatial_dims is None: + raise ValueError("normalization() 'spatial_dims' required for 'batch' norm") + if spatial_dims < 0 or spatial_dims > 3: + raise ValueError("normalization() 'spatial_dims' must be 1, 2, or 3") + if num_features is None: + raise ValueError("normalization() 'num_features' required for 'batch' norm") + norm_type = ( + paddle.nn.BatchNorm1D, + paddle.nn.BatchNorm2D, + paddle.nn.BatchNorm3D, + )[spatial_dims - 1] + norm = norm_type(num_features, *args, **norm_args) + elif norm_name in ("group", "groupnorm"): + num_groups = norm_args.pop("num_groups", 1) + if num_features is None: + if "num_channels" not in norm_args: + raise ValueError( + "normalization() 'num_features' required for 'group' norm" + ) + num_features = norm_args.pop("num_channels") + norm = paddle.nn.GroupNorm(num_groups, num_features, *args, **norm_args) + elif norm_name in ("layer", "layernorm"): + if num_features is None: + raise ValueError("normalization() 'num_features' required for 'layer' norm") + norm = paddle.nn.GroupNorm(1, num_features, *args, **norm_args) + elif norm_name in ("instance", "instancenorm"): + if spatial_dims is None: + raise ValueError( + "normalization() 'spatial_dims' required for 'instance' norm" + ) + if spatial_dims < 0 or spatial_dims > 3: + raise ValueError("normalization() 'spatial_dims' must be 1, 2, or 3") + if num_features is None: + raise ValueError( + "normalization() 'num_features' required for 'instance' norm" + ) + norm_type = ( + paddle.nn.InstanceNorm1D, + paddle.nn.InstanceNorm2D, + paddle.nn.InstanceNorm3D, + )[spatial_dims - 1] + norm = norm_type(num_features, *args, **norm_args) + else: + raise ValueError("normalization() unknown layer type {norm_name!r}") + return norm + + +def norm_layer(*args, **kwargs) -> paddle.nn.Layer: + """Create normalization layer, see ``normalization``.""" + return normalization(*args, **kwargs) + + +def is_norm_layer(arg: Any) -> bool: + """Whether given module is a normalization layer.""" + if isinstance(arg, NormLayer): + return True + return is_batch_norm(arg) or is_group_norm(arg) or is_instance_norm(arg) + + +def is_batch_norm(arg: Any) -> bool: + """Whether given module is a batch normalization layer.""" + return isinstance( + arg, (paddle.nn.BatchNorm1D, paddle.nn.BatchNorm2D, paddle.nn.BatchNorm3D) + ) + + +def is_group_norm(arg: Any) -> bool: + """Whether given module is a group normalization layer.""" + return isinstance(arg, paddle.nn.GroupNorm) + + +def is_instance_norm(arg: Any) -> bool: + """Whether given module is an instance normalization layer.""" + return isinstance( + arg, + (paddle.nn.InstanceNorm1D, paddle.nn.InstanceNorm2D, paddle.nn.InstanceNorm3D), + ) + + +class NormLayer(LambdaLayer): + """Normalization layer.""" + + def __init__( + self, + arg: NormArg, + *args, + spatial_dims: Optional[int] = None, + num_features: Optional[int] = None, + **kwargs: Mapping[str, Any] + ) -> None: + if callable(arg): + norm = partial(arg, *args, **kwargs) if args or kwargs else arg + else: + kwargs.update(dict(spatial_dims=spatial_dims, num_features=num_features)) + norm = normalization(arg, *args, **kwargs) + super().__init__(norm) diff --git a/jointContribution/HighResolution/deepali/networks/layers/pool.py b/jointContribution/HighResolution/deepali/networks/layers/pool.py new file mode 100644 index 0000000000..b9007bedce --- /dev/null +++ b/jointContribution/HighResolution/deepali/networks/layers/pool.py @@ -0,0 +1,131 @@ +from functools import partial +from typing import Any +from typing import Callable +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Type +from typing import Union + +import paddle + +from .lambd import LambdaLayer + +PoolFunc = Callable[[paddle.Tensor], paddle.Tensor] +PoolArg = Union[PoolFunc, str, Mapping, Sequence, None] +POOLING_TYPES = { + "avg": (paddle.nn.AvgPool1D, paddle.nn.AvgPool2D, paddle.nn.AvgPool3D), + "avgpool": (paddle.nn.AvgPool1D, paddle.nn.AvgPool2D, paddle.nn.AvgPool3D), + "adaptiveavg": ( + paddle.nn.AdaptiveAvgPool1D, + paddle.nn.AdaptiveAvgPool2D, + paddle.nn.AdaptiveAvgPool3D, + ), + "adaptiveavgpool": ( + paddle.nn.AdaptiveAvgPool1D, + paddle.nn.AdaptiveAvgPool2D, + paddle.nn.AdaptiveAvgPool3D, + ), + "adaptivemax": ( + paddle.nn.AdaptiveMaxPool1D, + paddle.nn.AdaptiveMaxPool2D, + paddle.nn.AdaptiveMaxPool3D, + ), + "adaptivemaxpool": ( + paddle.nn.AdaptiveMaxPool1D, + paddle.nn.AdaptiveMaxPool2D, + paddle.nn.AdaptiveMaxPool3D, + ), + "max": (paddle.nn.MaxPool1D, paddle.nn.MaxPool2D, paddle.nn.MaxPool3D), + "maxpool": (paddle.nn.MaxPool1D, paddle.nn.MaxPool2D, paddle.nn.MaxPool3D), + "maxunpool": (paddle.nn.MaxUnPool1D, paddle.nn.MaxUnPool2D, paddle.nn.MaxUnPool3D), + "identity": paddle.nn.Identity, +} + + +def pooling( + arg: PoolArg, *args: Any, spatial_dims: Optional[int] = None, **kwargs +) -> paddle.nn.Layer: + """Get pooling layer. + + Args: + arg: Custom pooling function or module, or name of pooling layer with optional keyword arguments. + When ``arg`` is a callable but not of type ``paddle.nn.Layer``, it is wrapped in a ``PoolLayer``. + If ``None`` or 'identity', an instance of ``paddle.nn.Identity`` is returned. + spatial_dims: Number of spatial dimensions of input tensors. + args: Arguments to pass to init function of pooling layer. If ``arg`` is a callable, the given arguments + are passed to the function each time it is called as arguments. + kwargs: Additional keyword arguments for initialization of pooling layer. Overrides keyword arguments given as + second tuple item when ``arg`` is a ``(name, kwargs)`` tuple instead of a string. When ``arg`` is a callable, + the keyword arguments are passed each time the pooling function is called. + + Returns: + Pooling layer instance. + + """ + if isinstance(arg, paddle.nn.Layer) and not args and not kwargs: + return arg + if callable(arg): + return PoolLayer(arg, *args, **kwargs) + pool_name = "identity" + pool_args = {} + if isinstance(arg, str): + pool_name = arg + elif isinstance(arg, Mapping): + pool_name = arg.get("name") + if not pool_name: + raise ValueError("pooling() 'arg' map must contain 'name'") + if not isinstance(pool_name, str): + raise TypeError("pooling() 'name' must be str") + pool_args = {key: value for key, value in arg.items() if key != "name"} + elif isinstance(arg, Sequence): + if len(arg) != 2: + raise ValueError("pooling() 'arg' sequence must have length two") + pool_name, pool_args = arg + if not isinstance(pool_name, str): + raise TypeError("pooling() first 'arg' sequence argument must be str") + if not isinstance(pool_args, dict): + raise TypeError("pooling() second 'arg' sequence argument must be dict") + pool_args = pool_args.copy() + elif arg is not None: + raise TypeError( + "pooling() 'arg' must be str, mapping, 2-tuple, callable, or None" + ) + pool_type: Union[ + Type[paddle.nn.Layer], Sequence[Type[paddle.nn.Layer]] + ] = POOLING_TYPES.get(pool_name.lower()) + if pool_type is None: + raise ValueError(f"pooling() unknown pooling layer {pool_name!r}") + if isinstance(pool_type, Sequence): + if spatial_dims is None: + raise ValueError( + f"pooling() 'spatial_dims' required for pooling layer {pool_name!r}" + ) + try: + pool_type = pool_type[spatial_dims - 1] + except IndexError: + pool_type = None + if pool_type is None: + raise ValueError( + f"pooling({pool_name!r}) does not support spatial_dims={spatial_dims}" + ) + pool_args.update(kwargs) + module = pool_type(*args, **pool_args) + return module + + +def pool_layer(*args, **kwargs) -> paddle.nn.Layer: + return pooling(*args, **kwargs) + + +class PoolLayer(LambdaLayer): + """Pooling layer.""" + + def __init__( + self, arg: PoolArg, *args: Any, spatial_dims: Optional[int] = None, **kwargs + ) -> None: + if callable(arg): + pool = partial(arg, *args, **kwargs) if args or kwargs else arg + else: + pool = pooling(arg, *args, spatial_dims=spatial_dims, **kwargs) + super().__init__(pool) diff --git a/jointContribution/HighResolution/deepali/networks/layers/upsample.py b/jointContribution/HighResolution/deepali/networks/layers/upsample.py new file mode 100644 index 0000000000..a0af7c3a64 --- /dev/null +++ b/jointContribution/HighResolution/deepali/networks/layers/upsample.py @@ -0,0 +1,408 @@ +from enum import Enum +from typing import Optional +from typing import Sequence +from typing import Union + +import paddle +import paddle.nn.Layer as Layer + +from ...core.enum import PaddingMode +from ...core.enum import Sampling +from ...core.grid import ALIGN_CORNERS +from ...core.nnutils import upsample_output_padding +from ...core.nnutils import upsample_padding +from ...core.types import ScalarOrTuple +from ...modules import Pad +from .conv import convolution +from .pool import pooling + +__all__ = "Upsample", "UpsampleMode", "SubpixelUpsample" + + +class UpsampleMode(Enum): + DECONV = "deconv" + INTERPOLATE = "interpolate" + PIXELSHUFFLE = "pixelshuffle" + + +class Upsample(paddle.nn.Sequential): + """Upsamples data by `scale_factor`. + + This module is adapted from the MONAI project. Supported modes are: + + - "deconv": uses a transposed convolution. + - "interpolate": uses :py:class:`paddle.nn.Upsample`. + - "pixelshuffle": uses :py:class:`monai.networks.blocks.SubpixelUpsample`. + + This module can optionally apply a convolution prior to upsampling interpolation + (e.g., used to map the number of features from `in_channels` to `out_channels`). + + """ + + def __init__( + self, + spatial_dims: int, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + scale_factor: Union[Sequence[Union[int, float]], Union[int, float]] = 2, + kernel_size: Optional[ScalarOrTuple[int]] = None, + mode: Union[UpsampleMode, str] = "default", + pre_conv: Optional[Union[Layer, str, int]] = "default", + sampling: Union[Sampling, str] = Sampling.LINEAR, + align_corners: bool = ALIGN_CORNERS, + apply_pad_pool: bool = True, + padding_mode: Union[PaddingMode, str] = PaddingMode.ZEROS, + bias: bool = True, + init: str = "default", + ) -> None: + """Initialize upsampling layer. + + Args: + spatial_dims: Number of spatial dimensions of input tensor. + in_channels: Number of channels of the input tensor. + out_channels: Number of channels of the output tensor. Defaults to `in_channels`. + scale_factor: Multiplier for spatial size. Has to match input size if it is a tuple. + kernel_size: Kernel size of transposed convolution in "deconv" mode or default ``pre_conv`` otherwise. + The default kernel size in "deconv" mode is equal to the specified ``scale_factor``. The default + ``pre_conv`` kernel size in "interpolate" mode is 1. In "pixelshuffle" mode, the default is 3. + mode: Upsampling mode: "deconv", "interpolate", or "pixelshuffle". + pre_conv: A conv block applied before upsampling. When ``conv_block`` is ``"default"``, one reserved + conv layer will be utilized. This argument is only used for "interpolate" or "pixelshuffle" mode. + sampling: Interpolation mode used for ``paddle.nn.Upsample`` in "interpolate" mode. + align_corners: See `paddle.nn.Upsample`. + apply_pad_pool: If True the upsampled tensor is padded then average pooling is applied with a kernel the + size of `scale_factor` with a stride of 1. Only used in the pixelshuffle mode. + padding_mode: Padding mode to use for default ``pre_conv`` and pixelshuffle ``apply_pad_pool``. + bias: Whether to have a bias term in the default preconv and deconv layers. + init: How to initialize default ``conv_block`` weights (cf. ``convolution()``). In case of "pixelshuffle" mode, + if value is "icnr" or "default", ICNR initialization is used. Use "uniform" for initialization without ICNR. + + """ + super().__init__() + if out_channels is None: + out_channels = in_channels + upsample_mode = UpsampleMode.DECONV if mode == "default" else UpsampleMode(mode) + padding_mode = PaddingMode(padding_mode).conv_mode(spatial_dims) + if upsample_mode == UpsampleMode.DECONV: + if not in_channels: + raise ValueError( + f"{type(self).__name__}() 'in_channels' required in {upsample_mode.value!r} mode" + ) + if isinstance(scale_factor, (int, float)): + scale_factor = (scale_factor,) * spatial_dims + elif len(scale_factor) != spatial_dims: + raise ValueError( + f"{type(self).__name__}() 'scale_factor' must be scalar or sequence of length {spatial_dims}" + ) + try: + scale_factor = tuple(int(s) for s in scale_factor) + except TypeError: + raise TypeError( + f"{type(self).__name__}() 'scale_factor' must be ints for {upsample_mode.value!r} mode" + ) + if kernel_size is None: + kernel_size = scale_factor + elif isinstance(kernel_size, (int, float)): + kernel_size = (kernel_size,) * spatial_dims + elif len(kernel_size) != spatial_dims: + raise ValueError( + f"{type(self).__name__}() 'kernel_size' must be scalar or sequence of length {spatial_dims}" + ) + if any(k < s for k, s in zip(kernel_size, scale_factor)): + raise ValueError( + f"{type(self).__name__}() 'kernel_size' must be greater than or equal to 'scale_factor'" + ) + padding = upsample_padding(kernel_size, scale_factor) + output_padding = upsample_output_padding(kernel_size, scale_factor, padding) + deconv = convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=scale_factor, + padding=padding, + output_padding=output_padding, + dilation=1, + init=init, + bias=bias, + transposed=True, + ) + self.add_sublayer(name="deconv", sublayer=deconv) + elif upsample_mode == UpsampleMode.INTERPOLATE: + if pre_conv == "default": + if in_channels is None or in_channels == out_channels: + pre_conv = None + else: + if kernel_size is None: + kernel_size = 1 + if isinstance(kernel_size, int): + kernel_size = (kernel_size,) * spatial_dims + elif not isinstance(kernel_size, Sequence): + raise TypeError( + f"{type(self).__name__}() 'kernel_size' must be int or Sequence[int]" + ) + elif len(kernel_size) != spatial_dims: + raise ValueError( + f"{type(self).__name__}() 'kernel_size' must be int or {spatial_dims}-tuple" + ) + if any(k < 1 for k in kernel_size): + raise ValueError( + f"{type(self).__name__}() 'kernel_size' must be positive" + ) + if any(ks % 2 == 0 for ks in kernel_size): + padding = tuple(((ks - 1) // 2, ks // 2) for ks in kernel_size) + padding = tuple(p for a, b in reversed(padding) for p in (a, b)) + pre_pad = Pad(padding=padding, mode=padding_mode) + self.add_sublayer(name="prepad", sublayer=pre_pad) + padding = 0 + else: + padding = tuple(ks // 2 for ks in kernel_size) + pre_conv = convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + padding=padding, + padding_mode=padding_mode, + init=init, + bias=bias, + ) + if pre_conv is not None: + if not isinstance(pre_conv, paddle.nn.Layer): + raise TypeError( + f"{type(self).__name__}() 'preconv' must be string 'default' or Module" + ) + self.add_sublayer(name="preconv", sublayer=pre_conv) + mode = Sampling(sampling).interpolate_mode(spatial_dims) + upsample = paddle.nn.Upsample( + scale_factor=scale_factor, mode=mode, align_corners=align_corners + ) + self.add_sublayer(name="interpolate", sublayer=upsample) + elif upsample_mode == UpsampleMode.PIXELSHUFFLE: + try: + scale_factor_ = int(scale_factor) + except TypeError: + raise TypeError( + f"{type(self).__name__}() 'scale_factor' must be int for {upsample_mode.value!r} mode" + ) + if kernel_size is None: + kernel_size = 3 + module = SubpixelUpsample( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=out_channels, + scale_factor=scale_factor_, + conv_block=pre_conv, + apply_pad_pool=apply_pad_pool, + kernel_size=kernel_size, + padding_mode=padding_mode, + init=init, + bias=bias, + ) + self.add_sublayer(name="pixelshuffle", sublayer=module) + else: + raise NotImplementedError( + f"{type(self).__name__}() mode={mode!r} not implemented" + ) + + +class SubpixelUpsample(paddle.nn.Layer): + """Upsample using a subpixel CNN. + + This module is adapted from the MONAI project and supports 1D, 2D and 3D input images. + The module consists of two parts. First of all, a convolutional layer is employed + to increase the number of channels into: ``in_channels * (scale_factor ** spatial_dims)``. + Secondly, a pixel shuffle manipulation is utilized to aggregate the feature maps from + low resolution space and build the super resolution space. The first part of the module + is not fixed, a sequential layer can be used to replace the default single layer. + + See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution + Using a nEfficient Sub-Pixel Convolutional Neural Network." + + See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution". + + The idea comes from: + https://arxiv.org/abs/1609.05158 + + """ + + def __init__( + self, + spatial_dims: int, + in_channels: Optional[int], + out_channels: Optional[int] = None, + scale_factor: Union[int, float] = 2, + conv_block: Optional[Union[paddle.nn.Layer, str, int]] = "default", + apply_pad_pool: bool = True, + kernel_size: ScalarOrTuple[int] = 3, + padding_mode: Union[PaddingMode, str] = PaddingMode.ZEROS, + init: str = "default", + bias: Union[bool, str] = True, + ) -> None: + """Initialize upsampling layer. + + Args: + spatial_dims: Number of spatial dimensions of the input image. + in_channels: Number of channels of the input image. + out_channels: Optional number of channels of the output image. + scale_factor: Multiplier for spatial size. Must be castable to ``int``. Defaults to 2. + conv_block: A conv block to extract feature maps before upsampling. + + - When ``"default"``, one reserved conv layer will be utilized. + - When ``paddle..nn.Module``, the output number of channels must be divisible by ``(scale_factor ** spatial_dims)``. + + apply_pad_pool: If True the upsampled tensor is padded then average pooling is applied with a kernel the + size of `scale_factor` with a stride of 1. This implements the nearest neighbour resize convolution + component of subpixel convolutions described in Aitken et al. + kernel_size: Size of default ``conv_block`` kernel. Defaults to 3. + padding_mode: Padding mode to use for default ``conv_block`` and ``apply_pad_pool``. + init: How to initialize default ``conv_block`` weights (cf. ``convolution()``). If value is "icnr" or + "default", ICNR initialization is used. Use "uniform" for standard initialization without ICNR. + bias: Whether to have a bias term in the default conv_block. When a string is given, it specifies how + the bias term is initialized (cf. ``convolution()``). + + """ + super().__init__() + try: + scale_factor = int(scale_factor) + except TypeError: + raise TypeError("SubpixelUpsample() 'scale_factor' must be int") + if scale_factor < 1: + raise ValueError( + "SubpixelUpsample() 'scale_factor' must be a positive integer" + ) + if init in ("icnr", "ICNR"): + init = "default" + self.spatial_dims = spatial_dims + self.scale_factor = scale_factor + if conv_block == "default": + if not in_channels: + raise ValueError("SubpixelUpsample() 'in_channels' required") + out_channels = out_channels or in_channels + conv_out_channels = out_channels * scale_factor**spatial_dims + if kernel_size % 2 == 0: + padding = ((kernel_size - 1) // 2, kernel_size // 2) * spatial_dims + pre_pad = Pad(padding=padding, mode=padding_mode) + padding = 0 + else: + pre_pad = None + padding = kernel_size // 2 + conv_block = convolution( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=conv_out_channels, + kernel_size=kernel_size, + stride=1, + padding=padding, + padding_mode=padding_mode, + init=init, + bias=bias, + ) + if init == "default": + icnr_init(conv_block.weight, scale_factor) + if pre_pad is not None: + conv_block = paddle.nn.Sequential(pre_pad, conv_block) + elif conv_block is None: + conv_block = paddle.nn.Identity() + elif not isinstance(conv_block, paddle.nn.Layer): + raise ValueError( + "SubpixelUpsample() 'conv_block' must be string 'default', Module, or None" + ) + self.conv_block = conv_block + if apply_pad_pool: + pad_pool = paddle.nn.Sequential( + Pad( + padding=(scale_factor - 1, 0) * spatial_dims, + mode=padding_mode, + value=0, + ), + pooling( + "avg", spatial_dims=spatial_dims, kernel_size=scale_factor, stride=1 + ), + ) + else: + pad_pool = paddle.nn.Identity() + self.pad_pool = pad_pool + + def forward(self, x: paddle.Tensor) -> paddle.Tensor: + """ + + Args: + x: paddle.Tensor in shape (batch, channel, spatial_1[, spatial_2, ...). + + """ + x = self.conv_block(x) + x = pixelshuffle(x, self.spatial_dims, self.scale_factor) + x = self.pad_pool(x) + return x + + +def icnr_init( + weight: paddle.Tensor, + upsample_factor: int, + init=paddle.nn.initializer.KaimingNormal, +) -> None: + """ICNR initialization for 2D/3D kernels. + + Adapted from MONAI project and based on Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution". + + """ + out_channels, in_channels, *dims = tuple(weight.shape) + scale_factor = upsample_factor ** len(dims) + oc2 = int(out_channels / scale_factor) + kernel = paddle.zeros(shape=[oc2, in_channels] + dims) + kernel: paddle.Tensor = init(kernel) + x = kernel + perm_4 = list(range(x.ndim)) + perm_4[0] = 1 + perm_4[1] = 0 + kernel = x.transpose(perm=perm_4) + kernel = kernel.reshape(oc2, in_channels, -1) + kernel = kernel.repeat(1, 1, scale_factor) + kernel = kernel.reshape([in_channels, out_channels] + dims) + x = kernel + perm_5 = list(range(x.ndim)) + perm_5[0] = 1 + perm_5[1] = 0 + kernel = x.transpose(perm=perm_5) + weight.data.copy_(kernel) + + +def pixelshuffle( + x: paddle.Tensor, spatial_dims: int, scale_factor: int +) -> paddle.Tensor: + """Apply pixel shuffle to the tensor `x` with spatial dimensions `spatial_dims` and scaling factor `scale_factor`. + + See: Shi et al., 2016, "Real-Time Single Image and Video Super-Resolution + Using an Efficient Sub-Pixel Convolutional Neural Network." + + See: Aitken et al., 2017, "Checkerboard artifact free sub-pixel convolution". + + Args: + x: Input tensor + spatial_dims: number of spatial dimensions, typically 2 or 3 for 2D or 3D + scale_factor: factor to rescale the spatial dimensions by, must be >=1 + + Returns: + Reshuffled version of `x`. + + Raises: + ValueError: When input channels of `x` are not divisible by (scale_factor ** spatial_dims) + + """ + dim, factor = spatial_dims, scale_factor + input_size = list(tuple(x.shape)) + batch_size, channels = input_size[:2] + scale_divisor = factor**dim + if channels % scale_divisor != 0: + raise ValueError( + f"pixelshuffle() number of input channels ({channels}) must be evenly divisible by scale_factor ** spatial_dims ({factor}**{dim}={scale_divisor})" + ) + org_channels = channels // scale_divisor + output_size = [batch_size, org_channels] + [(d * factor) for d in input_size[2:]] + indices = tuple(range(2, 2 + 2 * dim)) + indices_factor, indices_dim = indices[:dim], indices[dim:] + permute_indices = (0, 1) + sum(zip(indices_dim, indices_factor), ()) + x = x.reshape(batch_size, org_channels, *([factor] * dim + input_size[2:])) + x = x.transpose(perm=permute_indices).reshape(output_size) + return x diff --git a/jointContribution/HighResolution/deepali/networks/resnet.py b/jointContribution/HighResolution/deepali/networks/resnet.py new file mode 100644 index 0000000000..7b3370ad75 --- /dev/null +++ b/jointContribution/HighResolution/deepali/networks/resnet.py @@ -0,0 +1,460 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any +from typing import Callable +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Type +from typing import Union + +import paddle +import paddle.nn.Layer as Layer + +from ..core.config import DataclassConfig +from ..core.enum import PaddingMode +from ..core.itertools import zip_longest_repeat_last +from ..core.types import ScalarOrTuple +from ..modules import ReprWithCrossReferences +from .blocks import ResidualUnit +from .layers import ActivationArg +from .layers import ConvLayer +from .layers import Linear +from .layers import NormArg +from .layers import Upsample +from .layers import is_batch_norm +from .layers import is_convolution +from .layers import is_group_norm +from .layers import pooling + +__all__ = "ResidualUnit", "ResNet", "ResNetConfig" +ModuleFactory = Union[Callable[..., Layer], Type[paddle.nn.Layer]] +MODEL_DEPTHS = {10, 18, 34, 50, 101, 152, 200} + + +@dataclass +class ResNetConfig(DataclassConfig): + """Configuration of residual network architecture.""" + + spatial_dims: int + in_channels: int = 1 + stride: ScalarOrTuple[int] = (1, 2, 2, 2) + num_blocks: ScalarOrTuple[int] = (3, 4, 6, 3) + num_channels: Sequence[int] = (64, 128, 256, 512) + num_layers: int = 2 + num_classes: Optional[int] = None + kernel_size: int = 3 + expansion: float = 1 + padding_mode: Union[PaddingMode, str] = "zeros" + pre_conv: bool = False + post_deconv: bool = False + recursive: ScalarOrTuple[bool] = False + bias: bool = False + norm: NormArg = "batch" + acti: ActivationArg = "relu" + order: str = "cna" + skip: str = "identity | conv1" + residual_pre_conv: str = "conv1" + residual_post_conv: str = "conv1" + + @classmethod + def from_depth( + cls, model_depth: int, spatial_dims: int, in_channels: int = 1, **kwargs + ) -> ResNetConfig: + """Get default ResNet configuration for given depth.""" + config = cls(spatial_dims=spatial_dims, in_channels=in_channels, **kwargs) + if model_depth == 10: + config.num_blocks = 1, 1, 1, 1 + config.num_layers = 2 + config.expansion = 1 + elif model_depth == 18: + config.num_blocks = 2, 2, 2, 2 + config.num_layers = 2 + config.expansion = 1 + elif model_depth == 34: + config.num_blocks = 3, 4, 6, 3 + config.num_layers = 2 + config.expansion = 1 + elif model_depth == 50: + config.num_blocks = 3, 4, 6, 3 + config.num_layers = 3 + config.expansion = 4 + elif model_depth == 101: + config.num_blocks = 3, 4, 23, 3 + config.num_layers = 3 + config.expansion = 4 + elif model_depth == 152: + config.num_blocks = 3, 8, 36, 3 + config.num_layers = 3 + config.expansion = 4 + elif model_depth == 200: + config.num_blocks = 3, 24, 36, 3 + config.num_layers = 3 + config.expansion = 4 + else: + raise ValueError( + "ResNetConfig.from_depth() 'model_depth' must be in {MODEL_DEPTHS!r}" + ) + return config + + +def classification_head( + spatial_dims: int, in_channels: int, num_classes: int, **kwargs +) -> paddle.nn.Layer: + """Image classification head for ResNet model.""" + pool = pooling("AdaptiveAvg", spatial_dims=spatial_dims, output_size=1) + fc = Linear(in_channels, num_classes) + return paddle.nn.Sequential(pool, paddle.nn.Flatten(), fc) + + +def conv_layer( + level: int, + is_first: bool, + spatial_dims: int, + in_channels: int, + out_channels: int, + **kwargs, +) -> paddle.nn.Layer: + """Convolutional layer before/between residual blocks when ``pre_conv=True``.""" + norm = kwargs.get("norm") + acti = kwargs.get("acti") + order = kwargs.get("order", "cna").upper() + if is_first: + if "N" in order and order.index("N") < order.index("C"): + norm = None + if "A" in order and order.index("A") < order.index("C"): + acti = None + kwargs.update(dict(norm=norm, acti=acti, order=order)) + return ConvLayer(spatial_dims, in_channels, out_channels, **kwargs) + + +def input_layer( + spatial_dims: int, in_channels: int, out_channels: int, **kwargs +) -> paddle.nn.Layer: + """First convolutional ResNet layer.""" + order = kwargs.get("order", "cna").upper() + norm = ( + None + if "N" in order and order.index("N") < order.index("C") + else kwargs.get("norm") + ) + kwargs.update( + dict(kernel_size=7, padding=3, stride=2, norm=norm, acti=None, order=order) + ) + conv = ConvLayer(spatial_dims, in_channels, out_channels, **kwargs) + pool = pooling("max", spatial_dims=spatial_dims, kernel_size=3, stride=2, padding=1) + return paddle.nn.Sequential(conv, pool) + + +class ResNet(ReprWithCrossReferences, paddle.nn.Sequential): + """Residual network. + + Note that unlike ``torchvision.models.ResNet``, the ``__init__`` function of this class + does not initialize the parameters of the model, other than the standard initialization + for each module type. In order to apply the initialization of the torchvision ResNet, call + functions ``init_conv_modules()``, ``init_norm_layers()``, and ``zero_init_residuals()`` + (in this order!) after constructing the ResNet model. + + """ + + def __init__( + self, + spatial_dims: int, + in_channels: int = 1, + num_channels: ScalarOrTuple[int] = (64, 128, 256, 512), + num_blocks: ScalarOrTuple[int] = (3, 4, 6, 3), + num_layers: int = 2, + num_classes: Optional[int] = None, + kernel_size: int = 3, + stride: ScalarOrTuple[int] = (1, 2, 2, 2), + expansion: float = 1, + padding_mode: Union[PaddingMode, str] = "zeros", + recursive: ScalarOrTuple[bool] = False, + pre_conv: bool = False, + post_deconv: bool = False, + bias: Union[bool, str] = False, + norm: NormArg = "batch", + acti: ActivationArg = "relu", + order: str = "cna", + skip: str = "identity | conv1", + residual_pre_conv: str = "conv1", + residual_post_conv: str = "conv1", + conv_layer: ModuleFactory = conv_layer, + deconv_layer: ModuleFactory = Upsample, + resnet_block: ModuleFactory = ResidualUnit, + input_layer: Optional[ModuleFactory] = input_layer, + output_layer: Optional[ModuleFactory] = None, + ) -> None: + """Initialize layers. + + Args: + spatial_dims: Number of spatial tensor dimensions. + in_channels: Number of input channels. + num_channels: Number of feature channels at each level. + num_blocks: Number of residual blocks at each level. + num_layers: Number of convolutional layers in each residual block. + num_classes: Number of output class probabilities of ``output_layer``. + kernel_size: Size of convolutional filters in residual blocks. + stride: Stride of initial convolution at each level. Subsequent convolutions have stride 1. + expansion: Expansion factor of ``num_channels``. If specified, the number of input and output + feature maps for each residual block are equal to ``expansion * num_channels``, and the + bottleneck convolutional layers after an initial convolution with kernel size 1 operate + on feature maps with ``num_channels`` each, which are subsequently expanded again by the + specified expansion level by another convolution with kernel size 1. + padding_mode: Padding mode for convolutional layers with kernel size greater than 1. + recursive: Whether residual blocks at each level are applied recursively. If ``True``, all residual + blocks at a given level share their convolutional modules. Other modules such as normalization + layers are not shared. When ``recursive=False``, a new residual block without any shared modules + is created each time. When ``recursive=True`` and the number of feature channels or spatial size + of a given residual block does not match the number of output channels or spatial size of the + preceeding block, respectively, a pre-convolutional layer which adjusts the number of channels + and/or spatial size is inserted between these residual blocks even when ``pre_conv=False``. + pre_conv: Always insert a separate convolutional layer between levels. When ``recursive=True``, + this is also the case when ``pre_conv=False`` if the output and input tensor shapes of final + and first residual block in subsequent levels do not match. + post_deconv: Whether to place upsampling layers after the sequence of residual blocks for levels + with a ``stride`` that is less than 1. By default, upsampling is performed as part of the + ``pre_conv``. If ``post_deconv=False``, a pre-upsampling layer is always inserted when a + level has an initial stride of less than 1 regardless of the ``pre_conv`` setting. + bias: Whether to use bias terms of convolutional layers. Can be either a boolean, or a string + indicating the function used to initialize these bias terms (cf. ``ConvLayer``). + norm: Type of normalization to use in convolutional layers. Use no normalization if ``None``. + acti: Type of non-linear activation function to use in each convolutional layer. + order: Order of convolution (C), normalization (N), and non-linear activation (A) in each + convolutional layer. If this string does not contain the character ``n|N``, no normalization + is performed regardless of the setting of ``norm`` (cf. ``ConvLayer``). + skip: Type(s) of shortcut connections (cf. ``ResidualUnit``). + residual_pre_conv: Type of pre-convolution when residual unit is a bottleneck block. + The kernel size is set to 1 if "conv1", and ``kernel_size`` otherwise. + residual_post_conv: Type of post-convolution when residual unit is a bottleneck block. + conv_layer: Type or callable used to create convolutional layers (cf. ``pre_conv``). + deconv_layer: Type or callable used to create upsampling layers. + resnet_block: Type or callable used to create residual blocks. + input_layer: Type or callable used to create initial layer which receives the input tensor. + output_layer: Type or callable used to create an output layer (head). If ``None`` and + ``num_classes`` is specified, a default ``classification_head`` is added. + + """ + super().__init__() + padding_mode = PaddingMode(padding_mode) + order = order.upper() + if "C" not in order: + raise ValueError("ResNet() 'order' must contain 'C' for convolution") + if isinstance(acti, str): + nonlinearity = acti + elif isinstance(acti, Sequence) and len(acti) == 2: + nonlinearity = acti[0] + if not isinstance(nonlinearity, str): + raise TypeError("ResNet() 'acti[0]' must be str") + else: + raise ValueError("ResNet() 'acti' must be str or 2-tuple") + self.nonlinearity = nonlinearity + if isinstance(num_channels, int): + num_channels = (num_channels,) + if isinstance(num_blocks, int): + num_blocks = (num_blocks,) + if isinstance(stride, (float, int)): + stride = (stride,) + if input_layer is None: + channels = in_channels + else: + channels = max(1, int(num_channels[0] * expansion + 0.5)) + layer = input_layer( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels, + kernel_size=kernel_size, + padding_mode=padding_mode, + bias=bias, + norm=norm, + acti=acti, + order=order, + ) + self.add_sublayer(name="input_layer", sublayer=layer) + for i, (m, b, s) in enumerate( + zip_longest_repeat_last(num_channels, num_blocks, stride) + ): + if m < 1: + raise ValueError(f"ResNet() 'num_channels' must be positive, got {m}") + n = max(1, int(m * expansion + 0.5)) + deconv = None + if s < 1: + deconv_in_channels = n if post_deconv else channels + deconv_out_channels = deconv_in_channels + if pre_conv and not post_deconv: + deconv_out_channels = n + channels = n + deconv = deconv_layer( + spatial_dims=spatial_dims, + in_channels=deconv_in_channels, + out_channels=deconv_out_channels, + scale_factor=1 / s, + bias=bias, + ) + if not post_deconv: + self.add_sublayer(name=f"deconv_{i}", sublayer=deconv) + s = 1 + with_pre_conv = pre_conv and (i > 0 or input_layer is None) + with_pre_conv = with_pre_conv or recursive and (s != 1 or channels != n) + with_pre_conv = with_pre_conv and (deconv is None or post_deconv) + if with_pre_conv: + is_first = input_layer is None and i == 0 + conv = conv_layer( + level=i, + is_first=is_first, + spatial_dims=spatial_dims, + in_channels=channels, + out_channels=n, + kernel_size=kernel_size, + padding_mode=padding_mode, + stride=s, + bias=bias, + norm=norm, + acti=acti, + order=order, + ) + self.add_sublayer(name=f"conv_{i}", sublayer=conv) + channels = n + s = 1 + block = None + for j in range(b): + block = resnet_block( + spatial_dims=spatial_dims, + in_channels=channels, + out_channels=n, + num_channels=m, + num_layers=num_layers, + kernel_size=kernel_size, + pre_conv=residual_pre_conv, + post_conv=residual_post_conv, + padding_mode=padding_mode, + stride=s, + bias=bias, + norm=norm, + acti=acti, + order=order, + skip=skip, + other=block if recursive else None, + ) + channels = n + s = 1 + self.add_sublayer(name=f"block_{i}_{j}", sublayer=block) + if deconv is not None and post_deconv: + self.add_sublayer(name=f"deconv_{i}", sublayer=deconv) + if num_classes and output_layer is None: + output_layer = classification_head + if output_layer is not None: + head_kwargs = dict( + kernel_size=kernel_size, + padding_mode=padding_mode, + bias=bias, + norm=norm, + acti=acti, + order=order, + ) + if num_classes: + head_kwargs["num_classes"] = num_classes + head = output_layer(spatial_dims, channels, **head_kwargs) + self.add_sublayer(name="output_layer", sublayer=head) + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.out_channels = channels + self.num_classes = num_classes or None + + @classmethod + def from_config(cls: Type[ResNet], config: ResNetConfig) -> ResNet: + return cls( + spatial_dims=config.spatial_dims, + in_channels=config.in_channels, + num_channels=config.num_channels, + num_blocks=config.num_blocks, + num_layers=config.num_layers, + num_classes=config.num_classes, + kernel_size=config.kernel_size, + stride=config.stride, + expansion=config.expansion, + padding_mode=config.padding_mode, + recursive=config.recursive, + pre_conv=config.pre_conv, + post_deconv=config.post_deconv, + bias=config.bias, + norm=config.norm, + acti=config.acti, + order=config.order, + skip=config.skip, + residual_pre_conv=config.residual_pre_conv, + residual_post_conv=config.residual_post_conv, + ) + + @classmethod + def from_dict(cls: Type[ResNet], config: Mapping[str, Any]) -> ResNet: + config = ResNetConfig.from_dict(config) + return cls.from_config(config) + + @classmethod + def from_depth( + cls: Type[ResNet], + model_depth: int, + spatial_dims: int, + in_channels: int = 1, + **kwargs, + ) -> ResNet: + config = ResNetConfig.from_depth( + model_depth, spatial_dims, in_channels, **kwargs + ) + return cls.from_config(config) + + def init_conv_modules(self) -> ResNet: + """Initialize parameters of convolutions.""" + init_conv_modules(self, nonlinearity=self.nonlinearity) + return self + + def init_norm_layers(self) -> ResNet: + """Initialize normalization layer weights and biases.""" + init_norm_layers(self) + return self + + def zero_init_residuals(self) -> ResNet: + """Zero-initialize the last normalization layer in each residual branch.""" + zero_init_residuals(self) + return self + + +def init_conv_modules( + network: paddle.nn.Layer, nonlinearity: str = "relu" +) -> paddle.nn.Layer: + """Initialize parameters of convolutions.""" + nonlinearity = nonlinearity.lower() + if nonlinearity == "lrelu": + nonlinearity = "leaky_relu" + if nonlinearity in ("relu", "leaky_relu"): + for module in network.sublayers(): + if is_convolution(module): + init_kaimingNormal = paddle.nn.initializer.KaimingNormal( + nonlinearity=nonlinearity + ) + init_kaimingNormal(module.weight) + if module.bias is not None: + init_Constant = paddle.nn.initializer.Constant(value=0) + init_Constant(module.bias) + return network + + +def init_norm_layers(network: paddle.nn.Layer) -> paddle.nn.Layer: + """Initialize batch and group norm layers weights to one and biases to zero.""" + for module in network.sublayers(): + if is_batch_norm(module) or is_group_norm(module): + init_Constant = paddle.nn.initializer.Constant(value=1) + init_Constant(module.weight) + init_Constant = paddle.nn.initializer.Constant(value=0) + init_Constant(module.bias) + return network + + +def zero_init_residuals(network: paddle.nn.Layer) -> paddle.nn.Layer: + """Zero-initialize the last normalization layer in each residual branch.""" + for module in network.sublayers(): + if isinstance(module, ResidualUnit): + module.zero_init_residual() + return network diff --git a/jointContribution/HighResolution/deepali/networks/unet.py b/jointContribution/HighResolution/deepali/networks/unet.py new file mode 100644 index 0000000000..d52a1850bd --- /dev/null +++ b/jointContribution/HighResolution/deepali/networks/unet.py @@ -0,0 +1,1085 @@ +from __future__ import annotations + +from copy import deepcopy +from dataclasses import dataclass +from dataclasses import field +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import List +from typing import Mapping +from typing import NamedTuple +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import Union + +import paddle +from paddle.nn import Identity +from paddle.nn import Layer + +from ..core.config import DataclassConfig +from ..core.enum import PaddingMode +from ..core.image import crop +from ..core.itertools import repeat_last +from ..core.nnutils import as_immutable_container +from ..core.types import ListOrTuple +from ..core.types import ScalarOrTuple +from ..modules import GetItem +from ..modules import ReprWithCrossReferences +from .blocks import ResidualUnit +from .layers import ActivationArg +from .layers import ConvLayer +from .layers import JoinLayer +from .layers import NormArg +from .layers import PoolLayer +from .layers import Upsample +from .layers import UpsampleMode +from .utils import module_output_size + +__all__ = ( + "SequentialUNet", + "UNet", + "UNetConfig", + "UNetDecoder", + "UNetDecoderConfig", + "UNetDownsampleConfig", + "UNetEncoder", + "UNetEncoderConfig", + "UNetLayerConfig", + "UNetOutputConfig", + "UNetUpsampleConfig", + "unet_conv_block", +) +ModuleFactory = Union[Callable[..., Layer], Type[paddle.nn.Layer]] +NumChannels = ListOrTuple[Union[int, Sequence[int]]] +NumBlocks = Union[int, Sequence[int]] +NumLayers = Optional[Union[int, Sequence[int]]] + + +def reversed_num_channels(num_channels: NumChannels) -> NumChannels: + """Reverse order of per-block/-stage number of feature channels.""" + rev_channels = tuple( + tuple(reversed(c)) if isinstance(c, Sequence) else c + for c in reversed(num_channels) + ) + return rev_channels + + +def decoder_num_channels_from_encoder_num_channels( + num_channels: NumChannels, +) -> NumChannels: + """Get default UNetDecoderConfig.num_channels from UNetEncoderConfig.num_channels.""" + num_channels = list(reversed_num_channels(num_channels)) + if isinstance(num_channels[0], Sequence): + num_channels[0] = num_channels[0][0] + return tuple(num_channels) + + +def first_num_channels(num_channels: NumChannels) -> int: + """Get number of feature channels of first block.""" + nc = num_channels[0] + if isinstance(nc, Sequence): + if not nc: + raise ValueError("first_num_channels() 'num_channels[0]' must not be empty") + nc = nc[0] + return nc + + +def last_num_channels(num_channels: NumChannels) -> int: + """Get number of feature channels of last block.""" + nc = num_channels[-1] + if isinstance(nc, Sequence): + if not nc: + raise ValueError("last_num_channels() 'num_channels[-1]' must not be empty") + nc = nc[-1] + return nc + + +@dataclass +class UNetLayerConfig(DataclassConfig): + kernel_size: ScalarOrTuple[int] = 3 + dilation: ScalarOrTuple[int] = 1 + padding: Optional[ScalarOrTuple[int]] = None + padding_mode: Union[PaddingMode, str] = "zeros" + init: str = "default" + bias: Union[str, bool, None] = None + norm: NormArg = "instance" + acti: ActivationArg = "lrelu" + order: str = "cna" + + def __post_init__(self): + self._join_kwargs_in_sequence("acti") + self._join_kwargs_in_sequence("norm") + + +@dataclass +class UNetDownsampleConfig(DataclassConfig): + mode: str = "conv" + factor: Union[int, Sequence[int]] = 2 + kernel_size: Optional[ScalarOrTuple[int]] = None + padding: Optional[ScalarOrTuple[int]] = None + + +@dataclass +class UNetUpsampleConfig(DataclassConfig): + mode: Union[str, UpsampleMode] = "deconv" + factor: Union[int, Sequence[int]] = 2 + kernel_size: Optional[ScalarOrTuple[int]] = None + dilation: Optional[ScalarOrTuple[int]] = None + padding: Optional[ScalarOrTuple[int]] = None + + +@dataclass +class UNetOutputConfig(DataclassConfig): + channels: int = 1 + kernel_size: int = 1 + dilation: int = 1 + padding: Optional[int] = None + padding_mode: Union[PaddingMode, str] = "zeros" + init: str = "default" + bias: Union[str, bool, None] = False + norm: NormArg = None + acti: ActivationArg = None + order: str = "cna" + + def __post_init__(self): + self._join_kwargs_in_sequence("acti") + self._join_kwargs_in_sequence("norm") + + +@dataclass +class UNetEncoderConfig(DataclassConfig): + num_channels: NumChannels = (8, 16, 32, 64) + num_blocks: NumBlocks = 2 + num_layers: NumLayers = None + conv_layer: UNetLayerConfig = field(default_factory=UNetLayerConfig) + downsample: Union[str, UNetDownsampleConfig] = field( + default_factory=UNetDownsampleConfig + ) + residual: bool = False + block_1_dilation: Optional[int] = None + stage_1_dilation: Optional[int] = None + + @property + def num_levels(self) -> int: + """Number of spatial encoder levels.""" + return len(self.num_channels) + + @property + def out_channels(self) -> int: + return last_num_channels(self.num_channels) + + def __post_init__(self): + if isinstance(self.downsample, str): + self.downsample = UNetDownsampleConfig(self.downsample) + + +@dataclass +class UNetDecoderConfig(DataclassConfig): + num_channels: NumChannels = (64, 32, 16, 8) + num_blocks: NumBlocks = 2 + num_layers: NumLayers = None + conv_layer: UNetLayerConfig = field(default_factory=UNetLayerConfig) + upsample: Union[str, UNetUpsampleConfig] = field(default_factory=UNetUpsampleConfig) + join_mode: str = "cat" + crop_skip: bool = False + residual: bool = False + + @property + def num_levels(self) -> int: + """Number of spatial decoder levels, including bottleneck input.""" + return len(self.num_channels) + + @property + def in_channels(self) -> int: + return first_num_channels(self.num_channels) + + @property + def out_channels(self) -> int: + return last_num_channels(self.num_channels) + + def __post_init__(self): + if isinstance(self.upsample, (str, UpsampleMode)): + self.upsample = UNetUpsampleConfig(self.upsample) + + @classmethod + def from_encoder( + cls, + encoder: Union[UNetEncoder, UNetEncoderConfig], + residual: Optional[bool] = None, + **kwargs, + ) -> UNetDecoderConfig: + """Derive decoder configuration from U-net encoder configuration.""" + if isinstance(encoder, UNetEncoder): + encoder = encoder.config + if not isinstance(encoder, UNetEncoderConfig): + raise TypeError( + f"{cls.__name__}.from_encoder() argument must be UNetEncoder or UNetEncoderConfig" + ) + if encoder.num_levels < 2: + raise ValueError( + f"{cls.__name__}.from_encoder() encoder must have at least two levels" + ) + if "upsample_mode" in kwargs: + if "upsample" in kwargs: + raise ValueError( + f"{cls.__name__}.from_encoder() 'upsample' and 'upsample_mode' are mutually exclusive" + ) + kwargs["upsample"] = UNetUpsampleConfig(kwargs.pop("upsample_mode")) + residual = encoder.residual if residual is None else residual + num_channels = decoder_num_channels_from_encoder_num_channels( + encoder.num_channels + ) + num_blocks = encoder.num_blocks + if isinstance(num_blocks, Sequence): + num_blocks = tuple(reversed(repeat_last(num_blocks, encoder.num_levels))) + num_layers = encoder.num_layers + if isinstance(num_layers, Sequence): + num_layers = tuple(reversed(repeat_last(num_layers, encoder.num_levels))) + return cls( + num_channels=num_channels, + num_blocks=num_blocks, + num_layers=num_layers, + conv_layer=encoder.conv_layer, + residual=residual, + **kwargs, + ) + + +@dataclass +class UNetConfig(DataclassConfig): + encoder: UNetEncoderConfig = field(default_factory=UNetEncoderConfig) + decoder: Optional[UNetDecoderConfig] = None + output: Optional[UNetOutputConfig] = None + + def __post_init__(self): + if self.decoder is None: + self.decoder = UNetDecoderConfig.from_encoder(self.encoder) + + +def unet_conv_block( + spatial_dims: int, + in_channels: int, + out_channels: int, + kernel_size: ScalarOrTuple[int] = 3, + stride: ScalarOrTuple[int] = 1, + padding: Optional[ScalarOrTuple[int]] = None, + padding_mode: Union[PaddingMode, str] = "zeros", + dilation: ScalarOrTuple[int] = 1, + groups: int = 1, + init: str = "default", + bias: Optional[Union[bool, str]] = None, + norm: NormArg = None, + acti: ActivationArg = None, + order: str = "CNA", + num_layers: Optional[int] = None, +) -> paddle.nn.Layer: + """Create U-net block of convolutional layers.""" + if num_layers is None: + num_layers = 1 + elif num_layers < 1: + raise ValueError("unet_conv_block() 'num_layers' must be positive") + + def conv_layer(m: int, n: int, s: int, d: int) -> ConvLayer: + return ConvLayer( + spatial_dims=spatial_dims, + in_channels=m, + out_channels=n, + kernel_size=kernel_size, + padding=padding, + padding_mode=padding_mode, + stride=s, + dilation=d, + groups=groups, + init=init, + bias=bias, + norm=norm, + acti=acti, + order=order, + ) + + block = paddle.nn.Sequential() + for i in range(num_layers): + m = in_channels if i == 0 else out_channels + n = out_channels + s = stride if i == 0 else 1 + d = dilation if s == 1 else 1 + conv = conv_layer(m, n, s, d) + block.add_sublayer(name=f"layer_{i + 1}", sublayer=conv) + return block + + +class UNetEncoder(ReprWithCrossReferences, paddle.nn.Layer): + """Downsampling path of U-net model.""" + + def __init__( + self, + spatial_dims: int, + in_channels: Optional[int] = None, + config: Optional[UNetEncoderConfig] = None, + conv_block: Optional[ModuleFactory] = None, + input_layer: Optional[ModuleFactory] = None, + ): + super().__init__() + if config is None: + config = UNetEncoderConfig() + elif not isinstance(config, UNetEncoderConfig): + raise TypeError( + f"{type(self).__name__}() 'config' must be UNetEncoderConfig" + ) + if config.num_levels < 2: + raise ValueError( + f"{type(self).__name__} U-net must have at least two spatial resolution levels" + ) + if config.downsample.mode == "none": + down_stride = (1,) * config.num_levels + if isinstance(config.downsample.factor, int): + down_stride = (1,) + (config.downsample.factor,) * (config.num_levels - 1) + else: + down_stride = repeat_last(config.downsample.factor, config.num_levels) + if conv_block is None: + conv_block = ResidualUnit if config.residual else unet_conv_block + elif not isinstance(conv_block, paddle.nn.Layer) and not callable(conv_block): + raise TypeError( + f"{type(self).__name__}() 'conv_block' must be Module or callable" + ) + num_blocks = repeat_last(config.num_blocks, config.num_levels) + num_layers = repeat_last(config.num_layers, config.num_levels) + num_channels = list(config.num_channels) + channels = first_num_channels(num_channels) + for i, (s, b, nc) in enumerate(zip(down_stride, num_blocks, num_channels)): + if not isinstance(b, int): + raise TypeError( + f"{type(self).__name__} 'num_blocks' must be int or Sequence[int]" + ) + if b < 1: + raise ValueError(f"{type(self).__name__} 'num_blocks' must be positive") + if isinstance(nc, int): + nc = (nc,) * b + if s > 1: + if config.downsample.mode == "conv": + nc = (nc[0],) + nc + else: + nc = (channels,) + nc + else: + nc = (nc[0],) + nc + elif not isinstance(nc, Sequence): + raise TypeError( + f"{type(self).__name__}() 'num_channels' values must be int or Sequence[int]" + ) + if not nc: + raise ValueError( + f"{type(self).__name__}() 'num_channels' must not contain empty sequence" + ) + num_channels[i] = list(nc) + channels = nc[-1] + num_channels: List[List[int]] = list(list(nc) for nc in num_channels) + channels = num_channels[0][0] + if in_channels is None: + in_channels = channels + if input_layer is None: + input_layer = ConvLayer if in_channels != channels else Identity + elif not isinstance(input_layer, paddle.nn.Layer) and not callable(input_layer): + raise TypeError( + f"{type(self).__name__}() 'input_layer' must be Module or callable" + ) + stages = paddle.nn.LayerDict() + channels = in_channels + for i, (s, l, nc) in enumerate(zip(down_stride, num_layers, num_channels)): + assert isinstance(nc, Sequence) and len(nc) > 0 + stage = paddle.nn.LayerDict() + if s > 1: + if config.downsample.mode == "conv": + c = nc[0] + k = config.downsample.kernel_size or config.conv_layer.kernel_size + if config.downsample.padding is None: + p = config.conv_layer.padding + else: + p = config.downsample.padding + d = 0 + if i == 0: + d = config.stage_1_dilation + d = d or config.conv_layer.dilation + downsample = conv_block( + spatial_dims=spatial_dims, + in_channels=channels, + out_channels=c, + kernel_size=k, + stride=s, + dilation=d, + padding=p, + padding_mode=config.conv_layer.padding_mode, + init=config.conv_layer.init, + bias=config.conv_layer.bias, + norm=config.conv_layer.norm, + acti=config.conv_layer.acti, + order=config.conv_layer.order, + num_layers=l, + ) + channels = c + else: + if nc[0] != channels: + raise ValueError( + f"{type(self).__name__}() number of input channels of stage after pooling ({nc[0]}) must match number of channels of previous stage ({channels})" + ) + pool_size = config.downsample.kernel_size or s + pool_args = dict(kernel_size=pool_size, stride=s) + if pool_size % 2 == 0: + pool_args["padding"] = pool_size // 2 - 1 + else: + pool_args["padding"] = pool_size // 2 + if config.downsample.mode == "avg": + pool_args["count_include_pad"] = False + downsample = PoolLayer( + config.downsample.mode, spatial_dims=spatial_dims, **pool_args + ) + stage["downsample"] = downsample + s = 1 + elif i == 0: + c = nc[0] + stage["input"] = input_layer( + spatial_dims=spatial_dims, + in_channels=channels, + out_channels=c, + kernel_size=config.conv_layer.kernel_size, + dilation=config.stage_1_dilation or config.conv_layer.dilation, + padding=config.conv_layer.padding, + padding_mode=config.conv_layer.padding_mode, + init=config.conv_layer.init, + bias=config.conv_layer.bias, + norm=config.conv_layer.norm, + acti=config.conv_layer.acti, + order=config.conv_layer.order, + ) + channels = c + blocks = paddle.nn.Sequential() + for j, c in enumerate(nc[1:]): + d = 0 + if j == 0: + d = config.block_1_dilation + if i == 0: + d = d or config.stage_1_dilation + d = d or config.conv_layer.dilation + block = conv_block( + spatial_dims=spatial_dims, + in_channels=channels, + out_channels=c, + kernel_size=config.conv_layer.kernel_size, + stride=1, + dilation=d, + padding=config.conv_layer.padding, + padding_mode=config.conv_layer.padding_mode, + init=config.conv_layer.init, + bias=config.conv_layer.bias, + norm=config.conv_layer.norm, + acti=config.conv_layer.acti, + order=config.conv_layer.order, + num_layers=l, + ) + blocks.add_sublayer(name=f"block_{j + 1}", sublayer=block) + channels = c + s = 1 + stage["blocks"] = blocks + stages[f"stage_{i + 1}"] = stage + config = deepcopy(config) + config.num_channels = num_channels + self.config = config + self.num_channels: List[List[int]] = num_channels + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.stages = stages + + @property + def out_channels(self) -> int: + return self.num_channels[-1][-1] + + def output_size(self, in_size: ScalarOrTuple[int]) -> ScalarOrTuple[int]: + """Calculate output size of last feature map for a tensor of given spatial input size.""" + return self.output_sizes(in_size)[-1] + + def output_sizes(self, in_size: ScalarOrTuple[int]) -> List[ScalarOrTuple[int]]: + """Calculate output sizes of feature maps for a tensor of given spatial input size.""" + size = in_size + fm_sizes = [] + for name, stage in self.stages.items(): + assert isinstance(stage, paddle.nn.LayerDict) + if name == "stage_1": + size = module_output_size(stage["input"], size) + if "downsample" in stage: + size = module_output_size(stage["downsample"], size) + size = module_output_size(stage["blocks"], size) + fm_sizes.append(size) + return fm_sizes + + def forward(self, x: paddle.Tensor) -> Tuple[paddle.Tensor, ...]: + features = [] + for name, stage in self.stages.items(): + if not isinstance(stage, paddle.nn.LayerDict): + raise AssertionError( + f"{type(self).__name__}.forward() expected stage ModuleDict, got {type(stage)}" + ) + if name == "stage_1": + input_layer = stage["input"] + x = input_layer(x) + if "downsample" in stage: + downsample = stage["downsample"] + x = downsample(x) + blocks = stage["blocks"] + x = blocks(x) + features.append(x) + return tuple(features) + + +class UNetDecoder(ReprWithCrossReferences, paddle.nn.Layer): + """Upsampling path of U-net model.""" + + def __init__( + self, + spatial_dims: int, + in_channels: Optional[int] = None, + config: Optional[UNetDecoderConfig] = None, + conv_block: Optional[ModuleFactory] = None, + input_layer: Optional[ModuleFactory] = None, + output_all: bool = False, + ) -> None: + super().__init__() + if config is None: + config = UNetDecoderConfig() + elif not isinstance(config, UNetDecoderConfig): + raise TypeError( + f"{type(self).__name__}() 'config' must be UNetDecoderConfig" + ) + if config.num_levels < 2: + raise ValueError( + f"{type(self).__name__} U-net must have at least two spatial resolution levels" + ) + if not isinstance(config.num_channels, Sequence): + raise TypeError( + f"{type(self).__name__}() 'config.num_channels' must be Sequence" + ) + if any(isinstance(nc, Sequence) and not nc for nc in config.num_channels): + raise ValueError( + f"{type(self).__name__}() 'config.num_channels' contains empty sequence" + ) + num_blocks = repeat_last(config.num_blocks, config.num_levels) + num_layers = repeat_last(config.num_layers, config.num_levels) + scale_factor = repeat_last(config.upsample.factor, config.num_levels - 1) + upsample_mode = UpsampleMode(config.upsample.mode) + join_mode = config.join_mode + num_channels = list(config.num_channels) + for i, (b, nc) in enumerate(zip(num_blocks, num_channels)): + if not isinstance(b, int): + raise TypeError( + f"{type(self).__name__} 'num_blocks' must be int or Sequence[int]" + ) + if b < 1: + raise ValueError(f"{type(self).__name__} 'num_blocks' must be positive") + if isinstance(nc, int): + nc = (nc,) * (1 if i == 0 else b + 1) + elif isinstance(nc, Sequence): + nc = list(nc) + else: + raise TypeError( + f"{type(self).__name__}() 'num_channels' values must be int or Sequence[int]" + ) + if not nc: + raise ValueError( + f"{type(self).__name__}() 'num_channels' must not contain empty sequence" + ) + num_channels[i] = list(nc) + if ( + upsample_mode is UpsampleMode.INTERPOLATE + and config.upsample.kernel_size == 0 + ): + for i, nc in enumerate(num_channels[:-1]): + next_nc = num_channels[i + 1][0] + assert isinstance(nc, List) and len(nc) > 0 + if isinstance(config.num_channels[i], int): + if len(nc) == 1: + nc.append(next_nc) + else: + nc[-1] = next_nc + elif nc[-1] != next_nc: + raise ValueError( + f"{type(self).__name__}() 'num_channels' of last feature map in previous stage ({num_channels[i]}) must match number of channels of first feature map ({next_nc}) of next stage when upsampling mode is interpolation without preconv. Either adjust 'num_channels' or use non-zero 'upsample.kernel_size'." + ) + channels = num_channels[0][0] + if conv_block is None: + conv_block = ResidualUnit if config.residual else unet_conv_block + elif not isinstance(conv_block, paddle.nn.Layer) and not callable(conv_block): + raise TypeError( + f"{type(self).__name__}() 'conv_block' must be Module or callable" + ) + if in_channels is None: + in_channels = channels + if input_layer is None: + input_layer = ConvLayer if in_channels != channels else Identity + elif not isinstance(input_layer, paddle.nn.Layer) and not callable(input_layer): + raise TypeError( + f"{type(self).__name__}() 'input_layer' must be Module or callable" + ) + stages = paddle.nn.LayerDict() + stage = paddle.nn.LayerDict() + stage["input"] = input_layer( + spatial_dims=spatial_dims, + in_channels=in_channels, + out_channels=channels, + kernel_size=config.conv_layer.kernel_size, + dilation=config.conv_layer.dilation, + padding=config.conv_layer.padding, + padding_mode=config.conv_layer.padding_mode, + init=config.conv_layer.init, + bias=config.conv_layer.bias, + norm=config.conv_layer.norm, + acti=config.conv_layer.acti, + order=config.conv_layer.order, + ) + blocks = paddle.nn.Sequential() + for j, c in enumerate(num_channels[0][1:]): + block = conv_block( + spatial_dims=spatial_dims, + in_channels=channels, + out_channels=c, + kernel_size=config.conv_layer.kernel_size, + dilation=config.conv_layer.dilation, + padding=config.conv_layer.padding, + padding_mode=config.conv_layer.padding_mode, + init=config.conv_layer.init, + bias=config.conv_layer.bias, + norm=config.conv_layer.norm, + acti=config.conv_layer.acti, + order=config.conv_layer.order, + num_layers=num_layers[0], + ) + blocks.add_sublayer(name=f"block_{j + 1}", sublayer=block) + channels = c + stage["blocks"] = blocks + stages["stage_1"] = stage + for i, (s, l, nc) in enumerate( + zip(scale_factor, num_layers[1:], num_channels[1:]) + ): + assert isinstance(nc, Sequence) and len(nc) > 1 + stage = paddle.nn.LayerDict() + if ( + upsample_mode is UpsampleMode.INTERPOLATE + and config.upsample.kernel_size != 0 + ): + p = config.upsample.padding + if p is None: + p = config.conv_layer.padding + k = config.upsample.kernel_size + if k is None: + k = config.conv_layer.kernel_size + d = config.upsample.dilation or config.conv_layer.dilation + pre_conv = ConvLayer( + spatial_dims=spatial_dims, + in_channels=channels, + out_channels=nc[0], + kernel_size=k, + dilation=d, + padding=p, + padding_mode=config.conv_layer.padding_mode, + init=config.conv_layer.init, + bias=config.conv_layer.bias, + norm=config.conv_layer.norm, + acti=config.conv_layer.acti, + order=config.conv_layer.order, + ) + else: + pre_conv = "default" + if s > 1: + upsample = Upsample( + spatial_dims=spatial_dims, + in_channels=channels, + out_channels=nc[0], + scale_factor=s, + mode=upsample_mode, + align_corners=False, + pre_conv=pre_conv, + kernel_size=config.upsample.kernel_size, + padding_mode=config.conv_layer.padding_mode, + bias=True + if config.conv_layer.bias is None + else config.conv_layer.bias, + ) + stage["upsample"] = upsample + stage["join"] = JoinLayer(join_mode, dim=1) + channels = (2 if join_mode == "cat" else 1) * nc[0] + blocks = paddle.nn.Sequential() + for j, c in enumerate(nc[1:]): + block = conv_block( + spatial_dims=spatial_dims, + in_channels=channels, + out_channels=c, + kernel_size=config.conv_layer.kernel_size, + dilation=config.conv_layer.dilation, + padding=config.conv_layer.padding, + padding_mode=config.conv_layer.padding_mode, + init=config.conv_layer.init, + bias=config.conv_layer.bias, + norm=config.conv_layer.norm, + acti=config.conv_layer.acti, + order=config.conv_layer.order, + num_layers=l, + ) + blocks.add_sublayer(name=f"block_{j + 1}", sublayer=block) + channels = c + stage["blocks"] = blocks + stages[f"stage_{i + 2}"] = stage + config = deepcopy(config) + config.num_channels = num_channels + self.config = config + self.spatial_dims = spatial_dims + self.in_channels = in_channels + self.num_channels: List[List[int]] = num_channels + self.stages = stages + self.output_all = output_all + + @classmethod + def from_encoder( + cls, + encoder: Union[UNetEncoder, UNetEncoderConfig], + residual: Optional[bool] = None, + **kwargs, + ) -> UNetDecoder: + """Create U-net decoder given U-net encoder configuration.""" + config = UNetDecoderConfig.from_encoder(encoder, residual=residual, **kwargs) + return cls(spatial_dims=encoder.spatial_dims, config=config) + + @property + def out_channels(self) -> int: + return self.num_channels[-1][-1] + + def output_size(self, in_size: ScalarOrTuple[int]) -> ScalarOrTuple[int]: + """Calculate output size for an initial feature map of given spatial input size.""" + return self.output_sizes(in_size)[-1] + + def output_sizes(self, in_size: ScalarOrTuple[int]) -> List[ScalarOrTuple[int]]: + """Calculate output sizes for an initial feature map of given spatial input size.""" + size = in_size + out_sizes = [] + for name, stage in self.stages.items(): + if not isinstance(stage, paddle.nn.LayerDict): + raise AssertionError( + f"{type(self).__name__}.out_sizes() expected stage ModuleDict, got {type(stage)}" + ) + if name == "stage_1": + size = module_output_size(stage["input"], size) + if "upsample" in stage: + size = module_output_size(stage["upsample"], size) + size = module_output_size(stage["blocks"], size) + out_sizes.append(size) + return out_sizes + + def forward( + self, features: Sequence[paddle.Tensor] + ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]: + if not isinstance(features, Sequence): + raise TypeError(f"{type(self).__name__}() 'features' must be Sequence") + features = list(features) + if len(features) != len(self.stages): + raise ValueError( + f"{type(self).__name__}() 'features' must contain {len(self.stages)} tensors" + ) + x: paddle.Tensor = features.pop() + output: List[paddle.Tensor] = [] + for name, stage in self.stages.items(): + if not isinstance(stage, paddle.nn.LayerDict): + raise AssertionError( + f"{type(self).__name__}.forward() expected stage ModuleDict, got {type(stage)}" + ) + blocks = stage["blocks"] + if name == "stage_1": + input_layer = stage["input"] + x = input_layer(x) + else: + skip = features.pop() + upsample = stage["upsample"] + join = stage["join"] + x = upsample(x) + if self.config.crop_skip: + margin = tuple( + n - m for m, n in zip(tuple(x.shape)[2:], tuple(skip.shape)[2:]) + ) + assert all(m >= 0 and m % 2 == 0 for m in margin) + margin = tuple(m // 2 for m in margin) + skip = crop(skip, margin=margin) + x = join([x, skip]) + del skip + x = blocks(x) + if self.output_all: + output.append(x) + if self.output_all: + return tuple(output) + return x + + +class SequentialUNet(ReprWithCrossReferences, paddle.nn.Sequential): + """Sequential U-net architecture. + + The final module of this sequential module either outputs a tuple of feature maps at the + different resolution levels (``out_channels=None``), the final decoded feature map at the + highest resolution level (``out_channels == config.decoder.out_channels`` and ``output_layers=None``), + or a tensor with specified number of ``out_channels`` as produced by a final output layer otherwise. + Note that additional layers (e.g., a custom output layer or post-output layers) can be added to the + initialized sequential U-net using ``add_module()``. + + """ + + def __init__( + self, + spatial_dims: int, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + config: Optional[UNetConfig] = None, + conv_block: Optional[ModuleFactory] = None, + output_layer: Optional[ModuleFactory] = None, + bridge_layer: Optional[ModuleFactory] = None, + ) -> None: + super().__init__() + if config is None: + config = UNetConfig() + elif not isinstance(config, UNetConfig): + raise TypeError(f"{type(self).__name__}() 'config' must be UNetConfig") + config = deepcopy(config) + self.config = config + self.encoder = UNetEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + config=config.encoder, + conv_block=conv_block, + ) + in_channels = self.encoder.in_channels + self.decoder = UNetDecoder( + spatial_dims=spatial_dims, + in_channels=self.encoder.out_channels, + config=config.decoder, + conv_block=conv_block, + input_layer=bridge_layer, + output_all=output_layer is None and not out_channels, + ) + channels = self.decoder.out_channels + if not out_channels and config.output is not None: + out_channels = config.output.channels + if output_layer is None: + if self.decoder.output_all: + out_channels = self.decoder.num_channels + elif out_channels: + output_layer = ConvLayer + if output_layer is not None: + out_channels = out_channels or in_channels + if config.output is None: + config.output = UNetOutputConfig() + output = output_layer( + spatial_dims=spatial_dims, + in_channels=channels, + out_channels=out_channels, + kernel_size=config.output.kernel_size, + padding=config.output.padding, + padding_mode=config.output.padding_mode, + dilation=config.output.dilation, + init=config.output.init, + bias=config.output.bias, + norm=config.output.norm, + acti=config.output.acti, + order=config.output.order, + ) + self.add_sublayer(name="output", sublayer=output) + self.out_channels: Union[int, List[int]] = out_channels + + @property + def spatial_dims(self) -> int: + return self.encoder.spatial_dims + + @property + def in_channels(self) -> int: + return self.encoder.in_channels + + @property + def num_channels(self) -> List[List[int]]: + return self.decoder.num_channels + + @property + def num_levels(self) -> int: + return len(self.num_channels) + + def output_size(self, in_size: ScalarOrTuple[int]) -> ScalarOrTuple[int]: + """Calculate spatial output size given an input tensor with specified spatial size.""" + size = in_size + for module in self: + size = module_output_size(module, size) + return size + + +class UNet(ReprWithCrossReferences, paddle.nn.Layer): + """U-net with optionally multiple output layers.""" + + def __init__( + self, + spatial_dims: int, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + output_modules: Optional[Mapping[str, Layer]] = None, + output_indices: Optional[Union[Mapping[str, int], int]] = None, + config: Optional[UNetConfig] = None, + conv_block: Optional[ModuleFactory] = None, + bridge_layer: Optional[ModuleFactory] = None, + output_layer: Optional[ModuleFactory] = None, + output_name: str = "output", + ) -> None: + super().__init__() + if output_modules is None: + output_modules = {} + if not isinstance(output_modules, Mapping): + raise TypeError(f"{type(self).__name__}() 'output_modules' must be Mapping") + if config is None: + config = UNetConfig() + elif not isinstance(config, UNetConfig): + raise TypeError(f"{type(self).__name__}() 'config' must be UNetConfig") + config = deepcopy(config) + self.config = config + self.encoder = UNetEncoder( + spatial_dims=spatial_dims, + in_channels=in_channels, + config=config.encoder, + conv_block=conv_block, + ) + in_channels = self.encoder.in_channels + self.decoder = UNetDecoder( + spatial_dims=spatial_dims, + in_channels=self.encoder.out_channels, + config=config.decoder, + conv_block=conv_block, + input_layer=bridge_layer, + output_all=True, + ) + channels = self.decoder.out_channels + self.output_modules = paddle.nn.LayerDict() + if not out_channels and config.output is not None: + out_channels = config.output.channels + if output_layer is None: + if out_channels == channels and config.output is None: + self.output_modules[output_name] = GetItem(-1) + elif out_channels: + output_layer = ConvLayer + elif not output_modules: + out_channels = self.decoder.num_channels + if output_layer is not None: + out_channels = out_channels or in_channels + if config.output is None: + config.output = UNetOutputConfig() + output = output_layer( + spatial_dims=spatial_dims, + in_channels=channels, + out_channels=out_channels, + kernel_size=config.output.kernel_size, + padding=config.output.padding, + padding_mode=config.output.padding_mode, + dilation=config.output.dilation, + init=config.output.init, + bias=config.output.bias, + norm=config.output.norm, + acti=config.output.acti, + order=config.output.order, + ) + output = [("input", GetItem(-1)), ("layer", output)] + output = paddle.nn.Sequential(*output) + self.output_modules[output_name] = output + self.out_channels: Union[int, Sequence[int], None] = out_channels + if output_indices is None: + output_indices = {} + elif isinstance(output_indices, int): + output_indices = {name: output_indices for name in output_modules} + for name, output in output_modules.items(): + output_index = output_indices.get(name) + if output_index is not None: + if not isinstance(output_index, int): + raise TypeError( + f"{type(self).__name__}() 'output_indices' must be int" + ) + output = [("input", GetItem(output_index)), ("layer", output)] + output = paddle.nn.Sequential(*output) + self.output_modules[name] = output + + @property + def spatial_dims(self) -> int: + return self.encoder.spatial_dims + + @property + def in_channels(self) -> int: + return self.encoder.in_channels + + @property + def num_channels(self) -> List[List[int]]: + return self.decoder.num_channels + + @property + def num_levels(self) -> int: + return len(self.num_channels) + + @property + def num_output_layers(self) -> int: + return len(self.output_modules) + + def output_names(self) -> Iterable[str]: + return self.output_modules.keys() + + def output_is_dict(self) -> bool: + """Whether model output is dictionary of output tensors.""" + return not (self.output_is_tensor() or self.output_is_tuple()) + + def output_is_tensor(self) -> bool: + """Whether model output is a single output tensor.""" + return len(self.output_modules) == 1 and bool(self.out_channels) + + def output_is_tuple(self) -> bool: + """Whether model output is tuple of decoded feature maps.""" + return not self.output_modules + + def output_size(self, in_size: ScalarOrTuple[int]) -> ScalarOrTuple[int]: + out_sizes = self.output_sizes(in_size) + if self.output_is_tensor(): + assert len(out_sizes) == 1 + return out_sizes[0] + if self.output_is_tuple(): + return out_sizes[-1] + assert isinstance(out_sizes, dict) and len(out_sizes) > 0 + out_size = None + for size in out_sizes.values(): + if out_size is None: + out_size = size + elif out_size != size: + raise RuntimeError( + f"{type(self).__name__}.output_size() is ambiguous, use output_sizes() instead" + ) + assert out_size is not None + return out_size + + def output_sizes( + self, in_size: ScalarOrTuple[int] + ) -> Union[Dict[str, ScalarOrTuple[int]], List[ScalarOrTuple[int]]]: + enc_out_size = self.encoder.output_size(in_size) + dec_out_sizes = self.decoder.output_sizes(enc_out_size) + if self.output_is_tensor(): + return dec_out_sizes[-1:] + if self.output_is_tuple(): + return dec_out_sizes + out_sizes = {} + for name, module in self.output_modules.items(): + out_sizes[name] = module_output_size(module, dec_out_sizes) + return out_sizes + + def forward( + self, x: paddle.Tensor + ) -> Union[paddle.Tensor, NamedTuple, Tuple[paddle.Tensor, ...]]: + outputs = {} + features = self.decoder(self.encoder(x)) + for name, output in self.output_modules.items(): + outputs[name] = output(features) + if not outputs: + return features + if len(outputs) == 1 and self.out_channels: + return next(iter(outputs.values())) + return as_immutable_container(outputs) diff --git a/jointContribution/HighResolution/deepali/networks/utils.py b/jointContribution/HighResolution/deepali/networks/utils.py new file mode 100644 index 0000000000..5401c3d871 --- /dev/null +++ b/jointContribution/HighResolution/deepali/networks/utils.py @@ -0,0 +1,208 @@ +import paddle +from paddle.nn import Layer + +from ..core.nnutils import conv_output_size +from ..core.nnutils import conv_transposed_output_size +from ..core.nnutils import pad_output_size +from ..core.nnutils import pool_output_size +from ..core.nnutils import unpool_output_size +from ..core.nnutils import upsample_output_size +from ..core.types import ScalarOrTuple +from .blocks import SkipConnection +from .layers import Pad +from .layers import is_activation +from .layers import is_norm_layer + + +def module_output_size( + module: paddle.nn.Layer, in_size: ScalarOrTuple[int] +) -> ScalarOrTuple[int]: + """Calculate spatial size of output tensor after the given module is applied.""" + if not isinstance(module, paddle.nn.Layer) or type(module) is Layer: + raise TypeError( + "module_output_size() 'module' must be paddle.nn.Layer subclass" + ) + output_size = getattr(module, "output_size", None) + if callable(output_size): + return output_size(in_size) + if output_size is not None: + device = str("cpu").replace("cuda", "gpu") + m: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=in_size, dtype="int32", place=device) + ) + if m.ndim != 1: + raise ValueError( + "module_output_size() 'in_size' must be scalar or sequence" + ) + ndim = tuple(m.shape)[0] + s: paddle.Tensor = paddle.atleast_1d( + paddle.to_tensor(data=output_size, dtype="int32", place=device) + ) + if s.ndim != 1 or tuple(s.shape)[0] not in (1, ndim): + raise ValueError( + f"module_output_size() 'module.output_size' must be scalar or sequence of length {ndim}" + ) + n = s.expand(shape=ndim) + if isinstance(in_size, int): + return n.item() + return tuple(n.tolist()) + if isinstance(module, paddle.nn.Sequential): + size = in_size + for m in module: + size = module_output_size(m, size) + return size + if isinstance(module, SkipConnection): + return module_output_size(module.func, in_size) + if isinstance(module, (paddle.nn.Conv1D, paddle.nn.Conv2D, paddle.nn.Conv3D)): + return conv_output_size( + in_size, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + ) + if isinstance( + module, + ( + paddle.nn.Conv1DTranspose, + paddle.nn.Conv2DTranspose, + paddle.nn.Conv3DTranspose, + ), + ): + return conv_transposed_output_size( + in_size, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + output_padding=module.output_padding, + dilation=module.dilation, + ) + if isinstance( + module, + ( + paddle.nn.AvgPool1D, + paddle.nn.AvgPool2D, + paddle.nn.AvgPool3D, + paddle.nn.MaxPool1D, + paddle.nn.MaxPool2D, + paddle.nn.MaxPool3D, + ), + ): + return pool_output_size( + in_size, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + dilation=module.dilation, + ceil_mode=module.ceil_mode, + ) + if isinstance( + module, + ( + paddle.nn.AdaptiveAvgPool1D, + paddle.nn.AdaptiveAvgPool2D, + paddle.nn.AdaptiveAvgPool3D, + paddle.nn.AdaptiveMaxPool1D, + paddle.nn.AdaptiveMaxPool2D, + paddle.nn.AdaptiveMaxPool3D, + ), + ): + return module.output_size + if isinstance( + module, (paddle.nn.MaxUnPool1D, paddle.nn.MaxUnPool2D, paddle.nn.MaxUnPool3D) + ): + return unpool_output_size( + in_size, + kernel_size=module.kernel_size, + stride=module.stride, + padding=module.padding, + ) + if isinstance(module, Pad): + raise NotImplementedError() + if isinstance( + module, + ( + paddle.nn.Pad1D, + paddle.nn.Pad2D, + paddle.nn.Pad1D, + paddle.nn.Pad2D, + paddle.nn.Pad3D, + paddle.nn.ZeroPad2D, + paddle.nn.Pad1D, + paddle.nn.Pad2D, + paddle.nn.Pad3D, + ), + ): + return pad_output_size(in_size, module.padding) + if isinstance(module, paddle.nn.Upsample): + return upsample_output_size( + in_size, size=module.size, scale_factor=module.scale_factor + ) + if is_activation(module) or isinstance( + module, + ( + paddle.nn.ELU, + paddle.nn.Hardshrink, + paddle.nn.Hardsigmoid, + paddle.nn.Hardtanh, + paddle.nn.Hardswish, + paddle.nn.LeakyReLU, + paddle.nn.LogSigmoid, + paddle.nn.LogSoftmax, + paddle.nn.PReLU, + paddle.nn.ReLU, + paddle.nn.ReLU6, + paddle.nn.RReLU, + paddle.nn.SELU, + paddle.nn.CELU, + paddle.nn.GELU, + paddle.nn.Sigmoid, + paddle.nn.Softmax, + paddle.nn.Softmax, + paddle.nn.Softmin, + paddle.nn.Softplus, + paddle.nn.Softshrink, + paddle.nn.Softsign, + paddle.nn.Tanh, + paddle.nn.Tanhshrink, + paddle.nn.Threshold, + ), + ): + return in_size + if is_norm_layer(module) or isinstance( + module, + ( + paddle.nn.BatchNorm1D, + paddle.nn.BatchNorm2D, + paddle.nn.BatchNorm3D, + paddle.nn.SyncBatchNorm, + paddle.nn.GroupNorm, + paddle.nn.InstanceNorm1D, + paddle.nn.InstanceNorm2D, + paddle.nn.InstanceNorm3D, + paddle.nn.LayerNorm, + paddle.nn.LocalResponseNorm, + ), + ): + return in_size + if isinstance( + module, + ( + paddle.nn.AlphaDropout, + paddle.nn.Dropout, + paddle.nn.Dropout2D, + paddle.nn.Dropout3D, + ), + ): + return in_size + if isinstance(module, (paddle.nn.LayerDict, paddle.nn.LayerList)): + raise TypeError( + "module_output_size() order of modules in ModuleDict or ModuleList is undetermined" + ) + if isinstance(module, (paddle.nn.ParameterDict, paddle.nn.ParameterList)): + raise TypeError( + "module_output_size() 'module' cannot be paddle.nn.ParameterDict or paddle.nn.ParameterList" + ) + raise NotImplementedError( + f"module_output_size() not implemented for 'module' of type {type(module)}" + ) diff --git a/jointContribution/HighResolution/deepali/spatial/__init__.py b/jointContribution/HighResolution/deepali/spatial/__init__.py new file mode 100644 index 0000000000..67ef11eb3f --- /dev/null +++ b/jointContribution/HighResolution/deepali/spatial/__init__.py @@ -0,0 +1,233 @@ +"""Spatial coordinate and input data transformation modules. + +.. hint:: + + The spatial transforms defined by this library can be used to implement co-registration + approaches based on traditional optimization as well as those based on machine learning + (amortized optimization). + +A spatial transformation maps points from a target domain defined with respect to the unit cube +of a target sampling grid, to points defined with respect to the same domain, i.e., the domain +and codomain of the spatial coordinate map are identical. In order to transform an image defined +with respect to a different sampling grid, this transformation has to be followed by a mapping +from target cube domain to source cube domain. This is done, for example, by the spatial transformer +implemented by :class:`.ImageTransformer`. + +The ``forward()`` method of a :class:`.SpatialTransform` can be used to spatially transform any set +of points defined with respect to the grid domain of the spatial transformation, including in particular +a tensor of shape ``(N, M, D)``, i.e., a batch of ``N`` point sets with cardinality ``M``. It can +also be applied to a tensor of grid points of shape ``(N, ..., X, D)`` regardless if the grid points +are located at the undeformed grid positions or those of an already deformed grid. In case of a +non-rigid deformation, the point displacements are by default sampled at the input points. The sampled +flow vectors :math:`u` are then added to the input points :math:`x`, producing the output :math:`y = x + u(x)`. +If the boolean flag ``grid=True`` is passed to the :meth:`.SpatialTransform.forward` function, it is assumed +that the coordinates correspond to the positions of undeformed spatial grid points with domain equal to the +domain of the transformation. In this special case, a simple interpolation to resize the vector field to the +size of the input tensor is used. In case of a linear transformation, :math:`y = Ax + t`. + +The coordinate domain is :attr:`.Axes.CUBE_CORNERS` if ``grid.align_corners() == True`` (default), +and :attr:`.Axes.CUBE` otherwise. + +""" +import sys +from typing import Any +from typing import Optional + +from ..core.grid import Grid +from .base import LinearTransform +from .base import NonRigidTransform +from .base import ReadOnlyParameters +from .base import SpatialTransform +from .bspline import BSplineTransform # noqa: F401 +from .bspline import FreeFormDeformation +from .bspline import StationaryVelocityFreeFormDeformation +from .composite import CompositeTransform +from .composite import SequentialTransform # noqa: F401 +from .generic import GenericSpatialTransform +from .generic import TransformConfig +from .generic import affine_first +from .generic import has_affine_component +from .generic import has_nonrigid_component +from .generic import nonrigid_components +from .generic import transform_components +from .image import ImageTransform +from .linear import AffineTransform +from .linear import AnisotropicScaling +from .linear import EulerRotation +from .linear import FullAffineTransform +from .linear import HomogeneousTransform +from .linear import QuaternionRotation +from .linear import RigidQuaternionTransform +from .linear import RigidTransform +from .linear import Shearing +from .linear import SimilarityTransform +from .linear import Translation # noqa: F401 +from .nonrigid import DenseVectorFieldTransform +from .nonrigid import DisplacementFieldTransform +from .nonrigid import StationaryVelocityFieldTransform +from .parametric import ParametricTransform +from .transformer import ImageTransformer +from .transformer import PointSetTransformer +from .transformer import SpatialTransformer + +Affine = AffineTransform +AffineWithShearing = FullAffineTransform +Disp = DisplacementFieldTransform +DispField = DisplacementFieldTransform +DDF = DisplacementFieldTransform +DVF = DisplacementFieldTransform +FFD = FreeFormDeformation +FullAffine = FullAffineTransform +MatrixTransform = HomogeneousTransform +Quaternion = QuaternionRotation +Rigid = RigidTransform +RigidQuaternion = RigidQuaternionTransform +Rotation = EulerRotation +Scaling = AnisotropicScaling +ShearTransform = Shearing +Similarity = SimilarityTransform +SVF = StationaryVelocityFieldTransform +SVField = StationaryVelocityFieldTransform +SVFFD = StationaryVelocityFreeFormDeformation +LINEAR_TRANSFORMS = ( + "Affine", + "AffineTransform", + "AffineWithShearing", + "AnisotropicScaling", + "BSplineTransform", + "EulerRotation", + "IsotropicScaling", + "FullAffine", + "FullAffineTransform", + "HomogeneousTransform", + "MatrixTransform", + "Quaternion", + "QuaternionRotation", + "Rigid", + "RigidTransform", + "RigidQuaternion", + "RigidQuaternionTransform", + "Rotation", + "Scaling", + "Shearing", + "ShearTransform", + "Similarity", + "SimilarityTransform", + "Translation", +) +NONRIGID_TRANSFORMS = ( + "Disp", + "DispField", + "DisplacementFieldTransform", + "DDF", + "DVF", + "FFD", + "FreeFormDeformation", + "StationaryVelocityFieldTransform", + "StationaryVelocityFreeFormDeformation", + "SVF", + "SVField", + "SVFFD", +) +COMPOSITE_TRANSFORMS = "MultiLevelTransform", "SequentialTransform" +__all__ = ( + ( + "CompositeTransform", + "DenseVectorFieldTransform", + "GenericSpatialTransform", + "ImageTransform", + "ImageTransformer", + "LinearTransform", + "NonRigidTransform", + "ParametricTransform", + "PointSetTransformer", + "ReadOnlyParameters", + "SpatialTransform", + "SpatialTransformer", + "TransformConfig", + "affine_first", + "has_affine_component", + "has_nonrigid_component", + "is_linear_transform", + "is_nonrigid_transform", + "is_spatial_transform", + "new_spatial_transform", + "nonrigid_components", + "transform_components", + ) + + COMPOSITE_TRANSFORMS + + LINEAR_TRANSFORMS + + NONRIGID_TRANSFORMS +) + + +def is_spatial_transform(arg: Any) -> bool: + """Whether given object or named transformation is a transformation type. + + Args: + arg: Name of type or object. + + Returns: + Whether type of ``arg`` object or name of type is a transformation model. + + """ + if isinstance(arg, str): + return arg in LINEAR_TRANSFORMS or arg in NONRIGID_TRANSFORMS + return isinstance(arg, SpatialTransform) + + +def is_linear_transform(arg: Any) -> bool: + """Whether given object is a linear transformation type. + + Args: + arg: Name of type or object. + + Returns: + Whether type of ``arg`` object or name of type is a linear transformation. + + """ + if isinstance(arg, str): + return arg in LINEAR_TRANSFORMS + if isinstance(arg, SpatialTransform): + return arg.linear + return False + + +def is_nonrigid_transform(arg: Any) -> bool: + """Whether given object is a non-rigid transformation type. + + Args: + arg: Name of type or object. + + Returns: + Whether type of ``arg`` object or name of type is a non-rigid transformation. + + """ + if isinstance(arg, str): + return arg in NONRIGID_TRANSFORMS + if isinstance(arg, SpatialTransform): + return arg.nonrigid + return False + + +def new_spatial_transform( + name: str, grid: Grid, groups: Optional[int] = None, **kwargs +) -> SpatialTransform: + """Initialize new transformation model of named type. + + Args: + name: Name of transformation model. + grid: Grid of transformation domain. + groups: Number of transformations. + kwargs: Optional keyword arguments of transformation model. + + Returns: + New transformation module with optimizable parameters. + + """ + cls = getattr(sys.modules[__name__], name, None) + if cls is not None and (name in LINEAR_TRANSFORMS or name in NONRIGID_TRANSFORMS): + return cls(grid, groups=groups, **kwargs) + raise ValueError( + f"new_spatial_transform() 'name={name}' is not a valid transformation type" + ) diff --git a/jointContribution/HighResolution/deepali/spatial/base.py b/jointContribution/HighResolution/deepali/spatial/base.py new file mode 100644 index 0000000000..c1dc20554b --- /dev/null +++ b/jointContribution/HighResolution/deepali/spatial/base.py @@ -0,0 +1,543 @@ +from __future__ import annotations + +from abc import ABCMeta +from abc import abstractmethod +from copy import copy as shallow_copy +from typing import Optional +from typing import Tuple +from typing import TypeVar +from typing import Union +from typing import overload + +import paddle +from typing_extensions import final + +from ..core import functional as U +from ..core.grid import Axes +from ..core.grid import Grid +from ..core.linalg import as_homogeneous_matrix +from ..core.types import Device +from ..data.flow import FlowFields +from ..modules import DeviceProperty + +TSpatialTransform = TypeVar("TSpatialTransform", bound="SpatialTransform") +TLinearTransform = TypeVar("TLinearTransform", bound="LinearTransform") +TNonRigidTransform = TypeVar("TNonRigidTransform", bound="NonRigidTransform") + + +class ReadOnlyParameters(RuntimeError): + """Exception thrown when attempting to set parameters when these are provided by a callable.""" + + ... + + +class SpatialTransform(DeviceProperty, paddle.nn.Layer, metaclass=ABCMeta): + """Base class of all spatial coordinate transformations.""" + + def __init__(self, grid: Grid): + """Initialize base class. + + Args: + grid: Spatial domain with respect to which transformation is defined. + The unit cube domain is :attr:`.Axes.CUBE` if ``grid.align_corners() == False``, + and :attr:`.Axes.CUBE_CORNERS` otherwise. + + """ + if not isinstance(grid, Grid): + raise TypeError("SpatialTransform() 'grid' must be of type Grid") + super().__init__() + self._grid = grid + self._args = () + self._kwargs = {} + self.register_update_hook() + + def __copy__(self: TSpatialTransform) -> TSpatialTransform: + """Make shallow copy of this transformation. + + The copy shares containers for parameters and hooks with this module, but not containers of + buffers and modules. References to currently set buffers and modules are however copied, but + adding/removing a buffer or module to/from the shallow copy will not modify the buffers and + modules of the original module. The same is the case for adding/removing buffers or modules + to/from the original module. + + Returns: + Shallow copy of this spatial transformation module. + + """ + copy = self.__new__(type(self)) + copy.__dict__ = self.__dict__.copy() + for name in ("_buffers", "_non_persistent_buffers_set", "_modules"): + if name in self.__dict__: + copy.__dict__[name] = self.__dict__[name].copy() + return copy + + @overload + def condition(self) -> Tuple[tuple, dict]: + """Get arguments on which transformation is conditioned. + + Returns: + args: Positional arguments. + kwargs: Keyword arguments. + + """ + ... + + @overload + def condition(self: TSpatialTransform, *args, **kwargs) -> TSpatialTransform: + """Get new transformation which is conditioned on the specified arguments.""" + ... + + def condition( + self: TSpatialTransform, *args, **kwargs + ) -> Union[TSpatialTransform, Tuple[tuple, dict]]: + """Get or set data tensor on which transformation is conditioned.""" + if args: + return shallow_copy(self).condition_(*args) + return self._args, self._kwargs + + def condition_(self: TSpatialTransform, *args, **kwargs) -> TSpatialTransform: + """Set data tensor on which this transformation is conditioned.""" + self.clear_buffers() + self._args = args + self._kwargs = kwargs + return self + + def align_corners(self) -> bool: + """Whether extrema -1 and 1 coincide with grid border (False) or corner points (True).""" + return self._grid.align_corners() + + def axes(self) -> Axes: + """Axes with respect to which transformation is defined. + + Returns: + ``Axes.CUBE_CORNERS if self.align_corners() else Axes.CUBE``. + + """ + return Axes.from_align_corners(self.align_corners()) + + @overload + def grid(self) -> Grid: + ... + + @overload + def grid(self: TSpatialTransform, grid: Grid) -> TSpatialTransform: + ... + + def grid(self, grid: Optional[Grid] = None) -> Grid: + """Get grid domain of this transformation or a new transformation with the specified grid.""" + if grid is None: + return self._grid + return shallow_copy(self).grid_(grid) + + def grid_(self: TSpatialTransform, grid: Grid) -> TSpatialTransform: + """Set sampling grid which defines domain and codomain of this transformation.""" + if self._grid == grid: + return self + if grid.ndim != self.ndim: + raise ValueError( + f"{type(self).__name__}.grid_() must be {self.ndim}-dimensional" + ) + self.clear_buffers() + self._grid = grid + return self + + def dim(self) -> int: + """Number of spatial dimensions.""" + return self._grid.ndim + + @property + def ndim(self) -> int: + """Number of spatial dimensions.""" + return self.dim() + + @property + def linear(self) -> bool: + """Whether this transformation is linear.""" + return isinstance(self, LinearTransform) + + @property + def nonrigid(self) -> bool: + """Whether this transformation is non-rigid.""" + return not self.linear + + def fit(self: TSpatialTransform, flow: FlowFields, **kwargs) -> TSpatialTransform: + """Fit transformation to a given flow field. + + Args: + flow: Flow fields to approximate. + kwargs: Optional keyword arguments of fitting algorithm. + Arguments which are unused by a concrete implementation are ignored + without raising an error, e.g., a specified gradient descent step length + when a least square fit is computed instead. + + Returns: + Reference to this transformation. + + Raises: + RuntimeError: When this transformation has no optimizable parameters. + + """ + self.clear_buffers() + grid = self.grid() + flow = flow.to(self.device) + flow = flow.sample(shape=grid) + flow = flow.axes(Axes.from_grid(grid)) + self._fit(flow, **kwargs) + return self + + def _fit(self, flow: FlowFields, **kwargs) -> None: + """Fit transformation to flow field. + + This function may be overidden by subclasses to implement an analytic least squares + or otherwise more suitable fitting approach. Optimization related keyword arguments + may be ignored by these specializations. + + Args: + flow: Batch of flow vector fields sampled on ``self.grid()`` and + defined with respect to either ``Axes.CUBE`` or ``Axes.CUBE_CORNERS`` + depending on flag ``self.grid().align_corners()``. These displacement vector + fields will be approximated by this transformation. + kwargs: Keyword arguments of iterative optimization. Unused arguments are ignored. + lr: Initial step size for iterative gradient-based optimization. + steps: Maximum number of gradient steps after which to terminate. + epsilon: Upper mean squared error threshold at which to terminate. + + Raises: + RuntimeError: When this transformation has no optimizable parameters. + + """ + lr = float(kwargs.get("lr", 0.1)) + steps = int(kwargs.get("steps", 1000)) + epsilon = float(kwargs.get("epsilon", 1e-05)) + verbose = int(kwargs.get("verbose", 0)) + params = list(self.parameters()) + if not params: + raise RuntimeError( + f"{type(self).__name__}.fit() transformation has no optimizable parameters" + ) + optimizer = paddle.optimizer.Adam(parameters=params, lr=lr, weight_decay=0.0) + for step in range(steps): + optimizer.clear_grad() + loss = paddle.nn.functional.mse_loss(input=self.disp(), label=flow.tensor()) + loss.backward() + optimizer.step() + error = loss.detach() + converged = ( + error.less_equal(y=paddle.to_tensor(epsilon)).astype("bool").all() + ) + if verbose > 0 and (converged or step % verbose == 0): + print(f"{type(self).__name__}.fit(): step={step}, mse={error.tolist()}") + if converged: + break + + def forward(self, points: paddle.Tensor, grid: bool = False) -> paddle.Tensor: + """Transform normalized points by this spatial transformation. + + Args: + points: paddle.Tensor of shape ``(N, M, D)`` or ``(N, ..., Y, X, D)``. + grid: Whether ``points`` are the positions of undeformed grid points. + + Returns: + paddle.Tensor of same shape as ``points`` with transformed point coordinates. + + """ + transform = self.tensor().to(points.place) + apply = U.transform_grid if grid else U.transform_points + align_corners = self.align_corners() + return apply(transform, points, align_corners=align_corners) + + def points( + self, + points: paddle.Tensor, + grid: Optional[Grid] = None, + axes: Optional[Union[Axes, str]] = None, + to_grid: Optional[Grid] = None, + to_axes: Optional[Union[Axes, str]] = None, + ) -> paddle.Tensor: + """Transform points by this spatial transformation. + + Args: + points: paddle.Tensor of shape ``(N, M, D)`` or ``(N, ..., Y, X, D)``. + grid: Grid with respect to which input ``points`` are defined. Uses ``self.grid()`` if ``None``. + axes: Coordinate axes with respect to which ``points`` are defined. Uses ``self.axes()`` if ``None``. + to_grid: Grid with respect to which output points are defined. Same as ``grid`` if ``None``. + to_axes: Coordinate axes to which ``points`` should be mapped to. Same as ``axes`` if ``None``. + + Returns: + Point coordinates in ``(grid, axes)`` spatially transformed and mapped to coordinates with respect to ``(to_grid, to_axes)``. + + """ + if grid is None: + grid = self.grid() + if axes is None: + axes = self.axes() + else: + axes = Axes.from_arg(axes) + if to_grid is None: + to_grid = grid + if to_axes is None: + to_axes = axes + else: + to_axes = Axes.from_arg(to_axes) + points = grid.transform_points( + points, axes=axes, to_grid=self.grid(), to_axes=self.axes(), decimals=None + ) + points = self.forward(points) + points = self.grid().transform_points( + points, axes=self.axes(), to_grid=to_grid, to_axes=to_axes, decimals=None + ) + return points + + def disp(self, grid: Optional[Grid] = None) -> paddle.Tensor: + """Get displacement vector field representation of this transformation. + + Args: + grid: Grid on which to sample vector fields. Use ``self.grid()`` if ``None``. + + Returns: + Displacement vector fields as tensor of shape ``(N, D, ..., X)``. + + """ + if grid is None: + grid = self.grid() + data = self.tensor() + if data.ndim < 3: + raise AssertionError( + f"SpatialTransform.disp() expected {type(self).__name__}.tensor() to be at least 3-dimensional" + ) + if tuple(data.shape)[1] != self.dim(): + raise AssertionError( + f"SpatialTransform.disp() expected {type(self).__name__}.tensor() shape[1] to be equal to {type(self).__name__}.dim()={self.dim()}" + ) + if data.ndim == 3: + assert self.linear + data = U.affine_flow(data, grid) + else: + assert not self.linear + align_corners = grid.align_corners() + if grid != self.grid() or align_corners != self.align_corners(): + flow = FlowFields(data, grid=self.grid().reshape(tuple(data.shape)[2:])) + flow = flow.sample(shape=grid) + data = flow.tensor() + elif tuple(grid.shape) != tuple(data.shape)[2:]: + data = U.grid_reshape( + data, tuple(grid.shape), align_corners=align_corners + ) + return data + + @final + def flow( + self, grid: Optional[Grid] = None, device: Optional[Device] = None + ) -> FlowFields: + """Get flow field representation of this transformation. + + Args: + grid: Grid on which to sample flow fields. Use ``self.grid()`` if ``None``. + device: Device on which to store returned flow fields. + + Returns: + Flow fields defined by a tensor of shape ``(N, D, ..., X)`` and a common spatial grid. + + """ + if grid is None: + grid = self.grid() + data = self.disp(grid) + return FlowFields(data, grid=grid, device=device) + + @abstractmethod + def tensor(self) -> paddle.Tensor: + """Get tensor representation of this transformation. + + The tensor representation of a transformation is with respect to the unit cube axes defined + by its sampling grid as specified by ``self.axes()``. For a non-rigid transformation it is + a displacement vector field. For linear transformations, it is a batch of homogeneous + transformation tensors whose shape determines the type of linear transformation. + + Returns: + Returns a batch of homogeneous transformation matrices as tensor of shape ``(N, D, 1)`` + (translation), ``(N, D, D)`` (affine) or ``(N, D, D + 1)``, i.e., a 3-dimensional tensor, + if this transformation is a :class:`.LinearTransform`. In case of a non-rigid transformation, + a displacement vector field is returned as tensor of shape ``(N, D, ..., X)``, i.e., a + higher dimensional tensor, where ``D = self.ndim`` and the number of tensor dimensions is + equal to ``D + 2``. + + """ + raise NotImplementedError(f"{type(self).__name__}.tensor()") + + def update(self: TSpatialTransform) -> TSpatialTransform: + """Update internal state of this transformation. + + This function is called by a pre-forward hook. It can be overriden by subclasses to + update their internal state, e.g., to obtain current predictions of transformation + parameters, to compute a dense vector field from spline coefficients, to compute + a displacement field from a velocity field, etc. When calling other functions than + the module's ``__call__`` function, the ``update()`` of the transformation must be + called explicitly unless it is known that the specific transformation keeps no + internal state other than its (optimizable) parameters. + + """ + return self + + @staticmethod + def _update_hook(transform: paddle.nn.Layer, *args, **kwargs) -> None: + """Update callback which is registered as pre-forward hook.""" + assert isinstance(transform, SpatialTransform) + transform.update() + + def register_update_hook(self) -> None: + """Register a forward pre-hook which invokes :meth:`.SpatialTransform.update`.""" + self._update_hook_handle = self.register_forward_pre_hook( + hook=self._update_hook + ) + + def remove_update_hook(self) -> None: + """Remove previously registered :meth:`.SpatialTransform.update` hook.""" + if self._update_hook_handle is not None: + self._update_hook_handle.remove() + self._update_hook_handle = None + + @property + def inv(self: TSpatialTransform) -> TSpatialTransform: + """Get inverse transformation. + + Convenience property for applying the inverse transformation, e.g., + + .. code-block:: python + + y = transform(x) + x = transform.inv(y) + + """ + return self.inverse(link=True, update_buffers=True) + + def inverse( + self: TSpatialTransform, link: bool = False, update_buffers: bool = False + ) -> TSpatialTransform: + """Get inverse of this transformation. + + Args: + link: Whether the inverse transformation keeps a reference to this transformation. + If ``True``, the ``update()`` function of the inverse function will not recompute + shared parameters, e.g., parameters obtained by a callable neural network, but + directly access the parameters from this transformation. Note that when ``False``, + the inverse transformation will still share parameters, modules, and buffers with + this transformation, but these shared tensors may be replaced by a call of ``update()`` + (which is implicitly called as pre-forward hook when ``__call__()`` is invoked). + update_buffers: Whether buffers of inverse transformation should be updated after creating + the shallow copy. If ``False``, the ``update()`` function of the returned inverse + transformation has to be called before it is used. + + Returns: + Shallow copy of this transformation which computes and applies the inverse transformation. + The inverse transformation will share the parameters with this transformation. Not all + transformations may implement this functionality. + + Raises: + NotImplementedError: Transformation does not support sharing parameters with its inverse. + + """ + raise NotImplementedError(f"{type(self).__name__}.inverse()") + + def clear_buffers(self: TSpatialTransform) -> TSpatialTransform: + """Clear any buffers that are registered by ``self.update()``.""" + ... + + def extra_repr(self) -> str: + return f"grid={repr(self.grid())}" + + +class LinearTransform(SpatialTransform): + """Homogeneous coordinate transformation.""" + + @overload + def matrix(self) -> paddle.Tensor: + """Get matrix representation of linear transformation.""" + ... + + @overload + def matrix(self: TLinearTransform, arg: paddle.Tensor) -> TLinearTransform: + """Get shallow copy of this transformation with parameters obtained from given matrix.""" + ... + + @final + def matrix( + self: TLinearTransform, arg: Optional[paddle.Tensor] = None + ) -> Union[TLinearTransform, paddle.Tensor]: + """Get matrix representation of linear transformation or shallow copy with parameters set from matrix.""" + if arg is None: + return as_homogeneous_matrix(self.tensor()) + return shallow_copy(self).matrix_(arg) + + def matrix_(self: TLinearTransform, arg: paddle.Tensor) -> TLinearTransform: + raise NotImplementedError(f"{type(self).__name__}.matrix_()") + + +class NonRigidTransform(SpatialTransform): + """Base class for non-linear transformation models. + + All non-linear transformation models parameterize a dense displacement vector field, either a separate + displacement for each image in an input batch (e.g., for groupwise registration, batched registration), + or a single displacement field used to deform all images in an input batch. The parameterization, and + thereby the set of optimizable parameters, is defined by subclasses. The ``tensor()`` function must be + implemented by subclasses to evaluate the non-parametric dense displacement field given the current model + parameters. The flow vectors must be with respect to the grid ``Axes.CUBE``, i.e., where coordinate + -1 corresponds to the left edge of the unit cube with side length 2, and coordinate 1 the right edge of + this unit cube, respectively. Note that this corresponds to option ``align_corners=False`` of + ``paddle.nn.functional.grid_sample``. + + """ + + @final + def tensor(self) -> paddle.Tensor: + """Get tensor representation of this transformation. + + Returns: + Batch of displacement vector fields as tensor of shape ``(N, D, ..., X)``. + + """ + u = getattr(self, "u", None) + if u is None: + u = getattr(self.update(), "u", None) + if u is None or "u" not in {name for name, _ in self.named_buffers()}: + raise AssertionError( + f"{type(self).__name__}.update() required to register displacement vector field tensor as buffer named 'u'. See also NonRigidTransform.update() docstring." + ) + if not isinstance(u, paddle.Tensor): + raise AssertionError(f"{type(self).__name__}.tensor() 'u' must be tensor") + if u.ndim != self.ndim + 2: + raise AssertionError( + f"{type(self).__name__}.tensor() 'u' must be {self.ndim + 2}-dimensional" + ) + if tuple(u.shape)[1] != self.ndim: + raise AssertionError( + f"{type(self).__name__}.tensor() 'u' must have shape (N, {self.ndim}, ..., X)" + ) + return u + + def update(self: TNonRigidTransform) -> TNonRigidTransform: + """Update buffered vector fields. + + Required: + ``u``: Displacement vector field representation of non-rigid transformation. + + Optional: + ``v``: Velocity vector field representation of non-rigid transformation if applicable. + When this buffer is set, it can be used in a regularization term to encourage smoothness + or other desired properties on the (stationary) velocity field. Alternatively, a + regularization term may be based directly on the optimizable parameters. + + Returns: + Self reference to this updated transformation. + + """ + return self + + def clear_buffers(self: TNonRigidTransform) -> TNonRigidTransform: + """Clear any buffers that are registered by ``self.update()``.""" + super().clear_buffers() + for name in ("u", "v"): + try: + delattr(self, name) + except AttributeError: + pass + return self diff --git a/jointContribution/HighResolution/deepali/spatial/bspline.py b/jointContribution/HighResolution/deepali/spatial/bspline.py new file mode 100644 index 0000000000..4197703bca --- /dev/null +++ b/jointContribution/HighResolution/deepali/spatial/bspline.py @@ -0,0 +1,314 @@ +from __future__ import annotations + +from copy import copy as shallow_copy +from typing import Callable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import TypeVar +from typing import Union +from typing import cast +from typing import overload + +import paddle + +from ..core import functional as U +from ..core import kernels as K +from ..core.enum import SpatialDim +from ..core.grid import Grid +from ..core.types import ScalarOrTuple +from ..modules import ExpFlow +from .base import NonRigidTransform +from .parametric import ParametricTransform + +TBSplineTransform = TypeVar("TBSplineTransform", bound="BSplineTransform") + + +class BSplineTransform(ParametricTransform, NonRigidTransform): + """Non-rigid transformation parameterized by cubic B-spline function.""" + + def __init__( + self, + grid: Grid, + groups: Optional[int] = None, + params: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + stride: Optional[Union[int, Sequence[int]]] = None, + transpose: bool = False, + ) -> None: + """Initialize transformation parameters. + + Args: + grid: Grid domain on which transformation is defined. + groups: Number of transformations. A given image batch can either be deformed by a + single transformation, or a separate transformation for each image in the batch, e.g., + for group-wise or batched registration. The default is one transformation for all images + in the batch, or the batch length of the ``params`` tensor if provided. + params: Initial parameters. If a tensor is given, it is only registered as optimizable module + parameters when of type ``paddle.nn.Parameter``. When a callable is given instead, it will be + called by ``self.update()`` with arguments set and given by ``self.condition()``. When a boolean + argument is given, a new zero-initialized tensor is created. If ``True``, this tensor is registered + as optimizable module parameter. + stride: Number of grid points between control points plus one. This is the stride of the + transposed convolution used to upsample the control point displacements to the sampling ``grid`` + size. If ``None``, a stride of 1 is used. If a sequence of values is given, these must be the + strides for the different spatial grid dimensions in the order ``(sx, sy, sz)``. Note that + when the control point grid is subdivided in order to double its size along each spatial + dimension, the stride with respect to this subdivided control point grid remains the same. + transpose: Whether to use separable transposed convolution as implemented in AIRLab. + When ``False``, a more efficient implementation using multi-channel convolution followed + by a reshuffling of the output is performed. This more efficient and also more accurate + implementation is adapted from the C++ code of MIRTK (``mirtk::BSplineInterpolateImageFunction``). + + """ + if not grid.align_corners(): + raise ValueError( + "BSplineTransform() requires 'grid.align_corners() == True'" + ) + if stride is None: + stride = 5 + if isinstance(stride, int): + stride = (stride,) * grid.ndim + if len(stride) != grid.ndim: + raise ValueError( + f"BSplineTransform() 'stride' must be single int or {grid.ndim} ints" + ) + self.stride = tuple(int(s) for s in stride) + self._transpose = transpose + super().__init__(grid, groups=groups, params=params) + self.register_kernels(stride) + + @property + def data_shape(self) -> list: + """Get shape of transformation parameters tensor.""" + grid = self.grid() + shape = U.cubic_bspline_control_point_grid_size( + tuple(grid.shape), self.data_stride + ) + return tuple((grid.ndim,) + shape) + + @property + def data_stride(self) -> Tuple[int, ...]: + return tuple(reversed([int(s) for s in self.stride])) + + @paddle.no_grad() + def grid_(self: TBSplineTransform, grid: Grid) -> TBSplineTransform: + """Set sampling grid of transformation domain and codomain. + + If ``self.params`` is a callable, only the grid attribute is updated, and + the callable must return a tensor of matching size upon next evaluation. + + Args: + grid: New sampling grid for dense displacement field at which FFD is evaluated. + This function currently only supports subdivision of the control point grid, + i.e., the new ``grid`` must have size ``2 * n - 1`` along each spatial dimension + that should be subdivided, where ``n`` is the current grid size, or have the same + size as the current grid for dimensions that remain the same. + + Returns: + Reference to this modified transformation object. + + """ + params = self.params + current_grid = self._grid + if grid.ndim != current_grid.ndim: + raise ValueError( + f"{type(self).__name__}.grid_() argument must have {current_grid.ndim} dimensions" + ) + subdivide_dims: List[SpatialDim] = [] + if isinstance(params, paddle.Tensor): + current_grid = self._grid + if grid.ndim != current_grid.ndim: + raise ValueError( + f"{type(self).__name__}.grid_() argument must have {current_grid.ndim} dimensions" + ) + if not grid.align_corners(): + raise ValueError( + f"{type(self).__name__}() requires grid.align_corners() to be True" + ) + if not grid.same_domain_as(current_grid): + raise ValueError( + f"{type(self).__name__}.grid_() argument must define same grid domain as current grid" + ) + new_size = grid.size() + current_size = current_grid.size() + for i in range(grid.ndim): + if new_size[i] == 2 * current_size[i] - 1: + subdivide_dims.append(SpatialDim(i)) + elif new_size[i] != current_size[i]: + raise ValueError( + f"{type(self).__name__}.grid_() argument must have same size or new size '2n - 1'" + ) + self._grid = grid + if subdivide_dims: + new_shape = (tuple(params.shape)[0],) + self.data_shape + new_params = U.subdivide_cubic_bspline(params, dims=subdivide_dims) + for dim in subdivide_dims: + dim = dim.tensor_dim(params.ndim) + start = 1 + length = new_shape[dim] + new_params = paddle.slice( + new_params, axes=[dim], starts=[start], ends=[start + length] + ) + self.data_(new_params) + return self + + @staticmethod + def kernel_name(stride: int) -> str: + """Get name of buffer for 1-dimensional kernel for given control point spacing.""" + return "kernel_stride_" + str(stride) + + @overload + def kernel(self) -> Tuple[paddle.Tensor, ...]: + ... + + @overload + def kernel(self, stride: int) -> paddle.Tensor: + ... + + @overload + def kernel(self, stride: Sequence[int]) -> Tuple[paddle.Tensor, ...]: + ... + + def kernel( + self, stride: Optional[ScalarOrTuple[int]] = None + ) -> Union[paddle.Tensor, Tuple[paddle.Tensor, ...]]: + """Get 1-dimensional kernels for given control point spacing.""" + if stride is None: + stride = self.stride + if isinstance(stride, int): + return getattr(self, self.kernel_name(stride)) + return tuple(getattr(self, self.kernel_name(s)) for s in stride) + + def register_kernels(self, stride: Union[int, Sequence[int]]) -> None: + """Precompute cubic B-spline kernels.""" + if isinstance(stride, int): + stride = [stride] + for s in stride: + name = self.kernel_name(s) + if not hasattr(self, name): + if self._transpose: + kernel = K.cubic_bspline1d(s) + else: + kernel = U.bspline_interpolation_weights(degree=3, stride=s) + self.register_buffer(name=name, tensor=kernel, persistable=False) + + def deregister_kernels(self, stride: Union[int, Sequence[int]]) -> None: + """Remove precomputed cubic B-spline kernels.""" + if isinstance(stride, int): + stride = [stride] + for s in stride: + name = self.kernel_name(s) + if hasattr(self, name): + delattr(self, name) + + def evaluate_spline(self) -> paddle.Tensor: + """Evaluate cubic B-spline at sampling grid points.""" + data = self.data() + grid = self.grid() + if not grid.align_corners(): + raise AssertionError( + f"{type(self).__name__}() requires grid.align_corners() to be True" + ) + stride = self.stride + kernel = self.kernel(stride) + u = U.evaluate_cubic_bspline( + data, + shape=tuple(grid.shape), + stride=stride, + kernel=kernel, + transpose=self._transpose, + ) + return u + + +class FreeFormDeformation(BSplineTransform): + """Cubic B-spline free-form deformation model.""" + + def update(self) -> FreeFormDeformation: + """Update buffered displacement vector field.""" + super().update() + u = self.evaluate_spline() + self.register_buffer(name="u", tensor=u, persistable=False) + return self + + +class StationaryVelocityFreeFormDeformation(BSplineTransform): + """Stationary velocity field based transformation model using cubic B-spline parameterization.""" + + def __init__( + self, + grid: Grid, + groups: Optional[int] = None, + params: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + stride: Optional[Union[int, Sequence[int]]] = None, + scale: Optional[float] = None, + steps: Optional[int] = None, + transpose: bool = False, + ) -> None: + """Initialize transformation parameters. + + Args: + grid: Grid on which to sample flow field vectors. + groups: Number of velocity fields. + params: Initial parameters of cubic B-spline velocity fields of shape ``(N, C, ...X)``. + stride: Number of grid points between control points (minus one). + scale: Constant scaling factor of velocity fields. + steps: Number of scaling and squaring steps. + transpose: Whether to use separable transposed convolution as implemented in AIRLab. + When ``False``, a more efficient implementation using multi-channel convolution followed + by a reshuffling of the output is performed. This more efficient and also more accurate + implementation is adapted from the C++ code of MIRTK (``mirtk::BSplineInterpolateImageFunction``). + + """ + align_corners = grid.align_corners() + super().__init__( + grid, groups=groups, params=params, stride=stride, transpose=transpose + ) + self.exp = ExpFlow(scale=scale, steps=steps, align_corners=align_corners) + + def inverse( + self, link: bool = False, update_buffers: bool = False + ) -> StationaryVelocityFreeFormDeformation: + """Get inverse of this transformation. + + Args: + link: Whether the inverse transformation keeps a reference to this transformation. + If ``True``, the ``update()`` function of the inverse function will not recompute + shared parameters (e.g., parameters obtained by a callable neural network), but + directly access the parameters from this transformation. Note that when ``False``, + the inverse transformation will still share parameters, modules, and buffers with + this transformation, but these shared tensors may be replaced by a call of ``update()`` + (which is implicitly called as pre-forward hook when ``__call__()`` is invoked). + update_buffers: Whether buffers of inverse transformation should be updated after creating + the shallow copy. If ``False``, the ``update()`` function of the returned inverse + transformation has to be called before it is used. + + Returns: + Shallow copy of this transformation with ``exp`` module which uses negative scaling factor + to scale and square the stationary velocity field to compute the inverse displacement field. + + """ + inv = shallow_copy(self) + if link: + inv.link_(self) + inv.exp = cast(ExpFlow, self.exp).inverse() + if update_buffers: + v = getattr(inv, "v", None) + if v is not None: + u = inv.exp(v) + inv.register_buffer(name="u", tensor=u, persistable=False) + return inv + + def update(self) -> StationaryVelocityFreeFormDeformation: + """Update buffered velocity and displacement vector fields.""" + super().update() + v = self.evaluate_spline() + u = self.exp(v) + self.register_buffer(name="v", tensor=v, persistable=False) + self.register_buffer(name="u", tensor=u, persistable=False) + return self diff --git a/jointContribution/HighResolution/deepali/spatial/composite.py b/jointContribution/HighResolution/deepali/spatial/composite.py new file mode 100644 index 0000000000..fe1a49a58d --- /dev/null +++ b/jointContribution/HighResolution/deepali/spatial/composite.py @@ -0,0 +1,353 @@ +from __future__ import annotations + +from collections import OrderedDict +from copy import copy as shallow_copy +from typing import Iterable +from typing import Optional +from typing import Tuple +from typing import TypeVar +from typing import Union +from typing import overload + +import paddle + +from ..core.grid import Axes +from ..core.grid import Grid +from ..core.grid import grid_transform_points +from ..core.linalg import as_homogeneous_matrix +from ..core.linalg import homogeneous_matmul +from ..core.tensor import move_dim +from .base import SpatialTransform + +TCompositeTransform = TypeVar("TCompositeTransform", bound="CompositeTransform") + + +class CompositeTransform(SpatialTransform): + """Base class of composite spatial coordinate transformations. + + Base class of modules that apply one or more spatial transformations to map a tensor of + spatial points to another tensor of spatial points of the same shape as the input tensor. + + """ + + @overload + def __init__(self, grid: Grid) -> None: + """Initialize empty composite transformation.""" + ... + + @overload + def __init__(self, grid: Grid, *args: Optional[SpatialTransform]) -> None: + """Initialize composite transformation.""" + ... + + @overload + def __init__( + self, grid: Grid, transforms: Union[OrderedDict, paddle.nn.LayerDict] + ) -> None: + """Initialize composite transformation given named transforms in ordered dictionary.""" + ... + + @overload + def __init__(self, *args: Optional[SpatialTransform]) -> None: + """Initialize composite transformation.""" + ... + + @overload + def __init__(self, transforms: Union[paddle.nn.LayerDict, OrderedDict]) -> None: + """Initialize composite transformation given named transforms in ordered dictionary.""" + ... + + def __init__( + self, + *args: Optional[ + Union[Grid, paddle.nn.LayerDict, OrderedDict, SpatialTransform] + ], + ) -> None: + """Initialize composite transformation.""" + args_ = [arg for arg in args if arg is not None] + grid = None + if isinstance(args_[0], Grid): + grid = args_[0] + args_ = args_[1:] + if args_: + if isinstance(args_[0], (dict, paddle.nn.LayerDict)): + if len(args_) > 1: + raise ValueError( + f"{type(self).__name__}() multiple arguments not allowed when dict is given" + ) + transforms = args_[0] + else: + transforms = OrderedDict([(str(i), t) for i, t in enumerate(args_)]) + else: + transforms = OrderedDict() + if grid is None: + if transforms: + transform = next(iter(transforms.values())) + grid = transform.grid() + else: + raise ValueError( + f"{type(self).__name__}() requires a Grid or at least one SpatialTransform" + ) + for name, transform in transforms.items(): + if not isinstance(transform, SpatialTransform): + raise TypeError( + f"{type(self).__name__}() module '{name}' must be of type SpatialTransform" + ) + if not transform.grid().same_domain_as(grid): + raise ValueError( + f"{type(self).__name__}() transform '{name}' has different 'grid' center, direction, or cube extent" + ) + super().__init__(grid) + self._transforms = paddle.nn.LayerDict(sublayers=transforms) + + def bool(self) -> bool: + """Whether this module has at least one transformation.""" + return len(self._transforms) > 0 + + def __len__(self) -> int: + """Number of spatial transformations.""" + return len(self._transforms) + + @property + def linear(self) -> bool: + """Whether composite transformation is linear.""" + return all(transform.linear for transform in self.transforms()) + + def __contains__(self, name: Union[int, str]) -> bool: + """Whether composite contains named transformation.""" + if isinstance(name, int): + name = str(name) + return name in self._transforms.keys() + + def __getitem__(self, name: Union[int, str]) -> SpatialTransform: + """Get named transformation.""" + if isinstance(name, int): + name = str(name) + return self._transforms[name] + + def get( + self, name: Union[int, str], default: Optional[SpatialTransform] = None + ) -> Optional[SpatialTransform]: + """Get named transformation.""" + if isinstance(name, int): + name = str(name) + for key, transform in self._transforms.items(): + if key == name: + assert isinstance(transform, SpatialTransform) + return transform + return default + + def transforms(self) -> Iterable[SpatialTransform]: + """Iterate transformations in order of composition.""" + return self._transforms.values() + + def named_transforms(self) -> Iterable[Tuple[str, SpatialTransform]]: + """Iterate transformations in order of composition.""" + return self._transforms.items() + + def condition( + self: TCompositeTransform, *args, **kwargs + ) -> Union[TCompositeTransform, Optional[paddle.Tensor]]: + """Get or set data tensor on which transformations are conditioned.""" + if args or kwargs: + return shallow_copy(self).condition_(*args, **kwargs) + return self._args, self._kwargs + + def condition_(self: TCompositeTransform, *args, **kwargs) -> TCompositeTransform: + """Set data tensor on which transformations are conditioned.""" + assert args or kwargs + super().condition_(*args, **kwargs) + for transform in self.transforms(): + transform.condition_(*args, **kwargs) + return self + + def disp(self, grid: Optional[Grid] = None) -> paddle.Tensor: + """Get displacement vector field representation of this transformation. + + Args: + grid: Grid on which to sample vector fields. Use ``self.grid()`` if ``None``. + + Returns: + Displacement vector fields as tensor of shape ``(N, D, ..., X)``. + + """ + if grid is None: + grid = self.grid() + axes = Axes.from_grid(grid) + x = grid.coords(device=self.device).unsqueeze(axis=0) + if grid.same_domain_as(self.grid()): + y = self.forward(x) + else: + y = grid_transform_points(x, grid, axes, self.grid(), self.axes()) + y = self.forward(y) + y = grid_transform_points(y, self.grid(), self.axes(), grid, axes) + u = y - x + u = move_dim(u, -1, 1) + return u + + def update(self: TCompositeTransform) -> TCompositeTransform: + """Update buffered data such as predicted parameters, velocities, and/or displacements.""" + super().update() + for transform in self.transforms(): + transform.update() + return self + + def clear_buffers(self: TCompositeTransform) -> TCompositeTransform: + """Clear any buffers that are registered by ``self.update()``.""" + super().clear_buffers() + for transform in self.transforms(): + transform.clear_buffers() + return self + + +class MultiLevelTransform(CompositeTransform): + """Sum of spatial transformations applied to any set of points. + + A :class:`.MultiLevelTransform` adds the sum of the displacement vectors across all + spatial transforms at the input points that are being mapped to new locations, i.e., + + .. math:: + + \\vec{y} = \\vec{x} + \\sum_{i=0}^{n-1} \\vec{u}_i(\\vec{x}) + + """ + + def forward(self, points: paddle.Tensor, grid: bool = False) -> paddle.Tensor: + """Transform set of points by sum of spatial transformations. + + Args: + points: paddle.Tensor of shape ``(N, M, D)`` or ``(N, ..., Y, X, D)``. + grid: Whether ``points`` are the positions of undeformed grid points. + + Returns: + paddle.Tensor of same shape as ``points`` with transformed point coordinates. + + """ + x = points + if len(self) == 0: + return x + if self.linear: + y = super().forward(points, grid) + else: + u = paddle.zeros_like(x=x) + for i, transform in enumerate(self.transforms()): + y = transform.forward(x, grid=grid and i == 0) + u += y - x + y = x + u + return y + + def tensor(self) -> paddle.Tensor: + """Get tensor representation of this transformation. + + The tensor representation of a transformation is with respect to the unit cube axes defined + by its sampling grid as specified by ``self.axes()``. + + Returns: + In case of a composition of linear transformations, returns a batch of homogeneous transformation + matrices as tensor of shape ``(N, D, 1)`` (translation), ``(N, D, D)`` (affine) or ``(N, D, D + 1)``, + i.e., a 3-dimensional tensor. If this composite transformation contains a non-rigid transformation, + a displacement vector field is returned as tensor of shape ``(N, D, ..., X)``. + + """ + if self.linear: + transforms = list(self.transforms()) + if not transforms: + identity = paddle.eye(num_rows=self.ndim, num_columns=self.ndim + 1) + return identity.unsqueeze(axis=0) + transform = transforms[0] + mat = as_homogeneous_matrix(transform.tensor()) + for transform in transforms[1:]: + mat += as_homogeneous_matrix(transform.tensor()) + return mat + return self.disp() + + +class SequentialTransform(CompositeTransform): + """Composition of spatial transformations applied to any set of points. + + A :class:`.SequentialTransform` is the functional composition of spatial transforms, i.e., + + .. math:: + + \\vec{y} = \\vec{u}_{n-1} \\circ \\cdots \\circ \\vec{u}_0 \\circ \\vec{x} + + """ + + def forward(self, points: paddle.Tensor, grid: bool = False) -> paddle.Tensor: + """Transform points by sequence of spatial transformations. + + Args: + points: paddle.Tensor of shape ``(N, M, D)`` or ``(N, ..., Y, X, D)``. + grid: Whether ``points`` are the positions of undeformed grid points. + + Returns: + paddle.Tensor of same shape as ``points`` with transformed point coordinates. + + """ + if self.linear: + return super().forward(points, grid) + y = points + for i, transform in enumerate(self.transforms()): + y = transform.forward(y, grid=grid and i == 0) + return y + + def tensor(self) -> paddle.Tensor: + """Get tensor representation of this transformation. + + The tensor representation of a transformation is with respect to the unit cube axes defined + by its sampling grid as specified by ``self.axes()``. + + Returns: + In case of a composition of linear transformations, returns a batch of homogeneous transformation + matrices as tensor of shape ``(N, D, 1)`` (translation), ``(N, D, D)`` (affine) or ``(N, D, D + 1)``, + i.e., a 3-dimensional tensor. If this composite transformation contains a non-rigid transformation, + a displacement vector field is returned as tensor of shape ``(N, D, ..., X)``. + + """ + if self.linear: + transforms = list(self.transforms()) + if not transforms: + identity = paddle.eye(num_rows=self.ndim, num_columns=self.ndim + 1) + return identity.unsqueeze(axis=0) + transform = transforms[0] + mat = transform.tensor() + for transform in transforms[1:]: + mat = homogeneous_matmul(transform.tensor(), mat) + return mat + return self.disp() + + def inverse( + self: TCompositeTransform, link: bool = False, update_buffers: bool = False + ) -> TCompositeTransform: + """Get inverse of this transformation. + + Args: + link: Whether to inverse transformation keeps a reference to this transformation. + If ``True``, the ``update()`` function of the inverse function will not recompute + shared parameters, e.g., parameters obtained by a callable neural network, but + directly access the parameters from this transformation. Note that when ``False``, + the inverse transformation will still share parameters, modules, and buffers with + this transformation, but these shared tensors may be replaced by a call of ``update()`` + (which is implicitly called as pre-forward hook when ``__call__()`` is invoked). + update_buffers: Whether buffers of inverse transformation should be update after creating + the shallow copy. If ``False``, the ``update()`` function of the returned inverse + transformation has to be called before it is used. + + Returns: + Shallow copy of this transformation which computes and applied the inverse transformation. + The inverse transformation will share the parameters with this transformation. Not all + transformations may implement this functionality. + + Raises: + NotImplementedError: When a transformation does not support sharing parameters with its inverse. + + """ + copy = shallow_copy(self) + transforms = paddle.nn.LayerDict() + for name, transform in reversed(self.named_transforms()): + assert isinstance(transform, SpatialTransform) + transforms[name] = transform.inverse( + link=link, update_buffers=update_buffers + ) + copy._transforms = transforms + return copy diff --git a/jointContribution/HighResolution/deepali/spatial/generic.py b/jointContribution/HighResolution/deepali/spatial/generic.py new file mode 100644 index 0000000000..5398c11170 --- /dev/null +++ b/jointContribution/HighResolution/deepali/spatial/generic.py @@ -0,0 +1,419 @@ +from __future__ import annotations + +from dataclasses import dataclass +from pathlib import Path +from typing import Callable +from typing import Dict +from typing import List +from typing import Mapping +from typing import Optional +from typing import Union + +import paddle + +from ..core.affine import euler_rotation_angles +from ..core.affine import euler_rotation_matrix +from ..core.config import DataclassConfig +from ..core.grid import Grid +from ..core.linalg import quaternion_to_rotation_matrix +from ..core.linalg import rotation_matrix_to_quaternion +from ..core.types import ScalarOrTuple +from .bspline import FreeFormDeformation +from .bspline import StationaryVelocityFreeFormDeformation +from .composite import SequentialTransform +from .linear import AnisotropicScaling +from .linear import EulerRotation +from .linear import HomogeneousTransform +from .linear import QuaternionRotation +from .linear import Shearing +from .linear import Translation +from .nonrigid import DisplacementFieldTransform +from .nonrigid import StationaryVelocityFieldTransform + +ParamsDict = Mapping[str, paddle.Tensor] +AFFINE_NAMES = { + "A": "affine", + "K": "shearing", + "T": "translation", + "R": "rotation", + "S": "scaling", + "Q": "quaternion", +} +"""Names of elementary affine transformation child modules. + +The dictionary key is the letter used in :attr:`TransformConfig.affine_model`, i.e., + +- ``A``: ``"affine"`` +- ``K``: ``"shearing"`` +- ``T``: ``"translation"`` +- ``R``: ``"rotation"`` +- ``S``: ``"scaling"`` +- ``Q``: ``"quaternion"`` + +""" +AFFINE_TRANSFORMS = { + "A": HomogeneousTransform, + "K": Shearing, + "T": Translation, + "R": EulerRotation, + "S": AnisotropicScaling, + "Q": QuaternionRotation, +} +"""Types of elementary affine transformations. + +The dictionary key is the letter used in :attr:`TransformConfig.affine_model`, i.e., + +- ``A``: :class:`.HomogeneousTransform` +- ``K``: :class:`.Shearing` +- ``T``: :class:`.Translation` +- ``R``: :class:`.EulerRotation` +- ``S``: :class:`.AnisotropicScaling` +- ``Q``: :class:`.QuaternionRotation` + +""" +NONRIGID_TRANSFORMS = { + "DDF": DisplacementFieldTransform, + "FFD": FreeFormDeformation, + "SVF": StationaryVelocityFieldTransform, + "SVFFD": StationaryVelocityFreeFormDeformation, +} +"""Types of non-rigid transformations. + +The dictionary key is the string used in :attr:`TransformConfig.transform`, i.e., + +- ``DDF``: :class:`.DisplacementFieldTransform` +- ``FFD``: :class:`.FreeFormDeformation` +- ``SVF``: :class:`.StationaryVelocityFieldTransform` +- ``SVFFD``: :class:`.StationaryVelocityFreeFormDeformation` + +""" +VALID_COMPONENTS = ("Affine",) + tuple(NONRIGID_TRANSFORMS.keys()) +"""Valid transformation names in :attr:`TransformConfig.transform` string value. + +This includes "Affine" and all keys of the :data:`NONRIGID_TRANSFORMS` dictionary. + +""" + + +def transform_components(model: str) -> List[str]: + """Non-rigid component of transformation or ``None`` if it is a linear transformation.""" + return model.split(" o ") + + +def valid_transform_model( + model: str, max_affine: Optional[int] = None, max_nonrigid: Optional[int] = None +) -> bool: + """Whether given string denotes a valid transformation model.""" + components = transform_components(model) + num_affine = 0 + num_nonrigid = 0 + for component in components: + if component not in VALID_COMPONENTS: + return False + if component == "Affine": + num_affine += 1 + else: + num_nonrigid += 1 + if len(components) < 1: + return False + if max_affine is not None and num_affine > max_affine: + return False + if max_nonrigid is not None and num_nonrigid > max_nonrigid: + return False + return True + + +def has_affine_component(model: str) -> bool: + """Whether transformation model includes an affine component.""" + return "Affine" in transform_components(model) + + +def has_nonrigid_component(model: str) -> bool: + """Whether transformation model includes a non-rigid component.""" + return nonrigid_components(model) + + +def nonrigid_components(model: str) -> List[str]: + """Non-rigid components of transformation model.""" + return [comp for comp in transform_components(model) if comp in NONRIGID_TRANSFORMS] + + +def affine_first(model: str) -> bool: + """Whether transformation applies affine component first.""" + components = transform_components(model) + assert components, "must contain at least one transformation component" + return components[-1] == "Affine" + + +@dataclass +class TransformConfig(DataclassConfig): + """Configuration of generic spatial transformation model.""" + + transform: str = "Affine o SVF" + """String encoding of spatial transformation model to use. + + The linear transforms making up the ``Affine`` component are defined by :attr:`affine_model`. + + The non-rigid component can be one of the following: + + - ``DDF``: :class:`.DisplacementFieldTransform` + - ``FFD``: :class:`.FreeFormDeformation` + - ``SVF``: :class:`.StationaryVelocityFieldTransform` + - ``SVFFD``: :class:`.StationaryVelocityFreeFormDeformation` + + """ + affine_model: str = "TRS" + """String encoding of composition of elementary linear transformations. + + The string value of this configuration entry can be in one of two forms: + + - Matrix notation: Each letter is a factor in the sequence of matrix-matrix products. + - Function composition: Use deliminator " o " between transformations to denote composition. + + Valid elementary linear transform identifiers are: + + - ``A``: :class:`.HomogeneousTransform` + - ``K``: :class:`.Shearing` + - ``T``: :class:`.Translation` + - ``R``: :class:`.EulerRotation` + - ``S``: :class:`.AnisotropicScaling` + - ``Q``: :class:`.QuaternionRotation` + + """ + rotation_model: str = "ZXZ" + """Order of elementary Euler rotations. + + This configuration value is only used when :attr:`affine_model` contains an :class:`EulerRotation` + denoted by letter "R". Valid values are "ZXZ", "XZX", ... (cf. :func:`.core.affine.euler_rotation_matrix`). + + """ + control_point_spacing: ScalarOrTuple[int] = 1 + """Control point spacing of non-rigid transformations. + + The spacing must be given in voxel units of the grid domain with respect to + which the transformations are defined. + + """ + scaling_and_squaring_steps: int = 6 + """Number of scaling and squaring steps in case of a stationary velocity field transform.""" + flip_grid_coords: bool = False + """Whether predicted transformation parameters are with respect to a grid + with point coordinates in the order (..., x) instead of (x, ...).""" + + def _finalize(self, parent: Path) -> None: + """Finalize parameters after loading these from input file.""" + super()._finalize(parent) + + +class GenericSpatialTransform(SequentialTransform): + """Configurable generic spatial transformation.""" + + def __init__( + self, + grid: Grid, + params: Optional[Union[bool, Callable[..., ParamsDict], ParamsDict]] = True, + config: Optional[TransformConfig] = None, + ) -> None: + """Initialize spatial transformation.""" + if ( + params not in (None, False, True) + and not callable(params) + and not isinstance(params, Mapping) + ): + raise TypeError( + f"{type(self).__name__}() 'params' must be bool, callable, dict, or None" + ) + if config is None: + config = getattr(params, "config", None) + if config is None: + raise AssertionError( + f"{type(self).__name__}() 'config' or 'params.config' required" + ) + if not isinstance(config, TransformConfig): + raise TypeError( + f"{type(self).__name__}() 'params.config' must be TransformConfig" + ) + elif not isinstance(config, TransformConfig): + raise TypeError(f"{type(self).__name__}() 'config' must be TransformConfig") + if not valid_transform_model(config.transform, max_affine=1, max_nonrigid=1): + raise ValueError( + f"{type(self).__name__}() 'config.transform' invalid or not supported: {config.transform}" + ) + modules = paddle.nn.LayerDict() + if has_affine_component(config.transform): + for key in reversed(config.affine_model.replace(" o ", "")): + key = key.upper() + if key not in AFFINE_TRANSFORMS: + raise ValueError( + f"{type(self).__name__}() invalid character '{key}' in 'config.affine_model'" + ) + name = AFFINE_NAMES[key] + if name in modules: + raise NotImplementedError( + f"{type(self).__name__}() 'config.affine_model' must contain each elementary transform at most once, but encountered key '{key}' more than once." + ) + kwargs = dict( + grid=grid, params=params if isinstance(params, bool) else None + ) + if key == "R": + kwargs["order"] = config.rotation_model + modules[name] = AFFINE_TRANSFORMS[key](**kwargs) + nonrigid_models = nonrigid_components(config.transform) + if len(nonrigid_models) > 1: + raise ValueError( + f"{type(self).__name__}() 'config.transform' must contain at most one non-rigid component" + ) + if nonrigid_models: + nonrigid_model = nonrigid_models[0] + nonrigid_params = params if isinstance(params, bool) else None + nonrigid_kwargs = dict(grid=grid, params=nonrigid_params) + NonRigidTransform = NONRIGID_TRANSFORMS[nonrigid_model] + if nonrigid_model in ("DDF", "SVF") and config.control_point_spacing > 1: + size = grid.size_tensor() + stride = paddle.to_tensor(data=config.control_point_spacing).to(size) + size = size.div(stride).ceil().astype(dtype="int64") + nonrigid_kwargs["grid"] = grid.resize(size) + if nonrigid_model == "SVF": + nonrigid_kwargs["steps"] = config.scaling_and_squaring_steps + if nonrigid_model in ("FFD", "SVFFD"): + nonrigid_kwargs["stride"] = config.control_point_spacing + _modules = paddle.nn.LayerDict( + sublayers={"nonrigid": NonRigidTransform(**nonrigid_kwargs)} + ) + if affine_first(config.transform): + modules.update(_modules) + else: + _modules.update(modules) + modules = _modules + if isinstance(params, Mapping): + for name, transform in self.named_transforms(): + transform.data_(params[name]) + super().__init__(grid, modules) + self.config = config + self.params = params if callable(params) else None + + def _data(self) -> Dict[str, paddle.Tensor]: + """Get most recent transformation parameters.""" + if not self._transforms: + return {} + params = self.params + if params is None: + params = {} + for name, transform in self.named_transforms(): + params[name] = transform.data() + return params + if isinstance(params, GenericSpatialTransform): + return {} + if callable(params): + args, kwargs = self.condition() + pred = params(*args, **kwargs) + if not isinstance(pred, Mapping): + raise TypeError( + f"{type(self).__name__} 'params' callable return value must be a Mapping" + ) + elif isinstance(params, Mapping): + pred = params + else: + raise TypeError( + f"{type(self).__name__} 'params' attribute must be a callable, Mapping, linked GenericSpatialTransform, or None" + ) + data = {} + flip_grid_coords = self.config.flip_grid_coords + if "affine" in self._transforms: + matrix = pred["affine"] + assert isinstance(matrix, paddle.Tensor) + assert matrix.ndim >= 2 + D = tuple(matrix.shape)[-2] + assert tuple(matrix.shape)[-1] == D + 1 + if flip_grid_coords: + matrix[(...), :D, :D] = matrix[(...), :D, :D].flip(axis=(1, 2)) + matrix[(...), :D, (-1)] = matrix[(...), :D, (-1)].flip(axis=-1) + data["affine"] = matrix + if "translation" in self._transforms: + if "translation" in pred: + offset = pred["translation"] + else: + offset = pred["offset"] + assert isinstance(offset, paddle.Tensor) + if flip_grid_coords: + offset = offset.flip(axis=-1) + data["translation"] = offset + if "rotation" in self._transforms: + if "rotation" in pred: + angles = pred["rotation"] + else: + angles = pred["angles"] + assert isinstance(angles, paddle.Tensor) + if flip_grid_coords: + rotmodel = self.config.rotation_model + rotation = euler_rotation_matrix(angles, order=rotmodel).flip((1, 2)) + angles = euler_rotation_angles(rotation, order=rotmodel) + data["rotation"] = angles + if "scaling" in self._transforms: + if "scaling" in pred: + scales = pred["scaling"] + else: + scales = pred["scales"] + assert isinstance(scales, paddle.Tensor) + if flip_grid_coords: + scales = scales.flip(axis=-1) + data["scaling"] = scales + if "quaternion" in self._transforms: + q = pred["quaternion"] + assert isinstance(q, paddle.Tensor) + if flip_grid_coords: + m = quaternion_to_rotation_matrix(q) + m = m.flip(axis=(1, 2)) + q = rotation_matrix_to_quaternion(m) + data["quaternion"] = q + if "nonrigid" in self._transforms: + if "nonrigid" in pred: + vfield = pred["nonrigid"] + else: + vfield = pred["vfield"] + assert isinstance(vfield, paddle.Tensor) + if flip_grid_coords: + vfield = vfield.flip(axis=1) + data["nonrigid"] = vfield + return data + + def inverse( + self, link: bool = False, update_buffers: bool = False + ) -> GenericSpatialTransform: + """Get inverse of this transformation. + + Args: + link: Whether the inverse transformation keeps a reference to this transformation. + If ``True``, the ``update()`` function of the inverse function will not recompute + shared parameters (e.g., parameters obtained by a callable neural network), but + directly access the parameters from this transformation. Note that when ``False``, + the inverse transformation will still share parameters, modules, and buffers with + this transformation, but these shared tensors may be replaced by a call of ``update()`` + (which is implicitly called as pre-forward hook when ``__call__()`` is invoked). + update_buffers: Whether buffers of inverse transformation should be updated after creating + the shallow copy. If ``False``, the ``update()`` function of the returned inverse + transformation has to be called before it is used. + + Returns: + Shallow copy of this transformation which computes and applied the inverse transformation. + The inverse transformation will share the parameters with this transformation. Not all + transformations may implement this functionality. + + Raises: + NotImplementedError: When a transformation does not support sharing parameters with its inverse. + + """ + inv = super().inverse(link=link, update_buffers=update_buffers) + if link: + inv.params = self + return inv + + def update(self) -> GenericSpatialTransform: + """Update transformation parameters.""" + if self.params is not None: + params = self._data() + for k, p in params.items(): + transform = self._transforms[k] + transform.data_(p) + super().update() + return self diff --git a/jointContribution/HighResolution/deepali/spatial/image.py b/jointContribution/HighResolution/deepali/spatial/image.py new file mode 100644 index 0000000000..e6f1822983 --- /dev/null +++ b/jointContribution/HighResolution/deepali/spatial/image.py @@ -0,0 +1,14 @@ +from deprecation import deprecated + +from ..core import __version__ +from .transformer import ImageTransformer + + +@deprecated( + deprecated_in="0.3", + removed_in="1.0", + current_version=__version__, + details="Use deepali.spatial.ImageTransformer instead", +) +class ImageTransform(ImageTransformer): + ... diff --git a/jointContribution/HighResolution/deepali/spatial/linear.py b/jointContribution/HighResolution/deepali/spatial/linear.py new file mode 100644 index 0000000000..46c314739b --- /dev/null +++ b/jointContribution/HighResolution/deepali/spatial/linear.py @@ -0,0 +1,847 @@ +from __future__ import annotations + +import math +from collections import OrderedDict +from typing import Callable +from typing import Optional +from typing import Union + +import paddle + +from ..core import affine as U +from ..core.grid import Grid +from ..core.linalg import normalize_quaternion +from ..core.linalg import quaternion_to_rotation_matrix +from ..core.linalg import rotation_matrix_to_quaternion +from ..core.tensor import as_float_tensor +from .base import LinearTransform +from .composite import SequentialTransform +from .parametric import InvertibleParametricTransform + + +class HomogeneousTransform(InvertibleParametricTransform, LinearTransform): + """Arbitrary homogeneous coordinate transformation.""" + + def __init__( + self: HomogeneousTransform, + grid: Grid, + groups: Optional[int] = None, + params: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + ) -> None: + """Initialize transformation parameters. + + Args: + grid: Grid domain on which transformation is defined. + groups: Number of transformations. Must be either 1 or equal to batch size. + params: Homogeneous transform as tensor of shape ``(N, D, D + 1)``. + + """ + super().__init__(grid, groups=groups, params=params) + + @property + def data_shape(self: HomogeneousTransform) -> list: + """Get shape of transformation parameters tensor.""" + return tuple((self.ndim, self.ndim + 1)) + + def matrix_(self: HomogeneousTransform, arg: paddle.Tensor) -> HomogeneousTransform: + """Set transformation matrix.""" + if not isinstance(arg, paddle.Tensor): + raise TypeError("HomogeneousTransform.matrix() 'arg' must be tensor") + if arg.ndim != 3: + raise ValueError( + "HomogeneousTransform.matrix() 'arg' must be 3-dimensional tensor" + ) + return self.data_(arg) + + def tensor(self: HomogeneousTransform) -> paddle.Tensor: + """Get tensor representation of this transformation + + Returns: + Batch of homogeneous transformation matrices as tensor of shape ``(N, D, D + 1)``. + + """ + matrix = self.data() + if self.invert: + N = tuple(matrix.shape)[0] + D = tuple(matrix.shape)[1] + row = paddle.zeros(shape=(1, 1, D + 1), dtype=matrix.dtype) + row[..., -1] = 1 + matrix = paddle.concat(x=(matrix, row.expand(shape=[N, 1, D + 1])), axis=1) + matrix = paddle.linalg.inv(x=matrix) + start_4 = matrix.shape[1] + 0 if 0 < 0 else 0 + matrix = paddle.slice(matrix, [1], [start_4], [start_4 + D]) + matrix = matrix + return matrix + + def extra_repr(self: HomogeneousTransform) -> str: + """Print current transformation.""" + s = super().extra_repr() + ", matrix=" + if self.params is None: + s += "undef" + else: + s += f"{self.matrix().tolist()!r}" + return s + + +class Translation(InvertibleParametricTransform, LinearTransform): + """Translation.""" + + def __init__( + self: Translation, + grid: Grid, + groups: Optional[int] = None, + params: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + ) -> None: + """Initialize transformation parameters. + + Args: + grid: Grid domain on which transformation is defined. + groups: Number of transformations. Must be either 1 or equal to batch size. + params: Translation offsets as tensor of shape ``(N, D)`` or ``(N, D, 1)``. + + """ + super().__init__(grid, groups=groups, params=params) + + @property + def data_shape(self: Translation) -> list: + """Get shape of transformation parameters tensor.""" + return tuple((self.ndim,)) + + def offset(self: Translation) -> paddle.Tensor: + """Get current translation offset in cube units.""" + return self.data() + + @paddle.no_grad() + def offset_(self: Translation, arg: paddle.Tensor) -> Translation: + """Reset parameters to given translation in cube units.""" + if not isinstance(arg, paddle.Tensor): + raise TypeError("Translation.offset() 'arg' must be tensor") + params = as_float_tensor(arg) + if params.isnan().astype("bool").any() or params.isinf().astype("bool").any(): + raise ValueError("Translation.offset() 'arg' must not be nan or inf") + self.data_(params) + return self + + def tensor(self: Translation) -> paddle.Tensor: + """Get tensor representation of this transformation + + Returns: + Batch of homogeneous transformation matrices as tensor of shape ``(N, D, 1)``. + + """ + offset = self.offset() + if self.invert: + offset = -offset + return U.translation(offset) + + def extra_repr(self: Translation) -> str: + """Print current transformation.""" + s = super().extra_repr() + ", offset=" + if self.params is None: + s += "undef" + else: + s += f"{self.offset().tolist()!r}" + return s + + +class EulerRotation(InvertibleParametricTransform, LinearTransform): + """Euler rotation.""" + + def __init__( + self: EulerRotation, + grid: Grid, + groups: Optional[int] = None, + params: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + order: Optional[str] = None, + ) -> None: + """Initialize transformation parameters. + + Args: + grid: Grid domain on which transformation is defined. + groups: Number of transformations. Must be 1 or equal to batch size. + params: Rotation angles in degrees. This parameterization is adopted from MIRTK + and ensures that rotation angles and scaling factors (percentage) are within + a similar range of magnitude which is useful for direct optimization of these + parameters. If parameters are predicted by a callable ``paddle.nn.Layer``, + different output activations may be chosen before converting these to degrees. + order: Order in which to compose elementary rotations. For example in 3D, "zxz" means + that the first rotation occurs about z, the second about x, and the third rotation + about z again. In 2D, this argument is ignored and a single rotation about z + (plane normal) is applied. + + """ + if grid.ndim < 2 or grid.ndim > 3: + raise ValueError("EulerRotation() 'grid' must be 2- or 3-dimensional") + super().__init__(grid, groups=groups, params=params) + self.order = order + + @property + def data_shape(self: EulerRotation) -> list: + """Get shape of transformation parameters tensor.""" + return tuple((self.nangles,)) + + @property + def nangles(self: EulerRotation) -> int: + """Number of Euler angles.""" + return 1 if self.ndim == 2 else self.ndim + + def angles(self: EulerRotation) -> paddle.Tensor: + """Get Euler angles in radians.""" + params = self.data() + if self.has_parameters(): + params = params.tanh().mul(math.pi) + return params + + @paddle.no_grad() + def angles_(self: EulerRotation, arg: paddle.Tensor) -> EulerRotation: + """Reset parameters to given Euler angles in radians.""" + if not isinstance(arg, paddle.Tensor): + raise TypeError("EulerRotation.angles() 'arg' must be tensor") + shape = self.data_shape + if arg.ndim != len(shape) + 1: + raise ValueError( + f"EulerRotation.angles() 'arg' must be {len(shape) + 1}-dimensional tensor" + ) + shape = (arg.shape[0],) + self.data_shape + if arg.shape != shape: + raise ValueError(f"EulerRotation.angles() 'arg' must have shape {shape!r}") + params = as_float_tensor(arg) + if self.has_parameters(): + params = params.div(math.pi).atanh() + if params.isnan().astype("bool").any(): + raise ValueError( + "EulerRotation.angles() 'arg' must be in range [-pi, pi]" + ) + self.data_(params) + return self + + def tensor(self: EulerRotation) -> paddle.Tensor: + """Get tensor representation of this transformation + + Returns: + Batch of homogeneous transformation matrices as tensor of shape ``(N, D, D)``. + + """ + mat = U.euler_rotation_matrix(self.angles(), order=self.order) + if self.invert: + x = mat + perm_2 = list(range(x.ndim)) + perm_2[1] = 2 + perm_2[2] = 1 + mat = x.transpose(perm=perm_2) + return mat + + def extra_repr(self: EulerRotation) -> str: + """Print current transformation.""" + s = super().extra_repr() + ", angles=" + if self.params is None: + s += "undef" + else: + s += f"{self.angles().tolist()!r}" + return s + + +class QuaternionRotation(InvertibleParametricTransform, LinearTransform): + """Quaternion based rotation in 3D.""" + + def __init__( + self: QuaternionRotation, + grid: Grid, + groups: Optional[int] = None, + params: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + ) -> None: + """Initialize transformation parameters. + + Args: + grid: Grid domain on which transformation is defined. + groups: Number of transformations. Must be either 1 or equal to batch size. + params: (nnormalized quaternion as 2-dimensional tensor of ``(N, 4)``. + + """ + if grid.ndim != 3: + raise ValueError("QuaternionRotation() 'grid' must be 3-dimensional") + super().__init__(grid, groups=groups, params=params) + + @property + def data_shape(self: QuaternionRotation) -> list: + """Get shape of transformation parameters tensor.""" + return tuple((4,)) + + @paddle.no_grad() + def reset_parameters(self: QuaternionRotation) -> None: + """Reset transformation parameters.""" + params = self.params + if isinstance(params, paddle.Tensor): + paddle.assign( + paddle.to_tensor( + data=[0, 0, 0, 1], dtype=params.dtype, place=params.place + ), + output=params, + ) + + def quaternion(self: QuaternionRotation) -> paddle.Tensor: + """Get rotation quaternion.""" + params = self.data() + return normalize_quaternion(params) + + @paddle.no_grad() + def quaternion_(self: QuaternionRotation, arg: paddle.Tensor) -> QuaternionRotation: + """Set rotation quaternion.""" + if not isinstance(arg, paddle.Tensor): + raise TypeError("QuaternionRotation.quaternion() 'arg' must be tensor") + shape = self.data_shape + if arg.ndim != len(shape) + 1: + raise ValueError( + f"QuaternionRotation.quaternion() 'arg' must be {len(shape) + 1}-dimensional tensor" + ) + shape = (arg.shape[0],) + self.data_shape + if arg.shape != shape: + raise ValueError( + f"QuaternionRotation.quaternion() 'arg' must have shape {shape!r}" + ) + params = as_float_tensor(arg) + params = normalize_quaternion(params) + self.data_(params) + return self + + def matrix_(self: QuaternionRotation, arg: paddle.Tensor) -> QuaternionRotation: + """Set rotation quaternion from rotation matrix.""" + if not isinstance(arg, paddle.Tensor): + raise TypeError("QuaternionRotation.matrix() 'arg' must be tensor") + if arg.ndim != 3: + raise ValueError( + "QuaternionRotation.matrix() 'arg' must be 3-dimensional tensor" + ) + shape = arg.shape[0], 3, 3 + if arg.shape != shape: + raise ValueError(f"Rotation matrix must have shape {shape!r}") + arg = rotation_matrix_to_quaternion(arg) + return self.quaternion_(arg) + + def tensor(self: QuaternionRotation) -> paddle.Tensor: + """Get tensor representation of this transformation + + Returns: + Batch of homogeneous transformation matrices as tensor of shape ``(N, D, D)``. + + """ + q = self.data() + mat = quaternion_to_rotation_matrix(q) + if self.invert: + x = mat + perm_3 = list(range(x.ndim)) + perm_3[1] = 2 + perm_3[2] = 1 + mat = x.transpose(perm=perm_3) + return mat + + def extra_repr(self: QuaternionRotation) -> str: + """Print current transformation.""" + s = super().extra_repr() + ", q=" + if self.params is None: + s += "undef" + else: + s += f"{self.quaternion().tolist()!r}" + return s + + +class IsotropicScaling(InvertibleParametricTransform, LinearTransform): + """Isotropic scaling.""" + + def __init__( + self: IsotropicScaling, + grid: Grid, + groups: Optional[int] = None, + params: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + ) -> None: + """Initialize transformation parameters. + + Args: + grid: Grid domain on which transformation is defined. + groups: Number of transformations. Must be either 1 or equal to batch size. + params: Isotropic scaling factor as a percentage. This parameterization is + adopted from MIRTK and ensures that rotation angles in degrees and scaling + factors are within a similar range of magnitude which is useful for direct + optimization of these parameters. If parameters are predicted by a callable + ``paddle.nn.Layer``, different output activations may be chosen before + converting these to percentages (i.e., multiplied by 100). + + """ + super().__init__(grid, groups=groups, params=params) + + @property + def data_shape(self: IsotropicScaling) -> list: + """Get shape of transformation parameters tensor.""" + return tuple((1,)) + + @paddle.no_grad() + def reset_parameters(self: IsotropicScaling) -> None: + """Reset transformation parameters.""" + params = self.params + if isinstance(params, paddle.Tensor): + init_Constant = paddle.nn.initializer.Constant(value=1) + init_Constant(params) + + def scales(self: IsotropicScaling) -> paddle.Tensor: + """Get scaling factors.""" + params = self.data() + if self.has_parameters(): + params = params.sub(1).tanh().exp() + return params + + @paddle.no_grad() + def scales_(self: IsotropicScaling, arg: paddle.Tensor) -> IsotropicScaling: + """Set transformation parameters from scaling factors.""" + if not isinstance(arg, paddle.Tensor): + raise TypeError("IsotropicScaling.scales() 'arg' must be tensor") + shape = self.data_shape + if arg.ndim != len(shape) + 1: + raise ValueError( + f"IsotropicScaling.scales() 'arg' must be {len(shape) + 1}-dimensional tensor" + ) + shape = (arg.shape[0],) + shape + if arg.shape != shape: + raise ValueError( + f"IsotropicScaling.scales() 'arg' must have shape {shape!r}" + ) + params = as_float_tensor(arg) + if self.has_parameters(): + params = params.log().atanh().add(1) + if params.isnan().astype("bool").any(): + raise ValueError("IsotropicScaling.scales() 'arg' must be positive") + self.data_(params) + return self + + def tensor(self: IsotropicScaling) -> paddle.Tensor: + """Get tensor representation of this transformation + + Returns: + Batch of homogeneous transformation matrices as tensor of shape ``(N, D, D)``. + + """ + scales = self.scales() + if self.invert: + scales = 1 / scales + return U.scaling_transform( + scales.expand(shape=[tuple(scales.shape)[0], self.ndim]) + ) + + def extra_repr(self: IsotropicScaling) -> str: + """Print current transformation.""" + s = super().extra_repr() + ", scales=" + if self.params is None: + s += "undef" + else: + s += f"{self.scales().tolist()!r}" + return s + + +class AnisotropicScaling(InvertibleParametricTransform, LinearTransform): + """Anisotropic scaling.""" + + def __init__( + self: AnisotropicScaling, + grid: Grid, + groups: Optional[int] = None, + params: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + ) -> None: + """Initialize transformation parameters. + + Args: + grid: Grid domain on which transformation is defined. + groups: Number of transformations. Must be either 1 or equal to batch size. + params: Anisotropic scaling factors as percentages. This parameterization is + adopted from MIRTK and ensures that rotation angles in degrees and scaling + factors are within a similar range of magnitude which is useful for direct + optimization of these parameters. If parameters are predicted by a callable + ``paddle.nn.Layer``, different output activations may be chosen before + converting these to percentages (i.e., multiplied by 100). + + """ + super().__init__(grid, groups=groups, params=params) + + @property + def data_shape(self: AnisotropicScaling) -> list: + """Get shape of transformation parameters tensor.""" + return tuple((self.ndim,)) + + @paddle.no_grad() + def reset_parameters(self: AnisotropicScaling) -> None: + """Reset transformation parameters.""" + params = self.params + if isinstance(params, paddle.Tensor): + init_Constant = paddle.nn.initializer.Constant(value=1) + init_Constant(params) + + def scales(self: AnisotropicScaling) -> paddle.Tensor: + """Get scaling factors.""" + params = self.data() + if self.has_parameters(): + params = params.sub(1.0).tanh().exp() + return params + + @paddle.no_grad() + def scales_(self: AnisotropicScaling, arg: paddle.Tensor) -> AnisotropicScaling: + """Set transformation parameters from scaling factors.""" + if not isinstance(arg, paddle.Tensor): + raise TypeError("AnisotropicScaling.scales() 'arg' must be tensor") + shape = self.data_shape + if arg.ndim != len(shape) + 1: + raise ValueError( + f"AnisotropicScaling.scales() 'arg' must be {len(shape) + 1}-dimensional tensor" + ) + shape = (arg.shape[0],) + shape + if arg.shape != shape: + raise ValueError( + f"AnisotropicScaling.scales() 'arg' must have shape {shape!r}" + ) + params = as_float_tensor(arg) + if self.has_parameters(): + params = params.log().atanh().add(1) + if params.isnan().astype("bool").any(): + raise ValueError("AnisotropicScaling.scales() 'arg' must be positive") + self.data_(params) + return self + + def tensor(self: AnisotropicScaling) -> paddle.Tensor: + """Get tensor representation of this transformation + + Returns: + Batch of homogeneous transformation matrices as tensor of shape ``(N, D, D)``. + + """ + scales = self.scales() + if self.invert: + scales = 1 / scales + return U.scaling_transform(scales) + + def extra_repr(self: AnisotropicScaling) -> str: + """Print current transformation.""" + s = super().extra_repr() + ", scales=" + if self.params is None: + s += "undef" + else: + s += f"{self.scales().tolist()!r}" + return s + + +class Shearing(InvertibleParametricTransform, LinearTransform): + """Shear transformation.""" + + def __init__( + self: Shearing, + grid: Grid, + groups: Optional[int] = None, + params: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + ) -> None: + """Initialize transformation parameters. + + Args: + grid: Grid domain on which transformation is defined. + groups: Number of transformations. Must be either 1 or equal to batch size. + params: Shearing angles in degrees. This parameterization is adopted from MIRTK + and ensures that rotation angles and scaling factors (percentage) are within + a similar range of magnitude which is useful for direct optimization of these + parameters. If parameters are predicted by a callable ``paddle.nn.Layer``, + different output activations may be chosen before converting these to degrees. + + """ + if grid.ndim < 2 or grid.ndim > 3: + raise ValueError("Shearing() 'grid' must be 2- or 3-dimensional'") + super().__init__(grid, groups=groups, params=params) + + @property + def data_shape(self: Shearing) -> list: + """Get shape of transformation parameters tensor.""" + return tuple((self.nangles,)) + + @property + def nangles(self: Shearing) -> int: + """Number of shear angles.""" + return 1 if self.ndim == 2 else self.ndim + + def angles(self: Shearing) -> paddle.Tensor: + """Get shear angles in radians.""" + params = self.data() + if self.has_parameters(): + params = params.tanh().mul(math.pi / 4) + return params + + @paddle.no_grad() + def angles_(self: Shearing, arg: paddle.Tensor) -> Shearing: + """Set transformation parameters from shear angles in radians.""" + if not isinstance(arg, paddle.Tensor): + raise TypeError("Shearing.angles_() 'arg' must be tensor") + shape = self.data_shape + if arg.ndim != len(shape) + 1: + raise ValueError( + f"Shearing.angles_() 'arg' must be {len(shape) + 1}-dimensional tensor" + ) + shape = (arg.shape[0],) + shape + if arg.shape != shape: + raise ValueError(f"Shearing.angles_() 'arg' must have shape {shape!r}") + params = as_float_tensor(arg) + if self.has_parameters(): + params = params.mul(4 / math.pi).atanh() + if params.isnan().astype("bool").any(): + raise ValueError("Shear 'angles' must be in range [-pi/4, pi/4]") + self.data_(params) + return self + + def tensor(self: Shearing) -> paddle.Tensor: + """Get tensor representation of this transformation + + Returns: + Batch of homogeneous transformation matrices as tensor of shape ``(N, D, D)``. + + """ + mat = U.shear_matrix(self.angles()) + if self.invert: + mat = paddle.linalg.inv(x=mat) + return mat + + def extra_repr(self: Shearing) -> str: + """Print current transformation.""" + s = super().extra_repr() + "angles=" + if self.params is None: + s += "undef" + else: + s += f"{self.angles().tolist()!r}" + return s + + +class RigidTransform(SequentialTransform): + """Rigid transformation.""" + + def __init__( + self: RigidTransform, + grid: Grid, + groups: Optional[int] = None, + rotation: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + translation: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + ) -> None: + """Initialize transformation parameters. + + Args: + grid: Domain with respect to which transformation is defined. + groups: Number of transformations ``N``. + rotation: Parameters of ``EulerRotation``. + translation: Parameters of ``Translation``. + + """ + transforms = OrderedDict() + transforms["rotation"] = EulerRotation(grid, groups=groups, params=rotation) + transforms["translation"] = Translation(grid, groups=groups, params=translation) + super().__init__(transforms) + + @property + def rotation(self) -> EulerRotation: + return self._transforms["rotation"] + + @property + def translation(self) -> Translation: + return self._transforms["translation"] + + +class RigidQuaternionTransform(SequentialTransform): + """Rigid transformation parameterized by rotation quaternion.""" + + def __init__( + self: RigidQuaternionTransform, + grid: Grid, + groups: Optional[int] = None, + rotation: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + translation: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + ) -> None: + """Initialize transformation parameters. + + Args: + grid: Domain with respect to which transformation is defined. + groups: Number of transformations ``N``. + rotation: Parameters of ``QuaternionRotation``. + translation: Parameters of ``Translation``. + + """ + transforms = OrderedDict() + transforms["rotation"] = QuaternionRotation( + grid, groups=groups, params=rotation + ) + transforms["translation"] = Translation(grid, groups=groups, params=translation) + super().__init__(transforms) + + @property + def rotation(self) -> QuaternionRotation: + return self._transforms["rotation"] + + @property + def translation(self) -> Translation: + return self._transforms["translation"] + + +class SimilarityTransform(SequentialTransform): + """Similarity transformation with isotropic scaling.""" + + def __init__( + self: SimilarityTransform, + grid: Grid, + groups: Optional[int] = None, + scaling: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + rotation: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + translation: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + ) -> None: + """Initialize transformation parameters. + + Args: + grid: Domain with respect to which transformation is defined. + groups: Number of transformations ``N``. + scaling: Parameters of ``IsotropicScaling``. + rotation: Parameters of ``EulerRotation``. + translation: Parameters of ``Translation``. + + """ + transforms = OrderedDict() + transforms["scaling"] = IsotropicScaling(grid, groups=groups, params=scaling) + transforms["rotation"] = EulerRotation(grid, groups=groups, params=rotation) + transforms["translation"] = Translation(grid, groups=groups, params=translation) + super().__init__(transforms) + + @property + def rotation(self) -> EulerRotation: + return self._transforms["rotation"] + + @property + def scaling(self) -> IsotropicScaling: + return self._transforms["scaling"] + + @property + def translation(self) -> Translation: + return self._transforms["translation"] + + +class AffineTransform(SequentialTransform): + """Affine transformation without shearing.""" + + def __init__( + self: AffineTransform, + grid: Grid, + groups: Optional[int] = None, + scaling: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + rotation: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + translation: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + ) -> None: + """Initialize transformation parameters. + + Args: + grid: Domain with respect to which transformation is defined. + groups: Number of transformations ``N``. + scaling: Parameters of ``AnisotropicScaling``. + rotation: Parameters of ``EulerRotation``. + translation: Parameters of ``Translation``. + + """ + transforms = OrderedDict() + transforms["scaling"] = AnisotropicScaling(grid, groups=groups, params=scaling) + transforms["rotation"] = EulerRotation(grid, groups=groups, params=rotation) + transforms["translation"] = Translation(grid, groups=groups, params=translation) + super().__init__(transforms) + + @property + def rotation(self) -> EulerRotation: + return self._transforms["rotation"] + + @property + def scaling(self) -> AnisotropicScaling: + return self._transforms["scaling"] + + @property + def translation(self) -> Translation: + return self._transforms["translation"] + + +class FullAffineTransform(SequentialTransform): + """Affine transformation including shearing.""" + + def __init__( + self: FullAffineTransform, + grid: Grid, + groups: Optional[int] = None, + scaling: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + shearing: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + rotation: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + translation: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + ) -> None: + """Initialize transformation parameters. + + Args: + grid: Domain with respect to which transformation is defined. + groups: Number of transformations ``N``. + scaling: Parameters of ``AnisotropicScaling``. + shearing: Parameters of ``Shearing``. + rotation: Parameters of ``EulerRotation``. + translation: Parameters of ``Translation``. + + """ + transforms = OrderedDict() + transforms["scaling"] = AnisotropicScaling(grid, groups=groups, params=scaling) + transforms["shearing"] = Shearing(grid, groups=groups, params=shearing) + transforms["rotation"] = EulerRotation(grid, groups=groups, params=rotation) + transforms["translation"] = Translation(grid, groups=groups, params=translation) + super().__init__(transforms) + + @property + def rotation(self) -> EulerRotation: + return self._transforms["rotation"] + + @property + def scaling(self) -> AnisotropicScaling: + return self._transforms["scaling"] + + @property + def shearing(self) -> Shearing: + return self._transforms["shearing"] + + @property + def translation(self) -> Translation: + return self._transforms["translation"] diff --git a/jointContribution/HighResolution/deepali/spatial/nonrigid.py b/jointContribution/HighResolution/deepali/spatial/nonrigid.py new file mode 100644 index 0000000000..1fffb273dc --- /dev/null +++ b/jointContribution/HighResolution/deepali/spatial/nonrigid.py @@ -0,0 +1,283 @@ +from __future__ import annotations + +import math +from copy import copy as shallow_copy +from typing import Callable +from typing import Optional +from typing import Sequence +from typing import TypeVar +from typing import Union +from typing import cast + +import paddle + +from ..core import functional as U +from ..core.grid import Axes +from ..core.grid import Grid +from ..data.flow import FlowFields +from ..modules import ExpFlow +from .base import NonRigidTransform +from .parametric import ParametricTransform + +TDenseVectorFieldTransform = TypeVar( + "TDenseVectorFieldTransform", bound="DenseVectorFieldTransform" +) + + +class DenseVectorFieldTransform(ParametricTransform, NonRigidTransform): + """Dense vector field transformation with linear interpolation at non-grid point locations.""" + + def __init__( + self, + grid: Grid, + groups: Optional[int] = None, + params: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + stride: Optional[Union[float, Sequence[float]]] = None, + resize: bool = True, + ) -> None: + """Initialize transformation parameters. + + Args: + grid: Grid domain on which transformation is defined. + groups: Number of transformations. A given image batch can either be deformed by a + single transformation, or a separate transformation for each image in the batch, e.g., + for group-wise or batched registration. The default is one transformation for all images + in the batch, or the batch length of the ``params`` tensor if provided. + params: Initial parameters. If a tensor is given, it is only registered as optimizable module + parameters when of type ``paddle.nn.Parameter``. When a callable is given instead, it will be + called by ``self.update()`` with arguments set and given by ``self.condition()``. When a boolean + argument is given, a new zero-initialized tensor is created. If ``True``, this tensor is registered + as optimizable module parameter. + stride: Spacing between vector field grid points in units of input ``grid`` points. + Can be used to subsample the dense vector field with respect to the image grid of the fixed target + image. When ``grid.align_corners() is True``, the corner points of the ``grid`` and the resampled + vector field grid are aligned. Otherwise, the edges of the grid domains are aligned. + resize: Whether to resize vector field during transformation update. If ``True``, the buffered vector + field ``u`` (and ``v`` if applicable) is resized to match the image ``grid`` size. This means that + transformation constraints defined on these resized vector fields, such as those based on finite + differences, are evaluated at the image grid resolution rather than the resolution of the underlying + vector field parameterization. This influences the scale at which these constraints are imposed. + + """ + if stride is None: + stride = 1 + if isinstance(stride, (int, float)): + stride = (stride,) * grid.ndim + if len(stride) != grid.ndim: + raise ValueError( + f"{type(self).__name__}() 'stride' must be float or Sequence of length {grid.ndim}" + ) + self.stride = tuple(float(s) for s in stride) + self._resize = resize + super().__init__(grid, groups=groups, params=params) + + @property + def data_shape(self) -> list: + """Get shape of transformation parameters tensor.""" + grid = self.grid() + shape = self.data_grid_shape(grid) + return tuple((grid.ndim,) + shape) + + def data_grid( + self, + grid: Optional[Grid] = None, + stride: Optional[Union[float, Sequence[float]]] = None, + ) -> Grid: + if grid is None: + grid = self.grid() + if stride is None: + stride = self.stride + return grid.reshape(self.data_grid_shape(grid, stride)) + + def data_grid_shape( + self, + grid: Optional[Grid] = None, + stride: Optional[Union[float, Sequence[float]]] = None, + ) -> list: + if grid is None: + grid = self.grid() + if stride is None: + stride = self.stride + return tuple(int(math.ceil(n / s)) for n, s in zip(tuple(grid.shape), stride)) + + @paddle.no_grad() + def grid_( + self: TDenseVectorFieldTransform, grid: Grid + ) -> TDenseVectorFieldTransform: + """Set sampling grid of transformation domain and codomain. + + If ``self.params`` is a callable, only the grid attribute is updated, and + the callable must return a tensor of matching size upon next evaluation. + + """ + params = self.params + if isinstance(params, paddle.Tensor): + prev_grid = self._grid + grid_axes = Axes.from_grid(grid) + flow_axes = self.axes() + flow_grid = prev_grid.reshape(tuple(params.shape)[2:]) + flow = FlowFields(params, grid=flow_grid, axes=flow_axes) + flow = flow.sample(shape=self.data_grid(grid)) + flow = flow.axes(grid_axes) + super().grid_(grid) + try: + self.data_(flow.tensor()) + except Exception: + self._grid = prev_grid + raise + else: + super().grid_(grid) + return self + + def evaluate(self, resize: Optional[bool] = None) -> paddle.Tensor: + """Update buffered displacement vector field.""" + if resize is None: + resize = self._resize + u = self.data() + u = u.view(*tuple(u.shape)) + if resize: + align_corners = self.align_corners() + grid_shape = tuple(self.grid().shape) + u = U.grid_reshape(u, grid_shape, align_corners=align_corners) + return u + + +class DisplacementFieldTransform(DenseVectorFieldTransform): + """Dense displacement field transformation model.""" + + def fit(self, flow: FlowFields, **kwargs) -> DisplacementFieldTransform: + """Fit transformation to a given flow field. + + Args: + flow: Flow fields to approximate. + kwargs: Optional keyword arguments are ignored. + + Returns: + Reference to this transformation. + + Raises: + RuntimeError: When this transformation has no optimizable parameters. + + """ + params = self.params + if params is None: + raise AssertionError( + f"{type(self).__name__}.data() 'params' must be set first" + ) + grid = self.grid() + if not callable(params): + grid = self.grid().resize(self.data_shape[:1:-1]) + flow = flow.to(self.device) + flow = flow.sample(shape=grid) + flow = flow.axes(Axes.from_grid(grid)) + if callable(params): + self._fit(flow, **kwargs) + else: + self.data_(flow.tensor()) + return self + + def update(self) -> DisplacementFieldTransform: + """Update buffered displacement vector field.""" + super().update() + u = self.evaluate() + self.register_buffer(name="u", tensor=u, persistable=False) + return self + + +class StationaryVelocityFieldTransform(DenseVectorFieldTransform): + """Dense stationary velocity field transformation.""" + + def __init__( + self, + grid: Grid, + groups: Optional[int] = None, + params: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + stride: Optional[Union[float, Sequence[float]]] = None, + resize: bool = True, + scale: Optional[float] = None, + steps: Optional[int] = None, + ) -> None: + """Initialize transformation parameters. + + Args: + grid: Grid on which to sample velocity field vectors. + groups: Number of velocity fields. A given image batch can either be deformed by a + single displacement field, or one separate displacement field for each image in the + batch, e.g., for group-wise or batched registration. The default is one displacement + field for all images in the batch, or the batch length N of ``params`` if provided. + params: Initial parameters of velocity fields of shape ``(N, C, ...X)``, where N must match + the value of ``groups``, and vector components are the image channels in the order x, y, z. + Note that a tensor is only registered as optimizable module parameters when of type + ``paddle.nn.Parameter``. When a callable is given instead, it will be called each time the + model parameters are accessed with the arguments set and returned by ``self.condition()``. + When a boolean argument is given, a new zero-initialized tensor is created. If ``True``, + it is registered as optimizable parameter. + stride: Spacing between vector field grid points in units of input ``grid`` points. + Can be used to subsample the dense vector field with respect to the image grid of the fixed target + image. When ``grid.align_corners() is True``, the corner points of the ``grid`` and the resampled + vector field grid are aligned. Otherwise, the edges of the grid domains are aligned. + resize: Whether to resize vector field during transformation update. If ``True``, the buffered vector + fields ``v`` and ``u`` are resized to match the image ``grid`` size. This means that transformation + constraints defined on these resized vector fields, such as those based on finite differences, are + evaluated at the image grid resolution rather than the resolution of the underlying vector field + parameterization. This influences the scale at which these constraints are imposed. + scale: Constant scaling factor of velocity fields. + steps: Number of scaling and squaring steps. + + """ + super().__init__( + grid, groups=groups, params=params, stride=stride, resize=resize + ) + self.exp = ExpFlow(scale=scale, steps=steps, align_corners=grid.align_corners()) + + def grid_(self, grid: Grid) -> StationaryVelocityFieldTransform: + """Set sampling grid of transformation domain and codomain.""" + super().grid_(grid) + self.exp.align_corners = grid.align_corners() + return self + + def inverse( + self, link: bool = False, update_buffers: bool = False + ) -> StationaryVelocityFieldTransform: + """Get inverse of this transformation. + + Args: + link: Whether to inverse transformation keeps a reference to this transformation. + If ``True``, the ``update()`` function of the inverse function will not recompute + shared parameters, e.g., parameters obtained by a callable neural network, but + directly access the parameters from this transformation. Note that when ``False``, + the inverse transformation will still share parameters, modules, and buffers with + this transformation, but these shared tensors may be replaced by a call of ``update()`` + (which is implicitly called as pre-forward hook when ``__call__()`` is invoked). + update_buffers: Whether buffers of inverse transformation should be update after creating + the shallow copy. If ``False``, the ``update()`` function of the returned inverse + transformation has to be called before it is used. + + Returns: + Shallow copy of this transformation with ``exp`` module which uses negative scaling factor + to scale and square the stationary velocity field to computes the inverse displacement field. + + """ + inv = shallow_copy(self) + if link: + inv.link_(self) + inv.exp = cast(ExpFlow, self.exp).inverse() + if update_buffers: + v = getattr(inv, "v", None) + if v is not None: + u = inv.exp(v) + inv.register_buffer(name="u", tensor=u, persistable=False) + return inv + + def update(self) -> StationaryVelocityFieldTransform: + """Update buffered velocity and displacement vector fields.""" + super().update() + v = self.evaluate() + u = self.exp(v) + self.register_buffer(name="v", tensor=v, persistable=False) + self.register_buffer(name="u", tensor=u, persistable=False) + return self diff --git a/jointContribution/HighResolution/deepali/spatial/parametric.py b/jointContribution/HighResolution/deepali/spatial/parametric.py new file mode 100644 index 0000000000..aa939ca39b --- /dev/null +++ b/jointContribution/HighResolution/deepali/spatial/parametric.py @@ -0,0 +1,388 @@ +from __future__ import annotations + +from copy import copy as shallow_copy +from typing import Callable +from typing import Optional +from typing import Union +from typing import cast +from typing import overload + +import paddle + +from ..core.grid import Grid +from .base import ReadOnlyParameters +from .base import TSpatialTransform + + +class ParametricTransform: + """Mix-in for spatial transformations that have (optimizable) parameters.""" + + def __init__( + self: Union[TSpatialTransform, ParametricTransform], + grid: Grid, + groups: Optional[int] = None, + params: Optional[Union[bool, paddle.Tensor, Callable]] = True, + ) -> None: + """Initialize transformation parameters. + + Args: + grid: Grid domain on which transformation is defined. + groups: Number of transformations. A given image batch can either be deformed by a + single transformation, or a separate transformation for each image in the batch, e.g., + for group-wise or batched registration. The default is one transformation for all images + in the batch, or the batch length of the ``params`` tensor if provided. + params: Initial parameters. If a tensor is given, it is only registered as optimizable module + parameters when of type ``paddle.base.framework.EagerParamBase``. When a callable is given instead, it will be + called by ``self.update()`` with ``SpatialTransform.condition()`` arguments. When a boolean + argument is given, a new zero-initialized tensor is created. If ``True``, this tensor is + registered as optimizable module parameter. If ``None``, parameters must be set using + ``self.data()`` or ``self.data_()`` before this transformation is evaluated. + + """ + if isinstance(params, paddle.Tensor) and params.ndim < 2: + raise ValueError( + f"{type(self).__name__}() 'params' tensor must be at least 2-dimensional" + ) + super().__init__(grid) + if groups is None: + groups = tuple(params.shape)[0] if isinstance(params, paddle.Tensor) else 1 + shape = (groups,) + self.data_shape + if params is None: + self.params = None + elif isinstance(params, bool): + data = paddle.empty(shape=shape, dtype="float32") + if params: + out_0 = paddle.create_parameter( + shape=data.shape, + dtype=data.numpy().dtype, + default_initializer=paddle.nn.initializer.Assign(data), + ) + out_0.stop_gradient = not True + self.params = out_0 + else: + self.register_buffer(name="params", tensor=data, persistable=True) + self.reset_parameters() + elif isinstance(params, paddle.Tensor): + if shape and tuple(params.shape) != shape: + raise ValueError( + f"{type(self).__name__}() 'params' must be tensor of shape {shape!r}" + ) + if isinstance(params, paddle.base.framework.EagerParamBase): + self.params = params + else: + self.register_buffer(name="params", tensor=params, persistable=True) + elif callable(params): + self.params = params + self.register_buffer( + name="p", tensor=paddle.empty(shape=shape), persistable=False + ) + self.reset_parameters() + else: + raise TypeError( + f"{type(self).__name__}() 'params' must be bool, Callable, paddle.Tensor, or None" + ) + + def has_parameters(self) -> bool: + """Whether this transformation has optimizable parameters.""" + return isinstance(self.params, paddle.base.framework.EagerParamBase) + + @paddle.no_grad() + def reset_parameters(self: Union[TSpatialTransform, ParametricTransform]) -> None: + """Reset transformation parameters.""" + params = self.params + if params is None: + return + if callable(params): + params = self.p + init_Constant = paddle.nn.initializer.Constant(value=0.0) + init_Constant(params) + self.clear_buffers() + + @property + def data_shape(self) -> list: + """Get required shape of transformation parameters tensor, excluding batch dimension.""" + raise NotImplementedError(f"{type(self).__name__}.data_shape") + + @overload + def data(self) -> paddle.Tensor: + """Get (buffered) transformation parameters.""" + ... + + @overload + def data(self: TSpatialTransform, arg: paddle.Tensor) -> TSpatialTransform: + """Get shallow copy with specified parameters.""" + ... + + def data( + self: Union[TSpatialTransform, ParametricTransform], + arg: Optional[paddle.Tensor] = None, + ) -> Union[TSpatialTransform, paddle.Tensor]: + """Get transformation parameters or shallow copy with specified parameters, respectively.""" + params = self.params + if arg is None: + if params is None: + raise AssertionError( + f"{type(self).__name__}.data() 'params' must be set first" + ) + if callable(params): + params = getattr(self, "p") + return params + if not isinstance(arg, paddle.Tensor): + raise TypeError(f"{type(self).__name__}.data() 'arg' must be tensor") + shape = self.data_shape + if arg.ndim != len(shape) + 1: + raise ValueError( + f"{type(self).__name__}.data() 'arg' must be {len(shape) + 1}-dimensional tensor" + ) + shape = (arg.shape[0],) + shape + if arg.shape != shape: + raise ValueError( + f"{type(self).__name__}.data() 'arg' must have shape {shape!r}" + ) + copy = shallow_copy(self) + if callable(params): + delattr(copy, "p") + if isinstance(params, paddle.base.framework.EagerParamBase) and not isinstance( + arg, paddle.base.framework.EagerParamBase + ): + out_1 = paddle.create_parameter( + shape=arg.shape, + dtype=arg.numpy().dtype, + default_initializer=paddle.nn.initializer.Assign(arg), + ) + out_1.stop_gradient = not not params.stop_gradient + copy.params = out_1 + else: + copy.params = arg + copy.clear_buffers() + return copy + + def data_( + self: Union[TSpatialTransform, ParametricTransform], arg: paddle.Tensor + ) -> TSpatialTransform: + """Replace transformation parameters. + + Args: + arg: paddle.Tensor of transformation parameters with shape matching ``self.data_shape``, + excluding the batch dimension whose size may be different from the current tensor. + + Returns: + Reference to this in-place modified transformation module. + + Raises: + ReadOnlyParameters: When ``self.params`` is a callable which provides the parameters. + + """ + params = self.params + if callable(params): + raise ReadOnlyParameters( + f"Cannot replace parameters, try {type(self).__name__}.data() instead." + ) + if not isinstance(arg, paddle.Tensor): + raise TypeError( + f"{type(self).__name__}.data_() 'arg' must be tensor, not {type(arg)}" + ) + shape = self.data_shape + if arg.ndim != len(shape) + 1: + raise ValueError( + f"{type(self).__name__}.data_() 'arg' must be {len(shape) + 1}-dimensional tensor, but arg.ndim={arg.ndim}" + ) + shape = (arg.shape[0],) + shape + if tuple(arg.shape) != tuple(shape): + raise ValueError( + f"{type(self).__name__}.data_() 'arg' must have shape {shape!r}, not {arg.shape!r}" + ) + if isinstance(params, paddle.base.framework.EagerParamBase) and not isinstance( + arg, paddle.base.framework.EagerParamBase + ): + out_2 = paddle.create_parameter( + shape=arg.shape, + dtype=arg.numpy().dtype, + default_initializer=paddle.nn.initializer.Assign(arg), + ) + out_2.stop_gradient = not not params.stop_gradient + self.params = out_2 + else: + self.params = arg + self.clear_buffers() + return self + + def _data(self: Union[TSpatialTransform, ParametricTransform]) -> paddle.Tensor: + """Get most recent transformation parameters. + + When transformation parameters are obtained from a callable, this function invokes + this callable with ``self.condition()`` as arguments if set, and returns the parameter + obtained returned by this callable function or module. Otherwise, it simply returns a + reference to the ``self.params`` tensor. + + Returns: + Reference to ``self.params`` tensor or callable return value, respectively. + + """ + params = self.params + if params is None: + raise AssertionError( + f"{type(self).__name__}._data() 'params' must be set first" + ) + if isinstance(params, type(self)): + assert isinstance(params, ParametricTransform) + return cast(ParametricTransform, params).data() + if callable(params): + args, kwargs = self.condition() + pred = params(*args, **kwargs) + if not isinstance(pred, paddle.Tensor): + raise TypeError(f"{type(self).__name__}.params() value must be tensor") + shape = self.data_shape + if pred.ndim != len(shape) + 1: + raise ValueError( + f"{type(self).__name__}.params() tensor must be {len(shape) + 1}-dimensional" + ) + shape = (tuple(pred.shape)[0],) + shape + if tuple(pred.shape) != shape: + raise ValueError( + f"{type(self).__name__}.params() tensor must have shape {shape!r}" + ) + return pred + assert isinstance(params, paddle.Tensor) + return params + + def link( + self: Union[TSpatialTransform, ParametricTransform], other: TSpatialTransform + ) -> TSpatialTransform: + """Make shallow copy of this transformation which is linked to another instance.""" + return shallow_copy(self).link_(other) + + def link_( + self: Union[TSpatialTransform, ParametricTransform], + other: Union[TSpatialTransform, ParametricTransform], + ) -> TSpatialTransform: + """Link this transformation to another of the same type. + + This transformation is modified to use a reference to the given transformation. After linking, + the transformation will not have parameters on its own, and its ``update()`` function will not + recompute possibly previously shared parameters, e.g., parameters obtained by a callable neural + network. Instead, it directly copies the parameters from the linked transformation. + + Args: + other: Other transformation of the same type as ``self`` to which this transformation is linked. + + Returns: + Reference to this transformation. + + """ + if other is self: + raise ValueError( + f"{type(self).__name__}.link() cannot link tranform to itself" + ) + if type(self) != type(other): + raise TypeError( + f"{type(self).__name__}.link() 'other' must be of the same type, got {type(other).__name__}" + ) + self.params = other + if not hasattr(self, "p"): + if other.params is None: + p = paddle.empty(shape=self.data_shape) + else: + p = other.data() + self.register_buffer(name="p", tensor=p, persistable=False) + if other.params is None: + self.reset_parameters() + return self + + def unlink( + self: Union[TSpatialTransform, ParametricTransform] + ) -> TSpatialTransform: + """Make a shallow copy of this transformation with parameters set to ``None``.""" + return shallow_copy(self).unlink_() + + def unlink_( + self: Union[TSpatialTransform, ParametricTransform] + ) -> TSpatialTransform: + """Resets transformation parameters to ``None``.""" + self.params = None + if hasattr(self, "p"): + delattr(self, "p") + return self + + def update( + self: Union[TSpatialTransform, ParametricTransform] + ) -> TSpatialTransform: + """Update buffered data such as predicted parameters, velocities, and/or displacements.""" + if hasattr(self, "p"): + p = self._data() + self.register_buffer(name="p", tensor=p, persistable=False) + super().update() + return self + + +class InvertibleParametricTransform(ParametricTransform): + """Mix-in for spatial transformations that support on-demand inversion.""" + + def __init__( + self, + grid: Grid, + groups: Optional[int] = None, + params: Optional[ + Union[bool, paddle.Tensor, Callable[..., paddle.Tensor]] + ] = True, + invert: bool = False, + ) -> None: + """Initialize transformation parameters. + + Args: + grid: Grid domain on which transformation is defined. + groups: Number of transformations. A given image batch can either be deformed by a + single transformation, or a separate transformation for each image in the batch, e.g., + for group-wise or batched registration. The default is one transformation for all images + in the batch, or the batch length of the ``params`` tensor if provided. + params: Initial parameters. If a tensor is given, it is only registered as optimizable module + parameters when of type ``paddle.base.framework.EagerParamBaser``. When a callable is given instead, it will be + called by ``self.update()`` with ``self.condition()`` arguments. When a boolean argument is + given, a new zero-initialized tensor is created. If ``True``, this tensor is registered as + optimizable module parameter. + invert: Whether ``params`` correspond to the inverse transformation. When this flag is ``True``, + the ``self.tensor()`` and related methods return the transformation corresponding to the + inverse of the transformations with the given ``params``. For example in case of a rotation, + the rotation matrix is first constructed from the rotation parameters (e.g., Euler angles), + and then transposed if ``self.invert == True``. In general, inversion of linear transformations + and non-rigid transformations parameterized by velocity fields can be done efficiently on-the-fly. + + """ + super().__init__(grid, groups=groups, params=params) + self.invert = bool(invert) + + def inverse( + self: Union[TSpatialTransform, InvertibleParametricTransform], + link: bool = False, + update_buffers: bool = False, + ) -> TSpatialTransform: + """Get inverse of this transformation. + + Args: + link: Whether the inverse transformation keeps a reference to this transformation. + If ``True``, the ``update()`` function of the inverse function will not recompute + shared parameters (e.g., parameters obtained by a callable neural network), but + directly access the parameters from this transformation. Note that when ``False``, + the inverse transformation will still share parameters, modules, and buffers with + this transformation, but these shared tensors may be replaced by a call of ``update()`` + (which is implicitly called as pre-forward hook when ``__call__()`` is invoked). + update_buffers: Whether buffers of inverse transformation should be updated after creating + the shallow copy. If ``False``, the ``update()`` function of the returned inverse + transformation has to be called before it is used. + + Returns: + Shallow copy of this transformation which computes and applies the inverse transformation. + The inverse transformation will share the parameters with this transformation. + + """ + inv = shallow_copy(self) + if link: + inv.link_(self) + inv.invert = not self.invert + return inv + + def extra_repr( + self: Union[TSpatialTransform, InvertibleParametricTransform] + ) -> str: + """Print current transformation.""" + return super().extra_repr() + f", invert={self.invert}" diff --git a/jointContribution/HighResolution/deepali/spatial/transformer.py b/jointContribution/HighResolution/deepali/spatial/transformer.py new file mode 100644 index 0000000000..869932e6f3 --- /dev/null +++ b/jointContribution/HighResolution/deepali/spatial/transformer.py @@ -0,0 +1,302 @@ +from __future__ import annotations + +from copy import copy as shallow_copy +from typing import Dict +from typing import Optional +from typing import Tuple +from typing import TypeVar +from typing import Union +from typing import cast +from typing import overload + +import paddle + +from ..core.enum import PaddingMode +from ..core.enum import Sampling +from ..core.grid import Axes +from ..core.grid import Grid +from ..core.types import Scalar +from ..modules import SampleImage +from .base import SpatialTransform + +TSpatialTransformer = TypeVar("TSpatialTransformer", bound="SpatialTransformer") + + +class SpatialTransformer(paddle.nn.Layer): + """Spatially transform input data. + + A :class:`.SpatialTransformer` applies a :class:`.SpatialTransform` to a given input. How the spatial + transformation is applied to produce a transformed output is determined by the type of spatial transformer. + + The forward method of a spatial transformer invokes the spatial transform as a functor such that any registered + forward pre- and post-hooks are executed as part of the spatial transform evaluation. This includes in particular + the :meth:`.SpatialTransform.update` function if it is registered as a forward pre-hook. Note that this hook is by + default installed during initialization of a spatial transform. When the update of spatial transform parameters, + which may be inferred by a neural network, is done explicitly by the application, use :meth:`.SpatialTransform.remove_update_hook` + to remove this forward pre-hook before subsequent evaluations of the spatial transform. When doing so, ensure to + update the parameters when necessary using either :meth:`.SpatialTransformer.update` or :meth:`.SpatialTransform.update`. + + """ + + def __init__(self, transform: SpatialTransform) -> None: + """Initialize spatial transformer. + + Args: + transform: Spatial coordinate transformation. + + """ + if not isinstance(transform, SpatialTransform): + raise TypeError( + f"{type(self).__name__}() requires 'transform' of type SpatialTransform" + ) + super().__init__() + self._transform = transform + + @property + def transform(self) -> SpatialTransform: + """Spatial grid transformation.""" + return self._transform + + @overload + def condition(self) -> Tuple[tuple, dict]: + """Get arguments on which transformation is conditioned. + + Returns: + args: Positional arguments. + kwargs: Keyword arguments. + + """ + ... + + @overload + def condition(self: TSpatialTransformer, *args, **kwargs) -> TSpatialTransformer: + """Get new transformation which is conditioned on the specified arguments.""" + ... + + def condition( + self: TSpatialTransformer, *args, **kwargs + ) -> Union[TSpatialTransformer, Tuple[tuple, dict]]: + """Get or set data tensors and parameters on which transformation is conditioned.""" + if args: + return shallow_copy(self).condition_(*args) + return self._transform.condition() + + def condition_(self: TSpatialTransformer, *args, **kwargs) -> TSpatialTransformer: + """Set data tensors and parameters on which this transformation is conditioned.""" + self._transform.condition_(*args, **kwargs) + return self + + def update(self: TSpatialTransformer) -> TSpatialTransformer: + """Update internal state of spatial transformation (cf. :meth:`.SpatialTransform.update`).""" + self._transform.update() + return self + + +class ImageTransformer(SpatialTransformer): + """Spatially transform an image. + + The :class:`.ImageTransformer` applies a :class:`.SpatialTransform` to the sampling grid + points of the target domain, optionally followed by linear transformation from target to + source domain, and samples the input image of shape ``(N, C, ..., X)`` at these deformed + source grid points. If the spatial transformation is non-rigid, this is also commonly + referred to as warping the input image. + + Note that the :meth:`.ImageTransformer.forward` method invokes the spatial transform as + functor, i.e., it triggers any pre-forward and post-forward hooks that are registered with + the spatial transform when evaluating it. This in particular includes the forward pre-hook + that invokes :meth:`.SpatialTransform.update` (cf. :class:`.SpatialTransformer`). + + """ + + def __init__( + self, + transform: SpatialTransform, + target: Optional[Grid] = None, + source: Optional[Grid] = None, + sampling: Union[Sampling, str] = Sampling.LINEAR, + padding: Union[PaddingMode, str, Scalar] = PaddingMode.BORDER, + align_centers: bool = False, + flip_coords: bool = False, + ) -> None: + """Initialize spatial image transformer. + + Args: + transform: Spatial coordinate transformation which is applied to ``target`` grid points. + target: Sampling grid of output images. If ``None``, use ``transform.axes()``. + source: Sampling grid of input images. If ``None``, use ``target``. + sampling: Image interpolation mode. + padding: Image extrapolation mode or scalar out-of-domain value. + align_centers: Whether to implicitly align the ``target`` and ``source`` centers. + If ``True``, only the affine component of the target to source transformation + is applied after the spatial grid ``transform``. If ``False``, also the + translation of grid center points is considered. + flip_coords: Whether spatial transformation applies to flipped grid point coordinates + in the order (z, y, x). The default is grid point coordinates in the order (x, y, z). + + """ + super().__init__(transform) + if target is None: + target = transform.grid() + if source is None: + source = target + if not isinstance(target, Grid): + raise TypeError(f"{type(self).__name__}() 'target' must be of type Grid") + if not isinstance(source, Grid): + raise TypeError(f"{type(self).__name__}() 'source' must be of type Grid") + if not transform.grid().same_domain_as(target): + raise ValueError( + f"{type(self).__name__}() 'target' and 'transform' grid must define the same domain" + ) + device = transform.place + sampler = SampleImage( + target=target, + source=source, + sampling=sampling, + padding=padding, + align_centers=align_centers, + ) + self._sample = sampler.to(device) + grid_coords = target.coords(flip=flip_coords, device=device).unsqueeze(axis=0) + self.register_buffer(name="grid_coords", tensor=grid_coords, persistable=False) + self.flip_coords = bool(flip_coords) + + @property + def sample(self) -> SampleImage: + """Source image sampler.""" + return self._sample + + def target_grid(self) -> Grid: + """Sampling grid of output images.""" + return self._sample.target_grid() + + def source_grid(self) -> Grid: + """Sampling grid of input images.""" + return self._sample.source_grid() + + def align_centers(self) -> bool: + """Whether grid center points are implicitly aligned.""" + return self._sample.align_centers() + + @overload + def forward(self, data: paddle.Tensor) -> paddle.Tensor: + """Sample batch of images at spatially transformed target grid points.""" + ... + + @overload + def forward( + self, data: paddle.Tensor, mask: paddle.Tensor + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + """Sample batch of masked images at spatially transformed target grid points.""" + ... + + @overload + def forward( + self, data: Dict[str, Union[paddle.Tensor, Grid]] + ) -> Dict[str, Union[paddle.Tensor, Grid]]: + """Sample batch of optionally masked images at spatially transformed target grid points.""" + ... + + def forward( + self, + data: Union[paddle.Tensor, Dict[str, Union[paddle.Tensor, Grid]]], + mask: Optional[paddle.Tensor] = None, + ) -> Union[ + paddle.Tensor, + Tuple[paddle.Tensor, paddle.Tensor], + Dict[str, Union[paddle.Tensor, Grid]], + ]: + """Sample batch of images at spatially transformed target grid points.""" + grid: paddle.Tensor = cast(paddle.Tensor, self.grid_coords) + grid = self._transform(grid, grid=True) + if self.flip_coords: + grid = grid.flip(axis=(-1,)) + return self._sample(grid, data, mask) + + +class PointSetTransformer(SpatialTransformer): + """Spatially transform a set of points. + + The :class:`.PointSetTransformer` applies a :class:`.SpatialTransform` to a set of input points + with coordinates defined with respect to a specified target domain. This coordinate map may + further be followed by a linear transformation from the grid domain of the spatial transform + to a given source domain. When no spatial transform is given, use :func:`.grid_transform_points`. + + The forward method of a point set transformer performs the same operation as :meth:`.SpatialTransform.points`, + but with the target and source domain arguments specified during transformer initialization. In addition, + the point set transformer module invokes the spatial transform as a functor such that any registered + forward pre- and post-hooks are executed as part of the spatial transform evaluation. This includes the + forward pre-hook that invokes :meth:`.SpatialTransform.update` (cf. :class:`.SpatialTransformer`). + + """ + + def __init__( + self, + transform: SpatialTransform, + grid: Optional[Grid] = None, + axes: Optional[Union[Axes, str]] = None, + to_grid: Optional[Grid] = None, + to_axes: Optional[Union[Axes, str]] = None, + ) -> None: + """Initialize point set transformer. + + Args: + transform: Spatial coordinate transformation which is applied to input points. + grid: Grid with respect to which input points are defined. Uses ``transform.grid()`` if ``None``. + axes: Coordinate axes with respect to which input points are defined. Uses ``transform.axes()`` if ``None``. + to_grid: Grid with respect to which output points are defined. Same as ``grid`` if ``None``. + to_axes: Coordinate axes to which input points should be mapped to. Same as ``axes`` if ``None``. + + """ + super().__init__(transform) + if grid is None: + grid = transform.grid() + if axes is None: + axes = transform.axes() + else: + axes = Axes.from_arg(axes) + if to_grid is None: + to_grid = grid + if to_axes is None: + to_axes = axes + else: + to_axes = Axes.from_arg(to_axes) + self._grid = grid + self._axes = axes + self._to_grid = to_grid + self._to_axes = to_axes + + def target_axes(self) -> Axes: + """Coordinate axes with respect to which input points are defined.""" + return self._axes + + def target_grid(self) -> Grid: + """Sampling grid with respect to which input points are defined.""" + return self._grid + + def source_axes(self) -> Axes: + """Coordinate axes with respect to which output points are defined.""" + return self._to_axes + + def source_grid(self) -> Grid: + """Sampling grid with respect to which output points are defined.""" + return self._to_grid + + def forward(self, points: paddle.Tensor) -> paddle.Tensor: + """Spatially transform a set of points.""" + transform = self.transform + points = self._grid.transform_points( + points, + axes=self._axes, + to_grid=transform.grid(), + to_axes=transform.axes(), + decimals=None, + ) + points = transform(points) + points = transform.grid().transform_points( + points, + axes=transform.axes(), + to_grid=self._to_grid, + to_axes=self._to_axes, + decimals=None, + ) + return points diff --git a/jointContribution/HighResolution/deepali/utils/__init__.py b/jointContribution/HighResolution/deepali/utils/__init__.py new file mode 100644 index 0000000000..4015a36c08 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/__init__.py @@ -0,0 +1,10 @@ +"""Utility functions and classes for building applications. + +Currently, these utilities are primarily intended for use in example applications and tutorials +of this library. It is not recommended to build external applications on these while it is in a +non-stable release version (<1.0). + +.. note:: + These auxiliary modules should only depend on the :mod:`.core` library and third-party packages. + +""" diff --git a/jointContribution/HighResolution/deepali/utils/aws/__init__.py b/jointContribution/HighResolution/deepali/utils/aws/__init__.py new file mode 100644 index 0000000000..a807b62f68 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/aws/__init__.py @@ -0,0 +1 @@ +"""Interfaces and utilities for Amazon Web Services (AWS).""" diff --git a/jointContribution/HighResolution/deepali/utils/aws/resource.py b/jointContribution/HighResolution/deepali/utils/aws/resource.py new file mode 100644 index 0000000000..1d3104848a --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/aws/resource.py @@ -0,0 +1,437 @@ +from __future__ import annotations + +import os +import re +import shutil +from copy import deepcopy +from pathlib import Path +from typing import Any +from typing import Generator +from typing import Optional +from typing import TypeVar +from typing import Union +from urllib.parse import urlsplit + +PathStr = Union[Path, str] +PathUri = Union[Path, str] +T = TypeVar("T", bound="Resource") + + +class Resource(object): + """Interface for storage objects. + + This base class can be used for storage objects that are only stored locally. + The base implementations of the ``Resource`` interface functions reflect this use + case, where ``Resource("/path/to/file")`` represents such local path object. + The factory function ``Resource.from_uri`` is recommended for creating concrete + instances of ``Resource`` or one of its subclasses. To create a resource instance + for a local file path, use ``Resource.from_uri("file:///path/to/file")``. + An S3 object resource is created by ``Resource.from_uri("s3://bucket/key")``. + By using the ``Resource`` interface when writing tools that read and write + from either local, remote, or cloud storage, the tool CLI can create these + resource instances from input argument URIs or local file path strings + without URI scheme, i.e., without "file://" prefix. The consumer or producer + of a resource object can either directly read/write the object data using + the ``read_(bytes|text)`` and/or ``write_(bytes|text)`` functions, or + download/upload the storage object to/from a local file path using the + ``pull`` and ``push`` operations. Note that these operations directly interact + with the local storage if the resource instance is of base type ``Resource``, + rather than a remote or cloud storage specific subclass. The ``pull`` and ``push`` + operations should be preferred over ``read`` and ``write`` if the resource data + is accessed multiple times, in order to take advantage of the local temporary + copy of the resource object. Otherwise, system IO operations can be saved by + using the direct ``read`` and ``write`` operations instead. + + Additionally, the ``Resource.release`` function should be called by tools when + a resource is no longer required to indicate that the local copy of this resource + can be removed. If the resource object itself represents a local ``Resource``, + the release operation has no effect. To ensure that the ``release`` function is + called also in case of an exception, the ``Resource`` class implements the context + manager interface functions ``__enter__`` and ``__exit__``. + + Example usage with resource context: + + .. code-block:: python + + with Resource.from_uri("s3://bucket/key") as res: + # request download to local storage + path = res.pull().path + # use local storage object referenced by path + # local copy of storage object has been deleted + + The above is equivalent to using a try-finally block: + + .. code-block:: python + + res = Resource.from_uri("s3://bucket/key") + try: + path = res.pull().path + # use local storage object referenced by path + finally: + # delete local copy of storage object + res.release() + + Usage of the ``with`` statement is recommended. + + Nesting of contexts for a resource object is possible and post-pones the + invocation of the ``release`` operation until the outermost context has + been left. This is accomplished by using a counter that is increment by + ``__enter__``, and decremented again by ``__exit__``. + + It should be noted that ``Resource`` operations are generally not thread-safe, + and actual consumers of resource objects should require the main thread to + deal with obtaining, downloading (if ``pull`` is used), and releasing a resource. + For different resources from remote storage (e.g., AWS S3), when using multiple + processes (threads), the main process (thread) must initialize the default client + connection (e.g., using ``S3Client.init_default()``) before spawning processes. + + """ + + def __init__(self: T, path: PathStr) -> None: + """Initialize storage object. + + Args: + path (str, pathlib.Path): Local path of storage object. + + """ + self._path = Path(path).absolute() + self._depth = 0 + + def __enter__(self: T) -> T: + """Enter context.""" + self._depth = max(1, self._depth + 1) + return self + + def __exit__(self: T, *exc) -> None: + """Release resource when leaving outermost context.""" + self._depth = max(0, self._depth - 1) + if self._depth == 0: + self.release() + + @staticmethod + def from_path(*args: Optional[PathStr]) -> Resource: + """Create storage object from path or URI. + + Args: + args: Path or URI components. The last absolute path or URI is the base to which + subsequent arguments are appended. Any ``None`` value is ignored. See also ``to_uri()``. + + Returns: + obj (Resource): Instance of concrete type representing the referenced storage object. + + """ + return Resource.from_uri(to_uri(args)) + + @staticmethod + def from_uri(uri: str) -> Resource: + """Create storage object from URI. + + Args: + uri: URI of storage object. + + Returns: + obj (Resource): Instance of concrete type representing the referenced storage object. + + """ + res = urlsplit(uri, scheme="file") + if res.scheme == "file": + match = re.match("/+([a-zA-Z]:.*)", res.path) + path = match.group(1) if match else res.path + return Resource(Path("/" + res.netloc + "/" + path if res.netloc else path)) + if res.scheme == "s3": + from .s3.object import S3Object + + return S3Object.from_uri(uri) + raise ValueError("Invalid or unsupported storage object URI: %s", uri) + + @property + def uri(self: T) -> str: + """ + Returns: + uri (str): URI of storage object. + + """ + return self.path.as_uri() + + @property + def path(self: T) -> Path: + """Get absolute local path of storage object.""" + return self._path + + @property + def name(self: T) -> str: + """Name of storage object including file name extension, excluding directory path.""" + return self.path.name + + def with_path(self: T, path) -> T: + """Create copy of storage object reference with modified ``path``. + + Args: + path (str, pathlib.Path): New local path of storage object. + + Returns: + self: New storage object reference with modified ``path`` property. + + """ + obj = deepcopy(self) + obj._path = Path(path).absolute() + return obj + + def with_properties(self: T, **kwargs) -> T: + """Create copy of storage object reference with modified properties. + + Args: + **kwargs: New property values. Only specified properties are changed. + + Returns: + self: New storage object reference with modified properties. + + """ + obj = deepcopy(self) + for name, value in kwargs.items(): + setattr(obj, name, value) + return obj + + def exists(self: T) -> bool: + """Whether object exists in storage.""" + return self.path.exists() + + def is_file(self: T) -> bool: + """Whether storage object represents a file.""" + return self.path.is_file() + + def is_dir(self: T) -> bool: + """Whether storage object represents a directory.""" + return self.path.is_dir() + + def iterdir(self: T, prefix: Optional[str] = None) -> Generator[T, None, None]: + """List storage objects within directory, excluding subfolder contents. + + Args: + prefix: Name prefix. + + Returns: + iterable: Generator of storage objects. + + """ + assert type(self) is Resource, "must be implemented by subclass" + for path in self.path.iterdir(): + if not prefix or path.name.startswith(prefix): + yield Resource(path) + + def pull(self: T, force: bool = False) -> T: + """Download content of storage object to local path. + + Args: + force (bool): Whether to force download even if local path already exists. + + Returns: + self: This storage object. + + """ + return self + + def push(self: T, force: bool = False) -> T: + """Upload content of local path to storage object. + + Args: + force (bool): Whether to force upload even if storage object already exists. + + Returns: + self: This storage object. + + """ + return self + + def read_bytes(self: T) -> bytes: + """Read file content from local path if it exists, or referenced storage object otherwise. + + Returns: + data (bytes): Binary file content of storage object. + + """ + return self.pull().path.read_bytes() + + def write_bytes(self: T, data: bytes) -> T: + """Write bytes to storage object. + + Args: + data (bytes): Binary data to write. + + Returns: + self: This storage object. + + """ + self.path.parent.mkdir(parents=True, exist_ok=True) + self.path.write_bytes(data) + return self.push() + + def read_text(self: T, encoding: Optional[str] = None) -> str: + """Read text file content from local path if it exists, or referenced storage object otherwise. + + Args: + encoding (str): Text encoding. + + Returns: + text (str): Decoded text file content of storage object. + + """ + return self.pull().path.read_text() + + def write_text(self: T, text: str, encoding: Optional[str] = None) -> T: + """Write text to storage object. + + Args: + text (str): Text to write. + encoding (str): Text encoding. + + Returns: + self: This storage object. + + """ + self.path.parent.mkdir(parents=True, exist_ok=True) + self.path.write_text(text, encoding=encoding) + return self.push() + + def rmdir(self: T) -> T: + """Remove directory both locally and from remote storage.""" + try: + shutil.rmtree(self.path) + except FileNotFoundError: + pass + return self + + def unlink(self: T) -> T: + """Remove file both locally and from remote storage.""" + try: + self.path.unlink() + except FileNotFoundError: + pass + return self + + def delete(self: T) -> T: + """Remove object both locally and from remote storage.""" + try: + self.rmdir() + except NotADirectoryError: + self.unlink() + return self + + def release(self: T) -> T: + """Release local temporary copy of storage object. + + Only remove local copy of storage object. When the storage object + is only stored locally, i.e., self is not a subclass of Resource, + but of type ``Resource``, this operation does nothing. + + """ + if type(self) is not Resource: + try: + shutil.rmtree(self.path) + except FileNotFoundError: + pass + except NotADirectoryError: + try: + self.path.unlink() + except FileNotFoundError: + pass + return self + + def __str__(self: T) -> str: + """Get human-readable string representation of storage object reference.""" + return self.uri + + def __repr__(self: T) -> str: + """Get human-readable string representation of storage object reference.""" + return type(self).__name__ + "(path='{}')".format(self.path) + + +def is_absolute(path: Union[Path, str]) -> bool: + """Check whether given path string or URI is absolute.""" + if is_uri(path): + return True + return Path(path).is_absolute() + + +def is_uri(arg: Any) -> bool: + """Check whether a given argument is a URI.""" + if isinstance(arg, Path): + return False + if isinstance(arg, str): + if os.name == "nt" and re.match("([a-zA-Z]):[/\\\\](.*)", arg): + return False + return re.match("([a-zA-Z0-9]+)://(.*)", arg) is not None + return False + + +def to_uri(*args: Optional[PathStr]) -> str: + """Create valid URI from resource paths. + + Args: + args: Local path components or an already valid URI. The last absolute path or URI in this + list of arguments is the base path or URI prefix for subsequent relative paths which + are appended to this base to construct the URI. Any ``None`` values are ignored. + + Returns: + Valid URI. + + """ + args = [arg for arg in args if arg is not None] + for i, arg in enumerate(reversed(args)): + if is_uri(arg): + base = str(arg) + args = args[len(args) - i :] + break + elif Path(arg).is_absolute(): + base = Path(arg) + args = args[len(args) - i :] + break + else: + base = Path.cwd() + if isinstance(base, Path): + uri = local_path_uri(base.joinpath(*args)) + else: + uri = norm_uri(f"{base}/{'/'.join(args)}" if args else base) + return uri + + +def norm_uri(uri: str) -> str: + """Normalize URI. + + Args: + uri: A valid URI string. + + Returns: + Normalized URI string. + + """ + match = re.match("(?P[a-zA-Z0-9]+)://(?P.*)", uri) + if not match: + raise ValueError(f"norm_uri() 'uri' is not a valid URI: {uri}") + scheme = match["scheme"].lower() + path = re.sub("^/+", "", re.sub("[/\\\\]{1,}", "/", match["path"])) + if scheme == "file": + if os.name != "nt" or re.match("(?P[a-zA-Z]):[/\\\\]", path) is None: + path = "/" + path + return "file://" + path + if scheme == "s3": + return "s3://" + path + return urlsplit(uri, scheme="file").geturl() + + +def local_path_uri(arg: PathStr) -> str: + """Create valid URI from local path. + + Unlike Path.as_uri(), this function does not escape special characters as used in format template strings. + + Args: + arg: Local path. + + Returns: + Valid URI. + + """ + uri = norm_uri(f"file://{Path(arg).absolute()}") + if uri.endswith("/"): + uri = uri[:-1] + return uri diff --git a/jointContribution/HighResolution/deepali/utils/aws/s3/__init__.py b/jointContribution/HighResolution/deepali/utils/aws/s3/__init__.py new file mode 100644 index 0000000000..8b12caaaf9 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/aws/s3/__init__.py @@ -0,0 +1,6 @@ +"""Interfaces and utilities for AWS Simple Storage Service (S3).""" +from .client import S3Client +from .config import S3Config +from .object import S3Object + +__all__ = "S3Config", "S3Client", "S3Object" diff --git a/jointContribution/HighResolution/deepali/utils/aws/s3/client.py b/jointContribution/HighResolution/deepali/utils/aws/s3/client.py new file mode 100644 index 0000000000..f9bc241477 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/aws/s3/client.py @@ -0,0 +1,627 @@ +"""Client for AWS Simple Storage Service (S3).""" +from __future__ import annotations + +import io +from contextlib import contextmanager +from enum import Enum +from pathlib import Path +from typing import Any +from typing import Callable +from typing import Dict +from typing import Generator +from typing import List +from typing import Optional +from typing import Set +from typing import Tuple +from typing import Union + +import boto3 + +from .config import S3Config + +PathStr = Union[Path, str] + + +def match_all(key: str) -> bool: + """Default 'match' function used by 'S3Client.download_files()'.""" + return True + + +class S3Client(object): + """Client for AWS Simple Storage Service (S3).""" + + class Operation(Enum): + """Enumeration of permissible S3 operations. + + Note that these permissions restrict the set of allowed operations and + is independent of S3 object permissions which are configured in AWS S3. + + """ + + READ = "r" + WRITE = "w" + DELETE = "d" + + _default: List[S3Client] = [] + + @classmethod + def default(cls, client: Optional[S3Client] = None) -> S3Client: + """Default client instance. + + This function is not thread-safe if not called first by the main thread. + + Args: + client: If ``None``, a new client instance is created if no current default client exists. + Otherwise, the given client instance replaces the current default instance. + + Returns: + Current default client instance. If no default client instance has been set before, + a new instance with default configuration is created. + + """ + if client is not None: + if cls._default: + cls._default[-1] = client + else: + cls._default.append(client) + elif not cls._default: + cls._default.append(cls()) + return cls._default[-1] + + @classmethod + @contextmanager + def init_default(cls, *args, **kwargs) -> Generator[S3Client, None, None]: + """Set new default client to use in current context. + + This function is not thread-safe. Only the main thread may set the default client + instance to use by worker threads. Any multi-threaded operations should share + a client instance, a reference to which should be provided by the main thread. + Nested ``with`` statements using this context manager function are supported, + as long as these occur in the main thread only. + + Args: + *args: Positional arguments for client ``__init__`` function call. + **kwargs: Keyword arguments for client ``__init__`` function call. + + Returns: + client (S3Client): Generator function that returns the connected default client + instance and resets the default to the previously active one upon completion. + + """ + client = cls(*args, **kwargs) + client.connect() + cls._default.append(client) + try: + yield client + finally: + client = cls._default.pop() + client.close() + + def __init__(self, config: Optional[S3Config] = None, **kwargs) -> None: + """Initialize S3 client. + + Args: + config: Client configuration. + **kwargs: Individual client configuration settings. + + """ + super().__init__() + if config is None: + config = S3Config() + self._config = config._replace(**kwargs) + self._client = None + + @classmethod + def from_arg(cls, arg: Union[S3Config, S3Client, Dict[str, Any], None]) -> S3Client: + """Get client instance given function argument. + + Args: + arg: Function argument. + + Returns: + client: S3 client instance. If ``arg`` is an S3Client instance, + this instance is returned. If ``arg`` is of type S3Config or dict, a new + instance is returned, which has been initialized with this configuration. + If ``arg`` is ``None``, the default client instance is returned. + + """ + if isinstance(arg, cls): + return arg + if isinstance(arg, S3Config): + return cls(arg) + if isinstance(arg, dict): + return cls(**arg) + return cls.default() + + @property + def exceptions(self): + """Get boto3 exceptions for connected client.""" + if self._client is None: + raise AssertionError( + f"{type(self).__name__} must be connected to access exceptions. Mabye use botocore.exceptions module instead." + ) + return self._client.exceptions + + @property + def config(self) -> S3Config: + """Get client configuration object.""" + return self._config + + @property + def ops(self) -> Set[S3Client.Operation]: + """Get list of permissible client operations.""" + return set(S3Client.Operation(c) for c in self.config.mode) + + def connect(self) -> S3Client: + """Establish connection with AWS S3.""" + if self._client is None: + self._client = boto3.client( + "s3", region_name=self.config.region, verify=self.config.verify + ) + return self + + def close(self) -> None: + """Close connection with AWS S3.""" + self._client = None + + def is_closed(self) -> bool: + """Check if S3 connection is closed.""" + return self._client is None + + def is_open(self) -> bool: + """Check if client connection is open.""" + return not self.is_closed() + + def __enter__(self) -> S3Client: + """Ensure client connection is open when entering context. + + This function increments a context depth counter in a non-thread-safe way. + Only the main thread should create client connection and pass these on to + worker threads that require access to this resource. + + """ + self.connect() + self._depth = max(1, self._depth + 1) + return self + + def __exit__(self, *exc) -> None: + """Close connection when exiting outermost context.""" + self._depth = max(0, self._depth - 1) + if self._depth == 0: + self.close() + + def exists(self, bucket: str, key: str) -> bool: + """Check if specified S3 object exists. + + Args: + bucket: Bucket name. + key: S3 object key. To query the existence of a folder, the + key must end with a forward slash character ('/'). + + Returns: + Whether S3 object exists in specified bucket. + + """ + assert not self.is_closed(), "client connection required" + if key.endswith("/"): + resp = self._client.list_objects_v2(Bucket=bucket, Prefix=key) + flag = bool(resp.get("Contents", [])) + else: + flag = False + try: + self._client.head_object(Bucket=bucket, Key=key) + flag = True + except self._client.exceptions.ClientError as error: + error_info = getattr(error, "response", {}).get("Error", {}) + if error_info.get("Code") != "404": + raise error + except self._client.exceptions.NoSuchKey: + pass + return flag + + def keys( + self, bucket: str, prefix: Optional[str] = None + ) -> Generator[str, None, None]: + """List S3 objects with specified key prefix. + + Args: + bucket: Bucket name. + prefix: Common key prefix. + + Returns: + Generator of S3 object keys. + + Raises: + PermissionError: If read operations have not been enabled for this client. + + """ + assert not self.is_closed(), "client connection required" + assert bucket, "S3 bucket must be specified" + if S3Client.Operation.READ not in self.ops: + raise PermissionError("S3 client has no read permissions") + kwargs = {"Bucket": bucket} + if prefix: + kwargs["Prefix"] = prefix + while True: + resp: dict = self._client.list_objects_v2(**kwargs) + for obj in resp.get("Contents", []): + yield obj["Key"] + try: + kwargs["ContinuationToken"] = resp["NextContinuationToken"] + except KeyError: + break + + def iterdir( + self, bucket: str, prefix: Optional[str] = None + ) -> Generator[str, None, None]: + """List S3 objects with specified key prefix, excluding subfolder contents. + + Args: + bucket: Bucket name. + prefix: Common key prefix. To list the contents of a folder, the + key prefix must end with a forward slash character ('/'), otherwise + the result will only contain the folder itself. + + Returns: + Generator of S3 object keys, where subfolders are represented + by keys ending with a forward slash character ('/'). + + Raises: + PermissionError: If read operations have not been enabled for this client. + + """ + assert not self.is_closed(), "client connection required" + assert bucket, "S3 bucket must be specified" + if S3Client.Operation.READ not in self.ops: + raise PermissionError("S3 client has no read permissions") + kwargs = {"Bucket": bucket, "Delimiter": "/"} + if prefix: + kwargs["Prefix"] = prefix + while True: + resp: dict = self._client.list_objects_v2(**kwargs) + for obj in resp.get("Contents", []): + yield obj["Key"] + for obj in resp.get("CommonPrefixes", []): + yield obj["Prefix"] + try: + kwargs["ContinuationToken"] = resp["NextContinuationToken"] + except KeyError: + break + + def listdir( + self, bucket: str, prefix: Optional[str] = None + ) -> Generator[str, None, None]: + """List names of S3 objects whose keys match a given prefix, excluding subfolder contents. + + This convenience function uses ``iterdir`` to obtain the keys of the S3 objects matching + a given ``prefix`` that correspond to the respective folder, and removes the folder key + prefix from the resulting object keys. It thus behaves similar to ``os.listdir``. + + Args: + bucket: Bucket name. + prefix: Common key prefix. To list the contents of a folder, the + key prefix must end with a forward slash character ('/'), otherwise + the result will only contain the folder itself. + + Returns: + Generator of S3 object keys, where subfolders are represented + by keys ending with a forward slash character ('/'). + + Raises: + PermissionError: If read operations have not been enabled for this client. + + """ + skip = (prefix or "").rfind("/") + 1 + for key in self.iterdir(bucket, prefix): + yield key[skip:] + + def read_bytes(self, bucket: str, key: str) -> bytes: + """Download binary object data. + + Args: + bucket: Bucket name. + key: S3 object key. + + Returns: + data: S3 object data. + + Raises: + PermissionError: If read operations have not been enabled for this client. + + """ + assert not self.is_closed(), "client connection required" + assert bucket, "S3 bucket must be specified" + assert key, "S3 object key must be specified" + if S3Client.Operation.READ not in self.ops: + raise PermissionError("S3 client has no read permissions") + buffer = io.BytesIO() + self._client.download_fileobj(Bucket=bucket, Key=key, Fileobj=buffer) + return buffer.getvalue() + + def write_bytes(self, bucket: str, key: str, data: bytes) -> None: + """Upload binary object data. + + Args: + bucket: Bucket name. + key: S3 object key. + data: S3 object data. + + Raises: + PermissionError: If write operations have not been enabled for this client. + + """ + assert not self.is_closed(), "client connection required" + assert bucket, "S3 bucket must be specified" + assert key, "S3 object key must be specified" + if S3Client.Operation.WRITE not in self.ops: + raise PermissionError("S3 client has no write permissions") + buffer = io.BytesIO(data) + self._client.upload_fileobj(Fileobj=buffer, Bucket=bucket, Key=key) + + def read_text(self, bucket: str, key: str, encoding: Optional[str] = None) -> str: + """Download text file content. + + Args: + bucket: Bucket name. + key: S3 object key. + encoding: Text encoding. + + Returns: + text: Decoded text. + + Raises: + PermissionError: If read operations have not been enabled for this client. + + """ + data = self.read_bytes(bucket, key) + return data.decode(encoding) if encoding else data.decode() + + def write_text( + self, bucket: str, key: str, text: str, encoding: Optional[str] = None + ) -> None: + """Upload text file content. + + Args: + bucket: Bucket name. + key: S3 object key. + text: S3 object data. + encoding: Text encoding. + + Raises: + PermissionError: If write operations have not been enabled for this client. + + """ + data = text.encode(encoding) if encoding else text.encode() + self.write_bytes(bucket=bucket, key=key, data=data) + + def download_file( + self, bucket: str, key: str, path: PathStr, overwrite: bool = True + ) -> None: + """Download S3 object to local file. + + Args: + bucket: Bucket name. + key: S3 object key. + path: Local file path or path of existing output directory. + overwrite: Whether to overwrite existing local file. + + Raises: + PermissionError: If read operations have not been enabled for this client. + FileExistsError: If local ``path`` exists and ``overwrite=False``. + + """ + assert not self.is_closed(), "client connection required" + assert bucket, "S3 bucket must be specified" + assert key, "S3 object key must be specified" + if S3Client.Operation.READ not in self.ops: + raise PermissionError("S3 client has no read permissions") + path = Path(path).absolute() + if path.is_dir(): + path = path.joinpath(key.rsplit("/", 1)[-1]) + if not overwrite and path.is_file(): + raise FileExistsError( + "Use overwrite=True to force overwriting existing local file '{}'".format( + path + ) + ) + data = self.read_bytes(bucket, key) + try: + path.unlink() + except FileNotFoundError: + path.parent.mkdir(parents=True, exist_ok=True) + try: + path.write_bytes(data) + except (FileNotFoundError, PermissionError): + raise + except Exception as e: + try: + path.unlink() + except Exception: + pass + raise e + + def download_files( + self, + bucket: str, + prefix: str, + path: PathStr, + match: Optional[Callable[[str], bool]] = None, + overwrite: bool = True, + ) -> Tuple[int, int]: + """Download S3 objects to local directory. + + Args: + bucket: Bucket name. + prefix: Common key prefix. + path: Local directory path. If the directory exists, downloaded + files are merged into this existing directory and subdirectories. + match: Filter function that takes a subkey without ``prefix`` and + returns either True or False. If False, the corresponding file is skipped. + overwrite: Whether to overwrite existing local files. + + Returns: + total: Number of matching S3 objects. + count: Number of actually downloaded files. + + Raises: + PermissionError: If read operations have not been enabled for this client. + + """ + total = 0 + count = 0 + skip = prefix.rfind("/") + 1 + if match is None: + match = match_all + for key in self.keys(bucket=bucket, prefix=prefix): + subkey = key[skip:] + if subkey and match(subkey): + total += 1 + try: + self.download_file( + bucket=bucket, + key=key, + path=path.joinpath(subkey), + overwrite=overwrite, + ) + count += 1 + except FileExistsError: + pass + return total, count + + def upload_file( + self, path: PathStr, bucket: str, key: str, overwrite: bool = True + ) -> None: + """Upload local file to S3. + + Args: + path: Local file path. + bucket: Bucket name. + key: S3 object key. + overwrite: Whether to overwrite existing S3 object. + + Raises: + PermissionError: If write operations have not been enabled for this client. + + """ + assert not self.is_closed(), "client connection required" + assert bucket, "S3 bucket must be specified" + assert key, "S3 object key must be specified" + assert not key.endswith("/"), "S3 object key must not end with forward slash" + if S3Client.Operation.WRITE not in self.ops: + raise PermissionError("S3 client has no write permissions") + if not overwrite and self.exists(bucket=bucket, key=key): + raise FileExistsError( + "Use overwrite=True to force overwriting " + + "existing S3 object '{k}' in bucket {b}".format(k=key, b=bucket) + ) + path = Path(path).absolute() + self.write_bytes(bucket=bucket, key=key, data=path.read_bytes()) + + def upload_files( + self, + path: PathStr, + bucket: str, + prefix: Optional[str] = None, + overwrite: bool = True, + ) -> Tuple[int, int]: + """Upload local directory to S3. + + Args: + path: Local directory path. + bucket: Bucket name. + prefix: Common S3 object key prefix. If ``None`` or empty string, + the content of the local directory is uploaded to the specified + ``bucket`` without any object key prefix, i.e., the files will + be located in S3 directly underneath the destination bucket. + overwrite: Whether to overwrite existing S3 objects. + + Returns: + total: Number of local files. + count: Number of uploaded files. + + Raises: + PermissionError: If write operations have not been enabled for this client. + + """ + assert not self.is_closed(), "client connection required" + assert bucket, "S3 bucket must be specified" + total = 0 + count = 0 + if prefix is None: + prefix = "/" + if not prefix.endswith("/"): + prefix += "/" + path = Path(path).absolute() + for child in path.iterdir(): + key = prefix + child.name + if child.is_dir(): + tot, cnt = self.upload_files( + path=child, bucket=bucket, prefix=key, overwrite=overwrite + ) + total += tot + count += cnt + else: + try: + self.upload_file( + path=child, bucket=bucket, key=key, overwrite=overwrite + ) + count += 1 + except FileExistsError: + pass + total += 1 + return total, count + + def delete_file(self, bucket: str, key: str) -> int: + """Delete S3 object. + + Args: + bucket: Bucket name. + key: S3 object key. + + Returns: + count: Number of deleted S3 objects (0 or 1). + + Raises: + PermissionError: If delete operations have not been enabled for this client. + + """ + assert not self.is_closed(), "client connection required" + if S3Client.Operation.DELETE not in self.ops: + raise PermissionError("S3 client has no permission to delete objects") + resp = self._client.delete_object(Bucket=bucket, Key=key) + return 1 if "DeleteMarker" in resp else 0 + + def delete_files(self, bucket: str, prefix: str) -> int: + """Delete S3 objects. + + Args: + bucket: Bucket name. + prefix: Common key prefix. + + Returns + count: Number of deleted S3 objects. + + Raises: + PermissionError: If delete operations have not been enabled for this client. + + """ + assert not self.is_closed(), "client connection required" + assert bucket, "S3 bucket must be specified" + if S3Client.Operation.DELETE not in self.ops: + raise PermissionError("S3 client has no permission to delete objects") + count = 0 + token = None + while True: + resp = self._client.list_objects_v2( + Bucket=bucket, Prefix=prefix, ContinuationToken=token + ) + count += len( + self._client.delete_objects( + Bucket=bucket, Delete=dict(Objects=resp["Contents"]) + )["Deleted"] + ) + try: + token = resp["NextContinuationToken"] + except KeyError: + break + return count diff --git a/jointContribution/HighResolution/deepali/utils/aws/s3/config.py b/jointContribution/HighResolution/deepali/utils/aws/s3/config.py new file mode 100644 index 0000000000..be1937470f --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/aws/s3/config.py @@ -0,0 +1,10 @@ +from typing import NamedTuple +from typing import Optional + + +class S3Config(NamedTuple): + """Configuration of AWS Simple Storage Service.""" + + mode: str = "r" + verify: bool = True + region: Optional[str] = None diff --git a/jointContribution/HighResolution/deepali/utils/aws/s3/object.py b/jointContribution/HighResolution/deepali/utils/aws/s3/object.py new file mode 100644 index 0000000000..f8431aa586 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/aws/s3/object.py @@ -0,0 +1,317 @@ +"""Representation of objects stored in AWS Simple Storage Service (S3).""" +from __future__ import annotations + +import re +from copy import deepcopy +from pathlib import Path +from tempfile import gettempdir +from typing import Generator +from typing import Optional + +from ..resource import PathStr +from ..resource import Resource +from .client import S3Client + + +class S3Object(Resource): + """Object stored in AWS Simple Storage Service (S3).""" + + def __init__(self, bucket: str, key: str, path: PathStr = None) -> None: + """Initialize AWS S3 object. + + Args: + bucket: Name of AWS S3 bucket containing this object. + key: Key of object in AWS S3 bucket with forward slashes as path separators. + When string ends with a forward slash, the S3 object represents a set + of S3 objects which share this common key prefix. The local path is + in this case a directory tree rather than a single file path. + path: File path of local file or directory corresponding to AWS S3 object. + When not specified, a temporary path in the tempfile.gettempdir() is + constructed from the bucket name and S3 object key. + + """ + if bucket is None: + raise TypeError("S3 bucket name must be str") + if bucket == "": + raise ValueError("S3 bucket name must not be an empty string") + key = self.normkey(key) + path = Path(path) if path else self.default_path(bucket=bucket, key=key) + if not key.endswith("/") and path.is_dir(): + key += "/" + super().__init__(path) + self._bucket = bucket + self._key = key + + @property + def s3(self) -> S3Client: + """ + Returns: + client: Underlying connected S3 client. + + """ + return S3Client.default().connect() + + @staticmethod + def normkey(key: str) -> str: + """Normalize S3 object key string.""" + if not isinstance(key, str): + raise TypeError("S3 object key must be str") + if key == "": + return "/" + key = re.sub("[/\\\\]{1,}", "/", key) + if len(key) > 1 and key.startswith("/"): + key = key[1:] + return key + + @staticmethod + def default_path(bucket: str, key: str) -> Path: + """Get default local path of S3 object copy. + + Args: + bucket: Name of S3 bucket. + key: Key of S3 object. + + Returns: + path: Absolute path of local copy of S3 object. + + """ + key = S3Object.normkey(key) + if key.startswith("/"): + key = key[1:] + return Path(gettempdir()).joinpath("deepali", "cache", "s3", bucket, key) + + def reset_path(self) -> S3Object: + """Reset ``path`` to ``default_path`` for given ``bucket`` and ``key``. + + Returns: + self: This instance. + + """ + self._path = self.default_path(bucket=self.bucket, key=self.key) + return self + + @classmethod + def from_uri(cls, uri: str) -> S3Object: + """Create AWS S3 object from URI. + + Args: + uri: URI of S3 object. Must start with 's3://' followed by the bucket name. + The remainder of the URI represents the object key, excluding the forward + slash separating the bucket name from the object key. + + Returns: + obj: S3 object instance. + + """ + match = re.match("[sS]3://(?P[^/]+)/(?P.*)", uri) + if not match: + raise ValueError("Invalid AWS S3 object URI: %s", uri) + return cls(bucket=match["bucket"], key=match["key"]) + + @property + def uri(self) -> str: + """URI of storage object.""" + return "s3://" + self.bucket + "/" + self.key + + @property + def bucket(self) -> str: + """Name of S3 bucket containing storage object.""" + return self._bucket + + @property + def key(self) -> str: + """Name of S3 key corresponding to storage object referenced by this instance.""" + return self._key + + @property + def name(self) -> str: + """Name of storage object including file name extension, excluding directory path.""" + key = self.key + if key.endswith("/"): + key = key[:-1] + return key.rsplit("/", 1)[-1] + + def with_bucket(self, bucket: str) -> S3Object: + """Create copy of storage object reference with modified ``bucket``. + + Args: + bucket: New bucket name. + + Returns: + self: New storage object reference with modified ``bucket`` property. + + """ + if bucket is None: + raise TypeError("Bucket name must be str") + if bucket == "": + raise ValueError("Invalid S3 bucket name") + obj = deepcopy(self) + obj._bucket = bucket + return obj + + def with_key(self, key: str) -> S3Object: + """Create copy of storage object reference with modified ``key``. + + Args: + key: New S3 object key. + + Returns: + self: New storage object reference with modified ``key`` property. + + """ + if key is None: + raise TypeError("S3 object key must be str") + if key == "": + raise ValueError("Invalid S3 object key") + obj = deepcopy(self) + obj._key = key + return obj + + def exists(self) -> bool: + """Whether object exists in AWS S3.""" + return self.s3.exists(bucket=self.bucket, key=self.key) + + def is_file(self) -> bool: + """Whether AWS S3 object exists and represents a file.""" + return not self.key.endswith("/") and self.exists() + + def is_dir(self) -> bool: + """Whether AWS S3 object exists and represents a directory.""" + return self.key.endswith("/") and self.exists() + + def iterdir(self, prefix: str = None) -> Generator[Resource, None, None]: + """List S3 objects within directory, excluding subfolder contents. + + Args: + prefix: Name prefix. + + Returns: + iterable: Generator of S3 objects. + + """ + if self.key.endswith("/"): + if prefix is None: + prefix = "" + if self.key != "/": + prefix = self.key + prefix + for key in self.s3.iterdir(bucket=self.bucket, prefix=prefix): + yield S3Object(bucket=self.bucket, key=key) + + def pull(self, force: bool = False) -> S3Object: + """Download content from AWS S3 to local path. + + Args: + force: Whether to force download even if local path already exists. + + Returns: + self: This instance. + + """ + if self.key.endswith("/"): + self.s3.download_files( + bucket=self.bucket, prefix=self.key, path=self.path, overwrite=force + ) + else: + try: + self.s3.download_file( + bucket=self.bucket, key=self.key, path=self.path, overwrite=force + ) + except FileExistsError: + pass + return self + + def push(self, force: bool = False) -> S3Object: + """Upload content of local path to AWS S3. + + Args: + force: Whether to force upload even if S3 object already exists. + + Returns: + self: This instance. + + """ + if self.key.endswith("/"): + self.s3.upload_files( + bucket=self.bucket, prefix=self.key, path=self.path, overwrite=force + ) + else: + self.s3.upload_file( + bucket=self.bucket, key=self.key, path=self.path, overwrite=force + ) + return self + + def read_bytes(self) -> bytes: + """Read file content from local path if it exists, or corresponding S3 object otherwise. + + Returns: + Binary file content of storage object. + + """ + assert not self.key.endswith("/"), "S3 object key must referrence a file object" + if self.path.is_file(): + return self.path.read_bytes() + return self.s3.read_bytes(bucket=self.bucket, key=self.key) + + def write_bytes(self, data: bytes) -> S3Object: + """Write bytes directly to storage object. + + Args: + data: Binary data to write. + + Returns: + self: This storage object. + + """ + assert not self.key.endswith("/"), "S3 object key must referrence a file object" + self.s3.write_bytes(bucket=self.bucket, key=self.key, data=data) + return self + + def read_text(self, encoding: Optional[str] = None) -> str: + """Read text file content from local path if it exists, or corresponding S3 object otherwise. + + Args: + encoding: Text encoding. + + Returns: + text: Decoded text file content of storage object. + + """ + data = self.read_bytes() + return data.decode(encoding) if encoding else data.decode() + + def write_text(self, text: str, encoding: Optional[str] = None) -> S3Object: + """Write text to storage object. + + Args: + text: Text to write. + encoding: Text encoding. + + Returns: + self: This storage object. + + """ + self.write_bytes(text.encode(encoding) if encoding else text.encode()) + return self + + def rmdir(self) -> S3Object: + """Remove S3 objects both locally and from cloud storage.""" + if not self.key.endswith("/"): + raise NotADirectoryError("S3 object key prefix must end with '/' character") + super().rmdir() + self.s3.delete_files(bucket=self.bucket, prefix=self.key) + return self + + def unlink(self) -> S3Object: + """Remove S3 object both locally and from cloud storage.""" + assert not self.key.endswith( + "/" + ), "S3 object key must not end with '/' character" + super().unlink() + self.s3.delete_file(bucket=self.bucket, key=self.key) + return self + + def __repr__(self) -> str: + """Get human-readable string representation of storage object reference.""" + return type(self).__name__ + "(bucket='{}', key='{}', path='{}')".format( + self.bucket, self.key, self.path + ) diff --git a/jointContribution/HighResolution/deepali/utils/cli/__init__.py b/jointContribution/HighResolution/deepali/utils/cli/__init__.py new file mode 100644 index 0000000000..d74c5c1a54 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/cli/__init__.py @@ -0,0 +1,34 @@ +"""Auxiliary functions for implementing `argparse `_ based command line interfaces.""" +from .argparse import ArgumentParser +from .argparse import MainCallable +from .argparse import ParsedArguments +from .argparse import ParserCallable +from .argparse import UnknownArguments +from .argparse import entry_point +from .argparse import main_func +from .environ import check_cuda_visible_devices +from .environ import cuda_visible_devices +from .environ import init_omp_num_threads +from .logging import LOG_FORMAT +from .logging import LogLevel +from .logging import configure_logging +from .warnings import filter_warning_of_experimental_named_tensors_feature + +Args = ParsedArguments +__all__ = ( + "Args", + "ArgumentParser", + "LogLevel", + "LOG_FORMAT", + "entry_point", + "check_cuda_visible_devices", + "configure_logging", + "cuda_visible_devices", + "filter_warning_of_experimental_named_tensors_feature", + "init_omp_num_threads", + "main_func", + "MainCallable", + "ParsedArguments", + "ParserCallable", + "UnknownArguments", +) diff --git a/jointContribution/HighResolution/deepali/utils/cli/argparse.py b/jointContribution/HighResolution/deepali/utils/cli/argparse.py new file mode 100644 index 0000000000..f014020e6a --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/cli/argparse.py @@ -0,0 +1,342 @@ +import argparse +import inspect +import logging +import os +import sys +from collections import namedtuple +from importlib import import_module +from pkgutil import walk_packages +from typing import Any +from typing import Callable +from typing import Dict +from typing import Iterable +from typing import List +from typing import Optional +from typing import Sequence +from typing import Union + +from pkg_resources import DistributionNotFound +from pkg_resources import get_distribution +from typing_extensions import Protocol + +ParsedArguments = argparse.Namespace +UnknownArguments = List[str] +STANDARD_ARGUMENTS_GROUP_TITLE = "Standard arguments" + + +class HelpFormatter(argparse.ArgumentDefaultsHelpFormatter): + def add_usage(self, usage, actions, groups, prefix=None): + if prefix is None: + prefix = "Usage: " + return super().add_usage(usage, actions, groups, prefix) + + +class ArgumentParser(argparse.ArgumentParser): + """Customized ArgumentParser subclass to be propagated to subparsers.""" + + def __init__(self, add_help=False, formatter_class=None, version=None, **kwargs): + if formatter_class is None: + formatter_class = HelpFormatter + super().__init__(add_help=False, formatter_class=formatter_class, **kwargs) + self._positionals.title = "Positional arguments" + self._optionals.title = "Optional arguments" + group = self.add_argument_group(STANDARD_ARGUMENTS_GROUP_TITLE) + if version: + group.add_argument( + "--version", + action="version", + version="%(prog)s " + version, + help="Show program version number and exit.", + ) + if add_help: + group.add_argument( + "-h", + "--help", + action="help", + default=argparse.SUPPRESS, + help="Show this help message and exit.", + ) + + +class ParserCallable(Protocol): + """Type annotation for argument parser constructor.""" + + def __call__(self, **kwargs) -> ArgumentParser: + ... + + +class MainCallable(Protocol): + """Type annotation for main function.""" + + def __call__(self, args: Optional[List[Any]] = None) -> int: + ... + + +def entry_point( + package: str, + path: Iterable[str], + description: Optional[str] = None, + subcommands: Optional[Union[Dict[str, str], Sequence[str]]] = None, +) -> Callable[[Optional[List[Any]]], int]: + """Create entry point for running named subcommand with given arguments. + + Args: + package: Name of package containing CLI modules. + path: Package paths containing CLI modules. + description: Description of entry point command. + subcommands: Description of subcommands. Dictionary keys are subcommand module + names excluding the ``package`` name of the main CLI module. The included + Python modules must define functions ``parser()`` and ``func()`` within + their global scope. If ``None``, all Python subpackages and modules within + the ``package`` are added automatically. + + Returns: + main: Main entry point of package CLI which takes a list + of command arguments as ``argv`` argument and returns an exit code. + + """ + if subcommands and not isinstance(subcommands, dict): + subcommands = {name: "" for name in sorted(subcommands)} + + def main(argv: Optional[List[Any]] = None) -> int: + """Run named subcommand with given arguments. + + Args: + argv (list): Command-line arguments. If None, sys.argv[1:] is used. + + Returns: + exit_code (int): Exit code of subcommand. + + """ + if argv is None: + argv = sys.argv[1:] + argv = [str(arg) for arg in argv] + if not argv: + argv = ["-h"] + elif argv[0] == "help": + argv = argv[1:] + ["-h"] + dist = _pypi_package_name(package) + if "__main__" in sys.argv[0]: + prog = f"{os.path.basename(sys.executable)} -m {package}" + else: + prog = os.path.basename(sys.argv[0]) + try: + version = get_distribution(dist).version + except DistributionNotFound: + version = None + log = logging.getLogger(dist) + log.addHandler(logging.StreamHandler()) + mainparser = ArgumentParser( + prog=prog, description=description, add_help=True, version=version + ) + subparsers = mainparser.add_subparsers() + default_options_parser = ArgumentParser(add_help=False) + group = default_options_parser.add_argument_group( + STANDARD_ARGUMENTS_GROUP_TITLE + ) + group.add_argument( + "-h", + "--help", + action="help", + default=argparse.SUPPRESS, + help="Show this help message and exit.", + ) + if version: + group.add_argument( + "--version", + action="version", + version="%(prog)s " + version, + help="Show program version number and exit.", + ) + common_options_parser = ArgumentParser(add_help=False) + group = common_options_parser.add_argument_group(STANDARD_ARGUMENTS_GROUP_TITLE) + group.add_argument("--log-level", default="INFO", help="Set logging level.") + CommandInfo = namedtuple("CommandInfo", ["module", "commands"]) + + def find_commands(name: str, path: List[str]) -> Dict[str, CommandInfo]: + commands = {} + module_infos = [ + (basename, ispkg) + for _, basename, ispkg in walk_packages(path) + if not basename.startswith("_") + ] + for basename, ispkg in module_infos: + module_name = ".".join([name, basename]) + command_name = basename.replace("_", "-") + commands[command_name] = CommandInfo( + module=module_name, + commands=find_commands( + name=module_name, + path=[os.path.join(prefix, basename) for prefix in path], + ) + if ispkg + else None, + ) + return commands + + if subcommands: + commands = {} + for module_name in subcommands.keys(): + name_parts = module_name.split(".") + command_name = name_parts[-1].replace("_", "-") + parent_module = package + parent_commands = commands + for parent_name in name_parts[:-1]: + parent_module = ".".join([parent_module, parent_name]) + parent_command = parent_name.replace("_", "-") + if parent_command not in parent_commands: + parent_commands[parent_command] = CommandInfo( + module=parent_module, commands={} + ) + parent_commands = parent_commands[parent_command].commands + parent_commands[command_name] = CommandInfo( + module=".".join([package, module_name]), commands=None + ) + else: + commands = find_commands(name=package, path=path) + + def get_description(module): + description = None + if subcommands and module.startswith(package + "."): + description = subcommands.get(module[len(package) + 1 :]) + if description is None: + module = import_module(module) + description = module.__doc__ + parser_fn = getattr(module, "parser", None) + if parser_fn is not None: + description = parser_fn(add_help=False).description + return description or "" + + def add_commands(subparsers, commands, argv): + if not argv: + return + if argv[0] in commands: + command_name = argv[0] + command_info = commands[command_name] + module = import_module(command_info.module) + if command_info.commands is None: + parser_fn = getattr(module, "parser") + parser = parser_fn(add_help=False) + func = getattr(module, "func") + func = _func_wrapper(func=func, init=getattr(module, "init", None)) + subparsers.add_parser( + command_name, + parents=[parser, default_options_parser, common_options_parser], + help=parser.description, + ).set_defaults(func=func) + else: + add_commands( + subparsers.add_parser( + command_name, + parents=[default_options_parser], + help=module.__doc__, + ).add_subparsers(), + command_info.commands, + argv[1:], + ) + else: + for command_name, command_info in commands.items(): + subparsers.add_parser( + command_name, + parents=[default_options_parser], + help=get_description(command_info.module), + ) + + add_commands(subparsers, commands, argv) + args, unknown = mainparser.parse_known_args(argv) + if hasattr(args, "func"): + log.setLevel(args.log_level) + log.info("") + try: + exit_code = args.func(args, unknown) + except KeyboardInterrupt: + log.debug("Execution interrupted by user") + exit_code = 1 + else: + mainparser.print_usage() + exit_code = 1 + return exit_code + + return main + + +def main_func( + parser: ParserCallable, + func: Callable[[ParsedArguments], int], + init: Callable[[ParsedArguments], int] = None, +) -> MainCallable: + """Create main function of subcommand. + + Args: + parser (callable): Builder function of ``ArgumentParser`` object. + func (callable): Function accepting ``ParsedArguments`` and returning exit code. + init (callable): Function accepting ``ParsedArguments`` and returning exit code. + This initialization function is called before ``func`` if provided. + + Returns: + main (callable): Main entry point of subcommand which takes a list + of command arguments as ``argv`` argument and returns an exit code. + + """ + func = _func_wrapper(func, init=init) + + def main(argv: Optional[List[Any]] = None) -> int: + """Call main function with parsed CLI arguments. + + Args: + argv (list): Command-line arguments. If None, sys.argv[1:] is used. + + Returns: + Exit code, zero on success. Note that not every CLI must return an exit code, + but throw an exception in case of an error instead. If the exception is not caught, + it will result in a non-zero exit code of the Python interpreter command. + Use the exit code return value for different success (or error) codes if needed. + + """ + argv = None if argv is None else [str(arg) for arg in argv] + try: + p = parser(add_help=True) + args = p.parse_args(argv) + exit_code = func(args) + except KeyboardInterrupt: + sys.stderr.write("Interrupted by user\n") + exit_code = 1 + return exit_code + + return main + + +def _func_wrapper( + func: Callable[[ParsedArguments], int], + init: Callable[[ParsedArguments], int] = None, +) -> Callable[[ParsedArguments], int]: + """Wrap command 'init' and 'func' callables.""" + nparams = len(inspect.signature(func).parameters) + assert nparams == 1 or nparams == 2, "func() must have one or two parameters" + + def call_init_then_func( + args: ParsedArguments, unknown: UnknownArguments = [] + ) -> int: + exit_code = 0 + if nparams == 1 and unknown: + raise RuntimeError(f"Unparsed arguments: {unknown}") + if init is not None: + exit_code = init(args) + if exit_code == 0: + if nparams == 2: + exit_code = func(args, unknown) + else: + exit_code = func(args) + return exit_code + + return call_init_then_func + + +def _pypi_package_name(module: str) -> str: + """Get Python package name given CLI module __name__.""" + parts = [] + for part in module.split("."): + if part in ("app", "apps", "cli"): + break + parts.append(part) + return ".".join(parts) diff --git a/jointContribution/HighResolution/deepali/utils/cli/environ.py b/jointContribution/HighResolution/deepali/utils/cli/environ.py new file mode 100644 index 0000000000..3aa1f948d0 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/cli/environ.py @@ -0,0 +1,45 @@ +import os +import re +from typing import Optional +from typing import Tuple + + +def cuda_visible_devices() -> Tuple[int, ...]: + """Get IDs of GPUs specified by CUDA_VISIBLE_DEVICES environment variable.""" + gpus = os.environ.get("CUDA_VISIBLE_DEVICES", "") + if not gpus: + return () + gpus = [x for x in gpus.split(",") if x] + gpu_ids = [] + RE_RANGE = re.compile("^(?P[0-9]+)-(?P[0-9]+)$") + for gpu in gpus: + match = RE_RANGE.match(gpu) + if match is None: + try: + gpu_ids.append(int(gpu)) + except TypeError: + raise TypeError(f"CUDA_VISIBLE_DEVICES contains invalid value {gpu}") + else: + gpu_ids.extend( + range(int(match.group("start")), int(match.group("end")) + 1) + ) + for gpu_id in gpu_ids: + if gpu_id < 0: + raise ValueError("CUDA_VISIBLE_DEVICES contains negative GPU ID") + return gpu_ids + + +def check_cuda_visible_devices(num: Optional[int] = None) -> int: + """Check if CUDA_VISIBLE_DEVICES environment variable is set.""" + gpu_ids = cuda_visible_devices() + if num and len(gpu_ids) != num: + raise RuntimeError(f"CUDA_VISIBLE_DEVICES must be set to {num} GPUs") + return len(gpu_ids) + + +def init_omp_num_threads(threads: Optional[int] = None, default: int = 1) -> int: + if threads is None or threads < 0: + threads = os.environ.get("OMP_NUM_THREADS", default) + threads = max(1, int(threads)) + os.environ["OMP_NUM_THREADS"] = str(threads) + return threads diff --git a/jointContribution/HighResolution/deepali/utils/cli/logging.py b/jointContribution/HighResolution/deepali/utils/cli/logging.py new file mode 100644 index 0000000000..ccb37eb246 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/cli/logging.py @@ -0,0 +1,28 @@ +"""Auxiliary functions to set up logging in main scripts.""" +import logging +from enum import Enum +from typing import Optional + +LOG_FORMAT = "%(asctime)-15s [%(levelname)s] %(message)s" + + +class LogLevel(str, Enum): + """Enumeration of logging levels for use as type annotation when using Typer.""" + + DEBUG = "DEBUG" + INFO = "INFO" + WARNING = "WARNING" + ERROR = "ERROR" + + def __int__(self) -> int: + """Cast enumeration to int logging level.""" + return int(getattr(logging, self.value)) + + +def configure_logging(log, args, format: Optional[str] = None): + """Initialize logging.""" + logging.basicConfig(format=format or LOG_FORMAT) + if hasattr(args, "log_level"): + log.setLevel(args.log_level) + else: + log.setLevel(logging.INFO) diff --git a/jointContribution/HighResolution/deepali/utils/cli/warnings.py b/jointContribution/HighResolution/deepali/utils/cli/warnings.py new file mode 100644 index 0000000000..aeea9f00d2 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/cli/warnings.py @@ -0,0 +1,10 @@ +import warnings + + +def filter_warning_of_experimental_named_tensors_feature() -> None: + """Filter out warning reminding that named tensors are still an experimental feature.""" + warnings.filterwarnings( + "ignore", + message="Named tensors and all their associated APIs are an experimental feature and subject to change.", + category=UserWarning, + ) diff --git a/jointContribution/HighResolution/deepali/utils/ignite/__init__.py b/jointContribution/HighResolution/deepali/utils/ignite/__init__.py new file mode 100644 index 0000000000..dc51750efa --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/ignite/__init__.py @@ -0,0 +1 @@ +"""Utilities.""" diff --git a/jointContribution/HighResolution/deepali/utils/ignite/handlers.py b/jointContribution/HighResolution/deepali/utils/ignite/handlers.py new file mode 100644 index 0000000000..0995318934 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/ignite/handlers.py @@ -0,0 +1,347 @@ +from logging import Logger +from typing import Any +from typing import Callable +from typing import Dict +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Union + +import paddle +from ignite.engine import Engine +from ignite.engine import Events +from ignite.engine import State +from paddle.io import DataLoader + +from ...core import RE_OUTPUT_KEY_INDEX +from ..tensorboard import add_summary_images +from ..tensorboard import escape_channel_index_format_string + + +def clamp_learning_rate( + engine: Engine, + optimizer: paddle.optim.optimizer.Optimizer, + min_learning_rate: Optional[float] = None, + max_learning_rate: Optional[float] = None, +): + if min_learning_rate is None: + min_learning_rate = 1e-12 + if max_learning_rate is None: + max_learning_rate = float("inf") + for param_group in optimizer.param_groups: + param_group["lr"] = max( + min_learning_rate, min(param_group["lr"], max_learning_rate) + ) + + +def set_distributed_sampler_epoch(engine: Engine, epoch: Optional[int] = None) -> None: + data = engine.state.dataloader + if epoch is None: + epoch = engine.state.epoch - 1 + if isinstance(data, paddle.io.DataLoader): + set_sampler_epoch = getattr(data.sampler, "set_epoch", None) + if callable(set_sampler_epoch): + set_sampler_epoch(epoch) + set_sampler_epoch = getattr(data.batch_sampler, "set_epoch", None) + if callable(set_sampler_epoch): + set_sampler_epoch(epoch) + if isinstance(data.batch_sampler, paddle.io.BatchSampler): + set_sampler_epoch = getattr(data.batch_sampler.sampler, "set_epoch", None) + if callable(set_sampler_epoch): + set_sampler_epoch(epoch) + data = data.dataset + set_epoch = getattr(data, "set_epoch", None) + if callable(set_epoch): + set_epoch(epoch) + + +def print_metrics( + engine: Engine, + prefix: Optional[str] = None, + names: Optional[Union[Mapping[str, str], Sequence[str], Set[str]]] = None, + logger: Optional[Logger] = None, + global_step_transform: Optional[Callable] = None, +) -> None: + """Log evaluated performance metrics stored in ``engine.state.metrics``.""" + metrics = get_scalar_metrics(engine.state) + if not metrics: + return + if names is not None: + metrics = {name: metrics[name] for name in names if name in metrics} + if isinstance(names, Mapping): + metrics = {names[k]: v for k, v in metrics.items()} + global_step = get_global_step(engine, global_step_transform) + if prefix is None: + prefix = "" + if global_step_transform is None: + prefix = "epoch={e:03d}, " + prefix += "iter={i:06d}, " + msg = prefix.format(e=engine.state.epoch, i=global_step) + msg += ", ".join( + f"{key}=" + (f"{value:.5f}" if isinstance(value, float) else f"{value}") + for key, value in metrics.items() + ) + print(msg) if logger is None else logger.info(msg) + + +def reset_iterable_dataset(engine: Engine) -> None: + """Reset iterator over engine.state.dataloader. + + This handler also calls ``set_distributed_sampler_epoch(engine)`` such that the + epoch number is set prior to calling ``engine.set_data(engine.state.dataloader)``, + which invokes ``iter(engine.state.dataloader)``. This is to ensure that if the + latter shuffles the data using the epoch number as seed, that the dataset is + then shuffled with the new epoch number. + + """ + set_distributed_sampler_epoch(engine, engine.state.epoch) + engine.set_data(engine.state.dataloader) + + +def reset_model_grads(engine: Engine, model: paddle.nn.Layer) -> None: + """Set ``grad`` attributes of model parameters to ``None``.""" + for p in model.parameters(): + p.grad = None + + +def reset_optimizer_grads( + engine: Engine, optimizer: paddle.optim.optimizer.Optimizer +) -> None: + """Set ``grad`` attributes of optimizer parameters to ``None``.""" + for group in optimizer.param_groups: + for p in group["params"]: + p.grad = None + + +def set_engine_state(engine: Engine, **kwargs) -> None: + """Add dictionary values as attributes to ``engine.state`` object.""" + for name, value in kwargs.items(): + setattr(engine.state, name, value) + + +def terminate_on_max_iteration(engine: Engine, max_iterations: int): + """Terminate training when maximum number of global iterations reached.""" + if max_iterations > 0 and engine.state.iteration >= max_iterations: + engine.terminate() + + +def write_summary_hists( + engine: Engine, + writer: paddle.utils.tensorboard.SummaryWriter, + model: paddle.nn.Layer, + prefix: Optional[str] = None, + weights: bool = True, + grads: bool = True, + global_step_transform: Optional[Callable] = None, +) -> None: + """Add histograms of model weights to TensorBoard summary.""" + global_step = get_global_step(engine, global_step_transform) + prefix = (prefix or "").format(e=engine.state.epoch, i=engine.state.iteration) + for name, p in model.named_parameters(): + if p.grad is None: + continue + name = name.replace(".", "/") + if weights: + writer.add_histogram( + tag=f"{prefix}weights/{name}", + values=p.data.detach().cpu().numpy(), + global_step=global_step, + ) + if grads: + writer.add_histogram( + tag=f"{prefix}grads/{name}", + values=p.grad.detach().cpu().numpy(), + global_step=global_step, + ) + + +def write_summary_images( + engine: Engine, + writer: paddle.utils.tensorboard.SummaryWriter, + names: Optional[Union[Mapping[str, str], Sequence[str], Set[str]]] = None, + prefix: Optional[str] = None, + image_transform: Optional[Callable[[str, paddle.Tensor], paddle.Tensor]] = True, + rescale_transform: Union[ + bool, Callable[[str, paddle.Tensor], paddle.Tensor] + ] = False, + global_step_transform: Optional[Callable[[Engine, Events], int]] = None, + channel_offset: Union[bool, Callable[[Engine], int]] = False, +) -> None: + """Add images stored in ``engine.state.batch`` and ``engine.state.output`` to TensorBoard summary. + + Args: + engine: Ignite engine. + writer: TensorBoard writer with open event file. + names: Possibly hierarchical keys of input batch or output image entries to write. + When a dictionary is given, replace image name by the map value. + prefix: Prefix string for TensorBoard tags. Can be format string with + placeholders ``{e}`` for epoch and ``{i}`` global iteration. + image_transform: Callable used to extract a 2D image tensor of shape + ``(C, Y, X)`` from each image. When a multi-channel tensor is returnd (C > 1), + each channel is saved as separate image to the TensorBoard event file. + By default, the central slice of the first image in the batch is extracted. + The first argument is the name of the image tensor if applicable, or an empty + string otherwise. This can be used to differently process different images. + rescale_transform: Image intensities must be in the closed interval ``[0, 1]``. + Set to ``False``, if this is already the case or when a custom + ``image_transform`` is used which normalizes the image intensities. + global_step_transform: Callable used to obtain global iteration. + Called with arguments ``engine`` and ``Events.ITERATION_COMPLETED``. + If not specified, use ``engine.state.iteration``. + channel_offset: Offset to use for channel index in image tag format string. + If ``True``, use batch index times batch size of data loader plus one. If ``False``, + use zero. Otherwise, a callable must be given which receives the engine on which + the event is triggered as argument. + + """ + + def filter_images( + arg: Union[Mapping[str, Any], Sequence[Any]], prefix: str = "" + ) -> Dict[str, paddle.Tensor]: + images = {} + if not isinstance(arg, dict): + arg = {str(i): value for i, value in enumerate(arg)} + for name, value in arg.items(): + if name == "loss": + continue + if isinstance(value, (dict, tuple, list)): + images.update(filter_images(value, prefix + name + ".")) + elif not isinstance(value, paddle.Tensor) or value.ndim < 4: + continue + else: + name = prefix + name + if isinstance(names, dict): + if name not in names: + continue + name = names[name] + elif names is not None and name not in names: + continue + images[name] = value + return images + + def format_tag(arg: str) -> str: + arg = escape_channel_index_format_string(arg) + arg = arg.format(e=engine.state.epoch, i=engine.state.iteration, b=batch_index) + return arg + + batch_index = (engine.state.iteration - 1) % engine.state.epoch_length + 1 + global_step = get_global_step(engine, global_step_transform) + channel_offset_value = 0 + if channel_offset is True: + dataloader: DataLoader = engine.state.dataloader + if not dataloader.batch_size: + raise RuntimeError( + "write_summary_images() 'channel_offset' is True, but 'engine.state.dataloader.batch_size' is None" + ) + channel_offset_value = (batch_index - 1) * dataloader.batch_size + 1 + elif callable(channel_offset): + channel_offset_value = channel_offset(engine) + elif channel_offset is not False: + raise TypeError( + "write_summary_images() 'channel_offset' must be bool or callable" + ) + prefix = format_tag(prefix or "") + if isinstance(names, dict): + names = { + RE_OUTPUT_KEY_INDEX.sub(".\\1", name): tag for name, tag in names.items() + } + names = {name: format_tag(tag) for name, tag in names.items()} + elif names is not None: + names = {RE_OUTPUT_KEY_INDEX.sub(".\\1", name) for name in names} + names = {format_tag(name) for name in names} + batch = engine.state.batch + if isinstance(batch, paddle.Tensor): + batch = {"batch": batch} + batch = filter_images(batch) + add_summary_images( + writer, + prefix, + batch, + global_step=global_step, + image_transform=image_transform, + rescale_transform=rescale_transform, + channel_offset=channel_offset_value, + ) + output = engine.state.output + if isinstance(output, paddle.Tensor): + output = {"output": output} + output = filter_images(output) + add_summary_images( + writer, + prefix, + output, + global_step=global_step, + image_transform=image_transform, + rescale_transform=rescale_transform, + channel_offset=channel_offset_value, + ) + writer.flush() + + +def write_summary_metrics( + engine: Engine, + writer: paddle.utils.tensorboard.SummaryWriter, + prefix: Optional[str] = None, + global_step_transform: Optional[Callable] = None, +): + """Add computed values stored in ``engine.state.metrics`` to Tensorboard summary.""" + global_step = get_global_step(engine, global_step_transform) + prefix = (prefix or "").format(e=engine.state.epoch, i=engine.state.iteration) + metrics = get_scalar_metrics(engine.state) + for name, value in metrics.items(): + writer.add_scalar(prefix + name, value, global_step=global_step) + writer.flush() + + +def write_summary_optim_params( + engine: Engine, + writer: paddle.utils.tensorboard.SummaryWriter, + optimizer: paddle.optim.optimizer.Optimizer, + params: Optional[Union[str, Sequence[str]]] = None, + prefix: Optional[str] = None, + global_step_transform: Optional[Callable] = None, +) -> None: + """Add optimization parameters to TensorBoard summary.""" + global_step = get_global_step(engine, global_step_transform) + prefix = (prefix or "").format(e=engine.state.epoch, i=engine.state.iteration) + if isinstance(params, str): + params = [params] + for group_id, param_group in enumerate(optimizer.param_groups): + for param_name in params or param_group.keys(): + try: + param_value = float(param_group[param_name]) + except (KeyError, TypeError): + continue + tag = prefix + param_name + if len(optimizer.param_groups) > 1: + tag += f"/group_{group_id}" + writer.add_scalar(tag, param_value, global_step) + + +def get_global_step( + engine: Engine, global_step_transform: Optional[Callable] = None +) -> int: + """Get global step for summary event.""" + if global_step_transform is None: + return engine.state.iteration + global_step = global_step_transform(engine, Events.ITERATION_COMPLETED) + if not isinstance(global_step, int): + raise TypeError( + f"'global_step_transform' must return an int, got {type(global_step)}" + ) + return global_step + + +def get_scalar_metrics(state: State) -> Dict[str, Union[float, int]]: + """Get scalar metrics stored in ``state.metrics``.""" + metrics = {} + for key, value in state.metrics.items(): + if isinstance(value, paddle.Tensor): + if value.squeeze().dim() != 0: + continue + value = value.item() + value = int(value) if isinstance(value, int) else float(value) + if isinstance(value, (float, int)): + metrics[key] = value + return metrics diff --git a/jointContribution/HighResolution/deepali/utils/ignite/metrics/__init__.py b/jointContribution/HighResolution/deepali/utils/ignite/metrics/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/jointContribution/HighResolution/deepali/utils/ignite/metrics/average_loss.py b/jointContribution/HighResolution/deepali/utils/ignite/metrics/average_loss.py new file mode 100644 index 0000000000..9e38ef1717 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/ignite/metrics/average_loss.py @@ -0,0 +1,56 @@ +from typing import Mapping +from typing import Union + +import paddle +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics.metric import Metric +from ignite.metrics.metric import reinit__is_reduced +from ignite.metrics.metric import sync_all_reduce + + +class AverageLoss(Metric): + """Calculates the average loss given batched loss values stored in ``engine.state.output``. + + This metric can be used instead of ``ignite.metrics.loss.Loss`` when the average loss values + for each example in the input batch are stored as 1-dimensional tensor in the ``engine.state.output``. + The (transformed) return value of the engines ``process_function`` must be either a tensor of the + computed loss values, or a dictionary containing the key ``"loss"``. + + """ + + @reinit__is_reduced + def reset(self) -> None: + self.accumulator = 0 + self.num_examples = 0 + + @reinit__is_reduced + def update(self, output: Union[paddle.Tensor, Mapping]) -> None: + loss = output.get("loss") if isinstance(output, Mapping) else output + if not isinstance(loss, paddle.Tensor): + raise TypeError( + "AverageLoss.update() 'output' loss value must be paddle.Tensor" + ) + if loss.ndim == 0: + self.accumulator += loss.item() + self.num_examples += 1 + elif loss.dim() == 1: + self.accumulator += loss.sum().item() + self.num_examples += tuple(loss.shape)[0] + else: + raise ValueError( + "AverageLoss.update() 'output' loss value must be scalar or 1-dimensional tensor" + ) + + @sync_all_reduce("accumulator", "num_examples") + def compute(self) -> float: + if self.num_examples == 0: + raise NotComputableError( + "Loss must have at least one example before it can be computed." + ) + return self.accumulator / self.num_examples + + @paddle.no_grad() + def iteration_completed(self, engine: Engine) -> None: + output = self._output_transform(engine.state.output) + self.update(output) diff --git a/jointContribution/HighResolution/deepali/utils/ignite/metrics/binary_classification.py b/jointContribution/HighResolution/deepali/utils/ignite/metrics/binary_classification.py new file mode 100644 index 0000000000..c7f14a7f62 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/ignite/metrics/binary_classification.py @@ -0,0 +1,123 @@ +import paddle +from ignite.metrics.confusion_matrix import ConfusionMatrix +from ignite.metrics.metrics_lambda import MetricsLambda + +EPSILON = 1e-15 + + +def DiceCoefficient(cm: ConfusionMatrix) -> MetricsLambda: + """Calculates Dice similarity coefficient from ``ignite.metrics.ConfusionMatrix``. + + Args: + cm: Instance of confusion matrix metric, where + ``cm[0, 0]=TN``, ``cm[0, 1]=FP``, + ``cm[1, 0]=FN``, ``cm[1, 1]=TP``. + + Returns: + MetricsLambda instance computing the Dice coefficient. + + """ + cm = cm.astype(paddle.float64) + return (2 * cm[1, 1] + EPSILON) / (2 * cm[1, 1] + cm[0, 1] + cm[1, 0] + EPSILON) + + +def FalseNegativeRate(cm: ConfusionMatrix) -> MetricsLambda: + """Calculates miss rate from ``ignite.metrics.ConfusionMatrix``. + + Args: + cm: Instance of confusion matrix metric, where + ``cm[0, 0]=TN``, ``cm[0, 1]=FP``, + ``cm[1, 0]=FN``, ``cm[1, 1]=TP``. + + Returns: + MetricsLambda instance computing the true negative rate. + + """ + cm = cm.astype(paddle.float64) + return cm[1, 0] / (cm[1, 0] + cm[1, 1] + EPSILON) + + +def TrueNegativeRate(cm: ConfusionMatrix) -> MetricsLambda: + """Calculates specificity from ``ignite.metrics.ConfusionMatrix``. + + Args: + cm: Instance of confusion matrix metric, where + ``cm[0, 0]=TN``, ``cm[0, 1]=FP``, + ``cm[1, 0]=FN``, ``cm[1, 1]=TP``. + + Returns: + MetricsLambda instance computing the true negative rate. + + """ + cm = cm.astype(paddle.float64) + return cm[0, 0] / (cm[0, 0] + cm[0, 1] + EPSILON) + + +def FalsePositiveRate(cm: ConfusionMatrix) -> MetricsLambda: + """Calculates fall-out from ``ignite.metrics.ConfusionMatrix``. + + Args: + cm: Instance of confusion matrix metric, where + ``cm[0, 0]=TN``, ``cm[0, 1]=FP``, + ``cm[1, 0]=FN``, ``cm[1, 1]=TP``. + + Returns: + MetricsLambda instance computing the true positive rate. + + """ + cm = cm.astype(paddle.float64) + return cm[0, 1] / (cm[0, 1] + cm[0, 0] + EPSILON) + + +def TruePositiveRate(cm: ConfusionMatrix) -> MetricsLambda: + """Calculates sensitivity from ``ignite.metrics.ConfusionMatrix``. + + Args: + cm: Instance of confusion matrix metric, where + ``cm[0, 0]=TN``, ``cm[0, 1]=FP``, + ``cm[1, 0]=FN``, ``cm[1, 1]=TP``. + + Returns: + MetricsLambda instance computing the true positive rate. + + """ + cm = cm.astype(paddle.float64) + return cm[1, 1] / (cm[1, 0] + cm[1, 1] + EPSILON) + + +def PositivePredictiveValue(cm: ConfusionMatrix) -> MetricsLambda: + """Calculates precision from ``ignite.metrics.ConfusionMatrix``. + + Args: + cm: Instance of confusion matrix metric, where + ``cm[0, 0]=TN``, ``cm[0, 1]=FP``, + ``cm[1, 0]=FN``, ``cm[1, 1]=TP``. + + Returns: + MetricsLambda instance computing the positive predictive value. + + """ + cm = cm.astype(paddle.float64) + return cm[1, 1] / (cm[0, 1] + cm[1, 1] + EPSILON) + + +def NegativePredictiveValue(cm: ConfusionMatrix) -> MetricsLambda: + """Calculates negative predictive value from ``ignite.metrics.ConfusionMatrix``. + + Args: + cm: Instance of confusion matrix metric, where + ``cm[0, 0]=TN``, ``cm[0, 1]=FP``, + ``cm[1, 0]=FN``, ``cm[1, 1]=TP``. + + Returns: + MetricsLambda instance computing the negative predictive value. + + """ + cm = cm.astype(paddle.float64) + return cm[0, 0] / (cm[0, 0] + cm[1, 0] + EPSILON) + + +Precision = PositivePredictiveValue +Recall = TruePositiveRate +Sensitivity = TruePositiveRate +Specificity = TrueNegativeRate diff --git a/jointContribution/HighResolution/deepali/utils/ignite/metrics/multilabel_classification.py b/jointContribution/HighResolution/deepali/utils/ignite/metrics/multilabel_classification.py new file mode 100644 index 0000000000..1384c3780c --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/ignite/metrics/multilabel_classification.py @@ -0,0 +1,97 @@ +from typing import Callable +from typing import Optional +from typing import Sequence +from typing import Union + +import paddle +from ignite.engine import Engine +from ignite.exceptions import NotComputableError +from ignite.metrics.metric import Metric +from ignite.metrics.metric import reinit__is_reduced +from ignite.metrics.metric import sync_all_reduce + + +class MultiLabelScore(Metric): + """Compute a score for each class label.""" + + def __init__( + self, + score_fn: Callable, + num_classes: int, + output_transform: Callable = lambda x: x, + device: Optional[Union[str, str]] = None, + ): + self.score_fn = score_fn + self.num_classes = num_classes + self.accumulator = None + self.num_examples = 0 + super().__init__(output_transform=output_transform, device=device) + + @reinit__is_reduced + def reset(self) -> None: + self.accumulator = paddle.zeros(shape=self.num_classes, dtype="float32") + self.num_examples = 0 + + @reinit__is_reduced + def update(self, output: Sequence[paddle.Tensor]) -> None: + y_pred, y = output + if y_pred.ndim < 2: + raise ValueError( + f"MultiLabelScore.update() y_pred must have shape (N, C, ...), but given {tuple(y_pred.shape)}" + ) + if tuple(y_pred.shape)[1] not in (1, self.num_classes): + raise ValueError( + f"MultiLabelScore.update() expected y_pred to have 1 or {self.num_channels} channels" + ) + if y.ndim + 1 == y_pred.ndim: + y = y.unsqueeze(axis=1) + elif y.ndim != y_pred.ndim: + raise ValueError( + f"MultiLabelScore.update() y_pred must have shape (N, C, ...) and y must have shape (N, ...) or (N, 1, ...), but given {tuple(y.shape)} vs {tuple(y_pred.shape)}" + ) + if tuple(y.shape) != (tuple(y_pred.shape)[0], 1) + tuple(y_pred.shape)[2:]: + raise ValueError("y and y_pred must have compatible shapes.") + scores = multilabel_score( + self.score_fn, y_pred, y, num_classes=self.num_classes + ) + self.accumulator += scores + self.num_examples += tuple(y_pred.shape)[0] + + @sync_all_reduce("accumulator", "num_examples") + def compute(self) -> float: + if self.num_examples == 0: + raise NotComputableError( + "Loss must have at least one example before it can be computed." + ) + return self.accumulator / self.num_examples + + @paddle.no_grad() + def iteration_completed(self, engine: Engine) -> None: + output = self._output_transform(engine.state.output) + self.update(output) + + +def multilabel_score( + score_fn, + preds: paddle.Tensor, + labels: paddle.Tensor, + num_classes: Optional[int] = None, +) -> paddle.Tensor: + """Evaluate score for each class label.""" + assert tuple(labels.shape)[1] == 1 + if num_classes is None: + num_classes = tuple(preds.shape)[1] + if num_classes == 1: + raise ValueError( + "multilabel_score() 'num_classes' required when 'preds' is not one-hot encoded" + ) + if tuple(preds.shape)[1] == num_classes: + preds = preds.argmax(axis=1, keepdim=True) + elif tuple(preds.shape)[1] != 1: + raise ValueError("multilabel_score() 'preds' must have shape (N, C|1, ..., X)") + result = paddle.zeros(shape=num_classes, dtype="float32") + for label in range(num_classes): + y_pred = preds.equal(y=label).astype(dtype="float32") + y = labels.equal(y=label).astype(dtype="float32") + result[label] = score_fn(y_pred, y).mean() + return result diff --git a/jointContribution/HighResolution/deepali/utils/ignite/output_transforms.py b/jointContribution/HighResolution/deepali/utils/ignite/output_transforms.py new file mode 100644 index 0000000000..8104c2b1b9 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/ignite/output_transforms.py @@ -0,0 +1,209 @@ +from typing import Callable +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +import paddle +from ignite.engine import Engine + +from ...core import ALIGN_CORNERS +from ...core import TensorCollection +from ...core import get_tensor +from ...core.image import grid_reshape +from ...core.tensor import as_one_hot_tensor + + +def get_output_transform(key: str) -> Callable[[TensorCollection], paddle.Tensor]: + """Get tensor at specified nested engine output map key entry.""" + + def output_transform(output: TensorCollection) -> paddle.Tensor: + return get_tensor(output, key) + + return output_transform + + +def cm_binary_output_transform( + y_pred: str = "y_pred", + y: str = "y", + reshape: str = "none", + align_corners: bool = ALIGN_CORNERS, +) -> Callable[[TensorCollection], Tuple[paddle.Tensor, paddle.Tensor]]: + """Convert engine output to a vector-valued one-hot prediction. + + Engine output transformation for use with ``ignite.metrics.ConfusionMatrix()``. + + Args: + output: Ignite engine state output. + + Returns: + y_pred: paddle.Tensor of one-hot encoded predictions with shape ``(N, 2, ..., X)``. + y: paddle.Tensor of target labels with shape ``(N, ..., X)``. + + """ + + def output_transform( + output: TensorCollection, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + dtype = "int32" + y_pred_tensor = get_tensor(output, y_pred) + y_tensor = get_tensor(output, y) + assert y_tensor.ndim == y_pred_tensor.ndim + assert tuple(y_tensor.shape)[1] == 1 + if tuple(y_tensor.shape)[0] == 1 and tuple(y_pred_tensor.shape)[0] > 1: + y_tensor = y_tensor.expand( + shape=tuple(y_pred_tensor.shape)[0:1] + tuple(y_tensor.shape)[1:] + ) + assert tuple(y_pred_tensor.shape)[0] == tuple(y_tensor.shape)[0] + assert tuple(y_pred_tensor.shape)[1] == 1 + grid_shape = tuple(y_tensor.shape)[2:] + if tuple(y_pred_tensor.shape)[2:] != grid_shape and reshape == "y_pred": + y_pred_tensor = grid_reshape( + y_pred_tensor, grid_shape, align_corners=align_corners + ) + y_pred_tensor = y_pred_tensor.round().astype("int64") + y_pred_tensor = as_one_hot_tensor(y_pred_tensor, num_classes=2, dtype=dtype) + if tuple(y_tensor.shape)[2:] != grid_shape and reshape == "y": + y_tensor = grid_reshape(y_tensor, grid_shape, align_corners=align_corners) + y_tensor = y_tensor.round().astype(dtype).squeeze(axis=1) + assert tuple(y_pred_tensor.shape)[2:] == tuple(y_tensor.shape)[1:] + return y_pred_tensor, y_tensor + + return output_transform + + +def cm_multilabel_output_transform( + y_pred: str = "y_pred", y: str = "y" +) -> Callable[[TensorCollection], Tuple[paddle.Tensor, paddle.Tensor]]: + """Convert engine output to a vector-valued one-hot prediction. + + Engine output transformation for use with ``ignite.metrics.ConfusionMatrix()``. + + Args: + output: Ignite engine state output. + + Returns: + y_pred: paddle.Tensor of multi-class probabilites with shape ``(N, C, ..., X)``. + y: paddle.Tensor of target labels with shape ``(N, ..., X)``. + + """ + + def output_transform( + output: TensorCollection, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + y_pred_tensor = get_tensor(output, y_pred) + y_tensor = get_tensor(output, y) + if tuple(y_tensor.shape)[1] == 1: + if y_tensor.is_floating_point(): + y_tensor = y_tensor.round().astype("int32") + y_tensor = y_tensor.squeeze(axis=1) + else: + y_tensor = y_tensor.argmax(axis=1) + if tuple(y_tensor.shape)[0] == 1 and tuple(y_pred_tensor.shape)[0] > 1: + y_tensor = y_tensor.expand( + shape=tuple(y_pred_tensor.shape)[0:1] + tuple(y_tensor.shape)[1:] + ) + return y_pred_tensor, y_tensor + + return output_transform + + +def cm_output_transform( + y_pred: str = "y_pred", y: str = "y", multilabel: bool = False +) -> Callable[[TensorCollection], Tuple[paddle.Tensor, paddle.Tensor]]: + """Get engine output transformation for binary or multi-class segmentation output, respectively.""" + if multilabel: + return cm_multilabel_output_transform(y_pred, y) + return cm_binary_output_transform(y_pred, y) + + +def negative_loss_score_function(engine: Engine) -> paddle.Tensor: + """Get negated loss value from ``engine.state.output``. + + This output transformation can be used as ``score_function`` argument of + ``ignite.handlers.ModelCheckpoint``, for example. + + """ + output = engine.state.output + if isinstance(output, dict): + output = output["loss"] + assert isinstance(output, paddle.Tensor) + return -output + + +def y_pred_y_output_transform( + y_pred: str = "y_pred", + y: str = "y", + channels: Optional[Union[int, Sequence[int]]] = None, +) -> Callable[[TensorCollection], Tuple[paddle.Tensor, paddle.Tensor]]: + """Get engine output transformation which returns (y_pred, y) tuples. + + Args: + y_pred: Output dictionary key for ``y_pred`` tensor. + y: Output dictionary key for ``y`` tensor. + channels: Indices of image channels to extract. + + """ + if isinstance(channels, int): + channels = (channels,) + if channels: + + def output_transform( + output: TensorCollection, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + a = get_tensor(output, y_pred) + b = get_tensor(output, y) + return a[:, (channels)], b[:, (channels)] + + else: + + def output_transform( + output: TensorCollection, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + a = get_tensor(output, y_pred) + b = get_tensor(output, y) + return a, b + + return output_transform + + +def y_pred_y_with_weight_output_transform( + y_pred: str = "y_pred", + y: str = "y", + weight: str = "weight", + kwarg: str = "weight", + channels: Optional[Union[int, Sequence[int]]] = None, +) -> Callable[[TensorCollection], Tuple[paddle.Tensor, paddle.Tensor]]: + """Get engine output transformation which returns (y_pred, y) tuples. + + Args: + y_pred: Output dictionary key for ``y_pred`` tensor. + y: Output dictionary key for ``y`` tensor. + weight: Output dictionary key for ``weight`` tensor. + kwarg: Name of weight keyword argument. + channels: Indices of image channels to extract. + + """ + if isinstance(channels, int): + channels = (channels,) + if channels: + + def output_transform( + output: TensorCollection, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + a = get_tensor(output, y_pred) + b = get_tensor(output, y) + w = get_tensor(output, weight) + return a[:, (channels)], b[:, (channels)], {kwarg: w} + + else: + + def output_transform( + output: TensorCollection, + ) -> Tuple[paddle.Tensor, paddle.Tensor]: + a = get_tensor(output, y_pred) + b = get_tensor(output, y) + w = get_tensor(output, weight) + return a, b, {kwarg: w} + + return output_transform diff --git a/jointContribution/HighResolution/deepali/utils/ignite/score_functions.py b/jointContribution/HighResolution/deepali/utils/ignite/score_functions.py new file mode 100644 index 0000000000..da28334d31 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/ignite/score_functions.py @@ -0,0 +1,21 @@ +from typing import Mapping + +import paddle +from ignite.engine import Engine + +from ...core import get_tensor + + +def negative_loss_score_function(engine: Engine, key: str = "loss") -> paddle.Tensor: + """Get negated loss value from ``engine.state.output``.""" + output = engine.state.output + if isinstance(output, Mapping): + loss = get_tensor(output, key) + elif isinstance(output, paddle.Tensor): + loss = output + else: + raise ValueError( + "negative_loss_score_function() engine output loss must be a paddle.Tensor" + ) + loss = loss.detach().sum() + return -float(loss) diff --git a/jointContribution/HighResolution/deepali/utils/paddle_aux.py b/jointContribution/HighResolution/deepali/utils/paddle_aux.py new file mode 100644 index 0000000000..463645233a --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/paddle_aux.py @@ -0,0 +1,321 @@ +# This file is generated by PaConvert ToolKit, please Don't edit it! +import paddle + + +def repeat(self, *args, **kwargs): + if args: + if len(args) == 1 and isinstance(args[0], (tuple, list)): + return paddle.tile(self, args[0]) + else: + return paddle.tile(self, list(args)) + elif kwargs: + assert "repeats" in kwargs + return paddle.tile(self, repeat_times=kwargs["repeats"]) + + +setattr(paddle.Tensor, "repeat", repeat) + + +def reshape(self, *args, **kwargs): + if args: + if len(args) == 1 and isinstance(args[0], (tuple, list)): + return paddle.reshape(self, args[0]) + else: + return paddle.reshape(self, list(args)) + elif kwargs: + assert "shape" in kwargs + return paddle.reshape(self, shape=kwargs["shape"]) + + +setattr(paddle.Tensor, "reshape", reshape) + + +def mul(self, *args, **kwargs): + if "other" in kwargs: + y = kwargs["other"] + elif "y" in kwargs: + y = kwargs["y"] + else: + y = args[0] + + if not isinstance(y, paddle.Tensor): + y = paddle.to_tensor(y, dtype=self.dtype) + + return paddle.multiply(self, y.astype(self.dtype)) + + +setattr(paddle.Tensor, "mul", mul) +setattr(paddle.Tensor, "multiply", mul) +setattr(paddle.Tensor, "multiply_", mul) + + +def div(self, *args, **kwargs): + dtype = self.dtype + if "other" in kwargs: + y = kwargs["other"] + elif "y" in kwargs: + y = kwargs["y"] + else: + y = args[0] + + if not isinstance(y, paddle.Tensor): + y = paddle.to_tensor(y) + + res = paddle.divide(self.astype("float32"), y.astype("float32")) + + if "rounding_mode" in kwargs: + rounding_mode = kwargs["rounding_mode"] + if rounding_mode == "trunc": + res = paddle.trunc(res) + elif rounding_mode == "floor": + res = paddle.floor(res) + + return res.astype(dtype) + + +setattr(paddle.Tensor, "div", div) +setattr(paddle.Tensor, "divide", div) +setattr(paddle.Tensor, "divide_", div) + + +def sub(self, *args, **kwargs): + if "other" in kwargs: + y = kwargs["other"] + elif "y" in kwargs: + y = kwargs["y"] + else: + y = args[0] + + if "alpha" in kwargs: + alpha = kwargs["alpha"] + if alpha != 1: + if not isinstance(y, paddle.Tensor): + y = paddle.to_tensor(alpha * y) + else: + y = alpha * y + else: + if not isinstance(y, paddle.Tensor): + y = paddle.to_tensor(y) + + return paddle.subtract(self, y.astype(self.dtype)) + + +setattr(paddle.Tensor, "sub", sub) +setattr(paddle.Tensor, "subtract", sub) +setattr(paddle.Tensor, "subtract_", sub) + + +def add(self, *args, **kwargs): + if "other" in kwargs: + y = kwargs["other"] + elif "y" in kwargs: + y = kwargs["y"] + else: + y = args[0] + + if "alpha" in kwargs: + alpha = kwargs["alpha"] + if alpha != 1: + if not isinstance(y, paddle.Tensor): + y = paddle.to_tensor(alpha * y) + else: + y = alpha * y + else: + if not isinstance(y, paddle.Tensor): + y = paddle.to_tensor(y) + + return paddle.add(self, y.astype(self.dtype)) + + +setattr(paddle.Tensor, "add", add) +setattr(paddle.Tensor, "add_", add) + + +def det(self): + try: + return self.det() + except Exception: + return paddle.linalg.det(self) + + +setattr(paddle.Tensor, "det", det) + + +def view(self, *args, **kwargs): + if args: + if len(args) == 1: + if isinstance(args[0], (tuple, list)): + return paddle.reshape(self, args[0]) # To change reshape => view + elif isinstance(args[0], str): + return paddle.view(self, args[0]) + else: + return paddle.reshape(self, list(args)) # To change reshape => view + else: + return paddle.reshape(self, list(args)) # To change reshape => view + elif kwargs: + key = [k for k in kwargs.keys()] + if "dtype" in kwargs: + return paddle.view(self, shape_or_dtype=kwargs[key[0]]) + else: + return paddle.reshape( + self, shape=kwargs[key[0]] + ) # To change reshape => view + + +setattr(paddle.Tensor, "view", view) + + +def _FUNCTIONAL_PAD(x, pad, mode="constant", value=0.0, data_format="NCHW"): + if len(x.shape) * 2 == len(pad) and mode == "constant": + pad = ( + paddle.to_tensor(pad, dtype="int32") + .reshape((-1, 2)) + .flip([0]) + .flatten() + .tolist() + ) + return paddle.nn.functional.pad(x, pad, mode, value, data_format) + + +def min_class_func(self, *args, **kwargs): + if "other" in kwargs: + kwargs["y"] = kwargs.pop("other") + ret = paddle.minimum(self, *args, **kwargs) + elif len(args) == 1 and isinstance(args[0], paddle.Tensor): + ret = paddle.minimum(self, *args, **kwargs) + else: + if "dim" in kwargs: + kwargs["axis"] = kwargs.pop("dim") + + if "axis" in kwargs or len(args) >= 1: + ret = paddle.min(self, *args, **kwargs), paddle.argmin( + self, *args, **kwargs + ) + else: + ret = paddle.min(self, *args, **kwargs) + + return ret + + +def max_class_func(self, *args, **kwargs): + if "other" in kwargs: + kwargs["y"] = kwargs.pop("other") + ret = paddle.maximum(self, *args, **kwargs) + elif len(args) == 1 and isinstance(args[0], paddle.Tensor): + ret = paddle.maximum(self, *args, **kwargs) + else: + if "dim" in kwargs: + kwargs["axis"] = kwargs.pop("dim") + + if "axis" in kwargs or len(args) >= 1: + ret = paddle.max(self, *args, **kwargs), paddle.argmax( + self, *args, **kwargs + ) + else: + ret = paddle.max(self, *args, **kwargs) + + return ret + + +setattr(paddle.Tensor, "min", min_class_func) +setattr(paddle.Tensor, "max", max_class_func) + + +def min(*args, **kwargs): + if "input" in kwargs: + kwargs["x"] = kwargs.pop("input") + + out_v = None + if "out" in kwargs: + out_v = kwargs.pop("out") + + if "other" in kwargs: + kwargs["y"] = kwargs.pop("other") + ret = paddle.minimum(*args, **kwargs) + elif len(args) == 2 and isinstance(args[1], paddle.Tensor): + ret = paddle.minimum(*args, **kwargs) + else: + if "dim" in kwargs: + kwargs["axis"] = kwargs.pop("dim") + + if "axis" in kwargs or len(args) >= 2: + if out_v: + ret = paddle.min(*args, **kwargs), paddle.argmin(*args, **kwargs) + paddle.assign(ret[0], out_v[0]) + paddle.assign(ret[1], out_v[1]) + return out_v + else: + ret = paddle.min(*args, **kwargs), paddle.argmin(*args, **kwargs) + return ret + else: + ret = paddle.min(*args, **kwargs) + return ret + + if out_v: + paddle.assign(ret, out_v) + return out_v + else: + return ret + + +def max(*args, **kwargs): + if "input" in kwargs: + kwargs["x"] = kwargs.pop("input") + + out_v = None + if "out" in kwargs: + out_v = kwargs.pop("out") + + if "other" in kwargs: + kwargs["y"] = kwargs.pop("other") + ret = paddle.maximum(*args, **kwargs) + elif len(args) == 2 and isinstance(args[1], paddle.Tensor): + ret = paddle.maximum(*args, **kwargs) + else: + if "dim" in kwargs: + kwargs["axis"] = kwargs.pop("dim") + + if "axis" in kwargs or len(args) >= 2: + if out_v: + ret = paddle.max(*args, **kwargs), paddle.argmax(*args, **kwargs) + paddle.assign(ret[0], out_v[0]) + paddle.assign(ret[1], out_v[1]) + return out_v + else: + ret = paddle.max(*args, **kwargs), paddle.argmax(*args, **kwargs) + return ret + return out_v + else: + ret = paddle.max(*args, **kwargs) + return ret + + if out_v: + paddle.assign(ret, out_v) + return out_v + else: + return ret + + +def split_tensor_func(self, split_size, dim=0): + if isinstance(split_size, int): + return paddle.split(self, self.shape[dim] // split_size, dim) + else: + return paddle.split(self, split_size, dim) + + +setattr(paddle.Tensor, "split", split_tensor_func) + + +def _STR_2_PADDLE_DTYPE(type): + type_map = { + "uint8": paddle.uint8, + "int8": paddle.int8, + "int16": paddle.int16, + "int32": paddle.int32, + "int64": paddle.int64, + "float16": paddle.float16, + "float32": paddle.float32, + "float64": paddle.float64, + "bfloat16": paddle.bfloat16, + } + return type_map.get(type) diff --git a/jointContribution/HighResolution/deepali/utils/sitk/__init__.py b/jointContribution/HighResolution/deepali/utils/sitk/__init__.py new file mode 100644 index 0000000000..f76ea3438f --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/sitk/__init__.py @@ -0,0 +1,5 @@ +"""Utility functions for `SimpleITK `_ data objects.""" +from .imageio import read_image +from .imageio import write_image + +__all__ = "read_image", "write_image" diff --git a/jointContribution/HighResolution/deepali/utils/sitk/grid.py b/jointContribution/HighResolution/deepali/utils/sitk/grid.py new file mode 100644 index 0000000000..902b1638f6 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/sitk/grid.py @@ -0,0 +1,306 @@ +import itertools +from pathlib import Path +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Union + +import numpy as np +import SimpleITK as sitk + +Coords = Union[np.ndarray, Sequence[float]] + + +class Grid(object): + """Finite and discrete image sampling grid oriented in world space.""" + + __slots__ = ["_size", "origin", "spacing", "direction"] + + def __init__( + self, + size: Tuple[Union[int, float], ...], + origin: Optional[Tuple[float, ...]] = None, + spacing: Optional[Tuple[float, ...]] = None, + direction: Optional[Tuple[float, ...]] = None, + ): + ndim = len(size) + if origin is None: + origin = (0.0,) * ndim + elif not isinstance(origin, tuple): + origin = tuple(origin) + if spacing is None: + spacing = (1.0,) * ndim + elif not isinstance(spacing, tuple): + spacing = tuple(spacing) + if direction is None: + direction = np.eye(ndim, ndim) + if isinstance(direction, np.ndarray): + direction = direction.flatten().astype(float) + if not isinstance(direction, tuple): + direction = tuple(direction) + self._size = tuple(float(n) for n in size) + self.origin = origin + self.spacing = spacing + self.direction = direction + + @property + def ndim(self) -> int: + """Number of image grid dimensions.""" + return len(self.shape) + + @property + def npts(self) -> int: + """Total number of image grid points.""" + return int(np.prod(self.size)) + + @staticmethod + def _int_size(size) -> Tuple[int, ...]: + """Get grid size from internal floating point representation.""" + return tuple([(int(n + 0.5) if n > 1 or n <= 0 else 1) for n in size]) + + @property + def size(self) -> Tuple[int, ...]: + """Size of image data array.""" + return self._int_size(self._size) + + @property + def shape(self) -> Tuple[int, ...]: + """Shape of image data array.""" + return tuple(reversed(self.size)) + + @classmethod + def from_file(cls, path: Union[Path, str]): + reader = sitk.ImageFileReader() + reader.SetFileName(str(path)) + reader.ReadImageInformation() + return cls.from_reader(reader) + + @classmethod + def from_reader(cls, reader: sitk.ImageFileReader): + return cls( + size=reader.GetSize(), + origin=reader.GetOrigin(), + spacing=reader.GetSpacing(), + direction=reader.GetDirection(), + ) + + @classmethod + def from_image(cls, image: sitk.Image): + return cls( + size=image.GetSize(), + origin=image.GetOrigin(), + spacing=image.GetSpacing(), + direction=image.GetDirection(), + ) + + def zeros_image(self, dtype: int = sitk.sitkInt16, channels: int = 1): + """Create empty image from grid.""" + img = sitk.Image(self.size, dtype, channels) + img.SetOrigin(self.origin) + img.SetDirection(self.direction) + img.SetSpacing(self.spacing) + return img + + def with_margin(self, margin: int): + """Create new image grid with an additional margin along each grid axis.""" + if not margin: + return self + return self.__class__( + size=tuple(n + 2 * margin for n in self._size), + origin=self.index_to_physical_space((-margin,) * self.ndim), + spacing=self.spacing, + direction=self.direction, + ) + + def with_spacing(self, *args: float): + """Create new image grid with specified spacing.""" + spacing = np.asarray(*args) + assert 0 <= spacing.ndim <= 1 + if spacing.ndim == 0: + spacing = spacing.repeat(self.ndim) + assert spacing.size == self.ndim + cur_size = np.asarray(self._size) + cur_spacing = np.asarray(self.spacing) + size = cur_size * cur_spacing / spacing + shift = 0.5 * (np.round(cur_size) - np.round(size) * spacing / cur_spacing) + origin = self.index_to_physical_space(shift) + return self.__class__( + size=size, origin=origin, spacing=spacing, direction=self.direction + ) + + def down(self, levels: int = 1): + """Create new image grid of half the size.""" + size = self._size + for _ in range(levels): + size = tuple(n / 2 for n in size) + cur_size = self.size + new_size = self._int_size(size) + spacing = tuple( + self.spacing[i] * cur_size[i] / new_size[i] + if new_size[i] > 0 + else self.spacing[i] + for i in range(self.ndim) + ) + return self.__class__(size=size, spacing=spacing) + + def up(self, levels: int = 1): + """Create new image grid of double the size.""" + size = self._size + for _ in range(levels): + size = tuple(2 * n for n in size) + cur_size = self.size + new_size = self._int_size(size) + spacing = tuple( + self.spacing[i] * cur_size[i] / new_size[i] + if new_size[i] > 0 + else self.spacing[i] + for i in range(self.ndim) + ) + return self.__class__(size=size, spacing=spacing) + + @property + def dcm(self) -> np.ndarray: + """Get direction cosine matrix.""" + return np.array(self.direction).reshape(self.ndim, self.ndim) + + @property + def transform(self) -> np.ndarray: + """Get homogeneous coordinate transformation from image grid to world space.""" + rotation = self.dcm + scaling = np.diag(self.spacing) + matrix = homogeneous_matrix(rotation @ scaling) + matrix[0:-1, (-1)] = self.origin + return matrix + + @property + def inverse_transform(self) -> np.ndarray: + """Get homogeneous coordinate transformation from world space to image grid.""" + rotation = self.dcm.T + scaling = np.diag([(1 / s) for s in self.spacing]) + translation = translation_matrix([(-t) for t in self.origin]) + return homogeneous_matrix(scaling @ rotation) @ translation + + @property + def indices(self) -> np.ndarray: + """Get array of image grid point coordinates in image space.""" + return np.flip( + np.stack( + np.meshgrid(*[np.arange(arg) for arg in self.shape], indexing="ij"), + axis=self.ndim, + ), + axis=-1, + ) + + def axis_indices(self, axis: int) -> np.ndarray: + """Get array of image grid indices along specified axis.""" + return np.arange(self.shape[axis]) + + @property + def points(self) -> np.ndarray: + """Get array of image grid point coordinates in world space.""" + return self.index_to_physical_space(self.indices) + + @property + def coords(self) -> Tuple[np.ndarray, ...]: + """Get 1D arrays of grid point coordinates along each axis.""" + return tuple(self.axis_coords(axis) for axis in range(self.ndim)) + + def axis_coords(self, axis: int) -> np.ndarray: + """Get array of image grid point coordinates in world space along specified axis.""" + indices = [[0]] * self.ndim + indices[axis] = [index for index in range(self.shape[axis])] + mesh = np.stack(np.meshgrid(*indices, indexing="ij"), axis=self.ndim) + return self.index_to_physical_space(mesh)[..., axis].flatten() + + @property + def corners(self) -> np.ndarray: + """Get corners of image domain in world space.""" + limits = [] + for axis in range(self.ndim): + limits.append((0, self.shape[axis])) + corners = [tuple(reversed(indices)) for indices in itertools.product(*limits)] + return self.index_to_physical_space(corners) + + def index_to_physical_space(self, points: Coords) -> np.ndarray: + """Map point coordinates from image to world space.""" + return transform_point(self.transform, points) + + def physical_space_to_index(self, points: Coords) -> np.ndarray: + """Map point coordinates from world to discrete image space.""" + return np.round(self.physical_space_to_continuous_index(points)).astype(int) + + def physical_space_to_continuous_index(self, points: Coords) -> np.ndarray: + """Map point coordinates from world to continuous image space.""" + return np.round(transform_point(self.inverse_transform, points), decimals=12) + + def __repr__(self) -> str: + return ( + self.__class__.__name__ + + "(size={size}, origin={origin}, spacing={spacing}, direction={direction})".format( + size=self.size, + origin=self.origin, + spacing=self.spacing, + direction=self.direction, + ) + ) + + +def homogeneous_coords( + point: np.ndarray, ndim: int = None, copy: bool = True +) -> np.ndarray: + """Create array with homogeneous point coordinates.""" + if ndim is not None: + if tuple(point.shape)[-1] == ndim + 1: + return np.copy(point) if copy else point + assert tuple(point.shape)[-1] == ndim + pts = np.ones( + tuple(point.shape)[0:-1] + (tuple(point.shape)[-1] + 1,), dtype=point.dtype + ) + pts[(...), :-1] = point + return pts + + +def homogeneous_matrix( + transform: np.ndarray, ndim: int = None, copy: bool = True +) -> np.ndarray: + """Create homogeneous transformation matrix from affine coordinate transformation.""" + assert transform.ndim == 2 + rows, cols = tuple(transform.shape) + if ndim is None: + ndim = rows + assert ( + tuple(transform.shape)[1] == ndim or tuple(transform.shape)[1] == ndim + 1 + ), "transform.shape={}".format(tuple(transform.shape)) + elif rows == ndim + 1 and cols == ndim + 1: + return np.copy(transform) if copy else transform + matrix = np.eye(ndim + 1, ndim + 1, dtype=transform.dtype) + matrix[0:rows, 0:cols] = transform + return matrix + + +def translation_matrix(displacement: Sequence[float]) -> np.ndarray: + """Create translation matrix for homogeneous coordinates.""" + displacement = np.asanyarray(displacement) + assert displacement.ndim == 1 + matrix = np.eye(displacement.size + 1) + matrix[0:-1, (-1)] = displacement + return matrix + + +def transform_point(matrix: np.ndarray, point: np.ndarray) -> np.ndarray: + """Transform one or more points given a transformation matrix.""" + pts = np.asanyarray(point) + dim = tuple(pts.shape)[-1] + mat = homogeneous_matrix(matrix, ndim=dim, copy=False) + x = pts.reshape(-1, dim) + y = np.matmul(x, mat[0:-1, 0:-1].T) + np.expand_dims(mat[0:-1, (-1)], axis=0) + return y.reshape(tuple(pts.shape)) + + +def transform_vector(matrix: np.ndarray, vector: np.ndarray) -> np.ndarray: + """Transform one or more vectors given a transformation matrix.""" + vec = np.asanyarray(vector) + dim = tuple(vec.shape)[-1] + v = vec.reshape(-1, dim) + u = np.matmul(v, matrix[0:dim, 0:dim].T) + return u.reshape(tuple(vec.shape)) diff --git a/jointContribution/HighResolution/deepali/utils/sitk/imageio.py b/jointContribution/HighResolution/deepali/utils/sitk/imageio.py new file mode 100644 index 0000000000..b043dde393 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/sitk/imageio.py @@ -0,0 +1,24 @@ +from pathlib import Path +from typing import Union + +import SimpleITK as sitk + +PathStr = Union[Path, str] + + +def read_image(path: PathStr) -> sitk.Image: + """Read image from file.""" + path = Path(path).absolute() + if not path.exists(): + raise FileNotFoundError(f"Image file '{path}' does not exist") + return sitk.ReadImage(str(path)) + + +def write_image(image: sitk.Image, path: PathStr, compress: bool = True): + """Write image to file.""" + path = Path(path).absolute() + try: + path.unlink() + except FileNotFoundError: + path.parent.mkdir(parents=True, exist_ok=True) + return sitk.WriteImage(image, str(path), compress) diff --git a/jointContribution/HighResolution/deepali/utils/sitk/numpy.py b/jointContribution/HighResolution/deepali/utils/sitk/numpy.py new file mode 100644 index 0000000000..0073141532 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/sitk/numpy.py @@ -0,0 +1,48 @@ +from typing import Optional +from typing import Sequence + +import numpy as np +import SimpleITK as sitk + + +def image_dtype(image: sitk.Image) -> np.dtype: + """Get NumPy data type of SimpleITK image.""" + return sitk.GetArrayViewFromImage(image).dtype + + +def image_from_array( + values: np.ndarray, + origin: Optional[Sequence[float]] = None, + spacing: Optional[Sequence[float]] = None, + direction: Optional[Sequence[float]] = None, +) -> sitk.Image: + """Create image from sampling grid and image data array.""" + components = 1 + ndim = 0 + if origin: + ndim = len(origin) + elif spacing: + ndim = len(spacing) + elif direction: + ndim = int(np.sqrt(len(direction))) + assert len(direction) == ndim * ndim + if ndim: + if values.ndim < ndim: + values = values.reshape((1,) * (ndim - values.ndim) + tuple(values.shape)) + elif values.ndim > ndim: + values = values.reshape(tuple(values.shape)[:ndim] + (-1,)) + components = tuple(values.shape)[-1] + assert ndim <= values.ndim <= ndim + 1 + image = sitk.GetImageFromArray(values, isVector=components > 1) + if origin: + image.SetOrigin(origin) + if spacing: + image.SetSpacing(spacing) + if direction: + image.SetDirection(direction) + return image + + +def array_from_image(image: sitk.Image, view: bool = False) -> np.ndarray: + """Get NumPy array with copy of image values.""" + return sitk.GetArrayViewFromImage(image) if view else sitk.GetArrayFromImage(image) diff --git a/jointContribution/HighResolution/deepali/utils/sitk/paddle.py b/jointContribution/HighResolution/deepali/utils/sitk/paddle.py new file mode 100644 index 0000000000..d3af006926 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/sitk/paddle.py @@ -0,0 +1,88 @@ +"""Auxiliary functions for conversion between SimpleITK and P.""" +from typing import Optional +from typing import Sequence + +import paddle +import SimpleITK as sitk + + +def image_from_tensor( + data: paddle.Tensor, + origin: Optional[Sequence[float]] = None, + spacing: Optional[Sequence[float]] = None, + direction: Optional[Sequence[float]] = None, +) -> sitk.Image: + """Create ``SimpleITK.Image`` from image data tensor. + + Args: + data: Image tensor of shape ``(C, ..., X)``. + origin: World coordinates of center of voxel with zero indices. + spacing: Voxel size in world units in each dimension. + direction: Flattened image orientation cosine matrix. + + Returns: + SimpleITK image. + + """ + data = data.detach().cpu() + nchannels = tuple(data.shape)[0] + if nchannels == 1: + data = data[0] + else: + x = data.unsqueeze(axis=-1) + perm_0 = list(range(x.ndim)) + perm_0[0] = -1 + perm_0[-1] = 0 + data = x.transpose(perm=perm_0).squeeze(axis=0) + image = sitk.GetImageFromArray(data.numpy(), isVector=nchannels > 1) + if origin: + image.SetOrigin(origin) + if spacing: + image.SetSpacing(spacing) + if direction: + image.SetDirection(direction) + return image + + +# def tensor_from_image(image: sitk.Image, dtype: Optional[paddle.dtype]=None, +# device: Optional=None) ->paddle.Tensor: +# """Create image data tensor from ``SimpleITK.Image``.""" +# if image.GetPixelID() == sitk.sitkUInt16: +# image = sitk.Cast(image, sitk.sitkInt32) +# elif image.GetPixelID() == sitk.sitkUInt32: +# image = sitk.Cast(image, sitk.sitkInt64) +# data = paddle.to_tensor(sitk.GetArrayFromImage(image)) +# data = data.unsqueeze(axis=0) +# if image.GetNumberOfComponentsPerPixel() > 1: +# x = data +# perm_1 = list(range(x.ndim)) +# perm_1[0] = -1 +# perm_1[-1] = 0 +# data = x.transpose(perm=perm_1).squeeze(axis=-1) +# return data.astype(dtype) + + +def tensor_from_image( + image: sitk.Image, + dtype: Optional[paddle.dtype] = None, + device: Optional[str] = None, +) -> paddle.Tensor: + """Create image data tensor from ``SimpleITK.Image``.""" + if image.GetPixelID() == sitk.sitkUInt16: + image = sitk.Cast(image, sitk.sitkInt32) + elif image.GetPixelID() == sitk.sitkUInt32: + image = sitk.Cast(image, sitk.sitkInt64) + data = paddle.to_tensor(sitk.GetArrayFromImage(image), place=device) + data = data.unsqueeze(axis=0) + if image.GetNumberOfComponentsPerPixel() > 1: + x = data + perm_1 = list(range(x.ndim)) + perm_1[0] = x.ndim - 1 # 将第0维度移动到最后 + perm_1[-1] = 0 # 将最后一维移动到第0维 + data = x.transpose(perm=perm_1).squeeze(axis=-1) + if dtype is not None: + data = data.astype(dtype) + else: + data = data.astype(paddle.float32) # 默认使用float32类型 + + return data diff --git a/jointContribution/HighResolution/deepali/utils/sitk/sample.py b/jointContribution/HighResolution/deepali/utils/sitk/sample.py new file mode 100644 index 0000000000..6024b1ce16 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/sitk/sample.py @@ -0,0 +1,151 @@ +from typing import Optional +from typing import Union + +import numpy as np +import scipy.interpolate +import scipy.ndimage +import scipy.spatial +import SimpleITK as sitk + +from .grid import Grid + + +def resample_image( + image: sitk.Image, + reference: Union[Grid, sitk.Image], + interpolator: int = sitk.sitkLinear, + padding_value: float = 0, +) -> sitk.Image: + """Interpolate image values at grid points of reference image. + + Args: + image: Scalar or vector-valued image to evaluate at the specified points. + reference: Sampling grid on which to evaluate interpolated image. If a path is specified, + interpolator: Enumeration value of SimpleITK interpolator to use. + padding_value: Output value when sampling point is outside the input image domain. + + Returns: + output: Image interpolated at reference grid points. + + """ + resampler = sitk.ResampleImageFilter() + resampler.SetInterpolator(interpolator) + resampler.SetDefaultPixelValue(padding_value) + if isinstance(reference, sitk.Image): + resampler.SetReferenceImage(reference) + else: + resampler.SetSize(reference.size) + resampler.SetOutputOrigin(reference.origin) + resampler.SetOutputDirection(reference.direction) + resampler.SetOutputSpacing(reference.spacing) + return resampler.Execute(image) + + +def warp_image( + image: sitk.Image, + displacement: sitk.Image, + reference: Optional[sitk.Image] = None, + interpolator: int = sitk.sitkLinear, + padding_value: float = 0, +) -> sitk.Image: + """Interpolate image values at displaced output grid points. + + Args: + image: Scalar or vector-valued image to evaluate at the specified points. + displacement: Sampling grid on which to evaluate interpolated continuous image. + interpolator: Enumeration value of SimpleITK interpolator to use. + padding_value: Output value when sampling point is outside the input image domain. + + Returns: + output: Image interpolated at displacement field grid points at the input image + positions obtained by adding the given displacement to these grid coordinates. + + """ + if reference is None: + reference = displacement + resampler = sitk.ResampleImageFilter() + resampler.SetInterpolator(interpolator) + resampler.SetDefaultPixelValue(padding_value) + resampler.SetReferenceImage(reference) + assert resampler.GetSize() == reference.GetSize() + disp_field = sitk.Cast(displacement, sitk.sitkVectorFloat64) + transform = sitk.DisplacementFieldTransform(disp_field) + resampler.SetTransform(transform) + return resampler.Execute(image) + + +def interpolate_ndimage( + image: sitk.Image, points: np.ndarray, padding_value: float = 0, order: int = 1 +) -> np.ndarray: + """Use ``scipy.ndimage.map_coordinates`` to interpolate image.""" + grid = Grid.from_image(image) + idxs = np.moveaxis(grid.physical_space_to_continuous_index(points), -1, 0) + idxs = np.flip(idxs, axis=0) + vals = sitk.GetArrayViewFromImage(image) + nval = image.GetNumberOfComponentsPerPixel() + if nval > 1: + out = np.stack( + [ + scipy.ndimage.map_coordinates( + vals[..., c], idxs, cval=padding_value, order=order + ) + for c in range(nval) + ], + axis=-1, + ) + else: + out = scipy.ndimage.map_coordinates(vals, idxs, cval=padding_value, order=order) + return out + + +def interpolate_regular_grid( + image: sitk.Image, points: np.ndarray, padding_value: float = 0 +) -> np.ndarray: + """Use ``scipy.interpolate.RegularGridInterpolator`` to interpolate image data.""" + size = tuple(points.shape)[0:-1] + grid = Grid.from_image(image) + vals = sitk.GetArrayViewFromImage(image) + nval = image.GetNumberOfComponentsPerPixel() + idxs = grid.physical_space_to_continuous_index(points.reshape(-1, grid.ndim)) + idxs = np.flip(idxs, axis=-1) + coords = tuple(np.arange(tuple(vals.shape)[axis]) for axis in range(grid.ndim)) + if nval > 1: + out = [] + for c in range(nval): + func = scipy.interpolate.RegularGridInterpolator( + coords, vals[..., c], bounds_error=False, fill_value=padding_value + ) + out.append(func(idxs).reshape(size).astype(vals.dtype)) + out = np.stack(out, axis=-1) + else: + func = scipy.interpolate.RegularGridInterpolator( + coords, vals, bounds_error=False, fill_value=padding_value + ) + out = func(idxs).reshape(size).astype(vals.dtype) + return out + + +def interpolate_griddata(image: sitk.Image, points: np.ndarray) -> np.ndarray: + """Interpolate image similar to ``scipy.interpolate.griddata``. + + This method should only be used for comparison. The used Delaunay triangulation + is not suited for interpolating image data sampled on a regular grid. + + Use ``warp_image`` (for regularly spaced ``points``) or ``interpolate_image``. + """ + size = tuple(points.shape)[0:-1] + grid = Grid.from_image(image) + tess = scipy.spatial.qhull.Delaunay(grid.points.reshape(-1, grid.ndim)) + vals = sitk.GetArrayViewFromImage(image) + nval = image.GetNumberOfComponentsPerPixel() + coor = points.reshape(-1, grid.ndim) + if nval > 1: + out = [] + for c in range(nval): + img = scipy.interpolate.LinearNDInterpolator(tess, vals[..., c].flatten()) + out.append(img(coor).reshape(size)) + out = np.stack(out, axis=-1) + else: + img = scipy.interpolate.LinearNDInterpolator(tess, vals.flatten()) + out = img(coor).reshape(size) + return out diff --git a/jointContribution/HighResolution/deepali/utils/tensorboard.py b/jointContribution/HighResolution/deepali/utils/tensorboard.py new file mode 100644 index 0000000000..619e1c1717 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/tensorboard.py @@ -0,0 +1,176 @@ +import re +from typing import Callable +from typing import Optional +from typing import Union + +import paddle + +from ..core.image import normalize_image +from ..core.types import TensorCollection + +RE_CHANNEL_INDEX = re.compile("\\{c(:[^}]+)?\\}") + + +def escape_channel_index_format_string(tag: str) -> str: + """Escape image channel index format before str.format() call.""" + return RE_CHANNEL_INDEX.sub("{{c\\1}}", tag) + + +def add_summary_image( + writer: paddle.utils.tensorboard.SummaryWriter, + tag: str, + image: paddle.Tensor, + global_step: Optional[int] = None, + walltime: Optional[float] = None, + image_transform: Union[bool, Callable[[str, paddle.Tensor], paddle.Tensor]] = True, + rescale_transform: Union[ + bool, Callable[[str, paddle.Tensor], paddle.Tensor] + ] = False, + channel_offset: int = 0, +) -> None: + """Add image to TensorBoard summary. + + Args: + writer: TensorBoard writer with open event file. + tag: Image tag passed to ``writer.add_image``. + image: Image data tensor. + global_step: Global step value to record. + walltime: Optional override default walltime seconds. + image_transform: Callable used to extract a 2D image tensor of shape + ``(C, Y, X)`` from ``data``. When a multi-channel tensor is returnd, + each channel is saved as separate image to the TensorBoard event file. + with channel index appended to the ``tag`` separated with an underscore ``_``. + By default, the central slice of the first image in the batch is extracted. + The first argument is the ``tag`` of the image tensor. This can be used to + differently process different images by the same callable transform. + rescale_transform: Image intensities must be in the closed interval ``[0, 1]``. + When ``False``, this must already the case, e.g., as part of ``image_transform``. + When ``True``, images with values outside this interval are rescaled unless + the ``tag`` matches "y" or "y_pred". + channel_offset: Offset to add to channel index in tag format string. + + """ + img = image.detach() + if image_transform is True: + image_transform = first_central_image_slice + if image_transform not in (False, None): + img = image_transform(tag, img) + if rescale_transform is True: + rescale_transform = normalize_summary_image + if rescale_transform not in (False, None): + img = rescale_transform(tag, img) + if img.ndim != 3: + raise AssertionError( + "add_summary_image() transformed tensor must have shape (C, H, W)" + ) + if tuple(img.shape)[0] > 1 and RE_CHANNEL_INDEX.search(tag) is None: + tag = tag + "/{c}" + kwargs = dict(global_step=global_step, walltime=walltime) + for c in range(tuple(img.shape)[0]): + start_0 = img.shape[0] + c if c < 0 else c + writer.add_image( + tag.format(c=c + channel_offset), + paddle.slice(img, [0], [start_0], [start_0 + 1]), + **kwargs + ) + + +def add_summary_images( + writer: paddle.utils.tensorboard.SummaryWriter, + prefix: str, + images: TensorCollection, + global_step: Optional[int] = None, + walltime: Optional[float] = None, + image_transform: Union[bool, Callable[[str, paddle.Tensor], paddle.Tensor]] = True, + rescale_transform: Union[ + bool, Callable[[str, paddle.Tensor], paddle.Tensor] + ] = False, + channel_offset: int = 0, +) -> None: + """Add slices of image tensors to TensorBoard summary. + + Args: + writer: TensorBoard writer with open event file. + prefix: Prefix string for TensorBoard tags. + images: Possibly nested dictionary and/or sequence of image tensors. + global_step: Global step value to record. + walltime: Optional override default walltime seconds. + image_transform: Callable used to extract a 2D image tensor of shape + ``(C, Y, X)`` from each image. When a multi-channel tensor is returnd (C > 1), + each channel is saved as separate image to the TensorBoard event file. + By default, the central slice of the first image in the batch is extracted. + The first argument is the name of the image tensor if applicable, or an empty + string otherwise. This can be used to differently process different images. + rescale_transform: Image intensities must be in the closed interval ``[0, 1]``. + Set to ``False``, if this is already the case or when a custom + ``image_transform`` is used which normalizes the image intensities. + channel_offset: Offset to add to channel index in tag format string. + + """ + if prefix is None: + prefix = "" + if not isinstance(images, dict): + images = {str(i): value for i, value in enumerate(images)} + for name, value in images.items(): + if isinstance(value, paddle.Tensor): + add_summary_image( + writer, + prefix + name, + value, + global_step=global_step, + walltime=walltime, + image_transform=image_transform, + rescale_transform=rescale_transform, + channel_offset=channel_offset, + ) + else: + add_summary_images( + writer, + prefix + name + "/", + value, + global_step=global_step, + walltime=walltime, + image_transform=image_transform, + rescale_transform=rescale_transform, + channel_offset=channel_offset, + ) + + +def central_image_slices( + tag: str, data: paddle.Tensor, start: int = 0, length: int = -1 +) -> paddle.Tensor: + if data.ndim < 4 or tuple(data.shape)[1] != 1: + raise AssertionError( + "central_image_slices() expects image tensors of shape (N, 1, ..., Y, X)" + ) + start_1 = data.shape[1] + 0 if 0 < 0 else 0 + data = paddle.slice(data, [1], [start_1], [start_1 + 1]).squeeze(axis=1) + for dim in range(1, data.ndim - 2): + start_2 = ( + data.shape[dim] + tuple(data.shape)[dim] // 2 + if tuple(data.shape)[dim] // 2 < 0 + else tuple(data.shape)[dim] // 2 + ) + data = paddle.slice(data, [dim], [start_2], [start_2 + 1]) + for dim in range(1, data.ndim - 2): + data = data.squeeze(axis=dim) + if length < 1: + length = tuple(data.shape)[0] + length = min(length, tuple(data.shape)[0]) + start_3 = data.shape[0] + start if start < 0 else start + return paddle.slice(data, [0], [start_3], [start_3 + length]) + + +def all_central_image_slices(tag: str, data: paddle.Tensor) -> paddle.Tensor: + """Extract central slice from each scalar image in batch.""" + return central_image_slices(tag, data, start=0, length=-1) + + +def first_central_image_slice(tag: str, data: paddle.Tensor) -> paddle.Tensor: + """Extract central slice of first image in batch.""" + return central_image_slices(tag, data, start=0, length=1) + + +def normalize_summary_image(tag: str, data: paddle.Tensor) -> paddle.Tensor: + """Linearly rescale image values to [0, 1].""" + return normalize_image(data, mode="unit") diff --git a/jointContribution/HighResolution/deepali/utils/vtk/__init__.py b/jointContribution/HighResolution/deepali/utils/vtk/__init__.py new file mode 100644 index 0000000000..a4628c2a68 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/vtk/__init__.py @@ -0,0 +1,14 @@ +"""Auxiliary functions for working with `VTK `_ data structures.""" +from .idlist import iter_cell_point_ids +from .idlist import iter_id_list +from .idlist import iter_point_cell_ids +from .polydataio import read_polydata +from .polydataio import write_polydata + +__all__ = ( + "iter_id_list", + "iter_cell_point_ids", + "iter_point_cell_ids", + "read_polydata", + "write_polydata", +) diff --git a/jointContribution/HighResolution/deepali/utils/vtk/idlist.py b/jointContribution/HighResolution/deepali/utils/vtk/idlist.py new file mode 100644 index 0000000000..4a42aae41d --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/vtk/idlist.py @@ -0,0 +1,35 @@ +"""Common auxiliary functions for working with VTK data structures.""" +from typing import Generator + +from vtk import vtkIdList +from vtk import vtkPolyData + + +def iter_id_list( + ids: vtkIdList, start: int = 0, stop: int = None +) -> Generator[int, None, None]: + """Iterate over IDs in vtkIdList.""" + if stop is None: + stop = ids.GetNumberOfIds() + elif stop < 0: + stop = stop % (ids.GetNumberOfIds() + 1) + for idx in range(start, stop): + yield ids.GetId(idx) + + +def iter_cell_point_ids( + polydata: vtkPolyData, cell_id: int, start: int = 0, stop: int = None +) -> Generator[int, None, None]: + """Iterate over IDs of points that a given vtkPolyData cell is made up of.""" + point_ids = vtkIdList() + polydata.GetCellPoints(cell_id, point_ids) + yield from iter_id_list(point_ids, start, stop) + + +def iter_point_cell_ids( + polydata: vtkPolyData, point_id: int, start: int = 0, stop: int = None +) -> Generator[int, None, None]: + """Iterate over IDs of vtkPolyData cells that contain a specified point.""" + cell_ids = vtkIdList() + polydata.GetPointCells(point_id, cell_ids) + yield from iter_id_list(cell_ids, start, stop) diff --git a/jointContribution/HighResolution/deepali/utils/vtk/image.py b/jointContribution/HighResolution/deepali/utils/vtk/image.py new file mode 100644 index 0000000000..6f66663799 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/vtk/image.py @@ -0,0 +1,66 @@ +"""Auxiliary functions for working with vtkImageData.""" +from typing import Optional + +import numpy as np +from vtk import vtkImageData +from vtk import vtkImageStencilData +from vtk import vtkImageStencilToImage +from vtk import vtkMatrixToLinearTransform +from vtk import vtkPolyData +from vtk import vtkPolyDataToImageStencil +from vtk import vtkTransformPolyDataFilter + +from ..sitk.grid import Grid +from .numpy import numpy_to_vtk_matrix4x4 + + +def surface_mesh_grid(*mesh: vtkPolyData, resolution: Optional[float] = None) -> Grid: + """Compute image grid for given surface mesh discretization with specified resolution.""" + bounds = np.zeros((6,), dtype=float) + for pointset in mesh: + _bounds = np.asarray(pointset.GetBounds()) + bounds[0::2] = np.minimum(bounds[0::2], _bounds[0::2]) + bounds[1::2] = np.maximum(bounds[1::2], _bounds[1::2]) + if resolution is None or resolution <= 0 or np.isnan(resolution): + resolution = np.sqrt(np.sum(np.square(bounds[1::2] - bounds[0::2]))) / 256 + return Grid( + size=np.ceil((bounds[1::2] - bounds[0::2]) / resolution).astype(int), + origin=bounds[0::2] + 0.5 * resolution, + spacing=np.asarray([resolution] * 3), + ) + + +def surface_image_stencil(mesh: vtkPolyData, grid: Grid) -> vtkImageStencilData: + """Convert vtkPolyData surface mesh to image stencil.""" + rot = np.eye(4, dtype=np.float) + rot[:3, :3] = np.array(grid.direction).reshape(3, 3) + rot = numpy_to_vtk_matrix4x4(rot) + transform = vtkMatrixToLinearTransform() + transform.SetInput(rot) + transformer = vtkTransformPolyDataFilter() + transformer.SetInputData(mesh) + transformer.SetTransform(transform) + converter = vtkPolyDataToImageStencil() + converter.SetInputConnection(transformer.GetOutputPort()) + converter.SetOutputOrigin(grid.origin) + converter.SetOutputSpacing(grid.spacing) + converter.SetOutputWholeExtent( + [0, grid.size[0] - 1, 0, grid.size[1] - 1, 0, grid.size[2] - 1] + ) + converter.Update() + stencil = vtkImageStencilData() + stencil.DeepCopy(converter.GetOutput()) + return stencil + + +def binary_image_stencil_mask(stencil: vtkImageStencilData) -> vtkImageData: + """Set values inside image stencil to specified value.""" + converter = vtkImageStencilToImage() + converter.SetInsideValue(1) + converter.SetOutsideValue(0) + converter.SetInputData(stencil) + converter.SetOutputScalarTypeToUnsignedChar() + converter.Update() + mask = vtkImageData() + mask.DeepCopy(converter.GetOutput()) + return mask diff --git a/jointContribution/HighResolution/deepali/utils/vtk/numpy.py b/jointContribution/HighResolution/deepali/utils/vtk/numpy.py new file mode 100644 index 0000000000..42ca2f013f --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/vtk/numpy.py @@ -0,0 +1,119 @@ +"""Bridge between VTK and NumPy.""" +import warnings +from typing import Optional +from typing import Union + +import numpy as np +from vtk import VTK_CHAR +from vtk import VTK_DOUBLE +from vtk import VTK_FLOAT +from vtk import VTK_LONG +from vtk import VTK_LONG_LONG +from vtk import VTK_SHORT +from vtk import VTK_UNSIGNED_CHAR +from vtk import VTK_UNSIGNED_LONG +from vtk import VTK_UNSIGNED_LONG_LONG +from vtk import VTK_UNSIGNED_SHORT +from vtk import vtkCellArray +from vtk import vtkDataArray +from vtk import vtkMatrix4x4 +from vtk import vtkPoints +from vtk import vtkPointSet +from vtk.util import numpy_support + +VTK_DATA_TYPE_FROM_NUMPY_DTYPE = { + np.dtype("int8"): VTK_CHAR, + np.dtype("int16"): VTK_SHORT, + np.dtype("int32"): VTK_LONG, + np.dtype("int64"): VTK_LONG_LONG, + np.dtype("uint8"): VTK_UNSIGNED_CHAR, + np.dtype("uint16"): VTK_UNSIGNED_SHORT, + np.dtype("uint32"): VTK_UNSIGNED_LONG, + np.dtype("uint64"): VTK_UNSIGNED_LONG_LONG, + np.dtype("float32"): VTK_FLOAT, + np.dtype("float64"): VTK_DOUBLE, +} +NUMPY_DTYPE_FROM_VTK_DATA_TYPE = { + VTK_CHAR: np.dtype("int8"), + VTK_SHORT: np.dtype("int16"), + VTK_LONG: np.dtype("int32"), + VTK_LONG_LONG: np.dtype("int64"), + VTK_UNSIGNED_CHAR: np.dtype("uint8"), + VTK_UNSIGNED_SHORT: np.dtype("uint16"), + VTK_UNSIGNED_LONG: np.dtype("uint32"), + VTK_UNSIGNED_LONG_LONG: np.dtype("uint64"), + VTK_FLOAT: np.dtype("float32"), + VTK_DOUBLE: np.dtype("float64"), +} + + +def numpy_to_vtk_matrix4x4(arr: np.ndarray) -> vtkMatrix4x4: + """Create vtkMatrix4x4 from NumPy array.""" + assert tuple(arr.shape) == (4, 4) + matrix = vtkMatrix4x4() + for i in range(4): + for j in range(4): + matrix.SetElement(i, j, arr[i, j]) + return matrix + + +def numpy_to_vtk_array(*args, name: Optional[str] = None, **kwargs) -> vtkDataArray: + """Convert NumPy array to vtkDataArray.""" + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=FutureWarning) + data_array = numpy_support.numpy_to_vtk(*args, **kwargs) + if name is not None: + data_array.SetName(name) + return data_array + + +def numpy_to_vtk_points(arr: np.ndarray) -> vtkPoints: + """Convert NumPy array to vtkPoints.""" + data = numpy_to_vtk_array(arr) + points = vtkPoints() + points.SetData(data) + return points + + +def numpy_to_vtk_cell_array(arr: np.ndarray) -> vtkCellArray: + """Convert NumPy array to vtkCellArray.""" + if arr.ndim != 2: + raise ValueError("numpy_to_vtk_cell_array() 'arr' must be 2-dimensional") + if arr.dtype not in (np.dtype("int32"), np.dtype("int64")): + raise TypeError( + "numpy_to_vtk_cell_array() 'arr' must have dtype int32 or int64" + ) + cells = vtkCellArray() + use_set_data = False + if use_set_data: + dtype = arr.dtype + offsets = np.repeat( + np.array(tuple(arr.shape)[1], dtype=dtype), tuple(arr.shape)[0] + ) + offsets = np.concatenate([[0], np.cumsum(offsets)]) + offsets = numpy_to_vtk_array(offsets.astype(dtype)) + connectivity = numpy_to_vtk_array(arr.flatten().astype(dtype)) + if not cells.SetData(offsets, connectivity): + raise RuntimeError( + "numpy_to_vtk_cell_array() failed to convert NumPy array to vtkCellArray" + ) + else: + cells.AllocateExact(tuple(arr.shape)[0], np.prod(tuple(arr.shape))) + for cell in arr: + cells.InsertNextCell(len(cell)) + for ptId in cell: + cells.InsertCellPoint(ptId) + assert cells.IsValid() + return cells + + +def vtk_to_numpy_array(data: vtkDataArray) -> np.ndarray: + """Convert vtkDataArray to NumPy array.""" + return numpy_support.vtk_to_numpy(data) + + +def vtk_to_numpy_points(points: Union[vtkPoints, vtkPointSet]) -> np.ndarray: + """Convert vtkPoints to NumPy array.""" + if isinstance(points, vtkPointSet): + points = points.GetPoints() + return vtk_to_numpy_array(points.GetData()) diff --git a/jointContribution/HighResolution/deepali/utils/vtk/polydataio.py b/jointContribution/HighResolution/deepali/utils/vtk/polydataio.py new file mode 100644 index 0000000000..d57e42fd70 --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/vtk/polydataio.py @@ -0,0 +1,133 @@ +from io import StringIO +from pathlib import Path +from typing import List +from typing import Tuple +from typing import Union + +import numpy as np +from vtk import vtkCellArray +from vtk import vtkPLYReader +from vtk import vtkPLYWriter +from vtk import vtkPoints +from vtk import vtkPolyData +from vtk import vtkPolyDataReader +from vtk import vtkPolyDataWriter +from vtk import vtkXMLPolyDataReader +from vtk import vtkXMLPolyDataWriter + +from .numpy import vtk_to_numpy_array +from .numpy import vtk_to_numpy_points + +PathStr = Union[Path, str] + + +def read_polydata(path: PathStr) -> vtkPolyData: + """Read vtkPolyData from specified file.""" + path = Path(path).absolute() + if not path.is_file(): + raise FileNotFoundError(str(path)) + suffix = path.suffix.lower() + if suffix == ".off": + return read_polydata_off(path) + elif suffix == ".ply": + reader = vtkPLYReader() + elif suffix == ".vtp": + reader = vtkXMLPolyDataReader() + elif suffix == ".vtk": + reader = vtkPolyDataReader() + else: + raise ValueError("Unsupported file name extension: {}".format(suffix)) + reader.SetFileName(str(path)) + reader.Update() + polydata = vtkPolyData() + polydata.DeepCopy(reader.GetOutput()) + return polydata + + +def read_off(path: PathStr) -> Tuple[List[List[float]], List[List[int]]]: + """Read values from .off file.""" + data = Path(path).read_text() + stream = StringIO(data) + magic = stream.readline().strip() + if magic not in ("OFF", "CNOFF"): + raise ValueError(f"Invalid OFF file header: {path}") + header = tuple([int(s) for s in stream.readline().strip().split(" ")]) + n_verts, n_faces = header[:2] + verts = [ + [float(s) for s in stream.readline().strip().split(" ")] for _ in range(n_verts) + ] + faces = [ + [int(s) for s in stream.readline().strip().split(" ")] for _ in range(n_faces) + ] + assert ( + len(verts) == n_verts + ), f"Expected {n_verts} vertices, found only {len(verts)}" + assert ( + len(faces) == n_faces + ), f"Expected {n_faces} vertices, found only {len(faces)}" + return verts, faces + + +def read_polydata_off(path: PathStr) -> vtkPolyData: + """Read vtkPolyData from .off file.""" + verts, faces = read_off(path) + points = vtkPoints() + polys = vtkCellArray() + for vert in verts: + points.InsertNextPoint(vert[:3]) + for poly in faces: + polys.InsertNextCell(poly[0], poly[1 : 1 + poly[0]]) + output = vtkPolyData() + output.SetPoints(points) + output.SetPolys(polys) + return output + + +def write_polydata(polydata: vtkPolyData, path: PathStr): + """Write vtkPolyData to specified file in XML format.""" + path = Path(path).absolute() + suffix = path.suffix.lower() + if suffix == ".off": + write_polydata_off(polydata, path) + return + if suffix == ".ply": + writer = vtkPLYWriter() + writer.SetFileTypeToBinary() + elif suffix == ".vtp": + writer = vtkXMLPolyDataWriter() + elif suffix == ".vtk": + writer = vtkPolyDataWriter() + else: + raise ValueError("Unsupported file name extension: {}".format(suffix)) + try: + path.unlink() + except FileNotFoundError: + path.parent.mkdir(parents=True, exist_ok=True) + writer.SetFileName(str(path)) + writer.SetInputData(polydata) + writer.Update() + + +def write_polydata_off(polydata: vtkPolyData, path: PathStr): + """Write vtkPolyData to specified file in OFF format.""" + path = Path(path).absolute() + try: + path.unlink() + except FileNotFoundError: + path.parent.mkdir(parents=True, exist_ok=True) + verts = vtk_to_numpy_points(polydata) + F = polydata.GetPolys().GetNumberOfCells() + faces = vtk_to_numpy_array(polydata.GetPolys().GetData()) + assert faces.ndim == 1 + if len(faces) / F != 4: + raise ValueError( + "write_polydata_off() only supports triangulated surface meshes" + ) + faces = faces.reshape(-1, 4) + with path.open(mode="wt") as fp: + fp.write("OFF\n") + fp.write(f"{len(verts)} {len(faces)} 0\n") + np.savetxt(fp, verts, delimiter=" ", newline="\n", header="", footer="") + np.savetxt( + fp, faces, delimiter=" ", newline="\n", header="", footer="", fmt="%d" + ) diff --git a/jointContribution/HighResolution/deepali/utils/vtk/simpleitk.py b/jointContribution/HighResolution/deepali/utils/vtk/simpleitk.py new file mode 100644 index 0000000000..b32a5137bd --- /dev/null +++ b/jointContribution/HighResolution/deepali/utils/vtk/simpleitk.py @@ -0,0 +1,179 @@ +from typing import Optional +from typing import Sequence +from typing import Union + +import SimpleITK as sitk +from vtk import VTK_CHAR +from vtk import VTK_DOUBLE +from vtk import VTK_FLOAT +from vtk import VTK_LONG +from vtk import VTK_LONG_LONG +from vtk import VTK_SHORT +from vtk import VTK_UNSIGNED_CHAR +from vtk import VTK_UNSIGNED_LONG +from vtk import VTK_UNSIGNED_LONG_LONG +from vtk import VTK_UNSIGNED_SHORT +from vtk import vtkImageData +from vtk import vtkImageImport +from vtk import vtkPoints +from vtk import vtkPointSet + +from ..sitk.grid import Grid +from ..sitk.sample import interpolate_ndimage +from .numpy import numpy_to_vtk_array +from .numpy import vtk_to_numpy_array +from .numpy import vtk_to_numpy_points + +VTK_DATA_TYPE_FROM_SITK_PIXEL_ID = { + sitk.sitkInt8: VTK_CHAR, + sitk.sitkInt16: VTK_SHORT, + sitk.sitkInt32: VTK_LONG, + sitk.sitkInt64: VTK_LONG_LONG, + sitk.sitkUInt8: VTK_UNSIGNED_CHAR, + sitk.sitkUInt16: VTK_UNSIGNED_SHORT, + sitk.sitkUInt32: VTK_UNSIGNED_LONG, + sitk.sitkUInt64: VTK_UNSIGNED_LONG_LONG, + sitk.sitkFloat32: VTK_FLOAT, + sitk.sitkFloat64: VTK_DOUBLE, + sitk.sitkVectorInt8: VTK_CHAR, + sitk.sitkVectorInt16: VTK_SHORT, + sitk.sitkVectorInt32: VTK_LONG, + sitk.sitkVectorInt64: VTK_LONG_LONG, + sitk.sitkVectorUInt8: VTK_UNSIGNED_CHAR, + sitk.sitkVectorUInt16: VTK_UNSIGNED_SHORT, + sitk.sitkVectorUInt32: VTK_UNSIGNED_LONG, + sitk.sitkVectorUInt64: VTK_UNSIGNED_LONG_LONG, + sitk.sitkVectorFloat32: VTK_FLOAT, + sitk.sitkVectorFloat64: VTK_DOUBLE, +} + + +def apply_warp_field_to_points( + warp_field: sitk.Image, points: vtkPoints, is_def_field: bool = False +) -> vtkPoints: + """Transform vtkPoints by linearly interpolated dense vector field. + + Args: + warp_field: Vector field that acts on the given points. + If ``None``, an identity deformation is assumed. + points: Input points in physical space of warp field. + is_def_field: If ``True``, the input ``warp_field`` must be a vector field + of output coordinates. Otherwise, ``warp_field`` must contain displacements + in physical image space. + + Returns: + Transformed output points. + + """ + out = points.NewInstance() + if warp_field is None: + out.DeepCopy(points) + return out + x = vtk_to_numpy_points(points) + y = interpolate_ndimage(warp_field, x) + if not is_def_field: + y += x + out.SetData(numpy_to_vtk_array(y)) + return out + + +def apply_warp_field_to_pointset( + warp_field: sitk.Image, pointset: vtkPointSet, is_def_field: bool = False +) -> vtkPointSet: + """Transform vtkPoints of vtkPointSet by linearly interpolated dense displacement field. + + Args: + warp_field: Vector field that acts on the given points. + If ``None``, an identity deformation is assumed. + pointset: Input point set (e.g., vtkPolyData surface mesh). + is_def_field: If ``True``, the input ``warp_field`` must be a vector field + of output coordinates. Otherwise, ``warp_field`` must contain displacements + in physical image space. + + Returns: + Deep copy of input ``pointset`` with transformed points. + + """ + output = pointset.NewInstance() + output.DeepCopy(pointset) + if warp_field is not None: + points = apply_warp_field_to_points( + warp_field, pointset.GetPoints(), is_def_field=is_def_field + ) + output.SetPoints(points) + return output + + +def image_data_grid(data: vtkImageData) -> Grid: + """Create image grid from vtkImageData object.""" + extent = data.GetExtent() + return Grid( + size=( + extent[1] - extent[0] + 1, + extent[3] - extent[2] + 1, + extent[5] - extent[4] + 1, + ), + origin=data.GetOrigin(), + spacing=data.GetSpacing(), + ) + + +def vtk_image_from_sitk_image( + image: sitk.Image, spacing: Optional[Union[float, Sequence[float]]] = None +) -> vtkImageData: + """Create vtkImageData from SimpleITK image. + + VTK versions before 9.0 do not support orientation information. Therefore, use only the pixel spacing + information for index to physical space transformations when working with VTK data structures. These + can be mapped into the original physical space by applying the rotation and translation from the + original image grid. + + Args: + image: SimpleITK image. + spacing: Data spacing to use instead of ``image.GetSpacing()``. + + Returns: + vtk_image: vtkImageData instance with spacing set, but origin equal to (0, 0, 0). + + """ + data = sitk.GetArrayFromImage(image).tobytes() + pixel_id = image.GetPixelIDValue() + data_type = VTK_DATA_TYPE_FROM_SITK_PIXEL_ID[pixel_id] + size = list(image.GetSize()) + if spacing is None: + spacing = image.GetSpacing() + elif isinstance(spacing, (int, float)): + spacing = (spacing,) * len(size) + spacing = list(spacing) + while len(size) < 3: + size.append(1) + spacing.append(1.0) + importer = vtkImageImport() + importer.CopyImportVoidPointer(data, len(data)) + importer.SetDataScalarType(data_type) + importer.SetNumberOfScalarComponents(image.GetNumberOfComponentsPerPixel()) + importer.SetDataSpacing(spacing) + importer.SetWholeExtent(0, size[0] - 1, 0, size[1] - 1, 0, size[2] - 1) + importer.SetDataExtentToWholeExtent() + importer.UpdateWholeExtent() + output = vtkImageData() + output.DeepCopy(importer.GetOutput()) + return output + + +def sitk_image_from_vtk_image( + image: vtkImageData, grid: Optional[Grid] = None +) -> sitk.Image: + """Create SimpleITK image from vtkImageData.""" + if image.GetNumberOfScalarComponents() != 1: + raise NotImplementedError( + "sitk_image_from_vtk_image() only supports scalar 'image'" + ) + data = image.GetPointData().GetScalars() + data = vtk_to_numpy_array(data) + data = data.reshape(tuple(reversed(grid.size))) + output = sitk.GetImageFromArray(data) + output.SetOrigin(grid.origin) + output.SetSpacing(grid.spacing) + output.SetDirection(grid.direction) + return output diff --git a/jointContribution/HighResolution/ffd/engine.py b/jointContribution/HighResolution/ffd/engine.py new file mode 100644 index 0000000000..9d12a0fe45 --- /dev/null +++ b/jointContribution/HighResolution/ffd/engine.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import math +import weakref +from collections import OrderedDict +from timeit import default_timer as timer +from typing import Any +from typing import Callable +from typing import Tuple + +import paddle + +from .losses import RegistrationLoss +from .losses import RegistrationResult +from .optim import slope_of_least_squares_fit + +PROFILING = False + + +class RegistrationEngine(object): + """Minimize registration loss until convergence.""" + + def __init__( + self, + model: paddle.nn.Layer, + loss: RegistrationLoss, + optimizer: paddle.optimizer.Optimizer, + max_steps: int = 500, + min_delta: float = 1e-06, + min_value: float = float("nan"), + max_history: int = 10, + ): + """Initialize registration loop.""" + self.model = model + self.loss = loss + self.optimizer = optimizer + self.num_steps = 0 + self.max_steps = max_steps + self.min_delta = min_delta + self.min_value = min_value + self.max_history = max(2, max_history) + self.loss_values = [] + self._eval_hooks = OrderedDict() + self._step_hooks = OrderedDict() + + @property + def loss_value(self) -> float: + if not self.loss_values: + return float("inf") + return self.loss_values[-1] + + def step(self) -> float: + """Perform one registration step. + + Returns: + Loss value prior to taking gradient step. + + """ + num_evals = 0 + + def closure() -> float: + self.optimizer.clear_grad() + t_start = timer() + result = self.loss.eval() + if PROFILING: + print(f"Forward pass in {timer() - t_start:.3f}s") + loss = result["loss"] + assert isinstance(loss, paddle.Tensor) + t_start = timer() + loss.backward() + if PROFILING: + print(f"Backward pass in {timer() - t_start:.3f}s") + nonlocal num_evals + num_evals += 1 + with paddle.no_grad(): + for hook in self._eval_hooks.values(): + hook(self, self.num_steps, num_evals, result) + return float(loss) + + loss_value = closure() + self.optimizer.step() + assert loss_value is not None + with paddle.no_grad(): + for hook in self._step_hooks.values(): + hook(self, self.num_steps, num_evals, loss_value) + return loss_value + + def run(self) -> float: + """Perform registration steps until convergence. + + Returns: + Loss value prior to taking last gradient step. + + """ + self.loss_values = [] + self.num_steps = 0 + while self.num_steps < self.max_steps and not self.converged(): + value = self.step() + self.num_steps += 1 + if math.isnan(value): + raise RuntimeError( + f"NaN value in registration loss at gradient step {self.num_steps}" + ) + if math.isinf(value): + raise RuntimeError( + f"Inf value in registration loss at gradient step {self.num_steps}" + ) + self.loss_values.append(value) + if len(self.loss_values) > self.max_history: + self.loss_values.pop(0) + return self.loss_value + + def converged(self) -> bool: + """Check convergence criteria.""" + values = self.loss_values + if not values: + return False + value = values[-1] + if self.min_delta < 0: + epsilon = abs(self.min_delta * value) + else: + epsilon = self.min_delta + slope = slope_of_least_squares_fit(values) + if abs(slope) < epsilon: + return True + if value < self.min_value: + return True + return False + + def register_eval_hook( + self, + hook: Callable[["RegistrationEngine", int, int, "RegistrationResult"], None], + ) -> "RemovableHandle": + r"""Registers an evaluation hook.""" + handle = RemovableHandle(self._eval_hooks) + self._eval_hooks[handle.id] = hook + return handle + + def register_step_hook( + self, hook: Callable[["RegistrationEngine", int, int, float], None] + ) -> "RemovableHandle": + r"""Registers a gradient step hook.""" + handle = RemovableHandle(self._step_hooks) + self._step_hooks[handle.id] = hook + return handle + + +class RemovableHandle: + r""" + A handle which provides the capability to remove a hook. + + Args: + hooks_dict (dict): A dictionary of hooks, indexed by hook ``id``. + extra_dict (Union[dict, List[dict]]): An additional dictionary or list of + dictionaries whose keys will be deleted when the same keys are + removed from ``hooks_dict``. + """ + + id: int + next_id: int = 0 + + def __init__(self, hooks_dict: Any, *, extra_dict: Any = None) -> None: + self.hooks_dict_ref = weakref.ref(hooks_dict) + self.id = RemovableHandle.next_id + RemovableHandle.next_id += 1 + + self.extra_dict_ref: Tuple = () + if isinstance(extra_dict, dict): + self.extra_dict_ref = (weakref.ref(extra_dict),) + elif isinstance(extra_dict, list): + self.extra_dict_ref = tuple(weakref.ref(d) for d in extra_dict) + + def remove(self) -> None: + hooks_dict = self.hooks_dict_ref() + if hooks_dict is not None and self.id in hooks_dict: + del hooks_dict[self.id] + + for ref in self.extra_dict_ref: + extra_dict = ref() + if extra_dict is not None and self.id in extra_dict: + del extra_dict[self.id] + + def __getstate__(self): + if self.extra_dict_ref is None: + return (self.hooks_dict_ref(), self.id) + else: + return ( + self.hooks_dict_ref(), + self.id, + tuple(ref() for ref in self.extra_dict_ref), + ) + + def __setstate__(self, state) -> None: + if state[0] is None: + # create a dead reference + self.hooks_dict_ref = weakref.ref(OrderedDict()) + else: + self.hooks_dict_ref = weakref.ref(state[0]) + self.id = state[1] + RemovableHandle.next_id = max(RemovableHandle.next_id, self.id + 1) + + if len(state) < 3 or state[2] is None: + self.extra_dict_ref = () + else: + self.extra_dict_ref = tuple(weakref.ref(d) for d in state[2]) + + def __enter__(self) -> "RemovableHandle": + return self + + def __exit__(self, type: Any, value: Any, tb: Any) -> None: + self.remove() diff --git a/jointContribution/HighResolution/ffd/hooks.py b/jointContribution/HighResolution/ffd/hooks.py new file mode 100644 index 0000000000..628f4fa606 --- /dev/null +++ b/jointContribution/HighResolution/ffd/hooks.py @@ -0,0 +1,72 @@ +from typing import Callable + +import paddle +from deepali.core import functional as U +from deepali.core.kernels import gaussian1d +from deepali.spatial import is_linear_transform + +from .engine import RegistrationEngine +from .engine import RegistrationResult + +RegistrationEvalHook = Callable[ + [RegistrationEngine, int, int, RegistrationResult], None +] +RegistrationStepHook = Callable[[RegistrationEngine, int, int, float], None] + + +def noop(reg: RegistrationEngine, *args, **kwargs) -> None: + """Dummy no-op loss evaluation hook.""" + ... + + +def normalize_linear_grad(reg: RegistrationEngine, *args, **kwargs) -> None: + """Loss evaluation hook for normalization of linear transformation gradient after backward pass.""" + denom = None + for param in reg.model.parameters(): + if not param.stop_gradient and param.grad is not None: + max_abs_grad = paddle.max(paddle.abs(param.grad)) + if denom is None or denom < max_abs_grad: + denom = max_abs_grad + if denom is None: + return + for param in reg.model.parameters(): + if not param.stop_gradient and param.grad is not None: + param.grad /= denom + + +def normalize_nonrigid_grad(reg: RegistrationEngine, *args, **kwargs) -> None: + """Loss evaluation hook for normalization of non-rigid transformation gradient after backward pass.""" + for param in reg.model.parameters(): + if not param.stop_gradient and param.grad is not None: + paddle.assign( + paddle.nn.functional.normalize(x=param.grad, p=2, axis=1), + output=param.grad, + ) + + +def normalize_grad_hook(transform) -> RegistrationEvalHook: + """Loss evaluation hook for normalization of transformation gradient after backward pass.""" + if is_linear_transform(transform): + return normalize_linear_grad + return normalize_nonrigid_grad + + +def _smooth_nonrigid_grad(reg: RegistrationEngine, sigma: float = 1) -> None: + """Loss evaluation hook for Gaussian smoothing of non-rigid transformation gradient after backward pass.""" + if sigma <= 0: + return + kernel = gaussian1d(sigma) + for param in reg.model.parameters(): + if not param.stop_gradient and param.grad is not None: + param.grad.copy_(U.conv(param.grad, kernel)) + + +def smooth_grad_hook(transform, sigma: float) -> RegistrationEvalHook: + """Loss evaluation hook for Gaussian smoothing of non-rigid gradient after backward pass.""" + if is_linear_transform(transform): + return noop + + def fn(reg: RegistrationEngine, *args, **kwargs): + return _smooth_nonrigid_grad(reg, sigma=sigma) + + return fn diff --git a/jointContribution/HighResolution/ffd/losses.py b/jointContribution/HighResolution/ffd/losses.py new file mode 100644 index 0000000000..7694cbd9d1 --- /dev/null +++ b/jointContribution/HighResolution/ffd/losses.py @@ -0,0 +1,382 @@ +import re +from collections import defaultdict +from typing import Dict +from typing import Generator +from typing import List +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import Type +from typing import TypeVar +from typing import Union + +import paddle +from deepali.core import Grid +from deepali.core import PaddingMode +from deepali.core import Sampling +from deepali.core import functional as U +from deepali.losses import BSplineLoss +from deepali.losses import DisplacementLoss +from deepali.losses import LandmarkPointDistance +from deepali.losses import PairwiseImageLoss +from deepali.losses import ParamsLoss +from deepali.losses import PointSetDistance +from deepali.losses import RegistrationLoss +from deepali.losses import RegistrationLosses +from deepali.losses import RegistrationResult +from deepali.modules import SampleImage +from deepali.spatial import BSplineTransform +from deepali.spatial import CompositeTransform +from deepali.spatial import SequentialTransform +from deepali.spatial import SpatialTransform +from paddle import Tensor +from paddle.nn import Layer + +RE_WEIGHT = re.compile( + "^((?P[0-9]+(\\.[0-9]+)?)\\s*[\\* ])?\\s*(?P[a-zA-Z0-9_-]+)\\s*(\\+\\s*(?P[0-9]+(\\.[0-9]+)?))?$" +) +RE_TERM_VAR = re.compile("^[a-zA-Z0-9_-]+\\((?P[a-zA-Z0-9_]+)\\)$") +TLayer = TypeVar("TLayer", bound=Layer) +TSpatialTransform = TypeVar("TSpatialTransform", bound=SpatialTransform) + + +class PairwiseImageRegistrationLoss(RegistrationLoss): + """Loss function for pairwise multi-channel image registration.""" + + def __init__( + self, + source_data: paddle.Tensor, + target_data: paddle.Tensor, + source_grid: Grid, + target_grid: Grid, + source_chns: Mapping[str, Union[int, Tuple[int, int]]], + target_chns: Mapping[str, Union[int, Tuple[int, int]]], + source_pset: Optional[paddle.Tensor] = None, + target_pset: Optional[paddle.Tensor] = None, + source_landmarks: Optional[paddle.Tensor] = None, + target_landmarks: Optional[paddle.Tensor] = None, + losses: Optional[RegistrationLosses] = None, + weights: Mapping[str, Union[float, str]] = None, + transform: Optional[Union[CompositeTransform, SpatialTransform]] = None, + sampling: Union[Sampling, str] = Sampling.LINEAR, + ): + """Initialize multi-channel registration loss function. + + Args: + source_data: Moving normalized multi-channel source image batch tensor. + source_data: Fixed normalized multi-channel target image batch tensor. + source_grid: Sampling grid of source image. + source_grid: Sampling grid of target image. + source_chns: Mapping from channel (loss, weight) name to index or range. + target_chns: Mapping from channel (loss, weight) name to index or range. + source_pset: Point sets defined with respect to source image grid. + target_pset: Point sets defined with respect to target image grid. + source_landmarks: Landmark points defined with respect to source image grid. + target_landmarks: Landmark points defined with respect to target image grid. + losses: Dictionary of named loss terms. Loss terms must be either a subclass of + ``PairwiseImageLoss``, ``DisplacementLoss``, ``PointSetDistance``, ``ParamsLoss``, + or ``paddle.nn.Layer``. In case of a ``PairwiseImageLoss``, the key (name) of the + loss term must be found in ``channels`` which identifies the corresponding ``target`` + and ``source`` data channels that this loss term relates to. If the name is not found + in the ``channels`` mapping, the loss term is called with all image channels as input. + If a loss term is not an instance of a known registration loss type, it is assumed to be a + regularization term without arguments, e.g., a ``paddle.nn.Layer`` which itself has a reference + to the parameters of the transformation that it is based on. + weights: Scalar weights of loss terms or name of channel with locally adaptive weights. + transform: Spatial transformation to apply to ``source`` image. + sampling: Image interpolation mode. + + """ + super().__init__() + self.register_buffer(name="_source_data", tensor=source_data) + self.register_buffer(name="_target_data", tensor=target_data) + self.source_grid = source_grid + self.target_grid = target_grid + self.source_chns = dict(source_chns or {}) + self.target_chns = dict(target_chns or {}) + self.source_pset = source_pset + self.target_pset = target_pset + self.source_landmarks = source_landmarks + self.target_landmarks = target_landmarks + if transform is None: + transform = SequentialTransform(self.target_grid) + elif isinstance(transform, SpatialTransform): + transform = SequentialTransform(transform) + elif not isinstance(transform, CompositeTransform): + raise TypeError( + "PairwiseImageRegistrationLoss() 'transform' must be of type CompositeTransform" + ) + self.transform = transform + self._sample_image = SampleImage( + target=self.target_grid, + source=self.source_grid, + sampling=sampling, + padding=PaddingMode.ZEROS, + align_centers=False, + ) + points = self.target_grid.coords(device=self._target_data.place) + self.register_buffer(name="grid_points", tensor=points.unsqueeze(axis=0)) + self.loss_terms = self.as_module_dict(losses) + self.weights = dict(weights or {}) + + @property + def device(self) -> str: + """Device on which loss is evaluated.""" + device = self._target_data.place + assert isinstance(device, paddle.base.libpaddle.Place) + return device + + def loss_terms_of_type(self, loss_type: Type[TLayer]) -> Dict[str, TLayer]: + """Get dictionary of loss terms of a specifictype.""" + return { + name: module + for name, module in self.loss_terms.items() + if isinstance(module, loss_type) + } + + def transforms_of_type( + self, transform_type: Type[TSpatialTransform] + ) -> List[TSpatialTransform]: + """Get list of spatial transformations of a specific type.""" + + def _iter_transforms(transform) -> Generator[SpatialTransform, None, None]: + if isinstance(transform, transform_type): + yield transform + elif isinstance(transform, CompositeTransform): + for t in transform.transforms(): + yield from _iter_transforms(t) + + transforms = list(_iter_transforms(self.transform)) + return transforms + + @property + def has_transform(self) -> bool: + """Whether a spatial transformation is set.""" + return len(self.transform) > 0 + + def target_data(self) -> paddle.Tensor: + """Target image tensor.""" + data = self._target_data + assert isinstance(data, paddle.Tensor) + return data + + def source_data(self, grid: Optional[paddle.Tensor] = None) -> paddle.Tensor: + """Sample source image at transformed target grid points.""" + data = self._source_data + assert isinstance(data, paddle.Tensor) + if grid is None: + return data + return self._sample_image(grid, data) + + def data_mask( + self, data: paddle.Tensor, channels: Dict[str, Union[int, Tuple[int, int]]] + ) -> paddle.Tensor: + """Get boolean mask from data tensor.""" + slice_ = self.as_slice(channels["msk"]) + start, stop = slice_.start, slice_.stop + start_0 = data.shape[1] + start if start < 0 else start + mask = paddle.slice(data, [1], [start_0], [start_0 + (stop - start)]) + return mask > 0.9 + + def overlap_mask( + self, source: paddle.Tensor, target: paddle.Tensor + ) -> Optional[paddle.Tensor]: + """Overlap mask at which to evaluate pairwise data term.""" + mask = self.data_mask(source, self.source_chns) + mask &= self.data_mask(target, self.target_chns) + return mask + + @classmethod + def as_slice(cls, arg: Union[int, Sequence[int]]) -> slice: + """Slice of image data channels associated with the specified name.""" + if isinstance(arg, int): + arg = (arg,) + if len(arg) == 1: + arg = arg[0], arg[0] + 1 + if len(arg) == 2: + arg = arg[0], arg[1], 1 + if len(arg) != 3: + raise ValueError( + f"{cls.__name__}.as_slice() 'arg' must be int or sequence of length 1, 2, or 3" + ) + return slice(*arg) + + @classmethod + def data_channels(cls, data: paddle.Tensor, c: slice) -> paddle.Tensor: + """Get subimage data tensor of named channel.""" + i = (slice(0, tuple(data.shape)[0]), c) + tuple( + slice(0, tuple(data.shape)[dim]) for dim in range(2, data.ndim) + ) + return data[i] + + def loss_input( + self, + name: str, + data: paddle.Tensor, + channels: Dict[str, Union[int, Tuple[int, int]]], + ) -> paddle.Tensor: + """Get input for named loss term.""" + if name in channels: + c = channels[name] + elif "img" not in channels: + raise RuntimeError( + f"Channels map contains neither entry for '{name}' nor 'img'" + ) + else: + c = channels["img"] + i: slice = self.as_slice(c) + return self.data_channels(data, i) + + def loss_mask( + self, + name: str, + data: paddle.Tensor, + channels: Dict[str, Union[int, Tuple[int, int]]], + mask: paddle.Tensor, + ) -> paddle.Tensor: + """Get mask for named loss term.""" + weight = self.weights.get(name, 1.0) + if not isinstance(weight, str): + return mask + match = RE_WEIGHT.match(weight) + if match is None: + raise RuntimeError( + f"Invalid weight string ('{weight}') for loss term '{name}'" + ) + chn = match.group("chn") + mul = match.group("mul") + add = match.group("add") + c = channels.get(chn) + if c is None: + raise RuntimeError( + f"Channels map contains no entry for '{name}' weight string '{weight}'" + ) + i = self.as_slice(c) + w = self.data_channels(data, i) + if mul is not None: + w = w * float(mul) + if add is not None: + w = w + float(add) + return w * mask + + def eval(self) -> RegistrationResult: + """Evaluate pairwise image registration loss.""" + result = {} + losses = {} + misc_excl = set() + x: Tensor = self.grid_points + y: Tensor = self.transform(x, grid=True) + variables = defaultdict(list) + for name, buf in self.transform.named_buffers(): + if not buf.stop_gradient: + var = name.rsplit(".", 1)[-1] + variables[var].append(buf) + variables["w"] = [U.move_dim(y - x, -1, 1)] + data_terms = self.loss_terms_of_type(PairwiseImageLoss) + misc_excl |= set(data_terms.keys()) + if data_terms: + source = self.source_data(y) + target = self.target_data() + mask = self.overlap_mask(source, target) + for name, term in data_terms.items(): + s = self.loss_input(name, source, self.source_chns) + t = self.loss_input(name, target, self.target_chns) + m = self.loss_mask(name, target, self.target_chns, mask) + losses[name] = term(s, t, mask=m) + result["source"] = source + result["target"] = target + result["mask"] = mask + dist_terms = self.loss_terms_of_type(PointSetDistance) + misc_excl |= set(dist_terms.keys()) + ldist_terms = { + k: v for k, v in dist_terms.items() if isinstance(v, LandmarkPointDistance) + } + dist_terms = {k: v for k, v in dist_terms.items() if k not in ldist_terms} + if dist_terms: + if self.source_pset is None: + raise RuntimeError(f"{type(self).__name__}() missing source point set") + if self.target_pset is None: + raise RuntimeError(f"{type(self).__name__}() missing target point set") + s = self.source_pset + t = self.transform(self.target_pset) + for name, term in dist_terms.items(): + losses[name] = term(t, s) + if ldist_terms: + if self.source_landmarks is None: + raise RuntimeError(f"{type(self).__name__}() missing source landmarks") + if self.target_landmarks is None: + raise RuntimeError(f"{type(self).__name__}() missing target landmarks") + s = self.source_landmarks + t = self.transform(self.target_landmarks) + for name, term in ldist_terms.items(): + losses[name] = term(t, s) + disp_terms = self.loss_terms_of_type(DisplacementLoss) + misc_excl |= set(disp_terms.keys()) + for name, term in disp_terms.items(): + match = RE_TERM_VAR.match(name) + if match: + var = match.group("var") + elif "v" in variables: + var = "v" + elif "u" in variables: + var = "u" + else: + var = "w" + bufs = variables.get(var) + if not bufs: + raise RuntimeError(f"Unknown variable in loss term name '{name}'") + value = paddle.to_tensor(data=0, dtype="float32", place=self.device) + for buf in bufs: + value += term(buf) + losses[name] = value + bspline_transforms = self.transforms_of_type(BSplineTransform) + bspline_terms = self.loss_terms_of_type(BSplineLoss) + misc_excl |= set(bspline_terms.keys()) + for name, term in bspline_terms.items(): + value = paddle.to_tensor(data=0, dtype="float32", place=self.device) + for bspline_transform in bspline_transforms: + value += term(bspline_transform.data()) + losses[name] = value + params_terms = self.loss_terms_of_type(ParamsLoss) + misc_excl |= set(params_terms.keys()) + for name, term in params_terms.items(): + value = paddle.to_tensor(data=0, dtype="float32", place=self.device) + count = 0 + for params in self.transform.parameters(): + value += term(params) + count += 1 + if count > 1: + value /= count + losses[name] = value + misc_terms = {k: v for k, v in self.loss_terms.items() if k not in misc_excl} + for name, term in misc_terms.items(): + losses[name] = term() + result["losses"] = losses + result["weights"] = self.weights + result["loss"] = self._weighted_sum(losses) + return result + + def _weighted_sum(self, losses: Mapping[str, paddle.Tensor]) -> paddle.Tensor: + """Compute weighted sum of loss terms.""" + loss = paddle.to_tensor(data=0, dtype="float32", place=self.device) + weights = self.weights + for name, value in losses.items(): + w = weights.get(name, 1.0) + if not isinstance(w, str): + value = w * value + loss += value.sum() + return loss + + +def weight_channel_names(weights: Mapping[str, Union[float, str]]) -> Dict[str, str]: + """Get names of channels that are used to weight loss term of another channel.""" + names = {} + for term, weight in weights.items(): + if not isinstance(weight, str): + continue + match = RE_WEIGHT.match(weight) + if match is None: + continue + names[term] = match.group("chn") + return names diff --git a/jointContribution/HighResolution/ffd/optim.py b/jointContribution/HighResolution/ffd/optim.py new file mode 100644 index 0000000000..3e9626755c --- /dev/null +++ b/jointContribution/HighResolution/ffd/optim.py @@ -0,0 +1,56 @@ +from typing import Sequence + +import paddle +import paddle.optimizer as optim + + +def new_optimizer( + name: str, model: paddle.nn.Layer, **kwargs +) -> paddle.optimizer.Optimizer: + """Initialize new optimizer for parameters of given model. + + Args: + name: Name of optimizer. + model: Module whose parameters are to be optimized. + kwargs: Keyword arguments for named optimizer. + + Returns: + New optimizer instance. + + """ + cls = getattr(optim, name, None) + if cls is None: + raise ValueError(f"Unknown optimizer: {name}") + if not issubclass(cls, paddle.optimizer.Optimizer): + raise TypeError( + f"Requested type '{name}' is not a subclass of paddle.optimizer.Optimizer" + ) + if "learning_rate" in kwargs: + if "lr" in kwargs: + raise ValueError( + "new_optimizer() 'lr' and 'learning_rate' are mutually exclusive" + ) + kwargs["lr"] = kwargs.pop("learning_rate") + kwargs["learning_rate"] = kwargs.pop("lr") + return cls(parameters=model.parameters(), **kwargs) + + +def slope_of_least_squares_fit(values: Sequence[float]) -> float: + """Compute slope of least squares fit of line to last n objective function values + + See also: + - https://www.che.udel.edu/pdf/FittingData.pdf + - https://en.wikipedia.org/wiki/1_%2B_2_%2B_3_%2B_4_%2B_%E2%8B%AF + - https://proofwiki.org/wiki/Sum_of_Sequence_of_Squares + + """ + n = len(values) + if n < 2: + return float("nan") + if n == 2: + return values[1] - values[0] + sum_x1 = (n + 1) / 2 + sum_x2 = n * (n + 1) * (2 * n + 1) / 6 + sum_y1 = sum(values) + sum_xy = sum((x + 1) * y for x, y in enumerate(values)) + return (sum_xy - sum_x1 * sum_y1) / (sum_x2 - n * sum_x1 * sum_x1) diff --git a/jointContribution/HighResolution/ffd/pairwise.py b/jointContribution/HighResolution/ffd/pairwise.py new file mode 100644 index 0000000000..8755e90cef --- /dev/null +++ b/jointContribution/HighResolution/ffd/pairwise.py @@ -0,0 +1,869 @@ +from pathlib import Path +from timeit import default_timer as timer +from typing import Any +from typing import Dict +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import Set +from typing import Tuple +from typing import Union + +import paddle +from deepali.core import Axes +from deepali.core import Device +from deepali.core import Grid +from deepali.core import PathStr +from deepali.core import functional as U +from deepali.core import join_kwargs_in_sequence +from deepali.data import FlowField +from deepali.data import Image +from deepali.losses import RegistrationResult +from deepali.losses import new_loss +from deepali.spatial import DisplacementFieldTransform +from deepali.spatial import HomogeneousTransform +from deepali.spatial import NonRigidTransform +from deepali.spatial import QuaternionRotation +from deepali.spatial import RigidQuaternionTransform +from deepali.spatial import SequentialTransform +from deepali.spatial import SpatialTransform +from deepali.spatial import Translation +from deepali.spatial import new_spatial_transform + +from .engine import RegistrationEngine +from .hooks import RegistrationEvalHook +from .hooks import RegistrationStepHook +from .hooks import normalize_grad_hook +from .hooks import smooth_grad_hook +from .losses import PairwiseImageRegistrationLoss +from .losses import weight_channel_names +from .optim import new_optimizer + + +def register_pairwise( + target: Union[PathStr, Dict[str, PathStr]], + source: Union[PathStr, Dict[str, PathStr]], + config: Optional[Dict[str, Any]] = None, + outdir: Optional[PathStr] = None, + verbose: Union[bool, int] = False, + debug: Union[bool, int] = False, + device: Optional[Device] = None, +) -> SpatialTransform: + """Register pair of images.""" + if config is None: + config = {} + if outdir is not None: + outdir = Path(outdir).absolute() + outdir.mkdir(parents=True, exist_ok=True) + loss_config, loss_weights = get_loss_config(config) + model_name, model_args, model_init = get_model_config(config) + optim_name, optim_args, optim_loop = get_optim_config(config) + levels, coarsest_level, finest_level = get_levels_config(config) + finest_spacing, min_size, pyramid_dims = get_pyramid_config(config) + device = get_device_config(config, device) + verbose = int(verbose) + debug = int(debug) + if verbose > 0: + print() + start = timer() + target_keys = set(loss_config.keys()) | set( + weight_channel_names(loss_weights).values() + ) + target_image, target_chns = read_images(target, names=target_keys, device=device) + source_image, source_chns = read_images( + source, names=loss_config.keys(), device=device + ) + if verbose > 3: + print(f"Read images from files in {timer() - start:.3f}s") + start_reg = timer() + target_image = append_mask(target_image, target_chns, config) + source_image = append_mask(source_image, source_chns, config) + norm_params = get_normalize_config(config, target_image, target_chns) + target_image = normalize_data_(target_image, target_chns, **norm_params) + source_image = normalize_data_(source_image, source_chns, **norm_params) + start = timer() + target_pyramid = target_image.pyramid( + levels, + start=finest_level, + end=coarsest_level, + dims=pyramid_dims, + spacing=finest_spacing, + min_size=min_size, + ) + source_pyramid = source_image.pyramid( + levels, + start=finest_level, + end=coarsest_level, + dims=pyramid_dims, + spacing=finest_spacing, + min_size=min_size, + ) + if verbose > 3: + print(f"Constructed Gaussian resolution pyramids in {timer() - start:.3f}s\n") + if verbose > 2: + print("Target image pyramid:") + print_pyramid_info(target_pyramid) + print("Source image pyramid:") + print_pyramid_info(source_pyramid) + del target_image + del source_image + source_grid = source_pyramid[finest_level].grid() + finest_grid = target_pyramid[finest_level].grid() + coarsest_grid = target_pyramid[coarsest_level].grid() + post_transform = get_post_transform(config, finest_grid, source_grid) + transform_downsample = model_args.pop("downsample", 0) + transform_grid = coarsest_grid.downsample(transform_downsample) + # here is ok + transform = new_spatial_transform( + model_name, grid=transform_grid, groups=1, **model_args + ) + if model_init: + if verbose > 1: + print(f"Fitting '{model_init}'...") + disp_field = FlowField.read(model_init).to(device=device) + assert isinstance(disp_field, FlowField) + start = timer() + transform = transform.to(device=device).fit(disp_field.batch()) + if verbose > 0: + print(f"Fitted initial displacement field in {timer() - start:.3f}s") + del disp_field + grid_transform = SequentialTransform(transform, post_transform) + grid_transform = grid_transform.to(device=device) + for level in range(coarsest_level, finest_level - 1, -1): + target_image = target_pyramid[level] + source_image = source_pyramid[level] + # here is ok + if outdir and debug > 0: + write_channels( + data=target_image.tensor(), + grid=target_image.grid(), + channels=target_chns, + outdir=outdir, + prefix=f"level_{level}_target_", + ) + write_channels( + data=source_image.tensor(), + grid=source_image.grid(), + channels=source_chns, + outdir=outdir, + prefix=f"level_{level}_source_", + ) + if level != coarsest_level: + start = timer() + transform_grid = target_image.grid().downsample(transform_downsample) + transform.grid_(transform_grid) + if verbose > 3: + print(f"Subdivided control point grid in {timer() - start:.3f}s") + grid_transform.grid_(target_image.grid()) + loss_terms = new_loss_terms(loss_config) + loss = PairwiseImageRegistrationLoss( + losses=loss_terms, + source_data=source_image.tensor().unsqueeze(axis=0), + target_data=target_image.tensor().unsqueeze(axis=0), + source_grid=source_image.grid(), + target_grid=target_image.grid(), + source_chns=source_chns, + target_chns=target_chns, + transform=grid_transform, + weights=loss_weights, + ) + loss = loss.to(device=device) + if outdir and debug > 1: + start = timer() + result = loss.eval() + if verbose > 3: + print(f"Evaluated initial loss in {timer() - start:.3f}s") + write_result( + result, + grid=target_image.grid(), + channels=source_chns, + outdir=outdir, + prefix=f"level_{level}_initial_", + ) + flow = grid_transform.flow(target_image.grid(), device=device) + flow[0].write(outdir / f"level_{level}_initial_def.mha") + optimizer = new_optimizer(optim_name, model=grid_transform, **optim_args) + engine = RegistrationEngine( + model=grid_transform, + loss=loss, + optimizer=optimizer, + max_steps=optim_loop.get("max_steps", 250), + min_delta=float(optim_loop.get("min_delta", "nan")), + ) + grad_sigma = float(optim_loop.get("smooth_grad", 0)) + if isinstance(transform, NonRigidTransform) and grad_sigma > 0: + engine.register_eval_hook(smooth_grad_hook(transform, sigma=grad_sigma)) + engine.register_eval_hook(normalize_grad_hook(transform)) + if verbose > 2: + engine.register_eval_hook(print_eval_loss_hook(level)) + elif verbose > 1: + engine.register_step_hook(print_step_loss_hook(level)) + if outdir and debug > 2: + engine.register_eval_hook( + write_result_hook( + level=level, + grid=target_image.grid(), + channels=source_chns, + outdir=outdir, + ) + ) + engine.run() + if verbose > 0 or outdir and debug > 0: + start = timer() + result = loss.eval() + if verbose > 3: + print(f"Evaluated final loss in {timer() - start:.3f}s") + if verbose > 0: + loss_value = float(result["loss"]) + print( + f"level={level:d}: loss={loss_value:.5f} ({engine.num_steps:d} steps)", + flush=True, + ) + if outdir and debug > 0: + write_result( + result, + grid=target_image.grid(), + channels=source_chns, + outdir=outdir, + prefix=f"level_{level}_final_", + ) + flow = grid_transform.flow(device=device) + flow[0].write(outdir / f"level_{level}_final_def.mha") + if verbose > 3: + print(f"Registered images in {timer() - start_reg:.3f}s") + if verbose > 0: + print() + return grid_transform + + +def append_mask( + image: Image, channels: Dict[str, Tuple[int, int]], config: Dict[str, Any] +) -> Image: + """Append foreground mask to data tensor.""" + data = image.tensor() + if "img" in channels: + lower_threshold, upper_threshold = get_clamp_config(config, "img") + mask = U.threshold( + data[slice(*channels["img"])], lower_threshold, upper_threshold + ) + else: + mask = paddle.ones(shape=(1,) + tuple(data.shape)[1:], dtype=data.dtype) + data = paddle.concat(x=[data, mask.astype(data.dtype)], axis=0) + channels["msk"] = tuple(data.shape)[0] - 1, tuple(data.shape)[0] + return Image(data, image.grid()) + + +def append_data( + data: Optional[paddle.Tensor], + channels: Dict[str, Tuple[int, int]], + name: str, + other: paddle.Tensor, +) -> paddle.Tensor: + """Append image data.""" + if data is None: + data = other + else: + data = paddle.concat(x=[data, other], axis=0) + channels[name] = tuple(data.shape)[0] - tuple(other.shape)[0], tuple(data.shape)[0] + return data + + +def read_images( + sample: Union[PathStr, Dict[str, PathStr]], names: Set[str], device: str +) -> Tuple[Image, Dict[str, Tuple[int, int]]]: + """Read image data from input files.""" + data = None + grid = None + if isinstance(sample, (Path, str)): + sample = {"img": sample} + img_path = sample.get("img") + seg_path = sample.get("seg") + sdf_path = sample.get("sdf") + for path in (img_path, seg_path, sdf_path): + if not path: + continue + grid = Grid.from_file(path).align_corners_(True) + break + else: + raise ValueError( + "One of 'img', 'seg', or 'sdf' input image file paths is required" + ) + assert grid is not None + dtype = "float32" + channels = {} + if "img" in names: + temp = Image.read(img_path, dtype=dtype, device=device) + data = append_data(data, channels, "img", temp.tensor()) + if "seg" in names: + if seg_path is None: + raise ValueError("Missing segmentation label image file path") + temp = Image.read(seg_path, dtype="int64", device=device) + temp_grid = temp.grid() + num_classes = int(temp.max()) + 1 + temp = temp.tensor().unsqueeze(axis=0) + temp = U.as_one_hot_tensor(temp, num_classes).to(dtype=dtype) + temp = temp.squeeze(axis=0) + temp = Image(temp, grid=temp_grid).sample(grid) + data = append_data(data, channels, "seg", temp.tensor()) + if "sdf" in names: + if sdf_path is None: + raise ValueError( + "Missing segmentation boundary signed distance field file path" + ) + temp = Image.read(sdf_path, dtype=dtype, device=device) + temp = temp.sample(shape=grid) + data = append_data(data, channels, "sdf", temp.tensor()) + if data is None: + if img_path is None: + raise ValueError("Missing intensity image file path") + data = Image.read(img_path, dtype=dtype, device=device) + channels = {"img": (0, 1)} + image = Image(data, grid=grid) + return image, channels + + +def get_device_config( + config: Dict[str, Any], device: Optional[Union[str, str]] = None +) -> str: + """Get configured PyTorch device.""" + if device is None: + device = config.get("device", "cpu") + if isinstance(device, int): + device = f"cuda:{device}" + elif device == "cuda": + device = "cuda:0" + return str(device).replace("cuda", "gpu") + + +def load_transform(path: PathStr, grid: Grid) -> SpatialTransform: + """Load transformation from file. + + Args: + path: File path from which to load spatial transformation. + grid: Target domain grid with respect to which transformation is defined. + + Returns: + Loaded spatial transformation. + + """ + target_grid = grid + + def convert_matrix( + matrix: paddle.Tensor, grid: Optional[Grid] = None + ) -> paddle.Tensor: + if grid is None: + pre = target_grid.transform(Axes.CUBE_CORNERS, Axes.WORLD) + post = target_grid.transform(Axes.WORLD, Axes.CUBE_CORNERS) + matrix = U.homogeneous_matmul(post, matrix, pre) + elif grid != target_grid: + pre = target_grid.transform(Axes.CUBE_CORNERS, grid=grid) + post = grid.transform(Axes.CUBE_CORNERS, grid=target_grid) + matrix = U.homogeneous_matmul(post, matrix, pre) + return matrix + + path = Path(path) + if path.suffix == ".pt": + value = paddle.load(path=path) + if isinstance(value, dict): + matrix = value.get("matrix") + if matrix is None: + raise KeyError( + "load_transform() .pt file dict must contain key 'matrix'" + ) + grid = value.get("grid") + elif isinstance(value, paddle.Tensor): + matrix = value + grid = None + else: + raise RuntimeError("load_transform() .pt file must contain tensor or dict") + if matrix.ndim == 2: + matrix = matrix.unsqueeze(axis=0) + if matrix.ndim != 3 or tuple(matrix.shape)[1:] != (3, 4): + raise RuntimeError( + "load_transform() .pt file tensor must have shape (N, 3, 4)" + ) + params = convert_matrix(matrix, grid) + return HomogeneousTransform(target_grid, params=params) + flow = FlowField.read(path, axes=Axes.WORLD) + flow = flow.axes(Axes.from_grid(target_grid)) + flow = flow.sample(shape=target_grid) + return DisplacementFieldTransform( + target_grid, params=flow.tensor().unsqueeze(axis=0) + ) + + +def get_post_transform( + config: Dict[str, Any], target_grid: Grid, source_grid: Grid +) -> Optional[SpatialTransform]: + """Get constant rigid transformation between image grid domains.""" + align = config.get("align", False) + if align is False or align is None: + return None + if isinstance(align, (Path, str)): + return load_transform(align, target_grid) + if align is True: + align_centers = True + align_directions = True + elif isinstance(align, dict): + align_centers = bool(align.get("centers", True)) + align_directions = bool(align.get("directions", True)) + else: + raise ValueError( + "get_post_transform() 'config' has invalid 'align' value: {align}" + ) + center_offset = ( + target_grid.world_to_cube(source_grid.center()).unsqueeze(axis=0) + if align_centers + else None + ) + rotation_matrix = ( + source_grid.direction() @ target_grid.direction().t().unsqueeze(axis=0) + if align_directions + else None + ) + transform = None + if center_offset is not None and rotation_matrix is not None: + transform = RigidQuaternionTransform( + target_grid, translation=center_offset, rotation=False + ) + transform.rotation.matrix_(rotation_matrix) + elif center_offset is not None: + transform = Translation(target_grid, params=center_offset) + elif rotation_matrix is not None: + transform = QuaternionRotation(target_grid, params=False) + transform.matrix_(rotation_matrix) + return transform + + +def get_clamp_config( + config: Dict[str, Any], channel: str +) -> Tuple[Optional[float], Optional[float]]: + """Get thresholds for named image channel. + + Args: + config: Configuration. + channel: Name of image channel. + + Returns: + lower_threshold: Lower threshold. + upper_threshold: Upper threshold. + + """ + input_config = config.get("input", {}) + if not isinstance(input_config, dict): + raise ValueError("get_clamp_config() 'input' value must be dict") + channel_config = input_config.get(channel) + if not isinstance(channel_config, dict): + channel_config = {"clamp": channel_config} + thresholds = channel_config.get("clamp", input_config.get("clamp")) + if thresholds is None: + thresholds = None, None + elif isinstance(thresholds, (int, float)): + thresholds = float(thresholds), None + if not isinstance(thresholds, (list, tuple)): + raise ValueError("get_clamp_config() value must be scalar or sequence") + if len(thresholds) != 2: + raise ValueError("get_clamp_config() value must be scalar or [min, max]") + thresholds = tuple(None if v is None else float(v) for v in thresholds) + lower_threshold, upper_threshold = thresholds + return lower_threshold, upper_threshold + + +def get_scale_config(config: Dict[str, Any], channel: str) -> Optional[float]: + """Get channel scaling factor.""" + input_config = config.get("input", {}) + if not isinstance(input_config, dict): + return None + channel_config = input_config.get(channel) + if not isinstance(channel_config, dict): + return None + value = channel_config.get("scale", input_config.get("scale")) + if value is None: + return None + return float(value) + + +def get_normalize_config( + config: Dict[str, Any], image: Image, channels: Dict[str, Tuple[int, int]] +) -> Dict[str, Dict[str, paddle.Tensor]]: + """Calculate data normalization parameters. + + Args: + config: Configuration. + image: Image data. + channels: Map of image channel slices. + + Returns: + Dictionary of normalization parameters. + + """ + scale = {} + shift = {} + for channel, (start, stop) in channels.items(): + start_1 = image.tensor().shape[0] + start if start < 0 else start + data = paddle.slice(image.tensor(), [0], [start_1], [start_1 + (stop - start)]) + lower_threshold, upper_threshold = get_clamp_config(config, channel) + scale_factor = get_scale_config(config, channel) + if channel in ("msk", "seg"): + if lower_threshold is None: + lower_threshold = 0 + if upper_threshold is None: + upper_threshold = 1 + else: + if lower_threshold is None: + lower_threshold = data.min() + if upper_threshold is None: + upper_threshold = data.max() + if scale_factor is None: + if upper_threshold > lower_threshold: + scale_factor = upper_threshold - lower_threshold + else: + scale_factor = 1 + else: + scale_factor = 1 / scale_factor + shift[channel] = lower_threshold + scale[channel] = scale_factor + return dict(shift=shift, scale=scale) + + +def normalize_data_( + image: Image, + channels: Dict[str, Tuple[int, int]], + shift: Optional[Dict[str, paddle.Tensor]] = None, + scale: Optional[Dict[str, paddle.Tensor]] = None, +) -> Image: + """Normalize image data.""" + if shift is None: + shift = {} + if scale is None: + scale = {} + for channel, (start, stop) in channels.items(): + start_2 = image.tensor().shape[0] + start if start < 0 else start + data = paddle.slice(image.tensor(), [0], [start_2], [start_2 + (stop - start)]) + offset = shift.get(channel) + if offset is not None: + data -= offset + norm = scale.get(channel) + if norm is not None: + data /= norm + if channel in ("msk", "seg"): + data.clip_(min=0, max=1) + return image + + +def get_levels_config(config: Dict[str, Any]) -> Tuple[int, int, int]: + """Get indices of coarsest and finest level from configuration.""" + cfg = config.get("pyramid", {}) + levels = cfg.get("levels", 4) + if isinstance(levels, int): + levels = levels - 1, 0 + if not isinstance(levels, (list, tuple)): + raise TypeError( + "register_pairwise() 'config' key 'pyramid.levels': value must be int, tuple, or list" + ) + coarsest_level, finest_level = levels + if finest_level > coarsest_level: + raise ValueError( + "register_pairwise() 'config' key 'pyramid.levels':" + + " finest level must be less or equal than coarsest level" + ) + levels = coarsest_level + 1 + if "max_level" in cfg: + levels = max(levels, cfg["max_level"]) + return levels, coarsest_level, finest_level + + +def get_pyramid_config( + config: Dict[str, Any] +) -> Tuple[Optional[Union[float, Sequence[float]]], int, Optional[Union[str, int]]]: + """Get settings of Gaussian resolution pyramid from configuration.""" + cfg = config.get("pyramid", {}) + min_size = cfg.get("min_size", 16) + finest_spacing = cfg.get("spacing") + dims = cfg.get("dims") + return finest_spacing, min_size, dims + + +def get_loss_config( + config: Dict[str, Any] +) -> Tuple[Dict[str, Dict[str, Any]], Dict[str, float]]: + """Instantiate terms of registration loss given configuration object. + + Args: + config: Configuration. + + Returns: + losses: Preparsed configuration of loss terms (cf. ``new_loss_terms()``). + weights: Weights of loss terms. + + """ + cfg = None + losses = {} + weights = {} + sections = "loss", "losses", "energy" + for name in sections: + if name in config: + if cfg is not None: + raise ValueError( + "get_loss_config() keys {sections} are mutually exclusive" + ) + cfg = config[name] + if not cfg: + cfg = "SSD" + if isinstance(cfg, str): + cfg = (cfg,) + if isinstance(cfg, Sequence): + names, cfg = cfg, {} + for i, name in enumerate(names): + cfg[f"loss_{i}"] = str(name) + if isinstance(cfg, dict): + for key, value in cfg.items(): + name = None + weight = 1 + kwargs = {} + if isinstance(value, str): + name = value + elif isinstance(value, Sequence): + if not value: + raise ValueError(f"get_loss_config() '{key}' loss entry is empty") + if len(value) == 1: + if isinstance(value[0], str): + value = {"name": value[0]} + elif len(value) > 1: + if isinstance(value[0], (int, float)): + value[0] = {"weight": value[0]} + if isinstance(value[1], str): + value[1] = {"name": value[1]} + value = join_kwargs_in_sequence(value) + if isinstance(value, dict): + kwargs = dict(value) + name = kwargs.pop("name", None) + weight = kwargs.pop("weight", 1) + elif len(value) == 2: + name = value[0] + kwargs = dict(value[1]) + elif len(value) == 3: + weight = float(value[0]) + name = value[1] + kwargs = dict(value[2]) + else: + raise ValueError( + f"get_loss_config() '{key}' invalid loss configuration" + ) + elif isinstance(value, dict): + kwargs = dict(value) + name = kwargs.pop("name", None) + weight = kwargs.pop("weight", 1) + else: + weight, name = value + if name is None: + raise ValueError(f"get_loss_config() missing 'name' for loss '{key}'") + if not isinstance(name, str): + raise TypeError(f"get_loss_config() 'name' of loss '{key}' must be str") + kwargs["name"] = name + losses[key] = kwargs + weights[key] = weight + else: + raise TypeError( + "get_loss_config() 'config' \"losses\" must be str, tuple, list, or dict" + ) + weights_config = config.get("weights", {}) + if isinstance(weights_config, (int, float)): + weights_config = (weights_config,) + if isinstance(weights_config, (list, tuple)): + names, weights_config = weights_config, {} + for i, weight in enumerate(names): + weights_config[f"loss_{i}"] = weight + if not isinstance(weights_config, dict): + raise TypeError( + "get_loss_config() 'weights' must be scalar, tuple, list, or dict" + ) + weights.update(weights_config) + losses = {k: v for k, v in losses.items() if weights.get(k, 0)} + weights = {k: v for k, v in weights.items() if k in losses} + return losses, weights + + +def new_loss_terms(config: Dict[str, Any]) -> Dict[str, paddle.nn.Layer]: + """Instantiate terms of registration loss. + + Args: + config: Preparsed configuration of loss terms. + target_tree: Target vessel centerline tree. + + Returns: + Mapping from channel or loss name to loss module instance. + + """ + losses = {} + for key, value in config.items(): + kwargs = dict(value) + name = kwargs.pop("name", None) + _ = kwargs.pop("weight", None) + if name is None: + raise ValueError(f"new_loss_terms() missing 'name' for loss '{key}'") + if not isinstance(name, str): + raise TypeError(f"new_loss_terms() 'name' of loss '{key}' must be str") + loss = new_loss(name, **kwargs) + losses[key] = loss + return losses + + +def get_model_config( + config: Dict[str, Any] +) -> Tuple[str, Dict[str, Any], Optional[str]]: + """Get configuration of transformation model to use.""" + cfg = config.get("model", {}) + cfg = dict(name=cfg) if isinstance(cfg, str) else dict(cfg) + model_name = cfg.pop("name") + assert isinstance(model_name, str) + assert model_name != "" + model_init = cfg.pop("init", None) + if model_init is not None: + model_init = Path(model_init).as_posix() + model_args = dict(cfg.get(model_name, cfg)) + return model_name, model_args, model_init + + +def get_optim_config( + config: Dict[str, Any] +) -> Tuple[str, Dict[str, Any], Dict[str, Any]]: + """Get configuration of optimizer to use.""" + cfg = config.get("optim", {}) + cfg = dict(name=cfg) if isinstance(cfg, str) else dict(cfg) + if "optimizer" in cfg: + if "name" in cfg: + raise ValueError( + "get_optim_config() keys ('name', 'optimizer') are mutually exclusive" + ) + cfg["name"] = cfg.pop("optimizer") + optim_name = str(cfg.pop("name", "LBFGS")) + optim_loop = {} + for key in ("max_steps", "min_delta", "smooth_grad"): + if key in cfg: + optim_loop[key] = cfg.pop(key) + optim_args = {k: v for k, v in cfg.items() if isinstance(k, str) and k[0].islower()} + optim_args.update(cfg.get(optim_name, {})) + lr_keys = "step_size", "learning_rate" + for lr_key in ("step_size", "learning_rate"): + if lr_key in optim_args: + if "lr" in optim_args: + raise ValueError( + f"get_optim_config() keys {lr_keys + ('lr',)} are mutually exclusive" + ) + optim_args["lr"] = optim_args.pop(lr_key) + return optim_name, optim_args, optim_loop + + +@paddle.no_grad() +def write_channels( + data: paddle.Tensor, + grid: Grid, + channels: Mapping[str, Tuple[int, int]], + outdir: PathStr, + prefix: str = "", +) -> None: + """Write image channels.""" + for name, (start, stop) in channels.items(): + image = data[slice(start, stop, 1)] + if name == "seg": + image = image.argmax(axis=0, keepdim=True).astype("uint8") + elif name == "msk": + image = image.mul(255).clip_(min=0, max=255).astype("uint8") + if not isinstance(image, Image): + image = Image(image, grid=grid) + image.write(outdir / f"{prefix}{name}.mha") + + +@paddle.no_grad() +def write_result( + result: RegistrationResult, + grid: Grid, + channels: Mapping[str, Tuple[int, int]], + outdir: PathStr, + prefix: str = "", +) -> None: + """Write registration result to output directory.""" + data = result["source"] + assert isinstance(data, paddle.Tensor) + write_channels(data[0], grid=grid, channels=channels, outdir=outdir, prefix=prefix) + data = result["mask"] + assert isinstance(data, paddle.Tensor) + if data.dtype == "bool": + data = data.astype("uint8").multiply_(y=paddle.to_tensor(255)) + mask = Image(data[0], grid=grid) + mask.write(outdir / f"{prefix}olm.mha") + + +def write_result_hook( + level: int, grid: Grid, channels: Mapping[str, Tuple[int, int]], outdir: Path +) -> RegistrationEvalHook: + """Get callback function for writing results after each evaluation.""" + + def fn( + _: RegistrationEngine, + num_steps: int, + num_evals: int, + result: RegistrationResult, + ) -> None: + prefix = f"level_{level}_step_{num_steps:03d}_eval_{num_evals}_" + write_result(result, grid=grid, channels=channels, outdir=outdir, prefix=prefix) + + return fn + + +def print_eval_loss_hook(level: int) -> RegistrationEvalHook: + """Get callback function for printing loss after each evaluation.""" + + def fn( + _: RegistrationEngine, num_steps: int, num_eval: int, result: RegistrationResult + ) -> None: + loss = float(result["loss"]) + message = f" {num_steps:>4d}:" + message += f" {loss:>12.05f} (loss)" + weights: Dict[str, Union[str, float]] = result.get("weights", {}) + losses: Dict[str, paddle.Tensor] = result["losses"] + for name, value in losses.items(): + value = float(value) + weight = weights.get(name, 1.0) + if not isinstance(weight, str): + value *= weight + elif "+" in weight: + weight = f"({weight})" + message += f", {value:>12.05f} [{weight} * {name}]" + if num_eval > 1: + message += " [evals={num_eval:d}]" + print(message, flush=True) + + return fn + + +def print_step_loss_hook(level: int) -> RegistrationStepHook: + """Get callback function for printing loss after each step.""" + + def fn(_: RegistrationEngine, num_steps: int, num_eval: int, loss: float) -> None: + message = f" {num_steps:>4d}: {loss:>12.05f}" + if num_eval > 1: + message += " [evals={num_eval:d}]" + print(message, flush=True) + + return fn + + +def print_pyramid_info(pyramid: Dict[str, Image]) -> None: + """Print information of image resolution pyramid.""" + levels = sorted(pyramid.keys()) + for level in reversed(levels): + grid = pyramid[level].grid() + size = ", ".join([f"{n:>3d}" for n in tuple(grid.shape)]) + origin = ", ".join([f"{n:.2f}" for n in grid.origin()]) + extent = ", ".join([f"{n:.2f}" for n in grid.extent()]) + domain = ", ".join([f"{n:.2f}" for n in grid.cube_extent()]) + print( + f"- Level {level}:" + + f" size=({size})" + + f", origin=({origin})" + + f", extent=({extent})" + + f", domain=({domain})" + ) + print() diff --git a/jointContribution/HighResolution/ffd/params_atlas_affine.yaml b/jointContribution/HighResolution/ffd/params_atlas_affine.yaml new file mode 100644 index 0000000000..5c43427ef6 --- /dev/null +++ b/jointContribution/HighResolution/ffd/params_atlas_affine.yaml @@ -0,0 +1,17 @@ +# Using free-form deformation model +model: + name: FullAffine +# Loss terms of objective function to minimize +energy: + seg: [1, MSE] +# Optimization scheme and parameters +optim: + name: Adam + step_size: 0.01 + min_delta: -0.01 + max_steps: 100 +# Gaussian resolution pyramid +pyramid: + dims: ["x", "y", "z"] + levels: 3 + spacing: [1., 1., 1.] diff --git a/jointContribution/HighResolution/ffd/params_seg.yaml b/jointContribution/HighResolution/ffd/params_seg.yaml new file mode 100644 index 0000000000..a1aa7701bb --- /dev/null +++ b/jointContribution/HighResolution/ffd/params_seg.yaml @@ -0,0 +1,22 @@ +# Using free-form deformation model +model: + name: FFD + stride: &stride [8, 8, 8] +# Loss terms of objective function to minimize +energy: + seg: [1, MSE] +# seg: [1, MSE] + be: [0.01, BSplineBending, stride: *stride] + # To approximate bending energy on coarser grid, use smaller stride, e.g.: + # be: [0.005, BSplineBending, stride: 1] +# Optimization scheme and parameters +optim: + name: Adam + step_size: 0.001 + min_delta: -0.01 + max_steps: 100 +# Gaussian resolution pyramid +pyramid: + dims: ["x", "y", "z"] + levels: 3 + spacing: [1., 1., 1.] diff --git a/jointContribution/HighResolution/ffd/register.py b/jointContribution/HighResolution/ffd/register.py new file mode 100644 index 0000000000..ae2f5b8eaf --- /dev/null +++ b/jointContribution/HighResolution/ffd/register.py @@ -0,0 +1,171 @@ +import json +import logging +import sys +from pathlib import Path +from timeit import default_timer as timer +from typing import Any +from typing import Dict + +import deepali +import paddle +import yaml +from deepali.core import Grid +from deepali.core import PathStr +from deepali.core import unlink_or_mkdir +from deepali.data import Image +from deepali.modules import TransformImage +from deepali.utils.cli import Args +from deepali.utils.cli import ArgumentParser +from deepali.utils.cli import configure_logging +from deepali.utils.cli import cuda_visible_devices +from deepali.utils.cli import filter_warning_of_experimental_named_tensors_feature +from deepali.utils.cli import main_func +from paddle import Tensor + +from .pairwise import register_pairwise + +log = logging.getLogger() + + +def parser(**kwargs) -> ArgumentParser: + """Construct argument parser.""" + if "description" not in kwargs: + kwargs["description"] = globals()["__doc__"] + parser = ArgumentParser(**kwargs) + parser.add_argument( + "-c", + "--config", + help="Configuration file", + default=Path(__file__).parent / "params_seg.yaml", + ) + parser.add_argument( + "-t", "--target", "--target-img", dest="target_img", help="Fixed target image" + ) + parser.add_argument( + "-s", "--source", "--source-img", dest="source_img", help="Moving source image" + ) + parser.add_argument("--target-seg", help="Fixed target segmentation label image") + parser.add_argument("--source-seg", help="Moving source segmentation label image") + parser.add_argument( + "-o", + "--output", + "--output-transform", + dest="output_transform", + help="Output transformation parameters", + ) + parser.add_argument( + "-w", + "--warped", + "--warped-img", + "--output-img", + dest="warped_img", + help="Deformed source image", + ) + parser.add_argument( + "--warped-seg", + "--output-seg", + dest="warped_seg", + help="Deformed source segmentation label image", + ) + parser.add_argument( + "--device", + help="Device on which to execute registration", + choices=("cpu", "cuda"), + default="cpu", + ) + parser.add_argument("--debug-dir", help="Output directory for intermediate files") + parser.add_argument( + "--debug", "--debug-level", help="Debug level", type=int, default=0 + ) + parser.add_argument( + "-v", "--verbose", help="Verbosity of output messages", type=int, default=0 + ) + parser.add_argument( + "--log-level", + help="Logging level", + choices=["DEBUG", "INFO", "WARNING", "ERROR"], + default="INFO", + ) + return parser + + +def init(args: Args) -> int: + """Initialize registration.""" + configure_logging(log, args) + if args.device == "cuda": + if not paddle.device.cuda.device_count() >= 1: + log.error("Cannot use --device 'cuda' ") + return 1 + gpu_ids = cuda_visible_devices() + if len(gpu_ids) != 1: + log.error("CUDA_VISIBLE_DEVICES must be set to one GPU") + return 1 + filter_warning_of_experimental_named_tensors_feature() + return 0 + + +def register_func(args: Args) -> deepali.spatial.SpatialTransform: + """Execute registration given parsed arguments.""" + config = load_config_para(args.config) + device = str("cuda:0" if args.device == "cuda" else "cpu").replace("cuda", "gpu") + start = timer() + transform = register_pairwise( + target={"img": args.target_img, "seg": args.target_seg}, + source={"img": args.source_img, "seg": args.source_seg}, + config=config, + outdir=args.debug_dir, + device=args.device, + verbose=args.verbose, + debug=args.debug, + ) + log.info(f"Elapsed time: {timer() - start:.3f}s") + if args.warped_img: + target_grid = Grid.from_file(args.target_img) + source_image = Image.read(args.source_img, device=device) + warp_image = TransformImage( + target=target_grid, + source=source_image.grid(), + sampling="linear", + padding=source_image.min(), + ).to(device) + # here is ok + data: Tensor = warp_image(transform.tensor(), source_image) + warped_image = Image(data, target_grid) + warped_image.write(unlink_or_mkdir(args.warped_img)) + if args.warped_seg: + target_grid = Grid.from_file(args.target_seg) + source_image = Image.read(args.source_seg, device=device) + warp_labels = TransformImage( + target=target_grid, + source=source_image.grid(), + sampling="nearest", + padding=0, + ).to(device) + data: Tensor = warp_labels(transform.tensor(), source_image) + warped_image = Image(data, target_grid) + warped_image.write(unlink_or_mkdir(args.warped_seg)) + if args.output_transform: + path = unlink_or_mkdir(args.output_transform) + if path.suffix == ".pt": + transform.clear_buffers() + paddle.save(obj=transform, path=path) + else: + transform.flow()[0].write(path) + return transform + + +main = main_func(parser, register_func, init=init) + + +def load_config_para(path: PathStr) -> Dict[str, Any]: + """Load registration parameters from configuration file.""" + config_path = Path(path).absolute() + log.info(f"Load configuration from {config_path}") + config_text = config_path.read_text() + if config_path.suffix == ".json": + return json.loads(config_text) + return yaml.safe_load(config_text) + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/jointContribution/HighResolution/main_ACDC.py b/jointContribution/HighResolution/main_ACDC.py new file mode 100644 index 0000000000..3494583ecc --- /dev/null +++ b/jointContribution/HighResolution/main_ACDC.py @@ -0,0 +1,221 @@ +import os +import shutil + +import nibabel as nib +import numpy as np +import util.pre_process as pre_process +import util.utils as util +from scipy.ndimage import zoom +from util.image_utils import combine_labels +from util.image_utils import crop_3Dimage +from util.image_utils import np_mean_dice +from util.image_utils import refineFusionResults + + +def atlas_selection( + atlas_path, + atlas_img_type, + atlas_nums_top, + target_path, + frame, + dice_scores, + parameter_file, +): + # calculate the similarity between the target image and the ED/ES of atlas + # decide to use ED or ES + atlas_top_list = [] + for index, atlas_id in enumerate(dice_scores[:atlas_nums_top]): + source_img = f"{atlas_path}/{atlas_id[0]}/seg_{frame}_{atlas_img_type}.nii.gz" + target_img_atlas = f"{target_path}/tmps/seg_{frame}_{index}.nii.gz" + affine_warped_image_path = f"{target_path}/tmps/affine_warped_source.nii.gz" + # read atlas sa and seg, and save them in the image space + + pre_process.register_with_deepali( + target_img, + source_img, + target_seg_file=target_img, + source_seg_file=source_img, + ffd_params_file=f"{parameter_file}_atlas_affine.yaml", + warped_img_path=affine_warped_image_path, + ) + + pre_process.register_with_deepali( + target_img, + affine_warped_image_path, + target_seg_file=target_img, + source_seg_file=affine_warped_image_path, + ffd_params_file=f"{parameter_file}_seg.yaml", + warped_img_path=target_img_atlas, + ) + + target_img_atlas = f"{target_path}/tmps/seg_{frame}_{index}.nii.gz" + seg_EDES = nib.load( + f"{target_path}/tmps/seg_{frame}_{index}.nii.gz" + ).get_fdata() + seg_EDES = refineFusionResults(seg_EDES, 2) + nib.save(nib.Nifti1Image(seg_EDES, np.eye(4)), target_img_atlas) + atlas_top_list.append(target_img_atlas) + return atlas_top_list + + +def crop_data_into_atlas_size(seg_nib, img_nib, atlas_size): + # 0 - background, 1 - LV, 2 - MYO, 4 - RV, img_data: 4D, WHD*time_t + # the template information 140, 140, 56 + seg_data = seg_nib.get_fdata() + img_data = img_nib.get_fdata().squeeze() + seg_data[seg_data == 1] = 0 + seg_data[seg_data == 3] = 1 + + affine_sa = seg_nib.affine + new_affine = affine_sa.copy() + # align with the atlas data + seg_data = np.flip(np.flip(seg_data, 2), 1) + img_data = np.flip(np.flip(img_data, 2), 1) + + res_xy, res_z = seg_nib.header["pixdim"][1], seg_nib.header["pixdim"][3] + atlas_xy, atlas_z = 1.25, 2.0 + + raw_size = seg_data.shape + # resize to atlas + if seg_data.ndim == 3: + seg_data = zoom( + seg_data, + zoom=(res_xy / atlas_xy, res_xy / atlas_xy, res_z / atlas_z), + order=0, + ) + img_data = zoom( + img_data, + zoom=(res_xy / atlas_xy, res_xy / atlas_xy, res_z / atlas_z), + order=1, + ) + else: + seg_data = zoom( + seg_data, + zoom=(res_xy / atlas_xy, res_xy / atlas_xy, res_z / atlas_z, 1), + order=0, + ) + img_data = zoom( + img_data, + zoom=(res_xy / atlas_xy, res_xy / atlas_xy, res_z / atlas_z, 1), + order=1, + ) + new_affine[:3, 0] /= seg_data.shape[0] / raw_size[0] + new_affine[:3, 1] /= seg_data.shape[1] / raw_size[1] + new_affine[:3, 2] /= seg_data.shape[2] / raw_size[2] + + # calculate coordinates of heart center and crop + heart_mask = (seg_data > 0).astype(np.uint8) + + c0 = np.median(np.where(heart_mask.sum(axis=-1).sum(axis=-1))[0]).astype(int) + c1 = np.median(np.where(heart_mask.sum(axis=0).sum(axis=-1))[0]).astype(int) + c2 = np.median(np.where(heart_mask.sum(axis=0).sum(axis=0))[0]).astype(int) + + crop_seg, crop_sa, new_affine = crop_3Dimage( + seg_data, img_data, (c0, c1, c2), atlas_size, affine_matrix=new_affine + ) + + return crop_seg, crop_sa, new_affine + + +if __name__ == "__main__": + data_dir = "./data" + atlas_path = "./Hammersmith_myo2" + parameter_file = "./ffd/params" + atlas_img_type = "image_space_crop" + atlas_nums_top = 3 + atlas_list = sorted(os.listdir(atlas_path)) + + device = "gpu" + tag = 0 + os.environ["CUDA_VISIBLE_DEVICES"] = str(0) + + subject_list = sorted(os.listdir(data_dir)) + interval = len(subject_list) + + print(f"----- from dataset {tag * interval} to {(tag + 1) * interval} ------") + + for i in range(tag * interval, (tag + 1) * interval): + subid = subject_list[i] + + print(f"----- processing {i}:{subid} ------") + img_path = f"{data_dir}/{subid}" + target_path = f"{data_dir}/{subid}/image_space_pipemesh" + util.setup_dir(target_path) + + Info_path = os.path.join(img_path, "Info.cfg") + with open(Info_path, "r") as file: + lines = file.readlines() + config = {} + for line in lines: + if ":" in line: + key, value = line.strip().split(":", 1) + config[key.strip()] = value.strip() + ED = int(config.get("ED", 0)) + ED = str(ED).zfill(2) + ES = int(config.get("ES", 0)) + ES = str(ES).zfill(2) + + sa_ED_path = f"{img_path}/{subid}_frame{ED}.nii.gz" + seg_sa_ED_path = f"{img_path}/{subid}_frame{ED}_gt.nii.gz" + sa_ED_nib = nib.load(sa_ED_path) + seg_ED_nib = nib.load(seg_sa_ED_path) + crop_seg_ED, crop_sa_ED, new_affine_ED = crop_data_into_atlas_size( + seg_ED_nib, sa_ED_nib, (140, 140, 56) + ) + + sa_ES_path = f"{img_path}/{subid}_frame{ES}.nii.gz" + seg_sa_ES_path = f"{img_path}/{subid}_frame{ES}_gt.nii.gz" + sa_ES_nib = nib.load(sa_ES_path) + seg_ES_nib = nib.load(seg_sa_ES_path) + crop_seg_ES, crop_sa_ES, new_affine_ES = crop_data_into_atlas_size( + seg_ES_nib, sa_ES_nib, (140, 140, 56) + ) + + # calculate top 3 similar atlases + dice_scores = [] + seg_ED = crop_seg_ED + for atlas_id in atlas_list: + source_img = f"{atlas_path}/{atlas_id}/seg_ED_{atlas_img_type}.nii.gz" + atlas_img = nib.load(source_img).get_fdata() + if atlas_img.shape[2] == 56: + # calculate the similarity between the target image and the atlas + dice_score = np_mean_dice(seg_ED, atlas_img) + dice_scores.append((atlas_id, dice_score)) + dice_scores.sort(key=lambda x: x[1], reverse=True) + + for frame in ["ED", "ES"]: + # save it in the image space + if frame == "ED": + seg_flip_time = crop_seg_ED + sa_flip_time = crop_sa_ED + else: + seg_flip_time = crop_seg_ES + sa_flip_time = crop_sa_ES + + target_img = f"{target_path}/seg_sa_{frame}.nii.gz" + nib.save(nib.Nifti1Image(seg_flip_time, np.eye(4)), target_img) + util.setup_dir(f"{target_path}/tmps") + + atlas_top_list = atlas_selection( + atlas_path, + atlas_img_type, + atlas_nums_top, + target_path, + frame, + dice_scores, + parameter_file, + ) + # vote the top 3 atlases + seg = combine_labels(atlas_top_list) + + nib.save( + nib.Nifti1Image(sa_flip_time, np.eye(4)), + f"{target_path}/sa_{frame}.nii.gz", + ) + + nib.save(nib.Nifti1Image(seg, np.eye(4)), target_img) + + try: + shutil.rmtree(f"{target_path}/tmps") + except FileNotFoundError: + pass diff --git a/jointContribution/HighResolution/requirements.txt b/jointContribution/HighResolution/requirements.txt new file mode 100644 index 0000000000..72320b3db7 --- /dev/null +++ b/jointContribution/HighResolution/requirements.txt @@ -0,0 +1,11 @@ +boto3==1.35.37 +dacite==1.8.1 +deprecation==2.1.0 +nibabel==5.3.0 +pandas==2.2.3 +pyvista==0.44.1 +PyYAML==6.0.2 +scipy==1.14.1 +setuptools==72.1.0 +SimpleITK==2.4.0 +typing_extensions==4.12.2 diff --git a/jointContribution/HighResolution/util/image_utils.py b/jointContribution/HighResolution/util/image_utils.py new file mode 100644 index 0000000000..f7bb8e485f --- /dev/null +++ b/jointContribution/HighResolution/util/image_utils.py @@ -0,0 +1,183 @@ +import nibabel as nib +import numpy as np +from scipy.ndimage import gaussian_filter + + +def crop_3Dimage(seg, image_sa, center, size, affine_matrix=None): + """Crop a 3D image using a bounding box centred at (c0, c1, c2) with specified size (size0, size1, size2)""" + c0, c1, c2 = center + size0, size1, size2 = size + S_seg = tuple(seg.shape) + S0, S1, S2 = S_seg[0], S_seg[1], S_seg[2] + r0, r1, r2 = int(size0 / 2), int(size1 / 2), int(size2 / 2) + start0, end0 = c0 - r0, c0 + r0 + start1, end1 = c1 - r1, c1 + r1 + start2, end2 = c2 - r2, c2 + r2 + start0_, end0_ = max(start0, 0), min(end0, S0) + start1_, end1_ = max(start1, 0), min(end1, S1) + start2_, end2_ = max(start2, 0), min(end2, S2) + crop = seg[start0_:end0_, start1_:end1_, start2_:end2_] + crop_img = image_sa[start0_:end0_, start1_:end1_, start2_:end2_] + if crop_img.ndim == 3: + crop_img = np.pad( + crop_img, + ( + (start0_ - start0, end0 - end0_), + (start1_ - start1, end1 - end1_), + (start2_ - start2, end2 - end2_), + ), + "constant", + ) + crop = np.pad( + crop, + ( + (start0_ - start0, end0 - end0_), + (start1_ - start1, end1 - end1_), + (start2_ - start2, end2 - end2_), + ), + "constant", + ) + else: + crop_img = np.pad( + crop_img, + ( + (start0_ - start0, end0 - end0_), + (start1_ - start1, end1 - end1_), + (start2_ - start2, end2 - end2_), + (0, 0), + ), + "constant", + ) + crop = np.pad( + crop, + ( + (start0_ - start0, end0 - end0_), + (start1_ - start1, end1 - end1_), + (start2_ - start2, end2 - end2_), + (0, 0), + ), + "constant", + ) + if affine_matrix is None: + return crop, crop_img + else: + R, b = affine_matrix[0:3, 0:3], affine_matrix[0:3, -1] + affine_matrix[0:3, -1] = R.dot(np.array([c0 - r0, c1 - r1, c2 - r2])) + b + return crop, crop_img, affine_matrix + + +def np_categorical_dice(pred, truth, k): + """Dice overlap metric for label k""" + A = (pred == k).astype(np.float32) + B = (truth == k).astype(np.float32) + return 2 * np.sum(A * B) / (np.sum(A) + np.sum(B)) + + +def np_mean_dice(pred, truth): + """Dice mean metric""" + dsc = [] + for k in np.unique(truth)[1:]: + dsc.append(np_categorical_dice(pred, truth, k)) + return np.mean(dsc) + + +def combine_labels(input_paths, pad=-1, seed=None): + def get_most_popular(count_map): + return max(count_map, key=count_map.get) + + def is_equivocal(count_map): + return len(set(count_map.values())) > 1 + + def decide_on_tie(count_map, rng): + max_count = max(count_map.values()) + tied_labels = [ + label for label, count in count_map.items() if count == max_count + ] + return rng.choice(tied_labels) + + def calculate_counts(input_paths, output_shape): + counts = [{} for _ in range(np.prod(output_shape))] + for input_path in input_paths: + input_image = nib.load(input_path).get_fdata().astype(np.int32) + contended_voxel_indices = np.where( + np.logical_and( + output != input_image, + np.logical_or(output > pad, input_image > pad), + ) + ) + idx = np.ravel_multi_index(contended_voxel_indices, output_shape) + labels = input_image[contended_voxel_indices] + _, counts_per_label = np.unique(idx, return_counts=True) + for idx, label, count in zip(idx, labels, counts_per_label): + counts[idx][label] = counts[idx].get(label, 0) + count + return counts + + output_image = nib.load(input_paths[0]) + output_data = output_image.get_fdata().astype(np.uint8) + output_shape = tuple(output_data.shape) + unanimous_mask = np.ones(output_shape, dtype=np.uint8) + output = output_data.copy() + counts = calculate_counts(input_paths, output_shape) + contended_voxel_indices = np.where(unanimous_mask == 0) + idx = np.ravel_multi_index(contended_voxel_indices, output_shape) + for idx, (z, y, x) in zip(idx, np.transpose(contended_voxel_indices)): + output[z, y, x] = get_most_popular(counts[idx]) + if seed is not None: + rng = np.random.default_rng(seed) + else: + rng = np.random.default_rng() + equivocal_voxel_indices = np.where(unanimous_mask == 0) + idx = np.ravel_multi_index(equivocal_voxel_indices, output_shape) + unique_indices, counts_per_voxel = np.unique(idx, return_counts=True) + for idx, (z, y, x) in zip(unique_indices, np.transpose(equivocal_voxel_indices)): + if is_equivocal(counts[idx]): + output[z, y, x] = decide_on_tie(counts[idx], rng) + return output + + +def threshold_image(data, threshold=130): + # Perform thresholding using NumPy operations + thresholded_data = data.copy() + thresholded_data[data <= threshold] = 0 + thresholded_data[(data > threshold)] = 1 + return thresholded_data + + +def blur_image(data, sigma): + # Apply Gaussian blurring to the data using scipy.ndimage.gaussian_filter + blurred_data = gaussian_filter(data, sigma=sigma) + return blurred_data + + +def binarize_image(data, lower_threshold=4, upper_threshold=4, binary_value=255): + # Perform binarization using NumPy operations + binarized_data = np.zeros_like(data) + binarized_data[(data >= lower_threshold) & (data <= upper_threshold)] = binary_value + return binarized_data + + +def padding(imageA, imageB, threshold, padding, invert=False): + # Create a mask for positions that require padding + if invert: + mask = imageB != threshold + else: + mask = imageB == threshold + + # Update 'imageA' using the mask and padding value + imageA[mask] = padding + return imageA + + +def refineFusionResults(data, alfa): + data = np.round(data) + + hrt = threshold_image(blur_image(binarize_image(data, 1, 4), alfa), 130) + rvendo = threshold_image(blur_image(binarize_image(data, 4, 4), alfa), 130) + lvepi = threshold_image(blur_image(binarize_image(data, 1, 2), alfa), 115) + lvendo = threshold_image(blur_image(binarize_image(data, 1, 1), alfa), 130) + + hrt = padding(hrt, hrt, 1, 4) + rvendo = padding(hrt, rvendo, 1, 4) + lvepi = padding(rvendo, lvepi, 1, 2) + data_final = padding(lvepi, lvendo, 1, 1) + return data_final diff --git a/jointContribution/HighResolution/util/paddle_aux.py b/jointContribution/HighResolution/util/paddle_aux.py new file mode 100644 index 0000000000..d11f7d8de6 --- /dev/null +++ b/jointContribution/HighResolution/util/paddle_aux.py @@ -0,0 +1,63 @@ +import paddle + + +def min_class_func(self, *args, **kwargs): + if "other" in kwargs: + kwargs["y"] = kwargs.pop("other") + ret = paddle.minimum(self, *args, **kwargs) + elif len(args) == 1 and isinstance(args[0], paddle.Tensor): + ret = paddle.minimum(self, *args, **kwargs) + else: + if "dim" in kwargs: + kwargs["axis"] = kwargs.pop("dim") + + if "axis" in kwargs or len(args) >= 1: + ret = paddle.min(self, *args, **kwargs), paddle.argmin( + self, *args, **kwargs + ) + else: + ret = paddle.min(self, *args, **kwargs) + + return ret + + +def max_class_func(self, *args, **kwargs): + if "other" in kwargs: + kwargs["y"] = kwargs.pop("other") + ret = paddle.maximum(self, *args, **kwargs) + elif len(args) == 1 and isinstance(args[0], paddle.Tensor): + ret = paddle.maximum(self, *args, **kwargs) + else: + if "dim" in kwargs: + kwargs["axis"] = kwargs.pop("dim") + + if "axis" in kwargs or len(args) >= 1: + ret = paddle.max(self, *args, **kwargs), paddle.argmax( + self, *args, **kwargs + ) + else: + ret = paddle.max(self, *args, **kwargs) + + return ret + + +setattr(paddle.Tensor, "min", min_class_func) +setattr(paddle.Tensor, "max", max_class_func) + + +def mul(self, *args, **kwargs): + if "other" in kwargs: + y = kwargs["other"] + elif "y" in kwargs: + y = kwargs["y"] + else: + y = args[0] + + if not isinstance(y, paddle.Tensor): + y = paddle.to_tensor(y) + + return paddle.multiply(self, y.astype(self.dtype)) + + +setattr(paddle.Tensor, "mul", mul) +setattr(paddle.Tensor, "multiply", mul) diff --git a/jointContribution/HighResolution/util/pre_process.py b/jointContribution/HighResolution/util/pre_process.py new file mode 100644 index 0000000000..731d331942 --- /dev/null +++ b/jointContribution/HighResolution/util/pre_process.py @@ -0,0 +1,64 @@ +from dataclasses import dataclass + +import ffd.register as ffd_register +import paddle +import pyvista as pv +from deepali.core import PathStr + + +@dataclass +class DeepaliFFDRuntimeArgs: + """Dataclass packing registration arguments""" + + target_img: PathStr + source_img: PathStr + target_seg: PathStr = None + source_seg: PathStr = None + config: PathStr = None + output_transform: PathStr = None + warped_img: PathStr = None + warped_seg: PathStr = None + device: str = "cuda" + debug_dir: PathStr = None + debug: int = 0 + verbose: int = 0 + log_level: str = "WARNING" + + +def register_with_deepali( + target_img_file: PathStr = None, + source_img_file: PathStr = None, + target_seg_file: PathStr = None, + source_seg_file: PathStr = None, + target_mesh_file: PathStr = None, + ffd_params_file: PathStr = None, + output_transform_path: PathStr = None, + warped_img_path: PathStr = None, + warped_mesh_path: PathStr = None, + warped_seg_path: PathStr = None, +): + """Register two images using FFD with GPU-enabled Deepali and transform the mesh.""" + args = DeepaliFFDRuntimeArgs( + target_img=target_img_file, + source_img=source_img_file, + target_seg=target_seg_file, + source_seg=source_seg_file, + config=ffd_params_file, + output_transform=output_transform_path, + warped_img=warped_img_path, + warped_seg=warped_seg_path, + ) + ffd_register.init(args) + transform = ffd_register.register_func(args) + if target_mesh_file is not None: + warp_transform_on_mesh(transform, target_mesh_file, warped_mesh_path) + return transform + + +def warp_transform_on_mesh(transform, target_mesh_file, warped_mesh_path): + target_mesh = pv.read(target_mesh_file) + target_points = paddle.to_tensor(data=target_mesh.points).unsqueeze(axis=0) + target_points = target_points.to(device=transform.place) + warped_target_points = transform.points(target_points, axes="grid") + target_mesh.points = warped_target_points.squeeze(axis=0).detach().cpu().numpy() + target_mesh.save(warped_mesh_path) diff --git a/jointContribution/HighResolution/util/utils.py b/jointContribution/HighResolution/util/utils.py new file mode 100644 index 0000000000..776e1b17eb --- /dev/null +++ b/jointContribution/HighResolution/util/utils.py @@ -0,0 +1,7 @@ +import os + + +def setup_dir(dir_path): + if not os.path.exists(dir_path): + os.makedirs(dir_path) + return dir_path