Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ ci:

repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v6.0.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
Expand Down Expand Up @@ -80,7 +80,7 @@ repos:
)$

- repo: https://github.com/pre-commit/mirrors-prettier
rev: v3.1.0
rev: v4.0.0-alpha.8
hooks:
- id: prettier
files: \.(json|yml|yaml|toml)
Expand Down Expand Up @@ -112,7 +112,7 @@ repos:
- id: text-unicode-replacement-char

- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.2
rev: v0.13.3
hooks:
# try to fix what is possible
- id: ruff
Expand All @@ -123,7 +123,7 @@ repos:
- id: ruff

- repo: https://github.com/tox-dev/pyproject-fmt
rev: v2.6.0
rev: v2.7.0
hooks:
- id: pyproject-fmt
additional_dependencies: [tox]
Expand Down
12 changes: 7 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@ ______________________________________________________________________
</div>

# Looking for GPUs?
Over 340,000 developers use [Lightning Cloud](https://lightning.ai/?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) - purpose-built for PyTorch and PyTorch Lightning.
- [GPUs](https://lightning.ai/pricing?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) from $0.19.
- [Clusters](https://lightning.ai/clusters?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): frontier-grade training/inference clusters.

Over 340,000 developers use [Lightning Cloud](https://lightning.ai/?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) - purpose-built for PyTorch and PyTorch Lightning.

- [GPUs](https://lightning.ai/pricing?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme) from $0.19.
- [Clusters](https://lightning.ai/clusters?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): frontier-grade training/inference clusters.
- [AI Studio (vibe train)](https://lightning.ai/studios?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): workspaces where AI helps you debug, tune and vibe train.
- [AI Studio (vibe deploy)](https://lightning.ai/studios?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): workspaces where AI helps you optimize, and deploy models.
- [AI Studio (vibe deploy)](https://lightning.ai/studios?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): workspaces where AI helps you optimize, and deploy models.
- [Notebooks](https://lightning.ai/notebooks?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): Persistent GPU workspaces where AI helps you code and analyze.
- [Inference](https://lightning.ai/deploy?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): Deploy models as inference APIs.
- [Inference](https://lightning.ai/deploy?utm_source=tm_readme&utm_medium=referral&utm_campaign=tm_readme): Deploy models as inference APIs.

# Installation

Expand Down
12 changes: 6 additions & 6 deletions src/torchmetrics/functional/image/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def _error_relative_global_dimensionless_synthesis(
preds: Tensor,
target: Tensor,
ratio: float = 4,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
) -> Tensor:
"""Wrapper for deprecated import.

Expand Down Expand Up @@ -82,7 +82,7 @@ def _peak_signal_noise_ratio(
target: Tensor,
data_range: Union[float, tuple[float, float]] = 3.0,
base: float = 10.0,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
dim: Optional[Union[int, tuple[int, ...]]] = None,
) -> Tensor:
"""Wrapper for deprecated import.
Expand Down Expand Up @@ -135,7 +135,7 @@ def _root_mean_squared_error_using_sliding_window(
def _spectral_angle_mapper(
preds: Tensor,
target: Tensor,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
) -> Tensor:
"""Wrapper for deprecated import.

Expand All @@ -156,7 +156,7 @@ def _multiscale_structural_similarity_index_measure(
gaussian_kernel: bool = True,
sigma: Union[float, Sequence[float]] = 1.5,
kernel_size: Union[int, Sequence[int]] = 11,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
data_range: Optional[Union[float, tuple[float, float]]] = None,
k1: float = 0.01,
k2: float = 0.03,
Expand Down Expand Up @@ -194,7 +194,7 @@ def _structural_similarity_index_measure(
gaussian_kernel: bool = True,
sigma: Union[float, Sequence[float]] = 1.5,
kernel_size: Union[int, Sequence[int]] = 11,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
data_range: Optional[Union[float, tuple[float, float]]] = None,
k1: float = 0.01,
k2: float = 0.03,
Expand Down Expand Up @@ -226,7 +226,7 @@ def _structural_similarity_index_measure(
)


def _total_variation(img: Tensor, reduction: Literal["mean", "sum", "none", None] = "sum") -> Tensor:
def _total_variation(img: Tensor, reduction: Optional[Literal["mean", "sum", "none"]] = "sum") -> Tensor:
"""Wrapper for deprecated import.

>>> from torch import rand
Expand Down
6 changes: 4 additions & 2 deletions src/torchmetrics/functional/image/ergas.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import torch
from torch import Tensor
from typing_extensions import Literal
Expand Down Expand Up @@ -45,7 +47,7 @@ def _ergas_compute(
preds: Tensor,
target: Tensor,
ratio: float = 4,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
) -> Tensor:
"""Erreur Relative Globale Adimensionnelle de Synthèse.

Expand Down Expand Up @@ -85,7 +87,7 @@ def error_relative_global_dimensionless_synthesis(
preds: Tensor,
target: Tensor,
ratio: float = 4,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
) -> Tensor:
"""Calculates `Error relative global dimensionless synthesis`_ (ERGAS) metric.

Expand Down
4 changes: 2 additions & 2 deletions src/torchmetrics/functional/image/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def _psnr_compute(
num_obs: Tensor,
data_range: Tensor,
base: float = 10.0,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
) -> Tensor:
"""Compute peak signal-to-noise ratio.

Expand Down Expand Up @@ -97,7 +97,7 @@ def peak_signal_noise_ratio(
target: Tensor,
data_range: Union[float, tuple[float, float]],
base: float = 10.0,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
dim: Optional[Union[int, tuple[int, ...]]] = None,
) -> Tensor:
"""Compute the peak signal-to-noise ratio.
Expand Down
6 changes: 4 additions & 2 deletions src/torchmetrics/functional/image/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional

import torch
from torch import Tensor
from typing_extensions import Literal
Expand Down Expand Up @@ -49,7 +51,7 @@ def _sam_update(preds: Tensor, target: Tensor) -> tuple[Tensor, Tensor]:
def _sam_compute(
preds: Tensor,
target: Tensor,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
) -> Tensor:
"""Compute Spectral Angle Mapper.

Expand Down Expand Up @@ -81,7 +83,7 @@ def _sam_compute(
def spectral_angle_mapper(
preds: Tensor,
target: Tensor,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
) -> Tensor:
"""Universal Spectral Angle Mapper.

Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/image/scc.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def spatial_correlation_coefficient(
target: Tensor,
hp_filter: Optional[Tensor] = None,
window_size: int = 8,
reduction: Optional[Literal["mean", "none", None]] = "mean",
reduction: Optional[Optional[Literal["mean", "none"]]] = "mean",
) -> Tensor:
"""Compute Spatial Correlation Coefficient (SCC_).

Expand Down
8 changes: 4 additions & 4 deletions src/torchmetrics/functional/image/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def _ssim_update(

def _ssim_compute(
similarities: Tensor,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
) -> Tensor:
"""Apply the specified reduction to pre-computed structural similarity.

Expand All @@ -213,7 +213,7 @@ def structural_similarity_index_measure(
gaussian_kernel: bool = True,
sigma: Union[float, Sequence[float]] = 1.5,
kernel_size: Union[int, Sequence[int]] = 11,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
data_range: Optional[Union[float, tuple[float, float]]] = None,
k1: float = 0.01,
k2: float = 0.03,
Expand Down Expand Up @@ -427,7 +427,7 @@ def _multiscale_ssim_update(

def _multiscale_ssim_compute(
mcs_per_image: Tensor,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
) -> Tensor:
"""Apply the specified reduction to pre-computed multi-scale structural similarity.

Expand All @@ -452,7 +452,7 @@ def multiscale_structural_similarity_index_measure(
gaussian_kernel: bool = True,
sigma: Union[float, Sequence[float]] = 1.5,
kernel_size: Union[int, Sequence[int]] = 11,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
data_range: Optional[Union[float, tuple[float, float]]] = None,
k1: float = 0.01,
k2: float = 0.03,
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/pairwise/cosine.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _pairwise_cosine_similarity_update(
def pairwise_cosine_similarity(
x: Tensor,
y: Optional[Tensor] = None,
reduction: Literal["mean", "sum", "none", None] = None,
reduction: Optional[Literal["mean", "sum", "none"]] = None,
zero_diagonal: Optional[bool] = None,
) -> Tensor:
r"""Calculate pairwise cosine similarity.
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/pairwise/euclidean.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def _pairwise_euclidean_distance_update(
def pairwise_euclidean_distance(
x: Tensor,
y: Optional[Tensor] = None,
reduction: Literal["mean", "sum", "none", None] = None,
reduction: Optional[Literal["mean", "sum", "none"]] = None,
zero_diagonal: Optional[bool] = None,
) -> Tensor:
r"""Calculate pairwise euclidean distances.
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/pairwise/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def _pairwise_linear_similarity_update(
def pairwise_linear_similarity(
x: Tensor,
y: Optional[Tensor] = None,
reduction: Literal["mean", "sum", "none", None] = None,
reduction: Optional[Literal["mean", "sum", "none"]] = None,
zero_diagonal: Optional[bool] = None,
) -> Tensor:
r"""Calculate pairwise linear similarity.
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/pairwise/manhattan.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def _pairwise_manhattan_distance_update(
def pairwise_manhattan_distance(
x: Tensor,
y: Optional[Tensor] = None,
reduction: Literal["mean", "sum", "none", None] = None,
reduction: Optional[Literal["mean", "sum", "none"]] = None,
zero_diagonal: Optional[bool] = None,
) -> Tensor:
r"""Calculate pairwise manhattan distance.
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/functional/pairwise/minkowski.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def pairwise_minkowski_distance(
x: Tensor,
y: Optional[Tensor] = None,
exponent: float = 2,
reduction: Literal["mean", "sum", "none", None] = None,
reduction: Optional[Literal["mean", "sum", "none"]] = None,
zero_diagonal: Optional[bool] = None,
) -> Tensor:
r"""Calculate pairwise minkowski distances.
Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/functional/regression/js_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union
from typing import Optional, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -53,7 +53,7 @@ def _jsd_update(p: Tensor, q: Tensor, log_prob: bool) -> tuple[Tensor, int]:


def _jsd_compute(
measures: Tensor, total: Union[int, Tensor], reduction: Literal["mean", "sum", "none", None] = "mean"
measures: Tensor, total: Union[int, Tensor], reduction: Optional[Literal["mean", "sum", "none"]] = "mean"
) -> Tensor:
"""Compute and reduce the Jensen-Shannon divergence based on the type of reduction."""
if reduction == "sum":
Expand All @@ -66,7 +66,7 @@ def _jsd_compute(


def jensen_shannon_divergence(
p: Tensor, q: Tensor, log_prob: bool = False, reduction: Literal["mean", "sum", "none", None] = "mean"
p: Tensor, q: Tensor, log_prob: bool = False, reduction: Optional[Literal["mean", "sum", "none"]] = "mean"
) -> Tensor:
r"""Compute `Jensen-Shannon divergence`_.

Expand Down
6 changes: 3 additions & 3 deletions src/torchmetrics/functional/regression/kl_divergence.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Union
from typing import Optional, Union

import torch
from torch import Tensor
Expand Down Expand Up @@ -48,7 +48,7 @@ def _kld_update(p: Tensor, q: Tensor, log_prob: bool) -> tuple[Tensor, int]:


def _kld_compute(
measures: Tensor, total: Union[int, Tensor], reduction: Literal["mean", "sum", "none", None] = "mean"
measures: Tensor, total: Union[int, Tensor], reduction: Optional[Literal["mean", "sum", "none"]] = "mean"
) -> Tensor:
"""Compute the KL divergenece based on the type of reduction.

Expand Down Expand Up @@ -80,7 +80,7 @@ def _kld_compute(


def kl_divergence(
p: Tensor, q: Tensor, log_prob: bool = False, reduction: Literal["mean", "sum", "none", None] = "mean"
p: Tensor, q: Tensor, log_prob: bool = False, reduction: Optional[Literal["mean", "sum", "none"]] = "mean"
) -> Tensor:
r"""Compute `KL divergence`_.

Expand Down
14 changes: 7 additions & 7 deletions src/torchmetrics/image/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class _ErrorRelativeGlobalDimensionlessSynthesis(ErrorRelativeGlobalDimensionles
def __init__(
self,
ratio: float = 4,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
**kwargs: Any,
) -> None:
_deprecated_root_import_class("ErrorRelativeGlobalDimensionlessSynthesis", "image")
Expand All @@ -54,12 +54,12 @@ def __init__(
gaussian_kernel: bool = True,
kernel_size: Union[int, Sequence[int]] = 11,
sigma: Union[float, Sequence[float]] = 1.5,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
data_range: Optional[Union[float, tuple[float, float]]] = None,
k1: float = 0.01,
k2: float = 0.03,
betas: tuple[float, ...] = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333),
normalize: Literal["relu", "simple", None] = "relu",
normalize: Optional[Literal["relu", "simple"]] = "relu",
**kwargs: Any,
) -> None:
_deprecated_root_import_class("MultiScaleStructuralSimilarityIndexMeasure", "image")
Expand Down Expand Up @@ -93,7 +93,7 @@ def __init__(
self,
data_range: Union[float, tuple[float, float]] = 3.0,
base: float = 10.0,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
dim: Optional[Union[int, tuple[int, ...]]] = None,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -200,7 +200,7 @@ def __init__(
gaussian_kernel: bool = True,
sigma: Union[float, Sequence[float]] = 1.5,
kernel_size: Union[int, Sequence[int]] = 11,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
data_range: Optional[Union[float, tuple[float, float]]] = None,
k1: float = 0.01,
k2: float = 0.03,
Expand Down Expand Up @@ -234,7 +234,7 @@ class _TotalVariation(TotalVariation):

"""

def __init__(self, reduction: Literal["mean", "sum", "none", None] = "sum", **kwargs: Any) -> None:
def __init__(self, reduction: Optional[Literal["mean", "sum", "none"]] = "sum", **kwargs: Any) -> None:
_deprecated_root_import_class("TotalVariation", "image")
super().__init__(reduction=reduction, **kwargs)

Expand All @@ -255,7 +255,7 @@ def __init__(
self,
kernel_size: Sequence[int] = (11, 11),
sigma: Sequence[float] = (1.5, 1.5),
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
**kwargs: Any,
) -> None:
_deprecated_root_import_class("UniversalImageQualityIndex", "image")
Expand Down
2 changes: 1 addition & 1 deletion src/torchmetrics/image/ergas.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class ErrorRelativeGlobalDimensionlessSynthesis(Metric):
def __init__(
self,
ratio: float = 4,
reduction: Literal["elementwise_mean", "sum", "none", None] = "elementwise_mean",
reduction: Optional[Literal["elementwise_mean", "sum", "none"]] = "elementwise_mean",
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
Expand Down
Loading
Loading