1010import torch
1111
1212def laplacian_pyramid_loss (
13- cortical_sheet : TensorType ["height" , "width" , "e" ], factor_w : float , factor_h : float
13+ cortical_sheet : TensorType ["height" , "width" , "e" ], factor_w : float , factor_h : float , interpolation : str = "bilinear"
1414):
1515 grid = cortical_sheet
1616 assert grid .ndim == 3 , "Expected grid to be a 3d tensor of shape (h, w, e)"
@@ -24,10 +24,10 @@ def laplacian_pyramid_loss(
2424 ), f"Expected factor_w to be <= grid.shape[2] = { grid .shape [2 ]} but got: { factor_w } "
2525 # Downscale the grid tensor
2626 downscaled_grid = F .interpolate (
27- grid , scale_factor = (1 / factor_h , 1 / factor_w ), mode = "bilinear"
27+ grid , scale_factor = (1 / factor_h , 1 / factor_w ), mode = interpolation
2828 )
2929 # Upscale the downscaled grid tensor
30- upscaled_grid = F .interpolate (downscaled_grid , size = grid .shape [2 :], mode = "bilinear" )
30+ upscaled_grid = F .interpolate (downscaled_grid , size = grid .shape [2 :], mode = interpolation )
3131
3232 # Calculate the MSE loss between the original grid and upscaled grid
3333 # loss = F.mse_loss(upscaled_grid, grid)
@@ -50,21 +50,23 @@ class LaplacianPyramid:
5050 layer_name : str
5151 factor_h : float
5252 factor_w : float
53+ interpolation : str = "bilinear"
5354 scale : Optional [Union [None , float ]] = field (default = 1.0 )
5455
5556 @classmethod
56- def from_layer (cls , model , layer , factor_h , factor_w , scale = 1.0 ):
57+ def from_layer (cls , model , layer , factor_h , factor_w , scale = 1.0 , interpolation : str = "bilinear" ):
5758 layer_name = get_name_by_layer (model = model , layer = layer )
5859 return cls (
5960 layer_name = layer_name ,
6061 scale = scale ,
6162 factor_h = factor_h ,
6263 factor_w = factor_w ,
64+ interpolation = interpolation
6365 )
6466
6567
6668def laplacian_pyramid_loss_on_bias (
67- cortical_sheet : TensorType ["h" , "w" ], factor_w : float , factor_h : float
69+ cortical_sheet : TensorType ["h" , "w" ], factor_w : float , factor_h : float , interpolation : str = "bilinear"
6870):
6971
7072 grid = cortical_sheet
@@ -81,10 +83,10 @@ def laplacian_pyramid_loss_on_bias(
8183 grid = grid .unsqueeze (0 ).unsqueeze (0 )
8284 # Downscale the grid tensor
8385 downscaled_grid = F .interpolate (
84- grid , scale_factor = (1 / factor_h , 1 / factor_w ), mode = "bilinear"
86+ grid , scale_factor = (1 / factor_h , 1 / factor_w ), mode = interpolation
8587 )
8688 # Upscale the downscaled grid tensor
87- upscaled_grid = F .interpolate (downscaled_grid , size = grid .shape [2 :], mode = "bilinear" )
89+ upscaled_grid = F .interpolate (downscaled_grid , size = grid .shape [2 :], mode = interpolation )
8890
8991 grid = rearrange (grid .squeeze (0 ).squeeze (0 ), "h w -> (h w)" ).unsqueeze (0 )
9092 upscaled_grid = rearrange (
@@ -107,10 +109,11 @@ class LaplacianPyramidOnBias:
107109 layer_name : str
108110 factor_h : float
109111 factor_w : float
112+ interpolation : str = "bilinear"
110113 scale : Optional [Union [None , float ]] = field (default = 1.0 )
111114
112115 @classmethod
113- def from_layer (cls , model , layer , factor_h , factor_w , scale = 1.0 ):
116+ def from_layer (cls , model , layer , factor_h , factor_w , scale = 1.0 , interpolation : str = "bilinear" ):
114117 assert (
115118 layer .bias is not None
116119 ), "Expected layer to have a bias, but got None. *sad sad sad*"
@@ -120,6 +123,7 @@ def from_layer(cls, model, layer, factor_h, factor_w, scale=1.0):
120123 scale = scale ,
121124 factor_h = factor_h ,
122125 factor_w = factor_w ,
126+ interpolation = interpolation
123127 )
124128
125129@dataclass
@@ -136,10 +140,11 @@ class LaplacianPyramidOnInput:
136140 layer_name : str
137141 factor_h : float
138142 factor_w : float
143+ interpolation : str = "bilinear"
139144 scale : Optional [Union [None , float ]] = field (default = 1.0 )
140145
141146 @classmethod
142- def from_layer (cls , model , layer , factor_h , factor_w , scale = 1.0 ):
147+ def from_layer (cls , model , layer , factor_h , factor_w , scale = 1.0 , interpolation : str = "bilinear" ):
143148 assert (
144149 layer .bias is not None
145150 ), "Expected layer to have a bias, but got None. *sad sad sad*"
@@ -149,4 +154,5 @@ def from_layer(cls, model, layer, factor_h, factor_w, scale=1.0):
149154 scale = scale ,
150155 factor_h = factor_h ,
151156 factor_w = factor_w ,
157+ interpolation = interpolation
152158 )
0 commit comments