Skip to content

Commit 495d57e

Browse files
authored
fix: disable compiler mode until nif_call fix" (#89)
* fix: disable compiler mode until nif_call fix * fix: remove nif call runner from app tree
1 parent 4e927d1 commit 495d57e

File tree

2 files changed

+3
-120
lines changed

2 files changed

+3
-120
lines changed

lib/emlx.ex

Lines changed: 2 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -326,125 +326,10 @@ defmodule EMLX do
326326
@behaviour Nx.Defn.Compiler
327327

328328
@impl Nx.Defn.Compiler
329-
def __jit__(key, vars, fun, args_list, opts) do
330-
__compile__(key, vars, fun, opts).(args_list)
331-
end
329+
defdelegate __jit__(key, vars, fun, args_list, opts), to: Nx.Defn.Evaluator
332330

333331
@impl Nx.Defn.Compiler
334-
def __compile__(key, vars, fun, opts) do
335-
backend = Nx.default_backend()
336-
337-
target_backend =
338-
case backend do
339-
EMLX.Backend ->
340-
backend
341-
342-
{EMLX.Backend, _} ->
343-
backend
344-
345-
Nx.BinaryBackend ->
346-
EMLX.Backend
347-
348-
{Nx.BinaryBackend, _} ->
349-
EMLX.Backend
350-
351-
other ->
352-
raise ArgumentError,
353-
"EMLX can only be used with the EMLX.Backend or Nx.BinaryBackend, got: #{inspect(other)}"
354-
end
355-
356-
expr = fun.(vars)
357-
358-
fn [args] ->
359-
{devices, nif_args} =
360-
Enum.map(args, fn arg ->
361-
case arg.() do
362-
%Nx.Tensor{data: %EMLX.Backend{ref: {device, ref}}} ->
363-
{device, ref}
364-
365-
%Nx.Tensor{data: %Nx.BinaryBackend{}} = t ->
366-
%Nx.Tensor{data: %EMLX.Backend{ref: {device, ref}}} =
367-
Nx.backend_copy(t, target_backend)
368-
369-
{device, ref}
370-
371-
other ->
372-
%Nx.Tensor{data: %EMLX.Backend{ref: {device, ref}}} = Nx.to_tensor(other)
373-
{device, ref}
374-
end
375-
end)
376-
|> Enum.unzip()
377-
378-
device =
379-
Enum.reduce_while(devices, :cpu, fn
380-
:gpu, _ ->
381-
{:halt, :gpu}
382-
383-
_, acc ->
384-
{:cont, acc}
385-
end)
386-
387-
cache_key = {__MODULE__, :compiled_fun, key}
388-
389-
compiled_fun =
390-
case :persistent_term.get(cache_key, :not_found) do
391-
:not_found ->
392-
eval_fun = Nx.Defn.Evaluator.__compile__(key, vars, fun, opts)
393-
394-
EMLX.NIF.set_compile(true)
395-
396-
evaluator_pid = Process.whereis(EMLX.Runner)
397-
398-
if not Process.alive?(evaluator_pid) do
399-
raise "EMLX.Runner not alive"
400-
end
401-
402-
callback = fn args ->
403-
args = Enum.map(args, fn ref -> fn -> EMLX.Backend.to_nx({device, ref}) end end)
404-
405-
eval_fun.([args])
406-
|> Nx.Defn.Composite.flatten_list()
407-
|> Enum.map(fn %Nx.Tensor{data: %{ref: {_device, ref}}} -> ref end)
408-
end
409-
410-
fun = NifCall.run(EMLX.Runner, callback, &nif_compile(nif_args, &1))
411-
412-
EMLX.NIF.set_compile(false)
413-
414-
:persistent_term.put(cache_key, fun)
415-
fun
416-
417-
cached_fun ->
418-
cached_fun
419-
end
420-
421-
nif_result =
422-
case device do
423-
:cpu -> EMLX.NIF.call_compiled_cpu(compiled_fun, nif_args)
424-
:gpu -> EMLX.NIF.call_compiled_gpu(compiled_fun, nif_args)
425-
end
426-
427-
results =
428-
nif_result
429-
|> unwrap!()
430-
|> Enum.map(fn ref ->
431-
EMLX.Backend.to_nx({device, ref})
432-
end)
433-
434-
{result, []} =
435-
Nx.Defn.Composite.traverse(expr, results, fn _node, [h | t] ->
436-
{h, t}
437-
end)
438-
439-
[result]
440-
end
441-
end
442-
443-
defp nif_compile(nif_args, tag) do
444-
nif_args
445-
|> EMLX.NIF.compile(tag)
446-
|> unwrap!()
447-
end
332+
defdelegate __compile__(key, vars, fun, opts), to: Nx.Defn.Evaluator
448333

449334
@impl Nx.Defn.Compiler
450335
defdelegate __partitions_options__(opts), to: Nx.Defn.Evaluator

lib/emlx/application.ex

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@ defmodule EMLX.Application do
22
use Application
33

44
def start(_type, _args) do
5-
children = [
6-
{NifCall.Runner, runner_opts: [nif_module: EMLX.NIF], name: EMLX.Runner}
7-
]
5+
children = []
86

97
Supervisor.start_link(children, strategy: :one_for_one, name: EMLX.Supervisor)
108
end

0 commit comments

Comments
 (0)