Skip to content

Commit 7990b7e

Browse files
feat: use func.return when returning from func.func (#1495)
Co-authored-by: Jonatan Kłosko <[email protected]>
1 parent e0ed58a commit 7990b7e

File tree

4 files changed

+74
-47
lines changed

4 files changed

+74
-47
lines changed

exla/lib/exla/defn.ex

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ defmodule EXLA.Defn do
237237
output = wrap_tuple_result(acc, acc_typespec)
238238

239239
outfeed = outfeed |> Outfeed.with_token(out_token) |> Outfeed.close(builder)
240-
Value.return(builder, output)
240+
Value.func_return(builder, output)
241241

242242
{{input_typespecs, input_indexes}, outfeed}
243243
end
@@ -307,7 +307,7 @@ defmodule EXLA.Defn do
307307
{res, cache} = recur_flatten(expr, state, new_cache(outfeed))
308308
outfeed = cache |> get_outfeed() |> Outfeed.close(function)
309309

310-
Value.return(function, res)
310+
Value.func_return(function, res)
311311

312312
{:ok, outfeed}
313313
end
@@ -433,6 +433,15 @@ defmodule EXLA.Defn do
433433
comp_arg_typespecs =
434434
for {i, typespec} <- inputs_and_typespecs, i >= used_buffers, do: typespec
435435

436+
outputs =
437+
if stream? do
438+
# The computation returns the final accumulator value
439+
{_chunk_result, acc} = outputs
440+
acc
441+
else
442+
outputs
443+
end
444+
436445
out_typespecs =
437446
[outputs]
438447
|> Nx.Defn.Composite.flatten_list()
@@ -1669,9 +1678,9 @@ defmodule EXLA.Defn do
16691678
{res, comp_cache} = recur_composite(expr, state, reset_token(cache, inner_token))
16701679

16711680
if outer_token do
1672-
Value.return(function, [get_token(comp_cache) | List.flatten(res)])
1681+
Value.func_return(function, [get_token(comp_cache) | List.flatten(res)])
16731682
else
1674-
Value.return(function, List.flatten(res))
1683+
Value.func_return(function, List.flatten(res))
16751684
end
16761685

16771686
{function, merge_outfeed(cache, comp_cache)}

exla/lib/exla/mlir/value.ex

Lines changed: 58 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -125,57 +125,71 @@ defmodule EXLA.MLIR.Value do
125125
end
126126
end
127127

128-
def is_infinity(%Value{function: func} = operand, typespec) do
128+
def is_infinity(%Value{function: func} = operand, out_typespec) do
129129
%{type: type} = get_typespec(operand)
130130

131-
typespec = Typespec.to_type(typespec, {:pred, 8})
131+
typespec = Typespec.to_type(out_typespec, {:pred, 8})
132132

133-
cond do
134-
Nx.Type.complex?(type) ->
135-
float_typespec = Typespec.to_type(typespec, complex_part_type(type))
136-
real = real(operand, float_typespec)
137-
imag = imag(operand, float_typespec)
138-
is_inf_real = is_infinity(real, typespec)
139-
is_inf_imag = is_infinity(imag, typespec)
140-
bitwise_or(is_inf_real, is_inf_imag, typespec)
141-
142-
Nx.Type.integer?(type) ->
143-
# Integers are never infinity. We use inequality to make sure
144-
# the operand is still a part of the computation
145-
not_equal(operand, operand, typespec)
133+
result =
134+
cond do
135+
Nx.Type.complex?(type) ->
136+
float_typespec = Typespec.to_type(typespec, complex_part_type(type))
137+
real = real(operand, float_typespec)
138+
imag = imag(operand, float_typespec)
139+
is_inf_real = is_infinity(real, typespec)
140+
is_inf_imag = is_infinity(imag, typespec)
141+
bitwise_or(is_inf_real, is_inf_imag, typespec)
142+
143+
Nx.Type.integer?(type) ->
144+
# Integers are never infinity. We use inequality to make sure
145+
# the operand is still a part of the computation
146+
not_equal(operand, operand, typespec)
147+
148+
true ->
149+
result_types = typespecs_to_mlir_types([typespec])
150+
op(func, "chlo.is_inf", [operand], result_types) |> one!()
151+
end
146152

147-
true ->
148-
result_types = typespecs_to_mlir_types([typespec])
149-
op(func, "chlo.is_inf", [operand], result_types) |> one!()
153+
if out_typespec.type == typespec.type do
154+
result
155+
else
156+
convert(result, out_typespec)
150157
end
151158
end
152159

153-
def is_nan(%Value{function: func} = operand, typespec) do
160+
def is_nan(%Value{function: func} = operand, out_typespec) do
154161
%{type: type} = get_typespec(operand)
155162

156-
typespec = Typespec.to_type(typespec, {:pred, 8})
163+
typespec = Typespec.to_type(out_typespec, {:pred, 8})
157164

158-
cond do
159-
Nx.Type.complex?(type) ->
160-
float_typespec = Typespec.to_type(typespec, complex_part_type(type))
161-
real = real(operand, float_typespec)
162-
imag = imag(operand, float_typespec)
163-
is_nan_real = is_nan(real, typespec)
164-
is_nan_imag = is_nan(imag, typespec)
165-
bitwise_or(is_nan_real, is_nan_imag, typespec)
166-
167-
Nx.Type.integer?(type) ->
168-
# Integers are never nan. We use inequality to make sure
169-
# the operand is still a part of the computation
170-
not_equal(operand, operand, typespec)
165+
result =
166+
cond do
167+
Nx.Type.complex?(type) ->
168+
float_typespec = Typespec.to_type(typespec, complex_part_type(type))
169+
real = real(operand, float_typespec)
170+
imag = imag(operand, float_typespec)
171+
is_nan_real = is_nan(real, typespec)
172+
is_nan_imag = is_nan(imag, typespec)
173+
bitwise_or(is_nan_real, is_nan_imag, typespec)
174+
175+
Nx.Type.integer?(type) ->
176+
# Integers are never nan. We use inequality to make sure
177+
# the operand is still a part of the computation
178+
not_equal(operand, operand, typespec)
179+
180+
true ->
181+
result_types = typespecs_to_mlir_types([typespec])
182+
is_inf = op(func, "chlo.is_inf", [operand], result_types) |> one!()
183+
is_finite = op(func, "stablehlo.is_finite", [operand], result_types) |> one!()
184+
is_not_inf = bitwise_not(is_inf, typespec)
185+
is_not_finite = bitwise_not(is_finite, typespec)
186+
bitwise_and(is_not_inf, is_not_finite, typespec)
187+
end
171188

172-
true ->
173-
result_types = typespecs_to_mlir_types([typespec])
174-
is_inf = op(func, "chlo.is_inf", [operand], result_types) |> one!()
175-
is_finite = op(func, "stablehlo.is_finite", [operand], result_types) |> one!()
176-
is_not_inf = bitwise_not(is_inf, typespec)
177-
is_not_finite = bitwise_not(is_finite, typespec)
178-
bitwise_and(is_not_inf, is_not_finite, typespec)
189+
if out_typespec.type == typespec.type do
190+
result
191+
else
192+
convert(result, out_typespec)
179193
end
180194
end
181195

@@ -706,6 +720,10 @@ defmodule EXLA.MLIR.Value do
706720
op(func, "stablehlo.while", initial, result_types, regions: regions)
707721
end
708722

723+
def func_return(func, values) when is_list(values) do
724+
op(func, "func.return", values, [])
725+
end
726+
709727
def return(func, values) when is_list(values) do
710728
op(func, "stablehlo.return", values, [])
711729
end

exla/test/exla/executable_test.exs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,7 @@ defmodule EXLA.ExecutableFeedTest do
160160

161161
assert res =
162162
Task.async(fn ->
163-
run_one([], [], [Typespec.token()], fn b ->
163+
run_one([], [], [t.typespec], fn b ->
164164
token = Value.create_token(b)
165165

166166
{new_token, [val]} = Value.infeed(token, [t.typespec])
@@ -185,7 +185,7 @@ defmodule EXLA.ExecutableFeedTest do
185185

186186
assert res =
187187
Task.async(fn ->
188-
run_one([], [], [token_shape, t.typespec], fn b ->
188+
run_one([], [], [t.typespec], fn b ->
189189
token = Value.create_token(b)
190190

191191
arg_shapes = [token_shape, t.typespec]

exla/test/support/exla_helpers.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ defmodule EXLAHelpers do
1515

1616
fun
1717
|> apply([builder | params])
18-
|> then(&EXLA.MLIR.Value.return(builder, List.wrap(&1)))
18+
|> then(&EXLA.MLIR.Value.func_return(builder, List.wrap(&1)))
1919

2020
EXLA.MLIR.Module.compile(
2121
builder.module,

0 commit comments

Comments
 (0)