Skip to content

Commit bb61c58

Browse files
Fix argmax/argmin behaviour with NaNs (#1499)
1 parent 7990b7e commit bb61c58

File tree

5 files changed

+79
-69
lines changed

5 files changed

+79
-69
lines changed

exla/lib/exla/lib.ex

Lines changed: 55 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ defmodule EXLA.Lib do
3434
def argmax(builder, op, type, opts \\ [])
3535

3636
def argmax(%Function{} = builder, %Value{} = op, type, opts) do
37-
argmin_or_max(builder, op, false, type, opts)
37+
argmin_or_max(builder, op, :max, type, opts)
3838
end
3939

4040
@doc """
@@ -49,37 +49,43 @@ defmodule EXLA.Lib do
4949
def argmin(builder, op, type, opts \\ [])
5050

5151
def argmin(%Function{} = builder, %Value{} = op, type, opts) do
52-
argmin_or_max(builder, op, true, type, opts)
52+
argmin_or_max(builder, op, :min, type, opts)
5353
end
5454

55-
defp argmin_or_max(builder, %Value{} = op, is_min?, type, opts) do
55+
defp argmin_or_max(builder, %Value{} = op, variant, type, opts) do
5656
tie_break = opts[:tie_break] || :low
5757
keep_axis = opts[:keep_axis] || false
58+
axis = opts[:axis]
5859

5960
op_typespec = Value.get_typespec(op)
6061

62+
{op, op_typespec} =
63+
if axis == nil and Nx.rank(op_typespec.shape) != 1 do
64+
# When no axis is given, we flatten the tensor and reduce over
65+
# the first axis
66+
typespec = Typespec.to_shape(op_typespec, {Nx.size(op_typespec.shape)})
67+
{Value.reshape(op, typespec), typespec}
68+
else
69+
{op, op_typespec}
70+
end
71+
72+
axis = axis || 0
73+
6174
init_value =
62-
if is_min?,
63-
do: max_number(builder, op_typespec.type),
64-
else: min_number(builder, op_typespec.type)
75+
case variant do
76+
:min -> max_number(builder, op_typespec.type)
77+
:max -> min_number(builder, op_typespec.type)
78+
end
6579

66-
axis = opts[:axis]
6780
index_init_value = Value.constant(builder, [0], Typespec.tensor(type, {}))
6881
iota = iota(builder, axis, Typespec.to_type(op_typespec, type))
69-
reduction = create_min_max_computation(builder, op_typespec.type, type, is_min?, tie_break)
82+
reduction = create_min_max_computation(builder, op_typespec.type, type, variant, tie_break)
7083

71-
dims =
72-
if axis do
73-
[axis]
74-
else
75-
Nx.axes(op_typespec.shape)
76-
end
77-
78-
shape = remove_axes(op_typespec.shape, dims)
84+
shape = Tuple.delete_at(op_typespec.shape, axis)
7985
typespecs = [Typespec.tensor(op_typespec.type, shape), Typespec.tensor(type, shape)]
8086

8187
[_, result] =
82-
Value.reduce(reduction, [init_value, index_init_value], [op, iota], dims, typespecs)
88+
Value.reduce(reduction, [init_value, index_init_value], [op, iota], [axis], typespecs)
8389

8490
if keep_axis do
8591
Value.reshape(result, Typespec.tensor(type, put_elem(op_typespec.shape, axis, 1)))
@@ -88,13 +94,7 @@ defmodule EXLA.Lib do
8894
end
8995
end
9096

91-
defp remove_axes(shape, axes) do
92-
axes
93-
|> Enum.reverse()
94-
|> Enum.reduce(shape, &Tuple.delete_at(&2, &1))
95-
end
96-
97-
defp create_min_max_computation(%Function{} = function, type, index_type, is_min?, tie_break) do
97+
defp create_min_max_computation(%Function{} = function, type, index_type, variant, tie_break) do
9898
arg_typespecs = [
9999
Typespec.tensor(type, {}),
100100
Typespec.tensor(index_type, {}),
@@ -109,27 +109,42 @@ defmodule EXLA.Lib do
109109
value_typespec = Typespec.tensor(type, {})
110110
idx_typespec = Typespec.tensor(index_type, {})
111111

112-
cmp =
113-
if is_min?,
114-
do: Value.less_equal(lhs_value, rhs_value, pred_typespec),
115-
else: Value.greater_equal(lhs_value, rhs_value, pred_typespec)
112+
comparator =
113+
case variant do
114+
:min -> &Value.less/3
115+
:max -> &Value.greater/3
116+
end
117+
118+
# Pick lhs if strictly before or if it is NaN
119+
pick_lhs_value =
120+
Value.bitwise_or(
121+
comparator.(lhs_value, rhs_value, pred_typespec),
122+
Value.is_nan(lhs_value, pred_typespec),
123+
pred_typespec
124+
)
116125

117-
max = Value.select(cmp, lhs_value, rhs_value, value_typespec)
118-
arg_max = Value.select(cmp, lhs_index, rhs_index, idx_typespec)
126+
max = Value.select(pick_lhs_value, lhs_value, rhs_value, value_typespec)
119127

120-
arg_max =
128+
idx_comparator =
121129
case tie_break do
122-
:low ->
123-
eq? = Value.equal(lhs_value, rhs_value, pred_typespec)
124-
id = Value.min(lhs_index, rhs_index, idx_typespec)
125-
Value.select(eq?, id, arg_max, idx_typespec)
126-
127-
:high ->
128-
eq? = Value.equal(lhs_value, rhs_value, pred_typespec)
129-
id = Value.max(lhs_index, rhs_index, idx_typespec)
130-
Value.select(eq?, id, arg_max, idx_typespec)
130+
:low -> &Value.less/3
131+
:high -> &Value.greater/3
131132
end
132133

134+
# If lhs and rhs are equal (and not NaN), then pick index based on tie_break
135+
pick_lhs_idx =
136+
Value.bitwise_or(
137+
pick_lhs_value,
138+
Value.bitwise_and(
139+
Value.equal(lhs_value, rhs_value, pred_typespec),
140+
idx_comparator.(lhs_index, rhs_index, pred_typespec),
141+
pred_typespec
142+
),
143+
pred_typespec
144+
)
145+
146+
arg_max = Value.select(pick_lhs_idx, lhs_index, rhs_index, idx_typespec)
147+
133148
Value.return(function, [max, arg_max])
134149
Function.pop_region(function)
135150
region

exla/lib/exla/mlir/value.ex

Lines changed: 3 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -157,34 +157,11 @@ defmodule EXLA.MLIR.Value do
157157
end
158158
end
159159

160-
def is_nan(%Value{function: func} = operand, out_typespec) do
161-
%{type: type} = get_typespec(operand)
162-
160+
def is_nan(%Value{} = operand, out_typespec) do
163161
typespec = Typespec.to_type(out_typespec, {:pred, 8})
164162

165-
result =
166-
cond do
167-
Nx.Type.complex?(type) ->
168-
float_typespec = Typespec.to_type(typespec, complex_part_type(type))
169-
real = real(operand, float_typespec)
170-
imag = imag(operand, float_typespec)
171-
is_nan_real = is_nan(real, typespec)
172-
is_nan_imag = is_nan(imag, typespec)
173-
bitwise_or(is_nan_real, is_nan_imag, typespec)
174-
175-
Nx.Type.integer?(type) ->
176-
# Integers are never nan. We use inequality to make sure
177-
# the operand is still a part of the computation
178-
not_equal(operand, operand, typespec)
179-
180-
true ->
181-
result_types = typespecs_to_mlir_types([typespec])
182-
is_inf = op(func, "chlo.is_inf", [operand], result_types) |> one!()
183-
is_finite = op(func, "stablehlo.is_finite", [operand], result_types) |> one!()
184-
is_not_inf = bitwise_not(is_inf, typespec)
185-
is_not_finite = bitwise_not(is_finite, typespec)
186-
bitwise_and(is_not_inf, is_not_finite, typespec)
187-
end
163+
# Only NaN is not equal to itself
164+
result = not_equal(operand, operand, typespec)
188165

189166
if out_typespec.type == typespec.type do
190167
result

nx/lib/nx.ex

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10008,6 +10008,15 @@ defmodule Nx do
1000810008
1
1000910009
>
1001010010
10011+
If the tensor includes any NaNs, returns the index of any of them
10012+
(NaNs are not equal, hence tie-break does not apply):
10013+
10014+
iex> Nx.argmax(Nx.tensor([2.0, :nan, 4.0]))
10015+
#Nx.Tensor<
10016+
s64
10017+
1
10018+
>
10019+
1001110020
### Aggregating over an axis
1001210021
1001310022
iex> t = Nx.tensor([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]])
@@ -10147,6 +10156,15 @@ defmodule Nx do
1014710156
0
1014810157
>
1014910158
10159+
If the tensor includes any NaNs, returns the index of any of them
10160+
(NaNs are not equal, hence tie-break does not apply):
10161+
10162+
iex> Nx.argmin(Nx.tensor([2.0, :nan, 4.0]))
10163+
#Nx.Tensor<
10164+
s64
10165+
1
10166+
>
10167+
1015010168
### Aggregating over an axis
1015110169
1015210170
iex> t = Nx.tensor([[[4, 2, 3], [1, -5, 3]], [[6, 2, 3], [4, 8, 3]]])

nx/lib/nx/binary_backend.ex

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1461,7 +1461,7 @@ defmodule Nx.BinaryBackend do
14611461
bin, {i, cur_extreme_x, cur_extreme_i} ->
14621462
x = binary_to_number(bin, type)
14631463

1464-
if cur_extreme_x == :first or comparator.(x, cur_extreme_x) do
1464+
if cur_extreme_x == :first or x == :nan or comparator.(x, cur_extreme_x) do
14651465
{i, {i + 1, x, i}}
14661466
else
14671467
{cur_extreme_i, {i + 1, cur_extreme_x, cur_extreme_i}}

nx/test/nx_test.exs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1443,7 +1443,7 @@ defmodule NxTest do
14431443
[:nan, 0, 1]
14441444
])
14451445

1446-
assert Nx.argmin(t, axis: 1) == Nx.tensor([0, 0, 0, 0, 2, 2, 1, 1, 0, 0, 0, 0])
1446+
assert Nx.argmin(t, axis: 1) == Nx.tensor([0, 1, 0, 0, 2, 1, 1, 1, 0, 1, 0, 0])
14471447
end
14481448

14491449
test "raises for invalid :tie_break option" do
@@ -1475,7 +1475,7 @@ defmodule NxTest do
14751475
[:nan, 0, 1]
14761476
])
14771477

1478-
assert Nx.argmax(t, axis: 1) == Nx.tensor([1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0])
1478+
assert Nx.argmax(t, axis: 1) == Nx.tensor([1, 1, 2, 2, 0, 1, 0, 0, 0, 1, 0, 0])
14791479
end
14801480
end
14811481

0 commit comments

Comments
 (0)