Skip to content

Commit 2723fe5

Browse files
committed
Removed unused import, added type signature of eps parameter and documentation on its usage.
Signed-off-by: Matteo Fasulo <[email protected]>
1 parent 6d8c637 commit 2723fe5

File tree

2 files changed

+7
-8
lines changed

2 files changed

+7
-8
lines changed

art/defences/preprocessor/variance_minimization_pytorch.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -167,17 +167,17 @@ def closure():
167167

168168
@staticmethod
169169
def _loss_func(
170-
z_init: "torch.Tensor", x: "torch.Tensor", mask: "torch.Tensor", norm: float, lamb: float, eps=1e-6
170+
z_init: "torch.Tensor", x: "torch.Tensor", mask: "torch.Tensor", norm: float, lamb: float, eps: float = 1e-6
171171
) -> "torch.Tensor":
172172
"""
173-
Loss function to be minimized - try to match SciPy implementation closely.
174-
175-
:param z_init: Initial guess.
173+
Calculate the total variance minimization loss function.
174+
:param z_init: Initial guess for the optimization.
176175
:param x: Original image.
177-
:param mask: A matrix that decides which points are kept.
178-
:param norm: The norm (positive integer).
176+
:param mask: Mask indicating which pixels to consider.
177+
:param norm: The norm to use (1, 2, or p).
179178
:param lamb: The lambda parameter in the objective function.
180-
:return: A single scalar loss value.
179+
:param eps: Small constant to avoid division by zero.
180+
:return: The total variance minimization loss.
181181
"""
182182
import torch
183183

tests/defences/preprocessor/test_variance_minimization_pytorch.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
import torch
66
import numpy as np
7-
from numpy.testing import assert_array_equal
87
import pytest
98

109
from art.defences.preprocessor.variance_minimization_pytorch import TotalVarMinPyTorch

0 commit comments

Comments
 (0)