Skip to content

Commit 9220f46

Browse files
committed
interpolation arg
1 parent 620ee3e commit 9220f46

File tree

4 files changed

+28
-17
lines changed

4 files changed

+28
-17
lines changed

tests/test_laplacian_pyramid.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,15 @@
1515
@pytest.mark.parametrize("dtype2", supported_dtypes)
1616
@pytest.mark.parametrize("height, width", [(16, 16), (32, 32)])
1717
@pytest.mark.parametrize("factor_w, factor_h", [(2.0, 2.0), (3.0, 3.0), (4.0, 4.0)])
18-
def test_laplacian_pyramid_loss_precision(dtype1, dtype2, height, width, factor_w, factor_h):
18+
@pytest.mark.parametrize("interpolation", ["bilinear", "nearest"])
19+
def test_laplacian_pyramid_loss_precision(dtype1, dtype2, height, width, factor_w, factor_h, interpolation):
1920
# Create a sample cortical_sheet tensor
2021
e = 16 # Example depth
2122
torch.manual_seed(42) # Set seed for reproducibility
2223
cortical_sheet = torch.rand(height, width, e).to(dtype1)
2324

2425
# Call the function with the given precision
25-
loss = laplacian_pyramid_loss(cortical_sheet, factor_w=factor_w, factor_h=factor_h)
26+
loss = laplacian_pyramid_loss(cortical_sheet, factor_w=factor_w, factor_h=factor_h, interpolation=interpolation)
2627

