Skip to content

Commit 762d3ee

Browse files
fix: broadcast vectors for grad calculation (#1535)
Co-authored-by: José Valim <[email protected]>
1 parent 8102cd9 commit 762d3ee

File tree

4 files changed

+189
-37
lines changed

4 files changed

+189
-37
lines changed

nx/lib/nx.ex

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5420,9 +5420,13 @@ defmodule Nx do
54205420
{_, [], 0} ->
54215421
fun.(left, right)
54225422

5423-
{[devec_left, devec_right], canonical_vectorized_axes, _offset} ->
5424-
devec_left
5425-
|> fun.(devec_right)
5423+
{[devec_left, devec_right], canonical_vectorized_axes, offset} ->
5424+
leading_names = Keyword.keys(canonical_vectorized_axes)
5425+
l = %{devec_left | names: leading_names ++ Enum.drop(devec_left.names, offset)}
5426+
r = %{devec_right | names: leading_names ++ Enum.drop(devec_right.names, offset)}
5427+
5428+
l
5429+
|> fun.(r)
54265430
|> vectorize(canonical_vectorized_axes)
54275431
end
54285432
end

nx/lib/nx/defn/expr.ex

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1394,6 +1394,10 @@ defmodule Nx.Defn.Expr do
13941394

13951395
## Constant helpers and related optimizations
13961396

1397+
defp constant(%{vectorized_axes: [_ | _]} = out, number) do
1398+
tensor(Nx.fill(out, number, type: out.type))
1399+
end
1400+
13971401
defp constant(%{shape: shape, type: type} = out, number) do
13981402
number =
13991403
cond do
@@ -1661,7 +1665,7 @@ defmodule Nx.Defn.Expr do
16611665

16621666
defp counter_to_name(counter), do: [?a + counter]
16631667

1664-
defp to_type_shape(%{type: type, shape: shape}) do
1668+
defp to_type_shape(%{vectorized_axes: [], type: type, shape: shape}) do
16651669
brackets =
16661670
shape
16671671
|> Tuple.to_list()

nx/lib/nx/defn/grad.ex

Lines changed: 96 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,10 @@ defmodule Nx.Defn.Grad do
1919
{:env, env} = Function.info(fun, :env)
2020
ids = stop_grads(env, ids)
2121

22-
# save vectorized axes before devectorizing
23-
expr = to_grad |> fun.()
22+
expr = fun.(to_grad)
2423

25-
transformed_expr = transform.(expr) |> validate_expr!() |> Nx.devectorize(keep_names: false)
24+
transformed_expr =
25+
expr |> transform.() |> validate_expr!()
2626

2727
{parents, nodes} = parents_tree(transformed_expr, ids)
2828

@@ -33,23 +33,17 @@ defmodule Nx.Defn.Grad do
3333
Composite.traverse(
3434
to_grad,
3535
{nodes, grads},
36-
fn %{vectorized_axes: vectorized_axes} = node, acc ->
37-
node
38-
|> Nx.devectorize(keep_names: false)
39-
|> to_grad(to_grad_ids, parents, acc)
40-
|> then(fn {node, acc} ->
41-
{Nx.vectorize(node, vectorized_axes), acc}
42-
end)
36+
fn node, acc ->
37+
to_grad(node, to_grad_ids, parents, acc)
4338
end
4439
)
4540

4641
{expr, graded}
4742
end
4843

49-
defp constant(float, shape) do
50-
shape = Nx.shape(shape)
44+
defp constant(float, %T{shape: shape} = t) do
5145
names = List.duplicate(nil, tuple_size(shape))
52-
Expr.constant(%T{shape: shape, type: {:f, 32}, names: names}, float, [])
46+
Expr.constant(%T{t | names: names, type: {:f, 32}}, float, [])
5347
end
5448

5549
defp validate_expr!(%T{data: %Expr{}} = expr) do
@@ -94,47 +88,88 @@ defmodule Nx.Defn.Grad do
9488
[:equal, :greater, :greater_equal, :less, :less_equal, :not_equal, :argsort]
9589

9690
defp parents_tree(expr, nodes) do
97-
Composite.reduce(expr, {%{}, nodes}, &recur_parents_tree/2)
91+
Composite.reduce(
92+
expr,
93+
{%{}, nodes},
94+
&recur_parents_tree(
95+
Nx.devectorize(&1, keep_names: true),
96+
&2,
97+
Keyword.keys(&1.vectorized_axes)
98+
)
99+
)
98100
end
99101

100-
defp recur_parents_tree(%T{data: %Expr{id: id, op: op}} = t, {parents, nodes}) do
102+
defp recur_parents_tree(%T{data: %Expr{id: id, op: op}} = t, {parents, nodes}, vectorized_names) do
101103
case nodes do
102-
%{^id => _} -> {parents, nodes}
103-
%{} -> parents_args(op, t, id, {parents, Map.put(nodes, id, t)})
104+
%{^id => _} ->
105+
{parents, nodes}
106+
107+
%{} ->
108+
# We use this to compute the proper axis sizes for the tensor
109+
nodes = Map.put(nodes, id, {t, vectorized_names})
110+
111+
parents_args(op, t, id, {parents, nodes}, vectorized_names)
104112
end
105113
end
106114

107-
defp parents_args(:metadata, %{data: %{args: [_, %{stop_grad: true}]}}, _id, acc) do
115+
defp parents_args(
116+
:metadata,
117+
%{data: %{args: [_, %{stop_grad: true}]}},
118+
_id,
119+
acc,
120+
_parent_vectorized_names
121+
) do
108122
acc
109123
end
110124

111-
defp parents_args(:optional, %{data: %{args: [call, _expr, callback]}} = t, id, acc) do
125+
defp parents_args(
126+
:optional,
127+
%{data: %{args: [call, _expr, callback]}} = t,
128+
id,
129+
acc,
130+
parent_vectorized_names
131+
) do
112132
expr = apply(callback, call.data.args)
113133

114134
# Now traverse over the optional expression where args are the new parameters.
115135
# Once we access the parameter itself, we point the parameter to the arg.
116-
{parents, nodes} =
117-
Composite.reduce(expr, acc, fn expr, {parents, nodes} ->
118-
parents = Map.update(parents, expr.data.id, [id], &[id | &1])
119-
recur_parents_tree(expr, {parents, nodes})
136+
{{parents, nodes}, _} =
137+
Composite.reduce(expr, {acc, parent_vectorized_names}, fn
138+
expr, {{parents, nodes}, expr_vectorized_names} ->
139+
arg_vectorized_names = compute_arg_vectorized_names(expr, expr_vectorized_names)
140+
parents = Map.update(parents, expr.data.id, [id], &[id | &1])
141+
142+
acc =
143+
recur_parents_tree(
144+
expr,
145+
{parents, nodes},
146+
arg_vectorized_names
147+
)
148+
149+
{acc, expr_vectorized_names}
120150
end)
121151

122-
{parents, Map.put(nodes, id, put_in(t.data.args, [call, expr, callback]))}
152+
updated_node =
153+
{put_in(t.data.args, [call, expr, callback]), parent_vectorized_names}
154+
155+
{parents, Map.put(nodes, id, updated_node)}
123156
end
124157

125158
# We register cond as a special node to avoid pretraversing it.
126159
# Instead we traverse it early on on the grad computation.
127-
defp parents_args(:cond, _, id, {parents, nodes}) do
160+
defp parents_args(:cond, _, id, {parents, nodes}, _parent_vectorized_names) do
128161
{Map.update(parents, __MODULE__, [id], &[id | &1]), nodes}
129162
end
130163

131-
defp parents_args(op, t, parent_id, acc) do
164+
defp parents_args(op, t, parent_id, acc, parent_vectorized_names) do
132165
reduce_args(op, t, acc, fn arg, {parents, nodes} ->
133166
if arg.data.op in @constants do
134167
{parents, nodes}
135168
else
169+
arg_vectorized_names = compute_arg_vectorized_names(t, parent_vectorized_names)
136170
parents = Map.update(parents, arg.data.id, [parent_id], &[parent_id | &1])
137-
recur_parents_tree(arg, {parents, nodes})
171+
172+
recur_parents_tree(arg, {parents, nodes}, arg_vectorized_names)
138173
end
139174
end)
140175
end
@@ -191,10 +226,27 @@ defmodule Nx.Defn.Grad do
191226
case nodes do
192227
%{^id => _} ->
193228
{nodes, grads} = traverse_parents(id, to_grad_ids, parents, {nodes, grads})
194-
{ans, nodes} = Map.pop!(nodes, id)
229+
{{ans, vectorized_names}, nodes} = Map.pop!(nodes, id)
195230
%T{data: %Expr{op: op, args: args}} = ans
196231
{gs, grads} = Map.pop(grads, id)
197232

233+
{args, ans} =
234+
if vectorized_names != [] do
235+
args =
236+
Enum.map(args, fn
237+
%T{} = arg ->
238+
revectorize_node(arg, vectorized_names)
239+
240+
opt ->
241+
opt
242+
end)
243+
244+
ans = Nx.vectorize(ans, vectorized_names)
245+
{args, ans}
246+
else
247+
{args, ans}
248+
end
249+
198250
case gs do
199251
nil ->
200252
{nodes, grads}
@@ -213,6 +265,22 @@ defmodule Nx.Defn.Grad do
213265
end
214266
end
215267

268+
defp compute_arg_vectorized_names(%{vectorized_axes: vectorized_axes}, []),
269+
do: Keyword.keys(vectorized_axes)
270+
271+
defp compute_arg_vectorized_names(
272+
%{vectorized_axes: vectorized_axes, names: names},
273+
parent_names
274+
) do
275+
Keyword.keys(vectorized_axes) ++ Enum.filter(names, &(&1 in parent_names))
276+
end
277+
278+
defp revectorize_node(node, vectorized_names) do
279+
vectorized_names = compute_arg_vectorized_names(node, vectorized_names)
280+
281+
Nx.vectorize(node, vectorized_names)
282+
end
283+
216284
defp update_grads(:elem, [%{type: {:tuple, size}} = tuple, pos], _ans, g, _to_grad_ids, grads) do
217285
update_in(grads[tuple.data.id], fn tuple ->
218286
tuple = tuple || Tuple.duplicate([], size)

nx/test/nx/defn/grad_test.exs

Lines changed: 81 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4256,21 +4256,97 @@ defmodule Nx.Defn.GradTest do
42564256
end
42574257

42584258
describe "vectorization" do
4259-
test "supports vectorization" do
4259+
test "supports combination of vectorized and non-vectorized tensors" do
4260+
x = Nx.tensor([[1, 2, 3], [4, 5, 6]]) |> Nx.vectorize(:x)
4261+
y = 1
4262+
4263+
grad = Nx.Defn.grad(y, fn y -> Nx.add(x, y) end)
4264+
4265+
assert grad == Nx.tensor([3.0, 3.0]) |> Nx.vectorize([:x])
4266+
end
4267+
4268+
test "supports combination of vectorized and non-vectorized tensors over composed function" do
4269+
x = Nx.tensor([[1, 2, 3], [4, 5, 6]], names: [:x, :y]) |> Nx.vectorize(:x)
4270+
y = 1
4271+
4272+
grad = Nx.Defn.grad(y, fn y -> Nx.add(y, Nx.sin(x)) end)
4273+
assert grad == Nx.tensor([3.0, 3.0]) |> Nx.vectorize([:x])
4274+
4275+
grad = Nx.Defn.grad(x, fn x -> Nx.add(y, Nx.sin(x)) end)
4276+
assert grad == Nx.cos(x)
4277+
end
4278+
4279+
# Skipping this as it's not supported yet.
4280+
@tag :skip
4281+
test "edge case where the same name changes meaning" do
4282+
x = Nx.tensor([[1], [2], [3]]) |> Nx.vectorize(x: 3)
4283+
4284+
grad =
4285+
Nx.Defn.grad(x, fn t ->
4286+
devec = Nx.devectorize(t, keep_names: true)
4287+
new_axis = Nx.reshape(devec, {1, 3, 1}, names: [:x, nil, nil])
4288+
4289+
Nx.vectorize(new_axis, x: 1)
4290+
end)
4291+
4292+
assert grad == Nx.tensor([[1], [1], [1]]) |> Nx.vectorize(x: 3)
4293+
end
4294+
4295+
test "supports heterogenous vectorization combinations" do
42604296
x = Nx.tensor([[1, 2, 3], [4, 5, 6]])
42614297
y = Nx.tensor([10, 20])
42624298

42634299
# first case: y is vectorized scalar, x is vectorized vectors, different vectorized axis names
42644300
# expected result: equivalent to fully broadcasting one tensor onto the other
42654301
x_vec = Nx.vectorize(x, :x)
42664302
y_vec = Nx.vectorize(y, :y)
4267-
{grad_x_vec, grad_y_vec} = Nx.Defn.grad({x_vec, y_vec}, fn {a, b} -> Nx.multiply(a, b) end)
42684303

4304+
grad_fun = fn x, y ->
4305+
Nx.Defn.grad({x, y}, fn {a, b} -> Nx.multiply(a, b) end)
4306+
end
4307+
4308+
{grad_x_vec, grad_y_vec} = grad_fun.(x_vec, y_vec)
4309+
4310+
# Explicit assertion on the results
42694311
assert grad_x_vec ==
4270-
Nx.tensor([[30.0, 30.0, 30.0], [30.0, 30.0, 30.0]])
4271-
|> Nx.vectorize(x_vec.vectorized_axes)
4312+
Nx.tensor([
4313+
[
4314+
[10.0, 10.0, 10.0],
4315+
[20.0, 20.0, 20.0]
4316+
],
4317+
[
4318+
[10.0, 10.0, 10.0],
4319+
[20.0, 20.0, 20.0]
4320+
]
4321+
])
4322+
|> Nx.vectorize([:x, :y])
4323+
4324+
assert grad_y_vec ==
4325+
Nx.tensor([
4326+
[6.0, 6.0],
4327+
[15.0, 15.0]
4328+
])
4329+
|> Nx.vectorize([:x, :y])
42724330

4273-
assert grad_y_vec == Nx.tensor([21.0, 21.0]) |> Nx.vectorize(y_vec.vectorized_axes)
4331+
# Conceptual assertion: the result should be equivalent to calling Nx.Defn.grad with
4332+
# each cross-entry of the combined vectors [(x0, y0), (x0, y1), (x1, y0), (x1, y1)]
4333+
4334+
{x0y0_wrt_x, x0y0_wrt_y} = grad_fun.(x[0], y[0])
4335+
{x0y1_wrt_x, x0y1_wrt_y} = grad_fun.(x[0], y[1])
4336+
{x1y0_wrt_x, x1y0_wrt_y} = grad_fun.(x[1], y[0])
4337+
{x1y1_wrt_x, x1y1_wrt_y} = grad_fun.(x[1], y[1])
4338+
4339+
assert grad_x_vec ==
4340+
[x0y0_wrt_x, x0y1_wrt_x, x1y0_wrt_x, x1y1_wrt_x]
4341+
|> Nx.stack()
4342+
|> Nx.reshape({2, 2, 3})
4343+
|> Nx.vectorize([:x, :y])
4344+
4345+
assert grad_y_vec ==
4346+
[x0y0_wrt_y, x0y1_wrt_y, x1y0_wrt_y, x1y1_wrt_y]
4347+
|> Nx.stack()
4348+
|> Nx.reshape({2, 2})
4349+
|> Nx.vectorize([:x, :y])
42744350

42754351
# second case: y is vectorized scalar, x is vectorized vectors, same vectorized axis name
42764352
# expected result: equivalent to "row-wise" broadcasting

0 commit comments

Comments
 (0)