@@ -32,13 +32,8 @@ defmodule EXLA.Defn do
32
32
33
33
@ doc false
34
34
def __stream__ ( key , input , acc , vars , fun , [ args ] , options ) do
35
- { debug? , options } = Keyword . pop ( options , :debug , false )
36
35
{ run_options , compile_options } = Keyword . pop ( options , :run_options , [ ] )
37
-
38
- { client_name , compile_options } =
39
- Keyword . pop_lazy ( compile_options , :client , & EXLA.Client . default_name / 0 )
40
-
41
- client = EXLA.Client . fetch! ( client_name )
36
+ debug? = Keyword . get ( compile_options , :debug , false )
42
37
compile_options = Keyword . put ( compile_options , :lazy_transfers , :never )
43
38
44
39
input_length = length ( Nx.Defn.Composite . flatten_list ( [ input ] ) )
@@ -50,21 +45,10 @@ defmodule EXLA.Defn do
50
45
used_inputs = Enum . to_list ( input_length .. ( input_length + acc_length - 1 ) // 1 )
51
46
52
47
comp_fun =
53
- & to_stream_computation ( client , input_length , acc_length , & 1 , & 2 , & 3 , & 4 , compile_options )
48
+ & to_stream_computation ( input_length , acc_length , & 1 , & 2 , & 3 , & 4 , & 5 , compile_options )
54
49
55
50
{ executable , { used_inputs , { output , acc_output } , outfeed , input_typespecs } } =
56
- compile (
57
- client ,
58
- key ,
59
- vars ,
60
- fun ,
61
- compile_options ,
62
- used_buffers ,
63
- used_inputs ,
64
- _stream = true ,
65
- debug? ,
66
- comp_fun
67
- )
51
+ compile ( key , vars , fun , compile_options , used_buffers , used_inputs , true , comp_fun )
68
52
69
53
# Now discard the infeed from used inputs, similar to how it is done to buffers.
70
54
# Note we discard all lazy transfers too, as they are not possible with streams.
@@ -136,13 +120,13 @@ defmodule EXLA.Defn do
136
120
end
137
121
138
122
defp to_stream_computation (
139
- client ,
140
123
input_length ,
141
124
acc_length ,
142
125
% Function { } = builder ,
143
126
expr ,
144
127
used_typespecs ,
145
128
outfeed ,
129
+ client ,
146
130
options
147
131
) do
148
132
% { token: root_token , infeeds: [ ] } = outfeed
@@ -237,18 +221,12 @@ defmodule EXLA.Defn do
237
221
238
222
@ doc false
239
223
def __compile__ ( key , vars , fun , options ) do
240
- { debug? , options } = Keyword . pop ( options , :debug , false )
241
224
{ run_options , compile_options } = Keyword . pop ( options , :run_options , [ ] )
242
-
243
- { client_name , compile_options } =
244
- Keyword . pop_lazy ( compile_options , :client , & EXLA.Client . default_name / 0 )
245
-
246
- client = EXLA.Client . fetch! ( client_name )
247
-
248
- callback = & to_root_computation ( & 1 , & 2 , & 3 , & 4 , Keyword . put ( compile_options , :client , client ) )
225
+ debug? = Keyword . get ( compile_options , :debug , false )
226
+ callback = & to_root_computation ( & 1 , & 2 , & 3 , & 4 , & 5 , compile_options )
249
227
250
228
{ executable , { used_inputs , outputs , outfeed , _input_typespecs? } } =
251
- compile ( client , key , vars , fun , compile_options , 0 , [ ] , _stream = false , debug? , callback )
229
+ compile ( key , vars , fun , compile_options , 0 , [ ] , _stream = false , callback )
252
230
253
231
fn [ args ] ->
254
232
{ time , lock } =
@@ -270,14 +248,12 @@ defmodule EXLA.Defn do
270
248
end
271
249
end
272
250
273
- defp to_root_computation ( % Function { } = function , expr , used_typespecs , outfeed , options ) do
251
+ defp to_root_computation ( % Function { } = function , expr , used_typespecs , outfeed , client , options ) do
274
252
params =
275
253
Enum . zip_with ( used_typespecs , Function . get_arguments ( function ) , fn { pos , _typespec } , arg ->
276
254
{ pos , arg }
277
255
end )
278
256
279
- client = Keyword . fetch! ( options , :client )
280
-
281
257
unless client do
282
258
raise ArgumentError , "missing client"
283
259
end
@@ -342,22 +318,15 @@ defmodule EXLA.Defn do
342
318
343
319
## Compile
344
320
345
- defp compile (
346
- client ,
347
- key ,
348
- vars ,
349
- fun ,
350
- options ,
351
- used_buffers ,
352
- used_inputs ,
353
- stream? ,
354
- debug? ,
355
- to_computation
356
- ) do
321
+ defp compile ( key , vars , fun , options , used_buffers , used_inputs , stream? , to_computation ) do
357
322
{ cache , options } = Keyword . pop ( options , :cache , true )
358
323
{ hooks , options } = Keyword . pop ( options , :hooks , % { } )
324
+ { debug? , options } = Keyword . pop ( options , :debug , false )
359
325
{ lazy_transfers , options } = Keyword . pop ( options , :lazy_transfers , :opt_in )
360
326
327
+ { client_name , options } = Keyword . pop_lazy ( options , :client , & EXLA.Client . default_name / 0 )
328
+ client = EXLA.Client . fetch! ( client_name )
329
+
361
330
{ args_key , reverse_args_identifiers } =
362
331
Enum . map_reduce ( vars , [ ] , fn var , acc ->
363
332
Nx.Defn.Composite . traverse ( var , acc , fn
@@ -453,7 +422,7 @@ defmodule EXLA.Defn do
453
422
end
454
423
455
424
expr = Nx.Defn.Composite . traverse ( expr || fun . ( vars ) , & Nx . devectorize / 1 )
456
- outfeed = to_computation . ( builder , expr , inputs_and_typespecs , outfeed )
425
+ outfeed = to_computation . ( builder , expr , inputs_and_typespecs , outfeed , client )
457
426
458
427
{ xla_time , executable } =
459
428
:timer . tc ( fn ->
0 commit comments