Skip to content

Commit e9b3d73

Browse files
refactor: Nx.Random.key as deftransform (#1525)
Co-authored-by: Jonatan Kłosko <[email protected]>
1 parent 1013a52 commit e9b3d73

File tree

3 files changed

+25
-14
lines changed

3 files changed

+25
-14
lines changed

nx/lib/nx/random.ex

Lines changed: 21 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -80,22 +80,33 @@ defmodule Nx.Random do
8080
[0, 12]
8181
>
8282
83-
A single key effectively consists of 64 bits, so all possible values
84-
of a 64-bit integer result in a different key. However, when passing
85-
an integer literal, Nx implicitly assumes its type to be `:s32`,
86-
which would result in an overflow for large integers. Therefore,
87-
when dealing with large seeds, make sure to explicitly use a 64 bit
88-
type:
89-
90-
iex> Nx.Random.key(Nx.u64(999999999999))
83+
iex> Nx.Random.key(999999999999)
9184
#Nx.Tensor<
9285
u32[2]
9386
[232, 3567587327]
9487
>
9588
"""
96-
defn key(seed) do
97-
seed = Nx.as_type(seed, :u64)
89+
deftransform key(seed) do
90+
seed =
91+
case seed do
92+
seed when is_integer(seed) ->
93+
Nx.u64(seed)
9894

95+
%Nx.Tensor{} = seed when seed.type == {:u, 64} ->
96+
seed
97+
98+
%Nx.Tensor{} = seed when seed.type == {:s, 64} ->
99+
Nx.bitcast(seed, {:u, 64})
100+
101+
other ->
102+
raise ArgumentError,
103+
"expected seed to be an integer, u64 tensor or s64 tensor, got: #{inspect(other)}"
104+
end
105+
106+
key_n(seed)
107+
end
108+
109+
defnp key_n(seed) do
99110
k1 = Nx.right_shift(seed, 32)
100111
k2 = Nx.bitwise_and(seed, Nx.u64(0xFFFFFFFF))
101112

nx/test/nx/random_test.exs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ defmodule Nx.RandomTest do
6767
describe "distributions" do
6868
defp distribution_case(name, args: args, expected: expected) do
6969
seed = :erlang.adler32("#{name}threefry2x32")
70-
key = Nx.Random.key(Nx.u64(seed))
70+
key = Nx.Random.key(seed)
7171
actual = apply(Nx.Random, name, [key | args])
7272

7373
assert_all_close(actual, expected)
@@ -261,7 +261,7 @@ defmodule Nx.RandomTest do
261261
|> assert_all_close(apply(expected_func, expected_args), rtol: 0.1)
262262

263263
seed = :erlang.adler32("uniformthreefry2x32")
264-
key = Nx.Random.key(Nx.tensor(seed, type: :u64))
264+
key = Nx.Random.key(seed)
265265
t = apply(Nx.Random, name, [key | args])
266266

267267
apply(Nx, moment, [t])

torchx/test/torchx/random_test.exs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ defmodule Torchx.Nx.RandomTest do
5555
describe "distributions" do
5656
defp distribution_case(name, args: args, expected: expected) do
5757
seed = :erlang.adler32("#{name}threefry2x32")
58-
key = Nx.Random.key(Nx.u64(seed))
58+
key = Nx.Random.key(seed)
5959
actual = apply(Nx.Random, name, [key | args])
6060

6161
assert_all_close(actual, expected)
@@ -209,7 +209,7 @@ defmodule Torchx.Nx.RandomTest do
209209
|> assert_all_close(apply(expected_func, expected_args), rtol: 0.1)
210210

211211
seed = :erlang.adler32("#{name}threefry2x32")
212-
key = Nx.Random.key(Nx.tensor(seed, type: :u64))
212+
key = Nx.Random.key(seed)
213213
t = apply(Nx.Random, name, [key | args])
214214

215215
apply(Nx, moment, [t])

0 commit comments

Comments
 (0)