@@ -19,10 +19,10 @@ defmodule Nx.Defn.Grad do
19
19
{ :env , env } = Function . info ( fun , :env )
20
20
ids = stop_grads ( env , ids )
21
21
22
- # save vectorized axes before devectorizing
23
- expr = to_grad |> fun . ( )
22
+ expr = fun . ( to_grad )
24
23
25
- transformed_expr = transform . ( expr ) |> validate_expr! ( ) |> Nx . devectorize ( keep_names: false )
24
+ transformed_expr =
25
+ expr |> transform . ( ) |> validate_expr! ( )
26
26
27
27
{ parents , nodes } = parents_tree ( transformed_expr , ids )
28
28
@@ -33,23 +33,17 @@ defmodule Nx.Defn.Grad do
33
33
Composite . traverse (
34
34
to_grad ,
35
35
{ nodes , grads } ,
36
- fn % { vectorized_axes: vectorized_axes } = node , acc ->
37
- node
38
- |> Nx . devectorize ( keep_names: false )
39
- |> to_grad ( to_grad_ids , parents , acc )
40
- |> then ( fn { node , acc } ->
41
- { Nx . vectorize ( node , vectorized_axes ) , acc }
42
- end )
36
+ fn node , acc ->
37
+ to_grad ( node , to_grad_ids , parents , acc )
43
38
end
44
39
)
45
40
46
41
{ expr , graded }
47
42
end
48
43
49
- defp constant ( float , shape ) do
50
- shape = Nx . shape ( shape )
44
+ defp constant ( float , % T { shape: shape } = t ) do
51
45
names = List . duplicate ( nil , tuple_size ( shape ) )
52
- Expr . constant ( % T { shape: shape , type: { :f , 32 } , names: names } , float , [ ] )
46
+ Expr . constant ( % T { t | names: names , type: { :f , 32 } } , float , [ ] )
53
47
end
54
48
55
49
defp validate_expr! ( % T { data: % Expr { } } = expr ) do
@@ -94,47 +88,88 @@ defmodule Nx.Defn.Grad do
94
88
[ :equal , :greater , :greater_equal , :less , :less_equal , :not_equal , :argsort ]
95
89
96
90
defp parents_tree ( expr , nodes ) do
97
- Composite . reduce ( expr , { % { } , nodes } , & recur_parents_tree / 2 )
91
+ Composite . reduce (
92
+ expr ,
93
+ { % { } , nodes } ,
94
+ & recur_parents_tree (
95
+ Nx . devectorize ( & 1 , keep_names: true ) ,
96
+ & 2 ,
97
+ Keyword . keys ( & 1 . vectorized_axes )
98
+ )
99
+ )
98
100
end
99
101
100
- defp recur_parents_tree ( % T { data: % Expr { id: id , op: op } } = t , { parents , nodes } ) do
102
+ defp recur_parents_tree ( % T { data: % Expr { id: id , op: op } } = t , { parents , nodes } , vectorized_names ) do
101
103
case nodes do
102
- % { ^ id => _ } -> { parents , nodes }
103
- % { } -> parents_args ( op , t , id , { parents , Map . put ( nodes , id , t ) } )
104
+ % { ^ id => _ } ->
105
+ { parents , nodes }
106
+
107
+ % { } ->
108
+ # We use this to compute the proper axis sizes for the tensor
109
+ nodes = Map . put ( nodes , id , { t , vectorized_names } )
110
+
111
+ parents_args ( op , t , id , { parents , nodes } , vectorized_names )
104
112
end
105
113
end
106
114
107
- defp parents_args ( :metadata , % { data: % { args: [ _ , % { stop_grad: true } ] } } , _id , acc ) do
115
+ defp parents_args (
116
+ :metadata ,
117
+ % { data: % { args: [ _ , % { stop_grad: true } ] } } ,
118
+ _id ,
119
+ acc ,
120
+ _parent_vectorized_names
121
+ ) do
108
122
acc
109
123
end
110
124
111
- defp parents_args ( :optional , % { data: % { args: [ call , _expr , callback ] } } = t , id , acc ) do
125
+ defp parents_args (
126
+ :optional ,
127
+ % { data: % { args: [ call , _expr , callback ] } } = t ,
128
+ id ,
129
+ acc ,
130
+ parent_vectorized_names
131
+ ) do
112
132
expr = apply ( callback , call . data . args )
113
133
114
134
# Now traverse over the optional expression where args are the new parameters.
115
135
# Once we access the parameter itself, we point the parameter to the arg.
116
- { parents , nodes } =
117
- Composite . reduce ( expr , acc , fn expr , { parents , nodes } ->
118
- parents = Map . update ( parents , expr . data . id , [ id ] , & [ id | & 1 ] )
119
- recur_parents_tree ( expr , { parents , nodes } )
136
+ { { parents , nodes } , _ } =
137
+ Composite . reduce ( expr , { acc , parent_vectorized_names } , fn
138
+ expr , { { parents , nodes } , expr_vectorized_names } ->
139
+ arg_vectorized_names = compute_arg_vectorized_names ( expr , expr_vectorized_names )
140
+ parents = Map . update ( parents , expr . data . id , [ id ] , & [ id | & 1 ] )
141
+
142
+ acc =
143
+ recur_parents_tree (
144
+ expr ,
145
+ { parents , nodes } ,
146
+ arg_vectorized_names
147
+ )
148
+
149
+ { acc , expr_vectorized_names }
120
150
end )
121
151
122
- { parents , Map . put ( nodes , id , put_in ( t . data . args , [ call , expr , callback ] ) ) }
152
+ updated_node =
153
+ { put_in ( t . data . args , [ call , expr , callback ] ) , parent_vectorized_names }
154
+
155
+ { parents , Map . put ( nodes , id , updated_node ) }
123
156
end
124
157
125
158
# We register cond as a special node to avoid pretraversing it.
126
159
# Instead we traverse it early on on the grad computation.
127
- defp parents_args ( :cond , _ , id , { parents , nodes } ) do
160
+ defp parents_args ( :cond , _ , id , { parents , nodes } , _parent_vectorized_names ) do
128
161
{ Map . update ( parents , __MODULE__ , [ id ] , & [ id | & 1 ] ) , nodes }
129
162
end
130
163
131
- defp parents_args ( op , t , parent_id , acc ) do
164
+ defp parents_args ( op , t , parent_id , acc , parent_vectorized_names ) do
132
165
reduce_args ( op , t , acc , fn arg , { parents , nodes } ->
133
166
if arg . data . op in @ constants do
134
167
{ parents , nodes }
135
168
else
169
+ arg_vectorized_names = compute_arg_vectorized_names ( t , parent_vectorized_names )
136
170
parents = Map . update ( parents , arg . data . id , [ parent_id ] , & [ parent_id | & 1 ] )
137
- recur_parents_tree ( arg , { parents , nodes } )
171
+
172
+ recur_parents_tree ( arg , { parents , nodes } , arg_vectorized_names )
138
173
end
139
174
end )
140
175
end
@@ -191,10 +226,27 @@ defmodule Nx.Defn.Grad do
191
226
case nodes do
192
227
% { ^ id => _ } ->
193
228
{ nodes , grads } = traverse_parents ( id , to_grad_ids , parents , { nodes , grads } )
194
- { ans , nodes } = Map . pop! ( nodes , id )
229
+ { { ans , vectorized_names } , nodes } = Map . pop! ( nodes , id )
195
230
% T { data: % Expr { op: op , args: args } } = ans
196
231
{ gs , grads } = Map . pop ( grads , id )
197
232
233
+ { args , ans } =
234
+ if vectorized_names != [ ] do
235
+ args =
236
+ Enum . map ( args , fn
237
+ % T { } = arg ->
238
+ revectorize_node ( arg , vectorized_names )
239
+
240
+ opt ->
241
+ opt
242
+ end )
243
+
244
+ ans = Nx . vectorize ( ans , vectorized_names )
245
+ { args , ans }
246
+ else
247
+ { args , ans }
248
+ end
249
+
198
250
case gs do
199
251
nil ->
200
252
{ nodes , grads }
@@ -213,6 +265,22 @@ defmodule Nx.Defn.Grad do
213
265
end
214
266
end
215
267
268
+ defp compute_arg_vectorized_names ( % { vectorized_axes: vectorized_axes } , [ ] ) ,
269
+ do: Keyword . keys ( vectorized_axes )
270
+
271
+ defp compute_arg_vectorized_names (
272
+ % { vectorized_axes: vectorized_axes , names: names } ,
273
+ parent_names
274
+ ) do
275
+ Keyword . keys ( vectorized_axes ) ++ Enum . filter ( names , & ( & 1 in parent_names ) )
276
+ end
277
+
278
+ defp revectorize_node ( node , vectorized_names ) do
279
+ vectorized_names = compute_arg_vectorized_names ( node , vectorized_names )
280
+
281
+ Nx . vectorize ( node , vectorized_names )
282
+ end
283
+
216
284
defp update_grads ( :elem , [ % { type: { :tuple , size } } = tuple , pos ] , _ans , g , _to_grad_ids , grads ) do
217
285
update_in ( grads [ tuple . data . id ] , fn tuple ->
218
286
tuple = tuple || Tuple . duplicate ( [ ] , size )
0 commit comments