Skip to content

Commit 773f4d6

Browse files
Remove Nx.map/2 (#1493)
1 parent 1563988 commit 773f4d6

File tree

11 files changed

+113
-291
lines changed

11 files changed

+113
-291
lines changed

exla/lib/exla/backend.ex

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -353,7 +353,6 @@ defmodule EXLA.Backend do
353353
{:window_product, [:tensor, :shape, :opts], [:tensor]},
354354
{:window_max, [:tensor, :shape, :opts], [:tensor]},
355355
{:window_min, [:tensor, :shape, :opts], [:tensor]},
356-
{:map, [:tensor, :opts, :fun], [:tensor]},
357356
{:sort, [:tensor, :opts], [:tensor]},
358357
{:argsort, [:tensor, :opts], [:tensor]},
359358
{:window_scatter_max, [:tensor, :source, :init_value, :window_dims, :opts],

exla/lib/exla/defn.ex

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1207,11 +1207,6 @@ defmodule EXLA.Defn do
12071207
mlir_scatter(tensors, out, :put)
12081208
end
12091209

1210-
defp to_operator(:map, [%Value{} = arg, _opts, fun], ans, _state) do
1211-
arg = to_type(arg, ans.type)
1212-
Value.map(fun, [arg], Nx.axes(ans.shape), expr_to_typespec(ans))
1213-
end
1214-
12151210
defp to_operator(op, [arg, opts], ans, state) when op in [:argmax, :argmin] do
12161211
apply(EXLA.Lib, op, [state.builder, arg, ans.type, opts])
12171212
end

exla/lib/exla/mlir/value.ex

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -660,24 +660,6 @@ defmodule EXLA.MLIR.Value do
660660
)
661661
end
662662

663-
def map(
664-
%Region{ref: mapper},
665-
[%Value{function: func} | _] = inputs,
666-
dimensions,
667-
typespec
668-
) do
669-
result_types = typespecs_to_mlir_types([typespec])
670-
671-
attributes = [
672-
dimensions: attr_array_i64_elements(dimensions)
673-
]
674-
675-
regions = [mapper]
676-
677-
op(func, "stablehlo.map", inputs, result_types, attributes: attributes, regions: regions)
678-
|> one!()
679-
end
680-
681663
def if_op(
682664
%Value{function: func} = pred,
683665
%Region{ref: on_true},

exla/test/exla/defn/expr_test.exs

Lines changed: 14 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,64 +1499,23 @@ defmodule EXLA.Defn.ExprTest do
14991499
end
15001500
end
15011501

1502-
describe "map" do
1503-
defn map_plus(t), do: Nx.map(t, fn x -> x + 1 end)
1504-
defn map_equal(t), do: Nx.map(t, [type: {:f, 64}], fn x -> Nx.equal(x, 1) end)
1505-
defn map_exp(t), do: Nx.map(t, [type: {:f, 64}], fn x -> Nx.exp(x) end)
1506-
1507-
@tag :unsupported_64_bit_op
1508-
test "maps a function over the tensor" do
1509-
assert_equal(map_plus(Nx.tensor([[1, 2, 3], [4, 5, 6]])), Nx.tensor([[2, 3, 4], [5, 6, 7]]))
1510-
end
1511-
1512-
@tag :unsupported_64_bit_op
1513-
test "maps a function with an output type" do
1514-
assert_equal(
1515-
map_equal(Nx.tensor([[1, 2, 3], [4, 5, 6]])),
1516-
Nx.tensor([[1.0, 0.0, 0.0], [0.0, 0.0, 0.0]], type: {:f, 64})
1517-
)
1518-
1519-
assert_equal(
1520-
map_exp(Nx.tensor([[1, 2, 3], [4, 5, 6]])),
1521-
Nx.tensor(
1522-
[
1523-
[2.718281828459045, 7.38905609893065, 20.085536923187668],
1524-
[54.598150033144236, 148.4131591025766, 403.4287934927351]
1525-
],
1526-
type: {:f, 64}
1527-
)
1528-
)
1529-
end
1530-
1531-
defn map_conditional(t), do: Nx.map(t, fn x -> if x > 0, do: x, else: -x end)
1532-
1533-
@tag :conditional_inside_map_reduce
1534-
@tag :unsupported_64_bit_op
1535-
test "maps a function with conditional" do
1536-
assert_equal(
1537-
map_conditional(Nx.tensor([-2, -1, 0, 1, 2])),
1538-
Nx.tensor([2, 1, 0, 1, 2])
1539-
)
1540-
end
1541-
1542-
defn while_inside_if(pred, x) do
1543-
if pred do
1544-
{x, _} =
1545-
while {x, i = 0}, i < 10 do
1546-
{x, i + 1}
1547-
end
1502+
defn while_inside_if(pred, x) do
1503+
if pred do
1504+
{x, _} =
1505+
while {x, i = 0}, i < 10 do
1506+
{x, i + 1}
1507+
end
15481508

1549-
x
1550-
else
1551-
x
1552-
end
1509+
x
1510+
else
1511+
x
15531512
end
1513+
end
15541514

1555-
test "while inside if" do
1556-
assert %{a: a, b: b} = while_inside_if(1, %{a: 1, b: 2.0})
1557-
assert_all_close(a, 1)
1558-
assert_all_close(b, 2.0)
1559-
end
1515+
test "while inside if" do
1516+
assert %{a: a, b: b} = while_inside_if(1, %{a: 1, b: 2.0})
1517+
assert_all_close(a, 1)
1518+
assert_all_close(b, 2.0)
15601519
end
15611520

15621521
describe "reduce" do

nx/lib/nx.ex

Lines changed: 0 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -11912,94 +11912,6 @@ defmodule Nx do
1191211912
end)
1191311913
end
1191411914

11915-
@doc """
11916-
Maps the given scalar function over the entire
11917-
tensor.
11918-
11919-
The type of the returned tensor will be of the same type
11920-
as the input tensor, unless the `:type` option is given.
11921-
Therefore, you may need to explicitly cast the tensor to
11922-
avoid errors. For example, if you have an integer tensor
11923-
and you convert it to a float, as below, it will fail:
11924-
11925-
tensor = Nx.tensor([[1, 2, 3], [4, 5, 6]]),
11926-
Nx.map(tensor, fn x -> Nx.multiply(x, 1.0) end)
11927-
11928-
You need to explicitly pass the output type in such cases:
11929-
11930-
iex> tensor = Nx.tensor([[1, 2, 3], [4, 5, 6]])
11931-
iex> Nx.map(tensor, [type: :f32], fn x -> Nx.multiply(x, 1.0) end)
11932-
#Nx.Tensor<
11933-
f32[2][3]
11934-
[
11935-
[1.0, 2.0, 3.0],
11936-
[4.0, 5.0, 6.0]
11937-
]
11938-
>
11939-
11940-
## Limitations
11941-
11942-
Given this function relies on anonymous functions, it
11943-
may not be available or efficient on all Nx backends.
11944-
Therefore, you should avoid using `map/2` whenever possible
11945-
and use other functions in the `Nx` module to achieve the
11946-
desired result.
11947-
11948-
Inside `defn`, consider using `Nx.Defn.Kernel.while/4` instead.
11949-
11950-
## Examples
11951-
11952-
iex> Nx.map(Nx.tensor([[1, 2, 3], [4, 5, 6]]), fn x -> Nx.add(x, 1) end)
11953-
#Nx.Tensor<
11954-
s64[2][3]
11955-
[
11956-
[2, 3, 4],
11957-
[5, 6, 7]
11958-
]
11959-
>
11960-
11961-
iex> Nx.map(Nx.tensor(1), fn x -> Nx.add(x, 1) end)
11962-
#Nx.Tensor<
11963-
s64
11964-
2
11965-
>
11966-
11967-
iex> Nx.map(Nx.tensor([[1, 2, 3], [4, 5, 6]]), [type: :f64], fn x -> Nx.add(x, 1) end)
11968-
#Nx.Tensor<
11969-
f64[2][3]
11970-
[
11971-
[2.0, 3.0, 4.0],
11972-
[5.0, 6.0, 7.0]
11973-
]
11974-
>
11975-
11976-
## Vectorized tensors
11977-
11978-
`map/3` behaves the same as with non-vectorized tensors, applying
11979-
`fun` in an element-wise fashion.
11980-
11981-
iex> Nx.map(Nx.tensor([[1, 2, 3], [4, 5, 6]]) |> Nx.vectorize(:x), [type: :f64], &Nx.add(&1, 1))
11982-
#Nx.Tensor<
11983-
vectorized[x: 2]
11984-
f64[3]
11985-
[
11986-
[2.0, 3.0, 4.0],
11987-
[5.0, 6.0, 7.0]
11988-
]
11989-
>
11990-
"""
11991-
@doc type: :element
11992-
def map(tensor, opts \\ [], fun) do
11993-
apply_vectorized(tensor, fn tensor ->
11994-
%T{type: type} = tensor
11995-
11996-
opts = keyword!(opts, type: type)
11997-
output_type = Nx.Type.normalize!(opts[:type])
11998-
out = %{tensor | type: output_type}
11999-
impl!(tensor).map(out, tensor, opts, fun)
12000-
end)
12001-
end
12002-
1200311915
## Matrix ops
1200411916

1200511917
@doc """

nx/lib/nx/backend.ex

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ defmodule Nx.Backend do
9393
@callback window_product(out :: tensor, tensor, shape, keyword) :: tensor
9494
@callback window_max(out :: tensor, tensor, shape, keyword) :: tensor
9595
@callback window_min(out :: tensor, tensor, shape, keyword) :: tensor
96-
@callback map(out :: tensor, tensor, keyword, fun) :: tensor
9796
@callback sort(out :: tensor, tensor, keyword) :: tensor
9897
@callback argsort(out :: tensor, tensor, keyword) :: tensor
9998
@callback window_scatter_max(out :: tensor, tensor, tensor, tensor, shape, keyword) :: tensor

nx/lib/nx/binary_backend.ex

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1561,20 +1561,6 @@ defmodule Nx.BinaryBackend do
15611561
window_reduce(out, tensor, init_value, window_dimensions, opts, fun)
15621562
end
15631563

1564-
@impl true
1565-
def map(%{type: output_type} = out, %{type: {_, size}} = tensor, _opts, fun) do
1566-
data = to_binary(tensor)
1567-
template = %{tensor | shape: {}}
1568-
1569-
output_data =
1570-
for <<bin::size(size)-bitstring <- data>>, into: <<>> do
1571-
tensor = put_in(template.data.state, bin)
1572-
number_to_binary(scalar_to_number(fun.(tensor)), output_type)
1573-
end
1574-
1575-
from_binary(out, output_data)
1576-
end
1577-
15781564
@impl true
15791565
def window_scatter_max(out, tensor, source, init_value, window_dimensions, opts) do
15801566
select_and_scatter(

nx/lib/nx/defn/expr.ex

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -985,13 +985,6 @@ defmodule Nx.Defn.Expr do
985985
expr(out, context, :window_reduce, [tensor, acc, window_dims, opts, fun])
986986
end
987987

988-
@impl true
989-
def map(%{type: type} = out, tensor, opts, fun) do
990-
args = [parameter(new_context(:map), type, {}, 0)]
991-
%{data: %{context: context}} = tensor = to_expr(tensor)
992-
expr(out, context, :map, [tensor, opts, apply_fun(context, fun, args, type)])
993-
end
994-
995988
@impl true
996989
def window_scatter_max(out, tensor, source, init_value, window_dims, opts) do
997990
{[tensor, source, init_value], context} = to_exprs([tensor, source, init_value])

nx/test/nx/defn/evaluator_test.exs

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -269,16 +269,6 @@ defmodule Nx.Defn.EvaluatorTest do
269269
test "calls external anonymous function via reduce" do
270270
assert calls_reduce_fun(&Nx.add/2, Nx.tensor([1, 2, 3])) == Nx.tensor(6)
271271
end
272-
273-
defn calls_map_fun(t) do
274-
Nx.map(t, fn x ->
275-
if Nx.equal(x, 0), do: 1, else: -x
276-
end)
277-
end
278-
279-
test "calls internal anonymous function via map" do
280-
assert calls_map_fun(Nx.tensor([0, 1, 2])) == Nx.tensor([1, -1, -2])
281-
end
282272
end
283273

284274
describe "access" do
@@ -691,7 +681,7 @@ defmodule Nx.Defn.EvaluatorTest do
691681
t = Nx.iota({2, 3}, vectorized_axes: [a: 1], type: :s64)
692682

693683
message = """
694-
test/nx/defn/evaluator_test.exs:660: the do-block in while must return tensors with the same shape, type, and names as the initial arguments.
684+
test/nx/defn/evaluator_test.exs:650: the do-block in while must return tensors with the same shape, type, and names as the initial arguments.
695685
696686
{\e[32m
697687
<<<<< Body (do-block) <<<<<
@@ -723,7 +713,7 @@ defmodule Nx.Defn.EvaluatorTest do
723713

724714
error =
725715
"""
726-
test/nx/defn/evaluator_test.exs:660: condition must be a scalar tensor, got: #Nx.Tensor<
716+
test/nx/defn/evaluator_test.exs:650: condition must be a scalar tensor, got: #Nx.Tensor<
727717
vectorized[x: 1]
728718
u8[1]
729719
\s\s

0 commit comments

Comments
 (0)