@@ -30,190 +30,6 @@ defmodule EXLA.Defn do
30
30
{ EXLA.Backend , [ client: client_name , device_id: device_id ] }
31
31
end
32
32
33
- @ doc false
34
- def __stream__ ( key , input , acc , vars , fun , [ args ] , options ) do
35
- { run_options , compile_options } = Keyword . pop ( options , :run_options , [ ] )
36
- debug? = Keyword . get ( compile_options , :debug , false )
37
- compile_options = Keyword . put ( compile_options , :lazy_transfers , :never )
38
-
39
- input_length = length ( Nx.Defn.Composite . flatten_list ( [ input ] ) )
40
- acc_length = length ( Nx.Defn.Composite . flatten_list ( [ acc ] ) )
41
-
42
- # The input vars should not be converted to buffers as they come from infeed
43
- # Accs are always considered as used
44
- used_buffers = input_length
45
- used_inputs = Enum . to_list ( input_length .. ( input_length + acc_length - 1 ) // 1 )
46
-
47
- comp_fun =
48
- & to_stream_computation ( input_length , acc_length , & 1 , & 2 , & 3 , & 4 , & 5 , compile_options )
49
-
50
- { executable , { used_inputs , { output , acc_output } , outfeed , input_typespecs } } =
51
- compile ( key , vars , fun , compile_options , used_buffers , used_inputs , true , comp_fun )
52
-
53
- # Now discard the infeed from used inputs, similar to how it is done to buffers.
54
- # Note we discard all lazy transfers too, as they are not possible with streams.
55
- used_inputs = for { i , nil } <- used_inputs , i >= used_buffers , do: { i , nil } , into: % { }
56
-
57
- # And capture the typespecs for the infeed.
58
- input_typespecs = Enum . take_while ( input_typespecs , fn { i , _ } -> i < input_length end )
59
-
60
- # Execution of streams requires the coordination of
61
- # multiple processes which is outlined below.
62
-
63
- # First, we get a lock on the executable, because we want
64
- # to avoid transfer to the device unless we know we are
65
- # ready to use the device.
66
- { time , lock } =
67
- :timer . tc ( fn ->
68
- EXLA.Defn.Lock . lock ( run_key ( executable ) )
69
- end )
70
-
71
- debug? && Logger . debug ( "EXLA device #{ executable . device_id } lock in #{ us_to_ms ( time ) } ms" )
72
-
73
- { time , streams } =
74
- :timer . tc ( fn ->
75
- buffers =
76
- EXLA.Defn.Buffers . filter_by_indexes ( args , used_inputs , fn arg , _ ->
77
- EXLA.Defn.Buffers . from_nx! ( arg , executable )
78
- end )
79
-
80
- # Now that we have transferred to device, we spawn a runner process
81
- # to execute the stream. We use a runner instead of a task to avoid
82
- # leaking messages in the inbox. We also don't use a supervisor
83
- # to keep them linked, which is safe because the agent is not used
84
- # outside the scope of the current process.
85
- #
86
- # Finally, note the runner cannot start immediately, we need to
87
- # setup the outfeed reader and register the on_unlock callback
88
- # that cancels the stream atomically. This is done inside
89
- # EXLA.Defn.Stream.run.
90
- { :ok , runner } =
91
- EXLA.Defn.Runner . start_link ( lock , fn ->
92
- EXLA.Executable . run ( executable , [ buffers ] , run_options )
93
- end )
94
-
95
- # The outfeed reader will redirect all outputs with flag 1 to the current
96
- # process. Once flag 0 is emitted, we know the stream is done.
97
- { output_typespecs , outfeed } = Outfeed . configure_stream_hook ( outfeed , self ( ) , lock )
98
- { :ok , outfeed_pid } = Outfeed . start_child ( executable , outfeed , Process . group_leader ( ) )
99
-
100
- stream =
101
- EXLA.Defn.Stream . run (
102
- executable ,
103
- lock ,
104
- runner ,
105
- outfeed_pid ,
106
- input ,
107
- input_typespecs ,
108
- output ,
109
- output_typespecs ,
110
- acc_output
111
- )
112
-
113
- [ stream ]
114
- end )
115
-
116
- debug? &&
117
- Logger . debug ( "EXLA stream start on device #{ executable . device_id } in #{ us_to_ms ( time ) } ms" )
118
-
119
- streams
120
- end
121
-
122
- defp to_stream_computation (
123
- input_length ,
124
- acc_length ,
125
- % Function { } = builder ,
126
- expr ,
127
- used_typespecs ,
128
- outfeed ,
129
- client ,
130
- options
131
- ) do
132
- % { token: root_token , infeeds: [ ] } = outfeed
133
-
134
- { input_typespecs , used_typespecs } =
135
- Enum . split_while ( used_typespecs , fn { i , _ } -> i < input_length end )
136
-
137
- # Drop all accumulator entries from used_typespecs as we will handle it separately.
138
- { acc_typespecs , used_typespecs } = Enum . split ( used_typespecs , acc_length )
139
-
140
- # The stream loop will be a three element tuple:
141
- #
142
- # The result of calling infeed.
143
- # The looping accumulator.
144
- # The looping constants.
145
- #
146
- # The input will be read as part of the infeed.
147
- acc_typespecs_l = Enum . map ( acc_typespecs , & elem ( & 1 , 1 ) )
148
- acc_typespec = List . to_tuple ( acc_typespecs_l )
149
- flag_typespec = Typespec . tensor ( { :pred , 8 } , { } )
150
-
151
- args = EXLA.MLIR.Function . get_arguments ( builder )
152
- { token , [ flag ] } = Value . infeed ( root_token , [ flag_typespec ] )
153
- init = [ flag , token | args ]
154
-
155
- arg_typespecs = Enum . map ( init , & Value . get_typespec / 1 )
156
- { pred_computation , [ flag | _ ] } = Function . push_region ( builder , arg_typespecs )
157
- typespec = Typespec . tensor ( { :pred , 8 } , { } )
158
- r0 = Value . constant ( builder , [ 1 ] , typespec )
159
- pred_op = Value . equal ( flag , r0 , typespec )
160
- Value . return ( builder , [ pred_op ] )
161
- Function . pop_region ( builder )
162
-
163
- { body_computation , [ _flag , token | args ] } = Function . push_region ( builder , arg_typespecs )
164
-
165
- { acc , constant } = Enum . split ( args , acc_length )
166
- { input_indices , input_typespecs } = Enum . unzip ( input_typespecs )
167
- { token , input } = Value . infeed ( token , input_typespecs )
168
- input_params = Enum . zip ( input_indices , input )
169
-
170
- { % Outfeed { token: token } = outfeed , acc } =
171
- case expr do
172
- { output_expr , acc_expr } ->
173
- acc_params =
174
- Enum . map ( acc_typespecs , fn { pos , _typespec } ->
175
- { pos , Enum . fetch! ( acc , pos - input_length ) }
176
- end )
177
-
178
- constant_params =
179
- Enum . with_index ( used_typespecs , fn { pos , _typespec } , index ->
180
- { pos , Enum . fetch! ( constant , index ) }
181
- end )
182
-
183
- state = % {
184
- client: client ,
185
- builder: builder ,
186
- precision: Keyword . get ( options , :precision , :default ) ,
187
- params: Map . new ( input_params ++ acc_params ++ constant_params ) ,
188
- scope_ids: Tree . scope_ids ( expr )
189
- }
190
-
191
- outfeed = Outfeed . with_token ( outfeed , token )
192
- { output , cache } = recur_flatten ( output_expr , state , new_cache ( outfeed ) )
193
- { acc , cache } = recur_flatten ( acc_expr , state , cache )
194
- outfeed = cache |> get_outfeed ( ) |> Outfeed . add_stream_hook ( builder , output )
195
- { outfeed , acc }
196
-
197
- _ ->
198
- raise "expected the function given to Nx.stream/3 to return a two-element tuple, got: " <>
199
- inspect ( expr )
200
- end
201
-
202
- # Emit the stream hook to signal loop output
203
- { token , [ flag ] } = Value . infeed ( token , [ flag_typespec ] )
204
- Value . return ( flag . function , [ flag , token | acc ] ++ List . flatten ( constant ) )
205
- Function . pop_region ( builder )
206
-
207
- [ _flag , out_token | results ] = Value . while ( builder , pred_computation , body_computation , init )
208
-
209
- acc = Enum . take ( results , acc_length )
210
- output = wrap_tuple_result ( acc , acc_typespec )
211
-
212
- outfeed = outfeed |> Outfeed . with_token ( out_token ) |> Outfeed . close ( builder )
213
- Value . func_return ( builder , output )
214
- outfeed
215
- end
216
-
217
33
@ doc false
218
34
def __jit__ ( key , vars , fun , args_list , options ) do
219
35
__compile__ ( key , vars , fun , options ) . ( args_list )
@@ -223,10 +39,10 @@ defmodule EXLA.Defn do
223
39
def __compile__ ( key , vars , fun , options ) do
224
40
{ run_options , compile_options } = Keyword . pop ( options , :run_options , [ ] )
225
41
debug? = Keyword . get ( compile_options , :debug , false )
226
- callback = & to_root_computation ( & 1 , & 2 , & 3 , & 4 , & 5 , compile_options )
42
+ callback = & to_computation ( & 1 , & 2 , & 3 , & 4 , & 5 , compile_options )
227
43
228
44
{ executable , { used_inputs , outputs , outfeed , _input_typespecs? } } =
229
- compile ( key , vars , fun , compile_options , 0 , [ ] , _stream = false , callback )
45
+ compile ( key , vars , fun , compile_options , 0 , [ ] , callback )
230
46
231
47
if compile_options [ :module_compilation ] == :to_mlir do
232
48
throw ( { :mlir_module , executable . ref , MapSet . new ( Map . keys ( used_inputs ) ) , outputs } )
@@ -252,7 +68,7 @@ defmodule EXLA.Defn do
252
68
end
253
69
end
254
70
255
- defp to_root_computation ( % Function { } = function , expr , used_typespecs , outfeed , client , options ) do
71
+ defp to_computation ( % Function { } = function , expr , used_typespecs , outfeed , client , options ) do
256
72
params =
257
73
Enum . zip_with ( used_typespecs , Function . get_arguments ( function ) , fn { pos , _typespec } , arg ->
258
74
{ pos , arg }
@@ -322,7 +138,7 @@ defmodule EXLA.Defn do
322
138
323
139
## Compile
324
140
325
- defp compile ( key , vars , fun , options , used_buffers , used_inputs , stream? , to_computation ) do
141
+ defp compile ( key , vars , fun , options , used_buffers , used_inputs , to_computation ) do
326
142
{ cache , options } = Keyword . pop ( options , :cache , true )
327
143
{ hooks , options } = Keyword . pop ( options , :hooks , % { } )
328
144
{ debug? , options } = Keyword . pop ( options , :debug , false )
@@ -361,7 +177,7 @@ defmodule EXLA.Defn do
361
177
362
178
{ eval_time , { expr , { ref , outputs , { used_inputs , defined_hooks } } } } =
363
179
:timer . tc ( fn ->
364
- expr_cache_fun . ( { key , stream? , args_key , lazy_transfers } , fn ->
180
+ expr_cache_fun . ( { key , args_key , lazy_transfers } , fn ->
365
181
expr = fun . ( vars )
366
182
inputs_and_hooks = Outfeed . used_inputs_and_hooks ( expr , used_inputs , lazy_transfers )
367
183
{ expr , { make_ref ( ) , Nx . to_template ( expr ) , inputs_and_hooks } }
@@ -395,15 +211,6 @@ defmodule EXLA.Defn do
395
211
comp_typespecs =
396
212
for { i , typespec } <- inputs_and_typespecs , i >= used_buffers , do: typespec
397
213
398
- outputs =
399
- if stream? do
400
- # The computation returns the final accumulator value
401
- { _chunk_result , acc } = outputs
402
- acc
403
- else
404
- outputs
405
- end
406
-
407
214
out_typespecs =
408
215
[ outputs ]
409
216
|> Nx.Defn.Composite . flatten_list ( )
@@ -417,7 +224,7 @@ defmodule EXLA.Defn do
417
224
# Only create the token when we know it will actually be
418
225
# used, that is: streaming, lazy transfers or hooks
419
226
outfeed =
420
- if stream? or reverse_infeeds != [ ] or hooks != % { } or defined_hooks != % { } do
227
+ if reverse_infeeds != [ ] or hooks != % { } or defined_hooks != % { } do
421
228
outfeed
422
229
|> Outfeed . with_token ( Value . create_token ( builder ) )
423
230
|> Outfeed . add_infeeds ( builder , reverse_infeeds )
0 commit comments