2728
# Ensure the loss is finite
2829
assert torch.isfinite(loss), f"Loss is not finite for dtype {dtype1}"
@@ -33,7 +34,7 @@ def test_laplacian_pyramid_loss_precision(dtype1, dtype2, height, width, factor_
3334
# Compare results to float32 (considered ground truth for higher precision)
3435
if dtype1 !=dtype2:
3536
float32_sheet = cortical_sheet.to(dtype2)
36-
expected_loss = laplacian_pyramid_loss(cortical_sheet=float32_sheet, factor_w=factor_w, factor_h=factor_h)
37+
expected_loss = laplacian_pyramid_loss(cortical_sheet=float32_sheet, factor_w=factor_w, factor_h=factor_h, interpolation=interpolation)
3738
assert_close(
3839
loss.to(dtype2), expected_loss, rtol=1e-3, atol=1e-4,
3940
msg=f"Loss mismatch for dtype1 {dtype1} loss: {loss.to(dtype2)} and dtype2 {dtype2} loss2: {expected_loss}"

tests/test_loss_conv.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,9 @@
1515
@pytest.mark.parametrize("hidden_channels", [16, 32])
1616
@pytest.mark.parametrize("init_from_layer", [True, False])
1717
@pytest.mark.parametrize("dtype", supported_dtypes)
18+
@pytest.mark.parametrize("interpolation", ["bilinear", "nearest"])
1819
def test_loss_conv(
19-
num_steps: int, hidden_channels: int, init_from_layer: bool, dtype
20+
num_steps: int, hidden_channels: int, init_from_layer: bool, dtype, interpolation: str
2021
): # num_steps is now passed by the fixture
2122

2223
# Define the model
@@ -30,16 +31,16 @@ def test_loss_conv(
3031
if init_from_layer:
3132
losses = [
3233
LaplacianPyramid.from_layer(
33-
model=model, layer=model[0], scale=1.0, factor_h=3.0, factor_w=3.0
34+
model=model, layer=model[0], scale=1.0, factor_h=3.0, factor_w=3.0, interpolation=interpolation
3435
),
3536
LaplacianPyramid.from_layer(
36-
model=model, layer=model[2], scale=1.0, factor_h=3.0, factor_w=3.0
37+
model=model, layer=model[2], scale=1.0, factor_h=3.0, factor_w=3.0, interpolation=interpolation
3738
),
3839
]
3940
else:
4041
losses = [
41-
LaplacianPyramid(layer_name="0", scale=1.0, factor_h=3.0, factor_w=3.0),
42-
LaplacianPyramid(layer_name="2", scale=1.0, factor_h=3.0, factor_w=3.0),
42+
LaplacianPyramid(layer_name="0", scale=1.0, factor_h=3.0, factor_w=3.0, interpolation=interpolation),
43+
LaplacianPyramid(layer_name="2", scale=1.0, factor_h=3.0, factor_w=3.0, interpolation=interpolation),
4344
]
4445

4546
# Define the TopoLoss

topoloss/core.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def get_layerwise_topo_losses(self, model, do_scaling: bool = True) -> dict:
4444
cortical_sheet=cortical_sheet,
4545
factor_h=loss_info.factor_h,
4646
factor_w=loss_info.factor_w,
47+
interpolation=loss_info.interpolation,
4748
)
4849
elif isinstance(loss_info, LaplacianPyramidOnBias):
4950
assert isinstance(
@@ -61,6 +62,7 @@ def get_layerwise_topo_losses(self, model, do_scaling: bool = True) -> dict:
6162
cortical_sheet=cortical_sheet,
6263
factor_h=loss_info.factor_h,
6364
factor_w=loss_info.factor_w,
65+
interpolation=loss_info.interpolation,
6466
)
6567

6668
elif isinstance(loss_info, LaplacianPyramidOnInput):
@@ -72,6 +74,7 @@ def get_layerwise_topo_losses(self, model, do_scaling: bool = True) -> dict:
7274
cortical_sheet=cortical_sheet,
7375
factor_h=loss_info.factor_h,
7476
factor_w=loss_info.factor_w,
77+
interpolation=loss_info.interpolation,
7578
)
7679

7780
if do_scaling:

topoloss/losses/laplacian_pyramid.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
import torch
1111

1212
def laplacian_pyramid_loss(
13-
cortical_sheet: TensorType["height", "width", "e"], factor_w: float, factor_h: float
13+
cortical_sheet: TensorType["height", "width", "e"], factor_w: float, factor_h: float, interpolation: str = "bilinear"
1414
):
1515
grid = cortical_sheet
1616
assert grid.ndim == 3, "Expected grid to be a 3d tensor of shape (h, w, e)"
@@ -24,10 +24,10 @@ def laplacian_pyramid_loss(
2424
), f"Expected factor_w to be <= grid.shape[2] = {grid.shape[2]} but got: {factor_w}"
2525
# Downscale the grid tensor
2626
downscaled_grid = F.interpolate(
27-
grid, scale_factor=(1 / factor_h, 1 / factor_w), mode="bilinear"
27+
grid, scale_factor=(1 / factor_h, 1 / factor_w), mode=interpolation
2828
)
2929
# Upscale the downscaled grid tensor
30-
upscaled_grid = F.interpolate(downscaled_grid, size=grid.shape[2:], mode="bilinear")
30+
upscaled_grid = F.interpolate(downscaled_grid, size=grid.shape[2:], mode=interpolation)
3131

3232
# Calculate the MSE loss between the original grid and upscaled grid
3333
# loss = F.mse_loss(upscaled_grid, grid)
@@ -50,21 +50,23 @@ class LaplacianPyramid:
5050
layer_name: str
5151
factor_h: float
5252
factor_w: float
53+
interpolation: str = "bilinear"
5354
scale: Optional[Union[None, float]] = field(default=1.0)
5455

5556
@classmethod
56-
def from_layer(cls, model, layer, factor_h, factor_w, scale=1.0):
57+
def from_layer(cls, model, layer, factor_h, factor_w, scale=1.0, interpolation: str ="bilinear"):
5758
layer_name = get_name_by_layer(model=model, layer=layer)
5859
return cls(
5960
layer_name=layer_name,
6061
scale=scale,
6162
factor_h=factor_h,
6263
factor_w=factor_w,
64+
interpolation=interpolation
6365
)
6466

6567

6668
def laplacian_pyramid_loss_on_bias(
67-
cortical_sheet: TensorType["h", "w"], factor_w: float, factor_h: float
69+
cortical_sheet: TensorType["h", "w"], factor_w: float, factor_h: float, interpolation: str = "bilinear"
6870
):
6971

7072
grid = cortical_sheet
@@ -81,10 +83,10 @@ def laplacian_pyramid_loss_on_bias(
8183
grid = grid.unsqueeze(0).unsqueeze(0)
8284
# Downscale the grid tensor
8385
downscaled_grid = F.interpolate(
84-
grid, scale_factor=(1 / factor_h, 1 / factor_w), mode="bilinear"
86+
grid, scale_factor=(1 / factor_h, 1 / factor_w), mode=interpolation
8587
)
8688
# Upscale the downscaled grid tensor
87-
upscaled_grid = F.interpolate(downscaled_grid, size=grid.shape[2:], mode="bilinear")
89+
upscaled_grid = F.interpolate(downscaled_grid, size=grid.shape[2:], mode=interpolation)
8890

8991
grid = rearrange(grid.squeeze(0).squeeze(0), "h w -> (h w)").unsqueeze(0)
9092
upscaled_grid = rearrange(
@@ -107,10 +109,11 @@ class LaplacianPyramidOnBias:
107109
layer_name: str
108110
factor_h: float
109111
factor_w: float
112+
interpolation: str = "bilinear"
110113
scale: Optional[Union[None, float]] = field(default=1.0)
111114

112115
@classmethod
113-
def from_layer(cls, model, layer, factor_h, factor_w, scale=1.0):
116+
def from_layer(cls, model, layer, factor_h, factor_w, scale=1.0, interpolation: str ="bilinear"):
114117
assert (
115118
layer.bias is not None
116119
), "Expected layer to have a bias, but got None. *sad sad sad*"
@@ -120,6 +123,7 @@ def from_layer(cls, model, layer, factor_h, factor_w, scale=1.0):
120123
scale=scale,
121124
factor_h=factor_h,
122125
factor_w=factor_w,
126+
interpolation=interpolation
123127
)
124128

125129
@dataclass
@@ -136,10 +140,11 @@ class LaplacianPyramidOnInput:
136140
layer_name: str
137141
factor_h: float
138142
factor_w: float
143+
interpolation: str = "bilinear"
139144
scale: Optional[Union[None, float]] = field(default=1.0)
140145

141146
@classmethod
142-
def from_layer(cls, model, layer, factor_h, factor_w, scale=1.0):
147+
def from_layer(cls, model, layer, factor_h, factor_w, scale=1.0, interpolation: str ="bilinear"):
143148
assert (
144149
layer.bias is not None
145150
), "Expected layer to have a bias, but got None. *sad sad sad*"
@@ -149,4 +154,5 @@ def from_layer(cls, model, layer, factor_h, factor_w, scale=1.0):
149154
scale=scale,
150155
factor_h=factor_h,
151156
factor_w=factor_w,
157+
interpolation=interpolation
152158
)

0 commit comments

Comments
 (0)