@@ -52,7 +52,7 @@ defmodule EXLA.Defn do
52
52
comp_fun =
53
53
& to_stream_computation ( client , input_length , acc_length , & 1 , & 2 , & 3 , & 4 , compile_options )
54
54
55
- { executable , used_inputs , { output , acc_output } , outfeed , input_typespecs } =
55
+ { executable , { used_inputs , { output , acc_output } , outfeed , input_typespecs } } =
56
56
compile (
57
57
client ,
58
58
key ,
@@ -84,9 +84,7 @@ defmodule EXLA.Defn do
84
84
EXLA.Defn.Lock . lock ( run_key ( executable ) )
85
85
end )
86
86
87
- if debug? do
88
- Logger . debug ( "EXLA device #{ executable . device_id } lock in #{ us_to_ms ( time ) } ms" )
89
- end
87
+ debug? && Logger . debug ( "EXLA device #{ executable . device_id } lock in #{ us_to_ms ( time ) } ms" )
90
88
91
89
{ time , streams } =
92
90
:timer . tc ( fn ->
@@ -131,9 +129,8 @@ defmodule EXLA.Defn do
131
129
[ stream ]
132
130
end )
133
131
134
- if debug? do
132
+ debug? &&
135
133
Logger . debug ( "EXLA stream start on device #{ executable . device_id } in #{ us_to_ms ( time ) } ms" )
136
- end
137
134
138
135
streams
139
136
end
@@ -250,7 +247,7 @@ defmodule EXLA.Defn do
250
247
251
248
callback = & to_root_computation ( & 1 , & 2 , & 3 , & 4 , Keyword . put ( compile_options , :client , client ) )
252
249
253
- { executable , used_inputs , outputs , outfeed , _input_typespecs? } =
250
+ { executable , { used_inputs , outputs , outfeed , _input_typespecs? } } =
254
251
compile ( client , key , vars , fun , compile_options , 0 , [ ] , _stream = false , debug? , callback )
255
252
256
253
fn [ args ] ->
@@ -259,18 +256,15 @@ defmodule EXLA.Defn do
259
256
EXLA.Defn.Lock . lock ( run_key ( executable ) )
260
257
end )
261
258
262
- if debug? do
263
- Logger . debug ( "EXLA device #{ executable . device_id } lock in #{ us_to_ms ( time ) } ms" )
264
- end
259
+ debug? && Logger . debug ( "EXLA device #{ executable . device_id } lock in #{ us_to_ms ( time ) } ms" )
265
260
266
261
{ time , res } =
267
262
:timer . tc ( fn ->
268
263
maybe_outfeed ( lock , executable , args , used_inputs , outputs , outfeed , run_options )
269
264
end )
270
265
271
- if debug? do
266
+ debug? &&
272
267
Logger . debug ( "EXLA execution on device #{ executable . device_id } in #{ us_to_ms ( time ) } ms" )
273
- end
274
268
275
269
res
276
270
end
@@ -360,15 +354,9 @@ defmodule EXLA.Defn do
360
354
debug? ,
361
355
to_computation
362
356
) do
363
- { { expr_cache_fun , comp_cache_fun } , options } =
364
- case Keyword . pop ( options , :cache , true ) do
365
- { true , options } ->
366
- Keyword . pop ( options , EXLA , { & EXLA.Defn.LockedCache . run / 2 , & EXLA.Defn.LockedCache . run / 2 } )
367
-
368
- { false , options } ->
369
- cache_fun = fn _key , fun -> fun . ( ) end
370
- { { cache_fun , cache_fun } , options }
371
- end
357
+ { cache , options } = Keyword . pop ( options , :cache , true )
358
+ { hooks , options } = Keyword . pop ( options , :hooks , % { } )
359
+ { lazy_transfers , options } = Keyword . pop ( options , :lazy_transfers , :opt_in )
372
360
373
361
{ args_key , reverse_args_identifiers } =
374
362
Enum . map_reduce ( vars , [ ] , fn var , acc ->
@@ -381,119 +369,134 @@ defmodule EXLA.Defn do
381
369
end )
382
370
end )
383
371
384
- { lazy_transfers , options } = Keyword . pop ( options , :lazy_transfers , :opt_in )
372
+ disk_key = % {
373
+ client: client . name ,
374
+ args: args_key ,
375
+ lazy_transfers: lazy_transfers ,
376
+ hooks: Map . keys ( hooks ) ,
377
+ options: options
378
+ }
385
379
386
- { eval_time , { expr , { ref , outputs , { used_inputs , defined_hooks } } } } =
387
- :timer . tc ( fn ->
388
- expr_cache_fun . ( { key , stream? , args_key , lazy_transfers } , fn ->
389
- expr = fun . ( vars )
390
- inputs_and_hooks = Outfeed . used_inputs_and_hooks ( expr , used_inputs , lazy_transfers )
391
- { expr , { make_ref ( ) , Nx . to_template ( expr ) , inputs_and_hooks } }
380
+ EXLA.Defn.Disk . cache ( cache , client , disk_key , debug? , fn ->
381
+ { { expr_cache_fun , comp_cache_fun } , options } =
382
+ if cache do
383
+ Keyword . pop ( options , EXLA , { & EXLA.Defn.LockedCache . run / 2 , & EXLA.Defn.LockedCache . run / 2 } )
384
+ else
385
+ cache_fun = fn _key , fun -> fun . ( ) end
386
+ { { cache_fun , cache_fun } , Keyword . delete ( options , EXLA ) }
387
+ end
388
+
389
+ { eval_time , { expr , { ref , outputs , { used_inputs , defined_hooks } } } } =
390
+ :timer . tc ( fn ->
391
+ expr_cache_fun . ( { key , stream? , args_key , lazy_transfers } , fn ->
392
+ expr = fun . ( vars )
393
+ inputs_and_hooks = Outfeed . used_inputs_and_hooks ( expr , used_inputs , lazy_transfers )
394
+ { expr , { make_ref ( ) , Nx . to_template ( expr ) , inputs_and_hooks } }
395
+ end )
392
396
end )
393
- end )
394
397
395
- if debug? do
396
- hit_or_miss = if expr , do: "miss" , else: "hit"
398
+ if debug? do
399
+ hit_or_miss = if expr , do: "miss" , else: "hit"
397
400
398
- Logger . debug (
399
- "EXLA defn evaluation #{ inspect ( key ) } cache #{ hit_or_miss } in #{ us_to_ms ( eval_time ) } ms"
400
- )
401
- end
401
+ Logger . debug (
402
+ "EXLA defn evaluation #{ inspect ( key ) } cache #{ hit_or_miss } in #{ us_to_ms ( eval_time ) } ms"
403
+ )
404
+ end
402
405
403
- { hooks , options } = Keyword . pop ( options , :hooks , % { } )
404
- outfeed = Outfeed . new ( hooks , defined_hooks )
405
- comp_key = { ref , client . name , outfeed . used_hooks , lazy_transfers , options }
406
+ outfeed = Outfeed . new ( hooks , defined_hooks )
407
+ comp_key = { ref , client . name , outfeed . used_hooks , lazy_transfers , options }
406
408
407
- { comp_time , { evaled , { xla_time , executable , inputs_and_typespecs , outfeed } } } =
408
- :timer . tc ( fn ->
409
- comp_cache_fun . ( comp_key , fn ->
410
- { reverse_inputs_and_typespecs , reverse_infeeds } =
411
- reverse_args_identifiers
412
- |> Enum . reverse ( )
413
- |> EXLA.Defn.Buffers . split_by_value ( used_inputs , fn
414
- { type , shape , _names } , i , nil -> { i , Typespec . tensor ( type , shape ) }
415
- { type , shape , _names } , i , depth -> { i , depth , Typespec . tensor ( type , shape ) }
416
- end )
409
+ { comp_time , { evaled , { xla_time , executable , inputs_and_typespecs , outfeed } } } =
410
+ :timer . tc ( fn ->
411
+ comp_cache_fun . ( comp_key , fn ->
412
+ { reverse_inputs_and_typespecs , reverse_infeeds } =
413
+ reverse_args_identifiers
414
+ |> Enum . reverse ( )
415
+ |> EXLA.Defn.Buffers . split_by_value ( used_inputs , fn
416
+ { type , shape , _names } , i , nil -> { i , Typespec . tensor ( type , shape ) }
417
+ { type , shape , _names } , i , depth -> { i , depth , Typespec . tensor ( type , shape ) }
418
+ end )
417
419
418
- inputs_and_typespecs = Enum . reverse ( reverse_inputs_and_typespecs )
419
-
420
- comp_typespecs =
421
- for { i , typespec } <- inputs_and_typespecs , i >= used_buffers , do: typespec
422
-
423
- outputs =
424
- if stream? do
425
- # The computation returns the final accumulator value
426
- { _chunk_result , acc } = outputs
427
- acc
428
- else
429
- outputs
430
- end
431
-
432
- out_typespecs =
433
- [ outputs ]
434
- |> Nx.Defn.Composite . flatten_list ( )
435
- |> Enum . map ( fn t ->
436
- t
437
- |> Nx . devectorize ( )
438
- |> then ( & Typespec . tensor ( & 1 . type , & 1 . shape ) )
439
- end )
420
+ inputs_and_typespecs = Enum . reverse ( reverse_inputs_and_typespecs )
421
+
422
+ comp_typespecs =
423
+ for { i , typespec } <- inputs_and_typespecs , i >= used_buffers , do: typespec
440
424
441
- EXLA.MLIR.Module . new ( comp_typespecs , out_typespecs , fn builder ->
442
- # Only create the token when we know it will actually be
443
- # used, that is: streaming, lazy transfers or hooks
444
- outfeed =
445
- if stream? or reverse_infeeds != [ ] or hooks != % { } or defined_hooks != % { } do
446
- outfeed
447
- |> Outfeed . with_token ( Value . create_token ( builder ) )
448
- |> Outfeed . add_infeeds ( builder , reverse_infeeds )
425
+ outputs =
426
+ if stream? do
427
+ # The computation returns the final accumulator value
428
+ { _chunk_result , acc } = outputs
429
+ acc
449
430
else
450
- outfeed
431
+ outputs
451
432
end
452
433
453
- expr = Nx.Defn.Composite . traverse ( expr || fun . ( vars ) , & Nx . devectorize / 1 )
454
- outfeed = to_computation . ( builder , expr , inputs_and_typespecs , outfeed )
455
-
456
- { xla_time , executable } =
457
- :timer . tc ( fn ->
458
- EXLA.MLIR.Module . compile (
459
- builder . module ,
460
- client ,
461
- comp_typespecs ,
462
- builder . return_typespecs ,
463
- options
464
- )
434
+ out_typespecs =
435
+ [ outputs ]
436
+ |> Nx.Defn.Composite . flatten_list ( )
437
+ |> Enum . map ( fn t ->
438
+ t
439
+ |> Nx . devectorize ( )
440
+ |> then ( & Typespec . tensor ( & 1 . type , & 1 . shape ) )
465
441
end )
466
442
467
- { :ok , { xla_time , executable , inputs_and_typespecs , % { outfeed | infeeds: [ ] } } }
443
+ EXLA.MLIR.Module . new ( comp_typespecs , out_typespecs , fn builder ->
444
+ # Only create the token when we know it will actually be
445
+ # used, that is: streaming, lazy transfers or hooks
446
+ outfeed =
447
+ if stream? or reverse_infeeds != [ ] or hooks != % { } or defined_hooks != % { } do
448
+ outfeed
449
+ |> Outfeed . with_token ( Value . create_token ( builder ) )
450
+ |> Outfeed . add_infeeds ( builder , reverse_infeeds )
451
+ else
452
+ outfeed
453
+ end
454
+
455
+ expr = Nx.Defn.Composite . traverse ( expr || fun . ( vars ) , & Nx . devectorize / 1 )
456
+ outfeed = to_computation . ( builder , expr , inputs_and_typespecs , outfeed )
457
+
458
+ { xla_time , executable } =
459
+ :timer . tc ( fn ->
460
+ EXLA.MLIR.Module . compile (
461
+ builder . module ,
462
+ client ,
463
+ comp_typespecs ,
464
+ builder . return_typespecs ,
465
+ options
466
+ )
467
+ end )
468
+
469
+ { :ok , { xla_time , executable , inputs_and_typespecs , % { outfeed | infeeds: [ ] } } }
470
+ end )
468
471
end )
469
472
end )
470
- end )
471
473
472
- cond do
473
- not debug? ->
474
- :ok
474
+ cond do
475
+ not debug? ->
476
+ :ok
475
477
476
- evaled ->
477
- Logger . debug (
478
- "EXLA compilation #{ inspect ( key ) } cache miss in #{ us_to_ms ( comp_time ) } ms (#{ us_to_ms ( xla_time ) } ms in XLA)"
479
- )
478
+ evaled ->
479
+ Logger . debug (
480
+ "EXLA compilation #{ inspect ( key ) } cache miss in #{ us_to_ms ( comp_time ) } ms (#{ us_to_ms ( xla_time ) } ms in XLA)"
481
+ )
480
482
481
- true ->
482
- Logger . debug ( "EXLA compilation #{ inspect ( key ) } cache hit in #{ us_to_ms ( comp_time ) } ms" )
483
- end
483
+ true ->
484
+ Logger . debug ( "EXLA compilation #{ inspect ( key ) } cache hit in #{ us_to_ms ( comp_time ) } ms" )
485
+ end
484
486
485
- if expr || evaled do
486
- measurements = % {
487
- eval_time: eval_time ,
488
- compile_time: comp_time ,
489
- total_time: eval_time + comp_time
490
- }
487
+ if expr || evaled do
488
+ measurements = % {
489
+ eval_time: eval_time ,
490
+ compile_time: comp_time ,
491
+ total_time: eval_time + comp_time
492
+ }
491
493
492
- :telemetry . execute ( [ :exla , :compilation ] , measurements , % { key: key } )
493
- end
494
+ :telemetry . execute ( [ :exla , :compilation ] , measurements , % { key: key } )
495
+ end
494
496
495
- outfeed = Outfeed . with_user_hooks ( outfeed , hooks )
496
- { executable , used_inputs , outputs , outfeed , inputs_and_typespecs }
497
+ outfeed = Outfeed . with_user_hooks ( outfeed , hooks )
498
+ { executable , { used_inputs , outputs , outfeed , inputs_and_typespecs } }
499
+ end )
497
500
end
498
501
499
502
defp us_to_ms ( time ) , do: Float . round ( time / 1000 , 1 )
0 commit comments