Skip to content

Commit ad28ea7

Browse files
authored
fix: add support for defn compilation in EXLA.to_mlir_module (#1530)
1 parent 2e140b7 commit ad28ea7

File tree

3 files changed

+82
-12
lines changed

3 files changed

+82
-12
lines changed

exla/lib/exla.ex

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -360,11 +360,18 @@ defmodule EXLA do
360360
Takes in a function, the argument templates and the compilation
361361
options and returns the textual representation of the MLIR module.
362362
363+
## Options
364+
365+
* `:within_defn_compiler` - a boolean that indicates whether
366+
this function is being called from within a `defn` compiler.
367+
Defaults to `false`.
368+
363369
## Examples
364370
365371
iex> fun = fn x, y -> Nx.add(Nx.sin(x), Nx.cos(y)) end
366372
iex> args = [1.0, 2.0]
367-
iex> EXLA.to_mlir_module(fun, args)
373+
iex> %{mlir_module: mlir_module} = EXLA.to_mlir_module(fun, args)
374+
iex> mlir_module
368375
"""
369376
module {
370377
func.func public @main(%arg0: tensor<f32>, %arg1: tensor<f32>) -> tensor<f32> {
@@ -377,20 +384,26 @@ defmodule EXLA do
377384
"""
378385
'''
379386
def to_mlir_module(function, args, options \\ []) do
380-
comp_fun = fn _key, callback ->
381-
{:ok, {_xla_time, executable, _extra, _outfeed}} = callback.()
382-
throw({:mlir_module, executable.ref})
383-
end
387+
{nested_compilation?, options} = Keyword.pop(options, :within_defn_compiler, false)
384388

385-
opts = [
386-
{EXLA, {&EXLA.Defn.LockedCache.run/2, comp_fun}},
387-
{:module_compilation, :to_mlir} | options
388-
]
389+
opts =
390+
Keyword.merge(options,
391+
module_compilation: :to_mlir,
392+
compiler: EXLA
393+
)
389394

390-
jit_apply(function, args, opts)
395+
if nested_compilation? do
396+
EXLA.Defn.__compile__(function, args, function, opts)
397+
else
398+
Nx.Defn.compile(function, args, opts)
399+
end
391400
catch
392-
{:mlir_module, ref} ->
393-
EXLA.MLIR.Module.as_string(%EXLA.MLIR.Module{ref: ref})
401+
{:mlir_module, ref, used_inputs, output_container} ->
402+
%{
403+
used_inputs: used_inputs,
404+
output_container: output_container,
405+
mlir_module: EXLA.MLIR.Module.as_string(%EXLA.MLIR.Module{ref: ref})
406+
}
394407
end
395408

396409
@doc """

exla/lib/exla/defn.ex

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,10 @@ defmodule EXLA.Defn do
228228
{executable, {used_inputs, outputs, outfeed, _input_typespecs?}} =
229229
compile(key, vars, fun, compile_options, 0, [], _stream = false, callback)
230230

231+
if compile_options[:module_compilation] == :to_mlir do
232+
throw({:mlir_module, executable.ref, MapSet.new(Map.keys(used_inputs)), outputs})
233+
end
234+
231235
fn [args] ->
232236
{time, lock} =
233237
:timer.tc(fn ->

exla/test/exla_test.exs

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,4 +17,57 @@ defmodule EXLATest do
1717
end
1818
end
1919
end
20+
21+
defmodule ValidCompiler do
22+
def __jit__(key, vars, fun, args_list, opts) do
23+
__compile__(key, vars, fun, opts).(args_list)
24+
end
25+
26+
def __compile__(_key, vars, fun, opts) do
27+
result = EXLA.to_mlir_module(fun, vars, Keyword.put(opts, :within_defn_compiler, true))
28+
throw({__MODULE__, result})
29+
end
30+
end
31+
32+
defmodule InvalidCompiler do
33+
def __jit__(key, vars, fun, args_list, opts) do
34+
__compile__(key, vars, fun, opts).(args_list)
35+
end
36+
37+
def __compile__(_key, vars, fun, opts) do
38+
# Keyword.delete to ensure default is false
39+
EXLA.to_mlir_module(fun, vars, Keyword.delete(opts, :within_defn_compiler))
40+
end
41+
end
42+
43+
describe "to_mlir_module/3" do
44+
test "fails if the compiler doesn't set the nested compilation flag" do
45+
assert_raise BadArityError, fn ->
46+
Nx.Defn.jit_apply(&Nx.add/2, [1, 2], compiler: __MODULE__.InvalidCompiler)
47+
end
48+
end
49+
50+
test "works if the compiler sets the nested compilation flag" do
51+
try do
52+
Nx.Defn.jit_apply(&Nx.add/2, [1, 2], compiler: __MODULE__.ValidCompiler)
53+
catch
54+
{__MODULE__.ValidCompiler, result} ->
55+
assert %{mlir_module: module, output_container: container, used_inputs: used_inputs} =
56+
result
57+
58+
assert module == """
59+
module {
60+
func.func public @main(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<i32> {
61+
%0 = stablehlo.add %arg0, %arg1 : tensor<i32>
62+
return %0 : tensor<i32>
63+
}
64+
}
65+
"""
66+
67+
assert Nx.compatible?(container, Nx.template({}, :s32))
68+
69+
assert MapSet.equal?(used_inputs, MapSet.new([0, 1]))
70+
end
71+
end
72+
end
2073
end

0 commit comments

Comments
 (0)