Skip to content

Commit d1ad7d8

Browse files
authored
fix: while with vectorized cond inside (#1474)
1 parent c6f8cec commit d1ad7d8

File tree

2 files changed

+38
-3
lines changed

2 files changed

+38
-3
lines changed

exla/test/exla/defn/expr_test.exs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1465,6 +1465,38 @@ defmodule EXLA.Defn.ExprTest do
14651465
assert_equal(while_inside_cond(1, 1), 1)
14661466
assert_equal(while_inside_cond(1, 0), 2)
14671467
end
1468+
1469+
defn cond_inside_while_vectorized(t, size) do
1470+
down = Nx.u8(0)
1471+
up = Nx.u8(1)
1472+
mode = down
1473+
i = Nx.s64(0)
1474+
1475+
[t, node, i, size, mode] =
1476+
Nx.broadcast_vectors([t, 0, i, size, mode])
1477+
1478+
{t, _} =
1479+
while {t, {node, i, _mode = mode, size, up}},
1480+
node != -1 and i >= 0 do
1481+
mode =
1482+
cond do
1483+
node >= size -> up
1484+
true -> -1
1485+
end
1486+
1487+
{t, {node, i - 1, mode, size, up}}
1488+
end
1489+
1490+
t
1491+
end
1492+
1493+
test "cond inside vectorized while" do
1494+
assert_raise CompileError,
1495+
~r/the do-block in while must return tensors with the same shape, type, and names as the initial arguments./,
1496+
fn ->
1497+
cond_inside_while_vectorized(Nx.vectorize(Nx.tensor([1, 2, 3]), :a), 3)
1498+
end
1499+
end
14681500
end
14691501

14701502
describe "map" do

nx/lib/nx/defn/expr.ex

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ defmodule Nx.Defn.Expr do
191191
last
192192
end
193193

194-
defp non_vectorized_cond(clauses, last = out, context) do
194+
defp non_vectorized_cond(clauses, out_template = last, context) do
195195
{preds, exprs} = Enum.unzip(clauses)
196196

197197
{broadcasted_clauses, vectorized_axes} =
@@ -211,10 +211,10 @@ defmodule Nx.Defn.Expr do
211211

212212
out =
213213
if vectorized_axes == [] do
214-
out
214+
out_template
215215
else
216216
{result, []} =
217-
Composite.traverse(out, vectorized_axes, fn
217+
Composite.traverse(out_template, vectorized_axes, fn
218218
%T{} = expr, [axes | tail] ->
219219
{%{expr | vectorized_axes: axes}, tail}
220220

@@ -538,6 +538,9 @@ defmodule Nx.Defn.Expr do
538538
{inner_arg, inner_context} = to_param_expr(inner_initial, :while)
539539
inner_condition = condition_body.(:condition, inner_arg) |> to_pred(line, file, :while)
540540
inner_body = condition_body.(:body, inner_arg) |> to_container_expr()
541+
542+
compatible_while!(file, line, inner_initial, inner_body)
543+
541544
inner_while = while(inner_initial, inner_context, inner_arg, inner_condition, inner_body)
542545

543546
vectorized_while__build_outer_while(

0 commit comments

Comments
 (0)