Skip to content

Commit a21d30b

Browse files
fix: Nx.Random.shuffle repeating a single value in certain cases on GPU (#1552)
Co-authored-by: Jonatan Klosko <[email protected]>
1 parent 7af065e commit a21d30b

File tree

2 files changed

+45
-24
lines changed

2 files changed

+45
-24
lines changed

exla/lib/exla/defn.ex

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1367,30 +1367,42 @@ defmodule EXLA.Defn do
13671367

13681368
## Computation helpers
13691369

1370-
defp sort_computation(op, type, arg_typespecs, %{builder: %EXLA.MLIR.Function{} = function}) do
1370+
defp sort_computation(operator, type, arg_typespecs, %{
1371+
builder: %EXLA.MLIR.Function{} = function
1372+
}) do
13711373
{region, [lhs, rhs | _]} = Function.push_region(function, arg_typespecs)
13721374

13731375
typespec = Typespec.tensor({:pred, 8}, {})
13741376

1375-
op =
1376-
cond do
1377-
Nx.Type.integer?(type) ->
1378-
apply(Value, op, [lhs, rhs, typespec])
1379-
1380-
op == :less ->
1381-
is_nan = Value.is_nan(rhs, typespec)
1382-
Value.bitwise_or(is_nan, Value.less(lhs, rhs, typespec), typespec)
1383-
1384-
op == :greater ->
1385-
is_nan = Value.is_nan(lhs, typespec)
1386-
Value.bitwise_or(is_nan, Value.greater(lhs, rhs, typespec), typespec)
1377+
{lhs, rhs} =
1378+
if Nx.Type.integer?(type) do
1379+
{lhs, rhs}
1380+
else
1381+
{sort_computation_canonicalize_float(lhs), sort_computation_canonicalize_float(rhs)}
13871382
end
13881383

1384+
op = apply(Value, operator, [lhs, rhs, typespec, [total_order: true]])
1385+
13891386
Value.return(function, [op])
13901387
Function.pop_region(function)
13911388
region
13921389
end
13931390

1391+
defp sort_computation_canonicalize_float(%Value{function: func} = op) do
1392+
# Standardize the representation of NaNs (-NaN, NaN) and zeros (-0, 0).
1393+
# See https://github.com/google/jax/blob/e81c82605f0e1813080cfe1037d043b27b38291d/jax/_src/lax/lax.py#L4248-L4253
1394+
1395+
op_typespec = Value.get_typespec(op)
1396+
1397+
zero = Value.constant(func, [0], Typespec.to_shape(op_typespec, {}))
1398+
zeros = Value.constant(func, [0], op_typespec)
1399+
nans = Value.constant(func, [:nan], op_typespec)
1400+
1401+
pred_typespec = Typespec.tensor({:pred, 8}, {})
1402+
op = Value.select(Value.equal(op, zero, pred_typespec), zeros, op, op_typespec)
1403+
Value.select(Value.is_nan(op, pred_typespec), nans, op, op_typespec)
1404+
end
1405+
13941406
defp op_computation(
13951407
op,
13961408
arg_typespecs,

exla/lib/exla/mlir/value.ex

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -54,31 +54,40 @@ defmodule EXLA.MLIR.Value do
5454
}
5555

5656
for {op, direction} <- @bin_comparison_ops do
57-
def unquote(op)(%Value{function: func} = lhs, %Value{function: func} = rhs, typespec) do
58-
compare_and_return_bool(func, lhs, rhs, typespec, unquote(direction))
57+
def unquote(op)(
58+
%Value{function: func} = lhs,
59+
%Value{function: func} = rhs,
60+
typespec,
61+
opts \\ []
62+
) do
63+
compare_and_return_bool(func, lhs, rhs, typespec, unquote(direction), opts[:total_order])
5964
end
6065
end
6166

62-
defp compare_and_return_bool(func, lhs, rhs, typespec, direction) do
67+
defp compare_and_return_bool(func, lhs, rhs, typespec, direction, total_order? \\ false) do
6368
%{type: lhs_type} = get_typespec(lhs)
6469
%{type: rhs_type} = get_typespec(rhs)
6570

6671
comparison_type =
6772
cond do
6873
Nx.Type.complex?(lhs_type) or Nx.Type.complex?(rhs_type) ->
69-
attr_comparison_type(:float)
74+
[compare_type: attr_comparison_type(:float)]
7075

7176
Nx.Type.float?(lhs_type) or Nx.Type.float?(rhs_type) ->
72-
attr_comparison_type(:float)
77+
attr =
78+
if total_order? do
79+
attr_comparison_type(:totalorder)
80+
else
81+
attr_comparison_type(:float)
82+
end
83+
84+
[compare_type: attr]
7385

7486
true ->
75-
attr_comparison_type(:notype)
87+
[]
7688
end
7789

78-
attributes = [
79-
comparison_direction: attr_comparison_direction(direction),
80-
compare_type: comparison_type
81-
]
90+
attributes = [comparison_direction: attr_comparison_direction(direction)] ++ comparison_type
8291

8392
result_types = typespecs_to_mlir_types([Typespec.to_type(typespec, {:pred, 8})])
8493

@@ -1072,7 +1081,7 @@ defmodule EXLA.MLIR.Value do
10721081
defp attr_comparison_direction(value) when value in [:eq, :lt, :le, :gt, :ge, :ne],
10731082
do: attr_enum("stablehlo", "comparison_direction", value)
10741083

1075-
defp attr_comparison_type(value) when value in [:float, :totalorder, :notype],
1084+
defp attr_comparison_type(value) when value in [:float, :totalorder],
10761085
do: attr_enum("stablehlo", "comparison_type", value)
10771086

10781087
defp attr_precision(value) when value in [:default, :high, :highest],

0 commit comments

Comments
 (0)