File tree Expand file tree Collapse file tree 3 files changed +25
-14
lines changed Expand file tree Collapse file tree 3 files changed +25
-14
lines changed Original file line number Diff line number Diff line change @@ -80,22 +80,33 @@ defmodule Nx.Random do
80
80
[0, 12]
81
81
>
82
82
83
- A single key effectively consists of 64 bits, so all possible values
84
- of a 64-bit integer result in a different key. However, when passing
85
- an integer literal, Nx implicitly assumes its type to be `:s32`,
86
- which would result in an overflow for large integers. Therefore,
87
- when dealing with large seeds, make sure to explicitly use a 64 bit
88
- type:
89
-
90
- iex> Nx.Random.key(Nx.u64(999999999999))
83
+ iex> Nx.Random.key(999999999999)
91
84
#Nx.Tensor<
92
85
u32[2]
93
86
[232, 3567587327]
94
87
>
95
88
"""
96
- defn key ( seed ) do
97
- seed = Nx . as_type ( seed , :u64 )
89
+ deftransform key ( seed ) do
90
+ seed =
91
+ case seed do
92
+ seed when is_integer ( seed ) ->
93
+ Nx . u64 ( seed )
98
94
95
+ % Nx.Tensor { } = seed when seed . type == { :u , 64 } ->
96
+ seed
97
+
98
+ % Nx.Tensor { } = seed when seed . type == { :s , 64 } ->
99
+ Nx . bitcast ( seed , { :u , 64 } )
100
+
101
+ other ->
102
+ raise ArgumentError ,
103
+ "expected seed to be an integer, u64 tensor or s64 tensor, got: #{ inspect ( other ) } "
104
+ end
105
+
106
+ key_n ( seed )
107
+ end
108
+
109
+ defnp key_n ( seed ) do
99
110
k1 = Nx . right_shift ( seed , 32 )
100
111
k2 = Nx . bitwise_and ( seed , Nx . u64 ( 0xFFFFFFFF ) )
101
112
Original file line number Diff line number Diff line change @@ -67,7 +67,7 @@ defmodule Nx.RandomTest do
67
67
describe "distributions" do
68
68
defp distribution_case ( name , args: args , expected: expected ) do
69
69
seed = :erlang . adler32 ( "#{ name } threefry2x32" )
70
- key = Nx.Random . key ( Nx . u64 ( seed ) )
70
+ key = Nx.Random . key ( seed )
71
71
actual = apply ( Nx.Random , name , [ key | args ] )
72
72
73
73
assert_all_close ( actual , expected )
@@ -261,7 +261,7 @@ defmodule Nx.RandomTest do
261
261
|> assert_all_close ( apply ( expected_func , expected_args ) , rtol: 0.1 )
262
262
263
263
seed = :erlang . adler32 ( "uniformthreefry2x32" )
264
- key = Nx.Random . key ( Nx . tensor ( seed , type: :u64 ) )
264
+ key = Nx.Random . key ( seed )
265
265
t = apply ( Nx.Random , name , [ key | args ] )
266
266
267
267
apply ( Nx , moment , [ t ] )
Original file line number Diff line number Diff line change @@ -55,7 +55,7 @@ defmodule Torchx.Nx.RandomTest do
55
55
describe "distributions" do
56
56
defp distribution_case ( name , args: args , expected: expected ) do
57
57
seed = :erlang . adler32 ( "#{ name } threefry2x32" )
58
- key = Nx.Random . key ( Nx . u64 ( seed ) )
58
+ key = Nx.Random . key ( seed )
59
59
actual = apply ( Nx.Random , name , [ key | args ] )
60
60
61
61
assert_all_close ( actual , expected )
@@ -209,7 +209,7 @@ defmodule Torchx.Nx.RandomTest do
209
209
|> assert_all_close ( apply ( expected_func , expected_args ) , rtol: 0.1 )
210
210
211
211
seed = :erlang . adler32 ( "#{ name } threefry2x32" )
212
- key = Nx.Random . key ( Nx . tensor ( seed , type: :u64 ) )
212
+ key = Nx.Random . key ( seed )
213
213
t = apply ( Nx.Random , name , [ key | args ] )
214
214
215
215
apply ( Nx , moment , [ t ] )
You can’t perform that action at this time.
0 commit comments