Skip to content

Commit e5c82f8

Browse files
authored
Add stack as a callback (#1482)
1 parent 510e689 commit e5c82f8

File tree

11 files changed

+160
-55
lines changed

11 files changed

+160
-55
lines changed

exla/lib/exla/backend.ex

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,17 @@ defmodule EXLA.Backend do
243243
jit([], expr_fun, tensors, [List.to_tuple(tensors)])
244244
end
245245

246+
@impl true
247+
def stack(out, tensors, axis) do
248+
out = Nx.to_template(out)
249+
250+
expr_fun = fn tensors ->
251+
Nx.Defn.Expr.stack(out, Tuple.to_list(tensors), axis)
252+
end
253+
254+
jit([], expr_fun, tensors, [List.to_tuple(tensors)])
255+
end
256+
246257
@impl true
247258
def slice(out, tensor, start_indices, lengths, strides) do
248259
out = Nx.to_template(out)

exla/lib/exla/defn.ex

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,10 +1293,13 @@ defmodule EXLA.Defn do
12931293
end
12941294

12951295
defp to_operator(:concatenate, [[%Value{} | _rest] = tensors, axis], ans, _state) do
1296-
tensors =
1297-
tensors
1298-
|> Enum.map(&to_type(&1, ans.type))
1296+
tensors = Enum.map(tensors, &to_type(&1, ans.type))
1297+
Value.concatenate(tensors, axis, expr_to_typespec(ans))
1298+
end
12991299

1300+
defp to_operator(:stack, [[%Value{} | _rest] = tensors, axis], ans, _state) do
1301+
reshape_typespec = Typespec.tensor(ans.type, put_elem(ans.shape, axis, 1))
1302+
tensors = Enum.map(tensors, &(&1 |> to_type(ans.type) |> Value.reshape(reshape_typespec)))
13001303
Value.concatenate(tensors, axis, expr_to_typespec(ans))
13011304
end
13021305

nx/lib/nx.ex

Lines changed: 42 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3251,20 +3251,9 @@ defmodule Nx do
32513251
"""
32523252
@doc type: :shape, from_backend: false
32533253
def new_axis(tensor, axis, name \\ nil) when is_integer(axis) do
3254-
apply_vectorized(tensor, fn tensor, offset ->
3255-
%{shape: shape, names: names} = tensor = to_tensor(tensor)
3256-
rank = tuple_size(shape)
3257-
norm = if axis < 0, do: axis + rank + 1, else: axis + offset
3258-
3259-
if norm not in offset..tuple_size(shape) do
3260-
raise ArgumentError,
3261-
"new axis position for shape #{inspect(shape)} must be " <>
3262-
"a number between #{-rank - 1 + offset} and #{rank - offset}, got: #{axis}"
3263-
end
3264-
3265-
new_shape = Tuple.insert_at(shape, norm, 1)
3266-
new_names = List.insert_at(names, norm, name)
3267-
impl!(tensor).reshape(%{tensor | shape: new_shape, names: new_names}, tensor)
3254+
apply_vectorized(tensor, fn %{shape: shape, names: names} = tensor, offset ->
3255+
{shape, names, _axis} = Nx.Shape.new_axis(shape, names, axis, name, 1, offset)
3256+
impl!(tensor).reshape(%{tensor | shape: shape, names: names}, tensor)
32683257
end)
32693258
end
32703259

@@ -14668,28 +14657,35 @@ defmodule Nx do
1466814657
t
1466914658

1467014659
[_ | _] = tensors ->
14671-
[%T{vectorized_axes: vectorized_axes} | _] =
14672-
tensors = broadcast_vectors(tensors, align_ranks: true)
14660+
concatenate_or_stack(
14661+
tensors,
14662+
fn shapes, names, offset -> Nx.Shape.concatenate(shapes, names, axis, offset) end,
14663+
fn out, tensors, axis -> list_impl!(tensors).concatenate(out, tensors, axis) end
14664+
)
14665+
end
14666+
end
1467314667

14674-
offset = length(vectorized_axes)
14675-
tensors = if vectorized_axes != [], do: Enum.map(tensors, &devectorize/1), else: tensors
14668+
defp concatenate_or_stack(tensors, shape_and_name, callback) do
14669+
[%T{vectorized_axes: vectorized_axes} | _] =
14670+
tensors = broadcast_vectors(tensors, align_ranks: true)
1467614671

14677-
{types, [s1 | _] = shapes, [n1 | _] = names} =
14678-
Enum.reduce(tensors, {[], [], []}, fn
14679-
%T{type: t, shape: s, names: n}, {types, shapes, names} ->
14680-
{[t | types], [s | shapes], [n | names]}
14681-
end)
14672+
offset = length(vectorized_axes)
14673+
tensors = if vectorized_axes != [], do: Enum.map(tensors, &devectorize/1), else: tensors
14674+
14675+
{types, shapes, names} =
14676+
Enum.reduce(tensors, {[], [], []}, fn
14677+
%T{type: t, shape: s, names: n}, {types, shapes, names} ->
14678+
{[t | types], [s | shapes], [n | names]}
14679+
end)
1468214680

14683-
axis = Nx.Shape.normalize_axis(s1, axis, n1, offset)
14684-
output_type = Enum.reduce(types, &Nx.Type.merge/2)
14681+
output_type = Enum.reduce(types, &Nx.Type.merge/2)
1468514682

14686-
{output_shape, output_names} =
14687-
Nx.Shape.concatenate(Enum.reverse(shapes), Enum.reverse(names), axis)
14683+
{output_shape, output_names, axis} =
14684+
shape_and_name.(Enum.reverse(shapes), Enum.reverse(names), offset)
1468814685

14689-
out = %{hd(tensors) | type: output_type, shape: output_shape, names: output_names}
14690-
result = list_impl!(tensors).concatenate(out, tensors, axis)
14691-
vectorize(result, vectorized_axes)
14692-
end
14686+
out = %{hd(tensors) | type: output_type, shape: output_shape, names: output_names}
14687+
result = callback.(out, tensors, axis)
14688+
vectorize(result, vectorized_axes)
1469314689
end
1469414690

1469514691
defp flatten_list_or_container(list) when is_list(list) do
@@ -14807,16 +14803,26 @@ defmodule Nx do
1480714803
>
1480814804
1480914805
"""
14810-
@doc type: :ndim, from_backend: false
14806+
@doc type: :ndim
1481114807
def stack(tensors, opts \\ []) do
1481214808
opts = keyword!(opts, axis: 0, name: nil)
1481314809
axis = opts[:axis]
1481414810
name = opts[:name]
1481514811

14816-
tensors
14817-
|> flatten_list_or_container()
14818-
|> Enum.map(&Nx.new_axis(&1, axis, name))
14819-
|> Nx.concatenate(axis: axis)
14812+
case flatten_list_or_container(tensors) do
14813+
[] ->
14814+
raise ArgumentError, "no tensors were given to stack"
14815+
14816+
[t] ->
14817+
Nx.new_axis(t, axis, name)
14818+
14819+
[_ | _] = tensors ->
14820+
concatenate_or_stack(
14821+
tensors,
14822+
fn shapes, names, offset -> Nx.Shape.stack(shapes, names, axis, name, offset) end,
14823+
fn out, tensors, axis -> list_impl!(tensors).stack(out, tensors, axis) end
14824+
)
14825+
end
1482014826
end
1482114827

1482214828
@doc """

nx/lib/nx/backend.ex

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ defmodule Nx.Backend do
7575
@callback put_slice(out :: tensor, tensor, tensor, list) :: tensor
7676
@callback gather(out :: tensor, input :: tensor, indices :: tensor, keyword) :: tensor
7777
@callback concatenate(out :: tensor, tensor, axis) :: tensor
78+
@callback stack(out :: tensor, tensor, axis) :: tensor
7879
@callback select(out :: tensor, tensor, tensor, tensor) :: tensor
7980

8081
@callback conv(out :: tensor, tensor, kernel :: tensor, keyword) :: tensor

nx/lib/nx/binary_backend.ex

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1999,6 +1999,19 @@ defmodule Nx.BinaryBackend do
19991999
offset
20002000
end
20012001

2002+
@impl true
2003+
def stack(out, tensors, axis) do
2004+
%{shape: output_shape, type: {_, size} = output_type} = out
2005+
2006+
tensors
2007+
|> Enum.map(fn %{shape: shape} = t ->
2008+
t = as_type(%{t | type: output_type}, t)
2009+
{to_binary(t), Tuple.insert_at(shape, axis, 1)}
2010+
end)
2011+
|> bin_concatenate(size, axis, output_shape)
2012+
|> then(&from_binary(out, &1))
2013+
end
2014+
20022015
@impl true
20032016
def concatenate(out, tensors, axis) do
20042017
%{shape: output_shape, type: {_, size} = output_type} = out

nx/lib/nx/defn/evaluator.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ defmodule Nx.Defn.Evaluator do
2020
alias Nx.Defn.{Composite, Expr, Tree}
2121

2222
@creation_ops [:eye, :iota, :from_binary]
23-
@list_ops [:concatenate]
23+
@list_ops [:concatenate, :stack]
2424
@indices_ops [:slice, :put_slice]
2525

2626
@impl true

nx/lib/nx/defn/expr.ex

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1201,6 +1201,12 @@ defmodule Nx.Defn.Expr do
12011201
expr(out, context, :concatenate, [tensors, axis])
12021202
end
12031203

1204+
@impl true
1205+
def stack(out, tensors, axis) do
1206+
{tensors, context} = to_exprs(tensors)
1207+
expr(out, context, :stack, [tensors, axis])
1208+
end
1209+
12041210
@impl true
12051211
def triangular_solve(out, a, b, opts) do
12061212
{[a, b], context} = to_exprs([a, b])

nx/lib/nx/defn/grad.ex

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -614,6 +614,21 @@ defmodule Nx.Defn.Grad do
614614
[{x, g}]
615615
end
616616

617+
defp grad(:stack, [tensors, axis], ans, g) do
618+
zero_axes = List.duplicate(0, Nx.rank(ans))
619+
ans_shape_list = Tuple.to_list(ans.shape)
620+
621+
{pairs, _} =
622+
Enum.map_reduce(tensors, 0, fn t, limit ->
623+
current_limit = 1 + limit
624+
start = List.replace_at(zero_axes, axis, limit)
625+
len = List.replace_at(ans_shape_list, axis, 1)
626+
{{t, Nx.slice(g, start, len)}, current_limit}
627+
end)
628+
629+
pairs
630+
end
631+
617632
defp grad(:concatenate, [tensors, axis], ans, g) do
618633
zero_axes = List.duplicate(0, Nx.rank(ans))
619634
ans_shape_list = Tuple.to_list(ans.shape)

nx/lib/nx/defn/tree.ex

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ defmodule Nx.Defn.Tree do
202202
{[%{token | hooks: hooks}], acc}
203203
end
204204

205-
def apply_args(%T{data: %Expr{op: :concatenate, args: [list | args]}}, _type, acc, fun) do
205+
def apply_args(%T{data: %Expr{op: op, args: [list | args]}}, _type, acc, fun)
206+
when op in [:concatenate, :stack] do
206207
{list, acc} = Enum.map_reduce(list, acc, fun)
207208
{[list | args], acc}
208209
end

nx/lib/nx/shape.ex

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1648,17 +1648,65 @@ defmodule Nx.Shape do
16481648
{shape, names}
16491649
end
16501650

1651+
@doc """
1652+
Returns the shape and name of new axis.
1653+
"""
1654+
def new_axis(shape, names, axis, name, size, offset) do
1655+
rank = tuple_size(shape)
1656+
norm = if axis < 0, do: axis + rank + 1, else: axis + offset
1657+
1658+
if norm not in offset..tuple_size(shape) do
1659+
raise ArgumentError,
1660+
"new axis position for shape #{inspect(shape)} must be " <>
1661+
"a number between #{-rank - 1 + offset} and #{rank - offset}, got: #{axis}"
1662+
end
1663+
1664+
new_shape = Tuple.insert_at(shape, norm, size)
1665+
new_names = List.insert_at(names, norm, name)
1666+
{new_shape, new_names, norm}
1667+
end
1668+
1669+
@doc """
1670+
Returns the shape and names after a stack.
1671+
1672+
## Examples
1673+
1674+
iex> Nx.Shape.stack([{3, 2}, {3, 2}, {3, 2}], [[nil, nil], [nil, :z], [:y, nil]], 0, :x, 0)
1675+
{{3, 3, 2}, [:x, :y, :z], 0}
1676+
"""
1677+
def stack(shapes, names, axis, name, offset) do
1678+
names =
1679+
Enum.zip_with(names, fn zipped ->
1680+
Enum.reduce(zipped, &merge_names!(&1, &2, axis, axis))
1681+
end)
1682+
1683+
case Enum.uniq(shapes) do
1684+
[shape] ->
1685+
new_axis(shape, names, axis, name, length(shapes), offset)
1686+
1687+
shapes ->
1688+
raise ArgumentError,
1689+
"can only stack tensors of the same shape, got distinct shapes: #{inspect(shapes)}"
1690+
end
1691+
end
1692+
16511693
@doc """
16521694
Returns the shape and names after a concat.
16531695
16541696
## Examples
16551697
1656-
iex> Nx.Shape.concatenate([{2, 3, 2}, {1, 3, 2}, {4, 3, 2}], [[:x, :y, :z], [:x, :y, :z], [:x, :y, :z]], 0)
1657-
{{7, 3, 2}, [:x, :y, :z]}
1698+
iex> Nx.Shape.concatenate([{2, 3, 2}, {1, 3, 2}, {4, 3, 2}], [[:x, :y, :z], [:x, :y, :z], [:x, :y, :z]], 0, 0)
1699+
{{7, 3, 2}, [:x, :y, :z], 0}
16581700
"""
1659-
def concatenate(shapes, names, axis) do
1660-
names = validate_concat_names!(names, axis)
1661-
{concat_dims(shapes, axis), names}
1701+
def concatenate([s1 | _] = shapes, [n1 | _] = names, axis, offset) do
1702+
axis = normalize_axis(s1, axis, n1, offset)
1703+
1704+
names =
1705+
Enum.zip_with(names, fn zipped ->
1706+
Enum.reduce(zipped, &merge_names!(&1, &2, axis, axis))
1707+
end)
1708+
1709+
{concat_dims(shapes, axis), names, axis}
16621710
end
16631711

16641712
defp concat_dims([s1 | shapes] = all_shapes, axis) do
@@ -2120,15 +2168,6 @@ defmodule Nx.Shape do
21202168
)
21212169
end
21222170

2123-
defp validate_concat_names!(names, axis) do
2124-
_ =
2125-
Enum.zip_with(names, fn zipped ->
2126-
Enum.reduce(zipped, &merge_names!(&1, &2, axis, axis))
2127-
end)
2128-
2129-
hd(names)
2130-
end
2131-
21322171
def fft({}) do
21332172
raise ArgumentError, "expected a tensor with rank > 0, got tensor with rank 0"
21342173
end

0 commit comments

Comments
 (0)