Skip to content

Commit cfa77a3

Browse files
Fix resize methods with kernel (#554)
1 parent 7a2e9bc commit cfa77a3

File tree

2 files changed

+88
-20
lines changed

2 files changed

+88
-20
lines changed

lib/axon/layers.ex

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2041,32 +2041,35 @@ defmodule Axon.Layers do
20412041
deftransformp resize_with_kernel(input, out_shape, spatial_axes, kernel_fun) do
20422042
for axis <- spatial_axes, reduce: input do
20432043
input ->
2044-
input_shape = Nx.shape(input)
2045-
input_size = elem(input_shape, axis)
2046-
output_size = elem(out_shape, axis)
2044+
resize_axis_with_kernel(input,
2045+
axis: axis,
2046+
output_size: elem(out_shape, axis),
2047+
kernel_fun: kernel_fun
2048+
)
2049+
end
2050+
end
20472051

2048-
inv_scale = input_size / output_size
2049-
kernel_scale = Nx.max(1, inv_scale)
2052+
defnp resize_axis_with_kernel(input, opts) do
2053+
axis = opts[:axis]
2054+
output_size = opts[:output_size]
2055+
kernel_fun = opts[:kernel_fun]
20502056

2051-
sample_f =
2052-
Nx.add(Nx.iota({1, output_size}), 0.5) |> Nx.multiply(Nx.subtract(inv_scale, 0.5))
2057+
input_size = Nx.axis_size(input, axis)
20532058

2054-
x = Nx.abs(Nx.subtract(sample_f, Nx.iota({input_size, 1}))) |> Nx.divide(kernel_scale)
2055-
weights = kernel_fun.(x)
2059+
inv_scale = input_size / output_size
2060+
kernel_scale = max(1, inv_scale)
20562061

2057-
weights_sum = Nx.sum(weights, axes: [0], keep_axes: true)
2062+
sample_f = (Nx.iota({1, output_size}) + 0.5) * inv_scale - 0.5
2063+
x = Nx.abs(sample_f - Nx.iota({input_size, 1})) / kernel_scale
2064+
weights = kernel_fun.(x)
20582065

2059-
weights =
2060-
Nx.select(
2061-
Nx.greater(Nx.abs(weights), 1000 * @f32_eps),
2062-
safe_divide(weights, weights_sum),
2063-
0
2064-
)
2066+
weights_sum = Nx.sum(weights, axes: [0], keep_axes: true)
20652067

2066-
input = Nx.dot(input, [axis], weights, [0])
2067-
# The transformed axis is moved to the end, so we transpose back
2068-
reorder_axis(input, -1, axis)
2069-
end
2068+
weights = Nx.select(Nx.abs(weights) > 1000 * @f32_eps, safe_divide(weights, weights_sum), 0)
2069+
2070+
input = Nx.dot(input, [axis], weights, [0])
2071+
# The transformed axis is moved to the end, so we transpose back
2072+
reorder_axis(input, -1, axis)
20702073
end
20712074

20722075
defnp fill_linear_kernel(x) do

test/axon/layers_test.exs

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -944,6 +944,71 @@ defmodule Axon.LayersTest do
944944
Axon.Layers.resize(inp)
945945
end
946946
end
947+
948+
# Adapted from NxImage
949+
test "methods" do
950+
# Reference values computed in jax
951+
952+
image = Nx.iota({1, 2, 2, 3}, type: :f32)
953+
954+
assert_equal(
955+
Axon.Layers.resize(image, size: {3, 3}, method: :nearest),
956+
Nx.tensor([
957+
[
958+
[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0], [3.0, 4.0, 5.0]],
959+
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0], [9.0, 10.0, 11.0]],
960+
[[6.0, 7.0, 8.0], [9.0, 10.0, 11.0], [9.0, 10.0, 11.0]]
961+
]
962+
])
963+
)
964+
965+
assert_equal(
966+
Axon.Layers.resize(image, size: {3, 3}, method: :bilinear),
967+
Nx.tensor([
968+
[
969+
[[0.0, 1.0, 2.0], [1.5, 2.5, 3.5], [3.0, 4.0, 5.0]],
970+
[[3.0, 4.0, 5.0], [4.5, 5.5, 6.5], [6.0, 7.0, 8.0]],
971+
[[6.0, 7.0, 8.0], [7.5, 8.5, 9.5], [9.0, 10.0, 11.0]]
972+
]
973+
])
974+
)
975+
976+
assert_all_close(
977+
Axon.Layers.resize(image, size: {3, 3}, method: :bicubic),
978+
Nx.tensor([
979+
[
980+
[[-0.5921, 0.4079, 1.4079], [1.1053, 2.1053, 3.1053], [2.8026, 3.8026, 4.8026]],
981+
[[2.8026, 3.8026, 4.8026], [4.5, 5.5, 6.5], [6.1974, 7.1974, 8.1974]],
982+
[[6.1974, 7.1974, 8.1974], [7.8947, 8.8947, 9.8947], [9.5921, 10.5921, 11.5921]]
983+
]
984+
]),
985+
atol: 1.0e-4
986+
)
987+
988+
assert_all_close(
989+
Axon.Layers.resize(image, size: {3, 3}, method: :lanczos3),
990+
Nx.tensor([
991+
[
992+
[[-1.1173, -0.1173, 0.8827], [0.7551, 1.7551, 2.7551], [2.6276, 3.6276, 4.6276]],
993+
[[2.6276, 3.6276, 4.6276], [4.5, 5.5, 6.5], [6.3724, 7.3724, 8.3724]],
994+
[[6.3724, 7.3724, 8.3724], [8.2449, 9.2449, 10.2449], [10.1173, 11.1173, 12.1173]]
995+
]
996+
]),
997+
atol: 1.0e-4
998+
)
999+
1000+
assert_all_close(
1001+
Axon.Layers.resize(image, size: {3, 3}, method: :lanczos5),
1002+
Nx.tensor([
1003+
[
1004+
[[-1.3525, -0.3525, 0.6475], [0.5984, 1.5984, 2.5984], [2.5492, 3.5492, 4.5492]],
1005+
[[2.5492, 3.5492, 4.5492], [4.5, 5.5, 6.5], [6.4508, 7.4508, 8.4508]],
1006+
[[6.4508, 7.4508, 8.4508], [8.4016, 9.4016, 10.4016], [10.3525, 11.3525, 12.3525]]
1007+
]
1008+
]),
1009+
atol: 1.0e-4
1010+
)
1011+
end
9471012
end
9481013

9491014
describe "lstm_cell" do

0 commit comments

Comments
 (0)