@@ -1016,19 +1016,150 @@ end
10161016end
10171017
10181018# random ops
1019+ """
1020+ rng_bit_generator(
1021+ ::Type{T},
1022+ seed::TracedRArray{UInt64,1},
1023+ shape;
1024+ algorithm::String="DEFAULT",
1025+ location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
1026+ )
1027+
1028+ Generate a random array of type `T` with the given shape and seed from a uniform random
1029+ distribution between 0 and 1. Returns a NamedTuple with the following fields:
1030+
1031+ - `output_state`: The state of the random number generator after the operation.
1032+ - `output`: The generated array.
1033+
1034+ # Arguments
1035+
1036+ - `T`: The type of the generated array.
1037+ - `seed`: The seed for the random number generator.
1038+ - `shape`: The shape of the generated array.
1039+ - `algorithm`: The algorithm to use for generating the random numbers. Defaults to
1040+ "DEFAULT". Other options include "PHILOX" and "THREE_FRY".
1041+ """
10191042@noinline function rng_bit_generator (
1043+ :: Type{T} ,
10201044 seed:: TracedRArray{UInt64,1} ,
10211045 shape;
10221046 algorithm:: String = " DEFAULT" ,
10231047 location= mlir_stacktrace (" rng_bit_generator" , @__FILE__ , @__LINE__ ),
1024- )
1025- output = MLIR. IR. TensorType (TracedRArray{UInt64,1 }, shape)
1048+ ) where {T<: Integer }
1049+ @assert algorithm in (" DEFAULT" , " PHILOX" , " THREE_FRY" )
1050+ if algorithm == " PHILOX"
1051+ @assert length (seed) ∈ (2 , 3 )
1052+ elseif algorithm == " THREE_FRY"
1053+ @assert length (seed) == 2
1054+ end
1055+
1056+ output = MLIR. IR. TensorType (shape, MLIR. IR. Type (T))
1057+ output_state = MLIR. IR. TensorType (size (seed), MLIR. IR. Type (UInt64))
10261058 rng_algorithm = MLIR. API. stablehloRngAlgorithmAttrGet (MLIR. IR. context (), algorithm)
1027- op = stablehlo. rng_bit_generator (seed. mlir_data; output, rng_algorithm, location)
1059+ op = stablehlo. rng_bit_generator (
1060+ seed. mlir_data; output, output_state, rng_algorithm, location
1061+ )
10281062 return (;
1029- output_state= TracedRArray {UInt64,1} ((), MLIR. IR. result (op, 1 ), MLIR. IR. size (seed)),
1030- output= TracedRArray {T,length(shape)} ((), MLIR. IR. result (op, 2 ), shape),
1063+ output_state= TracedRArray {UInt64,1} ((), MLIR. IR. result (op, 1 ), size (seed)),
1064+ output= TracedRArray {T,length(shape)} ((), MLIR. IR. result (op, 2 ), Tuple (shape)),
1065+ )
1066+ end
1067+
1068+ @noinline function rng_bit_generator (
1069+ :: Type{T} ,
1070+ seed:: TracedRArray{UInt64,1} ,
1071+ shape;
1072+ algorithm:: String = " DEFAULT" ,
1073+ location= mlir_stacktrace (" rng_bit_generator" , @__FILE__ , @__LINE__ ),
1074+ ) where {T<: AbstractFloat }
1075+ nbits = sizeof (T) * 8
1076+ uT = nbits == 16 ? UInt16 : (nbits == 32 ? UInt32 : UInt64)
1077+ (; output_state, output) = rng_bit_generator (uT, seed, shape; algorithm, location)
1078+ output = divide (
1079+ convert (TracedRArray{T,ndims (output)}, output),
1080+ constant (fill (T (typemax (uT)), Tuple (shape)); location),
1081+ )
1082+ return (; output_state, output)
1083+ end
1084+
1085+ """
1086+ randn(
1087+ ::Type{T},
1088+ seed::TracedRArray{UInt64,1},
1089+ shape;
1090+ algorithm::String="DEFAULT",
1091+ location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
10311092 )
1093+
1094+ Generate a random array of type `T` with the given shape and seed from a standard normal
1095+ distribution of mean 0 and standard deviation 1. Returns a NamedTuple with the following
1096+ fields:
1097+
1098+ - `output_state`: The state of the random number generator after the operation.
1099+ - `output`: The generated array.
1100+
1101+ # Arguments
1102+
1103+ - `T`: The type of the generated array.
1104+ - `seed`: The seed for the random number generator.
1105+ - `shape`: The shape of the generated array.
1106+ - `algorithm`: The algorithm to use for generating the random numbers. Defaults to
1107+ "DEFAULT". Other options include "PHILOX" and "THREE_FRY".
1108+ """
1109+ @noinline function randn (
1110+ :: Type{T} ,
1111+ seed:: TracedRArray{UInt64,1} ,
1112+ shape;
1113+ algorithm:: String = " DEFAULT" ,
1114+ location= mlir_stacktrace (" rand" , @__FILE__ , @__LINE__ ),
1115+ ) where {T}
1116+ res = rng_bit_generator (T, seed, shape; algorithm, location)
1117+ rand_uniform = res. output
1118+ seed = res. output_state
1119+ scaled_uniform = subtract (
1120+ multiply (rand_uniform, constant (fill (T (2 ), size (rand_uniform)))),
1121+ constant (fill (T (1 ), size (rand_uniform))),
1122+ )
1123+ probit = erf_inv (scaled_uniform)
1124+ rand_normal = multiply (probit, constant (fill (Base. sqrt (T (2 )), size (rand_uniform))))
1125+ return (; output_state= seed, output= rand_normal)
1126+ end
1127+
1128+ """
1129+ randexp(
1130+ ::Type{T},
1131+ seed::TracedRArray{UInt64,1},
1132+ shape;
1133+ algorithm::String="DEFAULT",
1134+ location=mlir_stacktrace("rand", @__FILE__, @__LINE__),
1135+ )
1136+
1137+ Generate a random array of type `T` with the given shape and seed from an exponential
1138+ distribution with rate 1. Returns a NamedTuple with the following fields:
1139+
1140+ - `output_state`: The state of the random number generator after the operation.
1141+ - `output`: The generated array.
1142+
1143+ # Arguments
1144+
1145+ - `T`: The type of the generated array.
1146+ - `seed`: The seed for the random number generator.
1147+ - `shape`: The shape of the generated array.
1148+ - `algorithm`: The algorithm to use for generating the random numbers. Defaults to
1149+ "DEFAULT". Other options include "PHILOX" and "THREE_FRY".
1150+ """
1151+ @noinline function randexp (
1152+ :: Type{T} ,
1153+ seed:: TracedRArray{UInt64,1} ,
1154+ shape;
1155+ algorithm:: String = " DEFAULT" ,
1156+ location= mlir_stacktrace (" rand" , @__FILE__ , @__LINE__ ),
1157+ ) where {T}
1158+ res = rng_bit_generator (T, seed, shape; algorithm, location)
1159+ rand_uniform = res. output
1160+ seed = res. output_state
1161+ rand_exp = negate (log_plus_one (negate (rand_uniform)))
1162+ return (; output_state= seed, output= rand_exp)
10321163end
10331164
10341165# functional ops
0 commit comments