@@ -326,125 +326,10 @@ defmodule EMLX do
326326 @ behaviour Nx.Defn.Compiler
327327
328328 @ impl Nx.Defn.Compiler
329- def __jit__ ( key , vars , fun , args_list , opts ) do
330- __compile__ ( key , vars , fun , opts ) . ( args_list )
331- end
329+ defdelegate __jit__ ( key , vars , fun , args_list , opts ) , to: Nx.Defn.Evaluator
332330
333331 @ impl Nx.Defn.Compiler
334- def __compile__ ( key , vars , fun , opts ) do
335- backend = Nx . default_backend ( )
336-
337- target_backend =
338- case backend do
339- EMLX.Backend ->
340- backend
341-
342- { EMLX.Backend , _ } ->
343- backend
344-
345- Nx.BinaryBackend ->
346- EMLX.Backend
347-
348- { Nx.BinaryBackend , _ } ->
349- EMLX.Backend
350-
351- other ->
352- raise ArgumentError ,
353- "EMLX can only be used with the EMLX.Backend or Nx.BinaryBackend, got: #{ inspect ( other ) } "
354- end
355-
356- expr = fun . ( vars )
357-
358- fn [ args ] ->
359- { devices , nif_args } =
360- Enum . map ( args , fn arg ->
361- case arg . ( ) do
362- % Nx.Tensor { data: % EMLX.Backend { ref: { device , ref } } } ->
363- { device , ref }
364-
365- % Nx.Tensor { data: % Nx.BinaryBackend { } } = t ->
366- % Nx.Tensor { data: % EMLX.Backend { ref: { device , ref } } } =
367- Nx . backend_copy ( t , target_backend )
368-
369- { device , ref }
370-
371- other ->
372- % Nx.Tensor { data: % EMLX.Backend { ref: { device , ref } } } = Nx . to_tensor ( other )
373- { device , ref }
374- end
375- end )
376- |> Enum . unzip ( )
377-
378- device =
379- Enum . reduce_while ( devices , :cpu , fn
380- :gpu , _ ->
381- { :halt , :gpu }
382-
383- _ , acc ->
384- { :cont , acc }
385- end )
386-
387- cache_key = { __MODULE__ , :compiled_fun , key }
388-
389- compiled_fun =
390- case :persistent_term . get ( cache_key , :not_found ) do
391- :not_found ->
392- eval_fun = Nx.Defn.Evaluator . __compile__ ( key , vars , fun , opts )
393-
394- EMLX.NIF . set_compile ( true )
395-
396- evaluator_pid = Process . whereis ( EMLX.Runner )
397-
398- if not Process . alive? ( evaluator_pid ) do
399- raise "EMLX.Runner not alive"
400- end
401-
402- callback = fn args ->
403- args = Enum . map ( args , fn ref -> fn -> EMLX.Backend . to_nx ( { device , ref } ) end end )
404-
405- eval_fun . ( [ args ] )
406- |> Nx.Defn.Composite . flatten_list ( )
407- |> Enum . map ( fn % Nx.Tensor { data: % { ref: { _device , ref } } } -> ref end )
408- end
409-
410- fun = NifCall . run ( EMLX.Runner , callback , & nif_compile ( nif_args , & 1 ) )
411-
412- EMLX.NIF . set_compile ( false )
413-
414- :persistent_term . put ( cache_key , fun )
415- fun
416-
417- cached_fun ->
418- cached_fun
419- end
420-
421- nif_result =
422- case device do
423- :cpu -> EMLX.NIF . call_compiled_cpu ( compiled_fun , nif_args )
424- :gpu -> EMLX.NIF . call_compiled_gpu ( compiled_fun , nif_args )
425- end
426-
427- results =
428- nif_result
429- |> unwrap! ( )
430- |> Enum . map ( fn ref ->
431- EMLX.Backend . to_nx ( { device , ref } )
432- end )
433-
434- { result , [ ] } =
435- Nx.Defn.Composite . traverse ( expr , results , fn _node , [ h | t ] ->
436- { h , t }
437- end )
438-
439- [ result ]
440- end
441- end
442-
443- defp nif_compile ( nif_args , tag ) do
444- nif_args
445- |> EMLX.NIF . compile ( tag )
446- |> unwrap! ( )
447- end
332+ defdelegate __compile__ ( key , vars , fun , opts ) , to: Nx.Defn.Evaluator
448333
449334 @ impl Nx.Defn.Compiler
450335 defdelegate __partitions_options__ ( opts ) , to: Nx.Defn.Evaluator
0 commit comments