Skip to content

Commit 28d2126

Browse files
committed
Cast multiple indexes on slice, closes #1472
1 parent adc6cd0 commit 28d2126

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

exla/lib/exla/defn.ex

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1184,6 +1184,15 @@ defmodule EXLA.Defn do
11841184
limit_indices = Enum.zip_with(start_indices, lengths, fn i, len -> i + len end)
11851185
Value.slice(tensor, start_indices, limit_indices, strides)
11861186
else
1187+
sample = Enum.find(start_indices, &(not is_integer(&1)))
1188+
1189+
type =
1190+
Enum.reduce(start_indices, op_type(sample), fn
1191+
index, acc when is_integer(index) -> acc
1192+
value, acc -> merge_type(op_type(value), acc)
1193+
end)
1194+
1195+
start_indices = Enum.map(start_indices, &to_type(&1, type))
11871196
zeros = List.duplicate(0, tuple_size(ans.shape))
11881197
slice = Value.dynamic_slice(tensor, start_indices, lengths)
11891198

exla/test/exla/backend_test.exs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,16 @@ defmodule EXLA.BackendTest do
181181
end
182182
end
183183

184+
describe "access" do
185+
test "multiple indexes" do
186+
tensor = Nx.eye({4, 4})
187+
index = Nx.u32(2)
188+
swap = Nx.s64(0)
189+
assert tensor[[index, swap]] |> Nx.to_number() == 0
190+
assert tensor[[0, swap]] |> Nx.to_number() == 1
191+
end
192+
end
193+
184194
test "conjugate" do
185195
assert inspect(Nx.conjugate(~V[1 2-0i 3+0i 0-i 0-2i])) =~
186196
"1.0-0.0i, 2.0+0.0i, 3.0-0.0i, 0.0+1.0i, 0.0+2.0i"

0 commit comments

Comments
 (0)