Skip to content

Commit d0d3665

Browse files
authored
fix: device in window_scatter_function now uses from_nx (#1640)
1 parent 7591682 commit d0d3665

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

torchx/lib/torchx/backend.ex

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1528,13 +1528,15 @@ defmodule Torchx.Backend do
15281528
|> then(unfold_flat)
15291529
|> then(function)
15301530

1531+
{device, _} = from_nx(tensor)
1532+
15311533
indices_to_flatten =
15321534
tensor
15331535
|> Nx.axes()
15341536
|> Enum.map(fn axis ->
15351537
tensor
15361538
|> Nx.shape()
1537-
|> Nx.iota(axis: axis, backend: Torchx.Backend)
1539+
|> Nx.iota(axis: axis, backend: {Torchx.Backend, device: device})
15381540
|> then(unfold_flat)
15391541
|> Nx.take_along_axis(Nx.new_axis(arg_idx, -1), axis: -1)
15401542
end)

torchx/test/torchx/device_test.exs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,12 @@ defmodule Torchx.DeviceTest do
4545
# assert_raise ArgumentError, fn -> Nx.backend_transfer(t) end
4646
end
4747
end
48+
49+
describe "indices_to_flatten" do
50+
test "works" do
51+
t = Nx.tensor([[1, 2], [3, 4]], backend: {TB, device: @device})
52+
t2 = Nx.tensor([[2, 6], [3, 1]], backend: {TB, device: @device})
53+
assert_equal Nx.window_scatter_max(t, t2, 0, {2, 3}), Nx.tensor([[0, 0, 0, 0, 6, 0], [0, 0, 2, 0, 0, 0], [0, 0, 3, 0, 0, 0], [0, 0, 0, 0, 0, 1]], backend: {TB, device: @device})
54+
end
55+
end
4856
end

0 commit comments

Comments
 (0)