@@ -32,6 +32,7 @@ defmodule EXLA.Defn do
32
32
33
33
@ doc false
34
34
def __stream__ ( key , input , acc , vars , fun , [ args ] , options ) do
35
+ { debug? , options } = Keyword . pop ( options , :debug , false )
35
36
{ run_options , compile_options } = Keyword . pop ( options , :run_options , [ ] )
36
37
37
38
{ client_name , compile_options } =
@@ -51,24 +52,26 @@ defmodule EXLA.Defn do
51
52
comp_fun =
52
53
& to_stream_computation ( client , input_length , acc_length , & 1 , & 2 , & 3 , & 4 , compile_options )
53
54
54
- { executable , used_inputs , { output , acc_output } , outfeed , extra , debug? } =
55
+ { executable , used_inputs , { output , acc_output } , outfeed , input_typespecs } =
55
56
compile (
56
57
client ,
57
- { :stream , key } ,
58
+ key ,
58
59
vars ,
59
60
fun ,
60
61
compile_options ,
61
62
used_buffers ,
62
63
used_inputs ,
63
64
_stream = true ,
65
+ debug? ,
64
66
comp_fun
65
67
)
66
68
67
- { input_typespecs , input_indexes } = extra
69
+ # Now discard the infeed from used inputs, similar to how it is done to buffers.
70
+ # Note we discard all lazy transfers too, as they are not possible with streams.
71
+ used_inputs = for { i , nil } <- used_inputs , i >= used_buffers , do: { i , nil } , into: % { }
68
72
69
- # Also discard the stream inputs from used inputs, similar to how it is done to buffers
70
- # Note we discard all lazy transfers too, as they are not possible with streams
71
- used_inputs = Enum . sort ( for { i , nil } <- used_inputs , i >= used_buffers , do: i )
73
+ # And capture the typespecs for the infeed.
74
+ input_typespecs = Enum . take_while ( input_typespecs , fn { i , _ } -> i < input_length end )
72
75
73
76
# Execution of streams requires the coordination of
74
77
# multiple processes which is outlined below.
@@ -120,7 +123,6 @@ defmodule EXLA.Defn do
120
123
outfeed_pid ,
121
124
input ,
122
125
input_typespecs ,
123
- input_indexes ,
124
126
output ,
125
127
output_typespecs ,
126
128
acc_output
@@ -151,9 +153,6 @@ defmodule EXLA.Defn do
151
153
{ input_typespecs , used_typespecs } =
152
154
Enum . split_while ( used_typespecs , fn { i , _ } -> i < input_length end )
153
155
154
- # Get all input indexes and shape
155
- input_indexes = Enum . map ( input_typespecs , & elem ( & 1 , 0 ) )
156
-
157
156
# Drop all accumulator entries from used_typespecs as we will handle it separately.
158
157
{ acc_typespecs , used_typespecs } = Enum . split ( used_typespecs , acc_length )
159
158
@@ -166,13 +165,10 @@ defmodule EXLA.Defn do
166
165
# The input will be read as part of the infeed.
167
166
acc_typespecs_l = Enum . map ( acc_typespecs , & elem ( & 1 , 1 ) )
168
167
acc_typespec = List . to_tuple ( acc_typespecs_l )
169
-
170
168
flag_typespec = Typespec . tensor ( { :pred , 8 } , { } )
171
169
172
170
args = EXLA.MLIR.Function . get_arguments ( builder )
173
-
174
171
{ token , [ flag ] } = Value . infeed ( root_token , [ flag_typespec ] )
175
-
176
172
init = [ flag , token | args ]
177
173
178
174
arg_typespecs = Enum . map ( init , & Value . get_typespec / 1 )
@@ -186,11 +182,9 @@ defmodule EXLA.Defn do
186
182
{ body_computation , [ _flag , token | args ] } = Function . push_region ( builder , arg_typespecs )
187
183
188
184
{ acc , constant } = Enum . split ( args , acc_length )
189
-
190
- { indices , input_typespecs } = Enum . unzip ( input_typespecs )
185
+ { input_indices , input_typespecs } = Enum . unzip ( input_typespecs )
191
186
{ token , input } = Value . infeed ( token , input_typespecs )
192
-
193
- input_params = Enum . zip ( indices , input )
187
+ input_params = Enum . zip ( input_indices , input )
194
188
195
189
{ % Outfeed { token: token } = outfeed , acc } =
196
190
case expr do
@@ -226,9 +220,7 @@ defmodule EXLA.Defn do
226
220
227
221
# Emit the stream hook to signal loop output
228
222
{ token , [ flag ] } = Value . infeed ( token , [ flag_typespec ] )
229
-
230
223
Value . return ( flag . function , [ flag , token | acc ] ++ List . flatten ( constant ) )
231
-
232
224
Function . pop_region ( builder )
233
225
234
226
[ _flag , out_token | results ] = Value . while ( builder , pred_computation , body_computation , init )
@@ -238,8 +230,7 @@ defmodule EXLA.Defn do
238
230
239
231
outfeed = outfeed |> Outfeed . with_token ( out_token ) |> Outfeed . close ( builder )
240
232
Value . func_return ( builder , output )
241
-
242
- { { input_typespecs , input_indexes } , outfeed }
233
+ outfeed
243
234
end
244
235
245
236
@ doc false
@@ -249,6 +240,7 @@ defmodule EXLA.Defn do
249
240
250
241
@ doc false
251
242
def __compile__ ( key , vars , fun , options ) do
243
+ { debug? , options } = Keyword . pop ( options , :debug , false )
252
244
{ run_options , compile_options } = Keyword . pop ( options , :run_options , [ ] )
253
245
254
246
{ client_name , compile_options } =
@@ -258,8 +250,8 @@ defmodule EXLA.Defn do
258
250
259
251
callback = & to_root_computation ( & 1 , & 2 , & 3 , & 4 , Keyword . put ( compile_options , :client , client ) )
260
252
261
- { executable , used_inputs , outputs , outfeed , :ok , debug ?} =
262
- compile ( client , key , vars , fun , compile_options , 0 , [ ] , _stream = false , callback )
253
+ { executable , used_inputs , outputs , outfeed , _input_typespecs ?} =
254
+ compile ( client , key , vars , fun , compile_options , 0 , [ ] , _stream = false , debug? , callback )
263
255
264
256
fn [ args ] ->
265
257
{ time , lock } =
@@ -306,10 +298,8 @@ defmodule EXLA.Defn do
306
298
307
299
{ res , cache } = recur_flatten ( expr , state , new_cache ( outfeed ) )
308
300
outfeed = cache |> get_outfeed ( ) |> Outfeed . close ( function )
309
-
310
301
Value . func_return ( function , res )
311
-
312
- { :ok , outfeed }
302
+ outfeed
313
303
end
314
304
315
305
defp maybe_outfeed ( lock , executable , args , used_inputs , outputs , outfeed , run_options )
@@ -367,6 +357,7 @@ defmodule EXLA.Defn do
367
357
used_buffers ,
368
358
used_inputs ,
369
359
stream? ,
360
+ debug? ,
370
361
to_computation
371
362
) do
372
363
{ { expr_cache_fun , comp_cache_fun } , options } =
@@ -379,8 +370,6 @@ defmodule EXLA.Defn do
379
370
{ { cache_fun , cache_fun } , options }
380
371
end
381
372
382
- { debug? , options } = Keyword . pop ( options , :debug , false )
383
-
384
373
{ args_key , reverse_args_identifiers } =
385
374
Enum . map_reduce ( vars , [ ] , fn var , acc ->
386
375
Nx.Defn.Composite . traverse ( var , acc , fn
@@ -396,7 +385,7 @@ defmodule EXLA.Defn do
396
385
397
386
{ eval_time , { expr , { ref , outputs , { used_inputs , defined_hooks } } } } =
398
387
:timer . tc ( fn ->
399
- expr_cache_fun . ( { key , args_key , lazy_transfers } , fn ->
388
+ expr_cache_fun . ( { key , stream? , args_key , lazy_transfers } , fn ->
400
389
expr = fun . ( vars )
401
390
inputs_and_hooks = Outfeed . used_inputs_and_hooks ( expr , used_inputs , lazy_transfers )
402
391
{ expr , { make_ref ( ) , Nx . to_template ( expr ) , inputs_and_hooks } }
@@ -412,12 +401,10 @@ defmodule EXLA.Defn do
412
401
end
413
402
414
403
{ hooks , options } = Keyword . pop ( options , :hooks , % { } )
415
-
416
404
outfeed = Outfeed . new ( hooks , defined_hooks )
417
-
418
405
comp_key = { ref , client . name , outfeed . used_hooks , lazy_transfers , options }
419
406
420
- { comp_time , { evaled , { xla_time , executable , extra , outfeed } } } =
407
+ { comp_time , { evaled , { xla_time , executable , inputs_and_typespecs , outfeed } } } =
421
408
:timer . tc ( fn ->
422
409
comp_cache_fun . ( comp_key , fn ->
423
410
{ reverse_inputs_and_typespecs , reverse_infeeds } =
@@ -430,7 +417,7 @@ defmodule EXLA.Defn do
430
417
431
418
inputs_and_typespecs = Enum . reverse ( reverse_inputs_and_typespecs )
432
419
433
- comp_arg_typespecs =
420
+ comp_typespecs =
434
421
for { i , typespec } <- inputs_and_typespecs , i >= used_buffers , do: typespec
435
422
436
423
outputs =
@@ -451,7 +438,7 @@ defmodule EXLA.Defn do
451
438
|> then ( & Typespec . tensor ( & 1 . type , & 1 . shape ) )
452
439
end )
453
440
454
- EXLA.MLIR.Module . new ( comp_arg_typespecs , out_typespecs , fn builder ->
441
+ EXLA.MLIR.Module . new ( comp_typespecs , out_typespecs , fn builder ->
455
442
# Only create the token when we know it will actually be
456
443
# used, that is: streaming, lazy transfers or hooks
457
444
outfeed =
@@ -464,25 +451,20 @@ defmodule EXLA.Defn do
464
451
end
465
452
466
453
expr = Nx.Defn.Composite . traverse ( expr || fun . ( vars ) , & Nx . devectorize / 1 )
467
-
468
- { extra , outfeed } =
469
- to_computation . ( builder , expr , inputs_and_typespecs , outfeed )
454
+ outfeed = to_computation . ( builder , expr , inputs_and_typespecs , outfeed )
470
455
471
456
{ xla_time , executable } =
472
457
:timer . tc ( fn ->
473
- typespecs =
474
- for { i , typespec } <- inputs_and_typespecs , i >= used_buffers , do: typespec
475
-
476
458
EXLA.MLIR.Module . compile (
477
459
builder . module ,
478
460
client ,
479
- typespecs ,
461
+ comp_typespecs ,
480
462
builder . return_typespecs ,
481
463
options
482
464
)
483
465
end )
484
466
485
- { :ok , { xla_time , executable , extra , % { outfeed | infeeds: [ ] } } }
467
+ { :ok , { xla_time , executable , inputs_and_typespecs , % { outfeed | infeeds: [ ] } } }
486
468
end )
487
469
end )
488
470
end )
@@ -511,7 +493,7 @@ defmodule EXLA.Defn do
511
493
end
512
494
513
495
outfeed = Outfeed . with_user_hooks ( outfeed , hooks )
514
- { executable , used_inputs , outputs , outfeed , extra , debug? }
496
+ { executable , used_inputs , outputs , outfeed , inputs_and_typespecs }
515
497
end
516
498
517
499
defp us_to_ms ( time ) , do: Float . round ( time / 1000 , 1 )
0 commit comments