-
Notifications
You must be signed in to change notification settings - Fork 34
Open
Description
(base) wmoses@MacBook-Pro-18 Reactant.jl % julia --project
_
_ _ _(_)_ | Documentation: https://docs.julialang.org
(_) | (_) (_) |
_ _ _| |_ __ _ | Type "?" for help, "]?" for Pkg help.
| | | | | | |/ _` | |
| | |_| | | | (_| | | Version 1.10.10 (2025-06-27)
_/ |\__'_|_|_|\__'_| | Official https://julialang.org/ release
|__/ |
julia> using Random
julia> using Reactant
julia> function loop!(random_field)
randn!(random_field)
return nothing
end
loop! (generic function with 1 method)
julia> random_field = Reactant.to_rarray(zeros(5, 1))
5×1 ConcretePJRTArray{Float64,2}:
0.0
0.0
0.0
0.0
0.0
julia> rloop! = @compile raise_first=true raise=true sync=true loop!(random_field)
Reactant compiled function loop! (with tag ##loop!_reactant#229)
julia> rloop!(random_field)
julia> @show random_field
random_field = ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}([-2.024062402609222; 0.4787963512522881; -0.09207401450871815; -0.7751920252093334; -0.20832415620781702;;])
5×1 ConcretePJRTArray{Float64,2}:
-2.024062402609222
0.4787963512522881
-0.09207401450871815
-0.7751920252093334
-0.20832415620781702
julia> rloop!(random_field)
julia> @show random_field
random_field = ConcretePJRTArray{Float64, 2, 1, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}}([-2.024062402609222; 0.4787963512522881; -0.09207401450871815; -0.7751920252093334; -0.20832415620781702;;])
5×1 ConcretePJRTArray{Float64,2}:
-2.024062402609222
0.4787963512522881
-0.09207401450871815
-0.7751920252093334
-0.20832415620781702
julia> @code_hlo loop!(random_field)
module @"reactant_loop!" attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<1x5xf64> {enzymexla.memory_effects = [], tf.aliasing_output = 0 : i32}) -> tensor<1x5xf64> attributes {enzymexla.memory_effects = []} {
%cst = stablehlo.constant dense<2.000000e+00> : tensor<1x5xf64>
%cst_0 = stablehlo.constant dense<1.000000e+00> : tensor<1x5xf64>
%cst_1 = stablehlo.constant dense<1.4142135623730951> : tensor<1x5xf64>
%c = stablehlo.constant dense<4607182418800017408> : tensor<5x1xui64>
%c_2 = stablehlo.constant dense<12> : tensor<5x1xui64>
%c_3 = stablehlo.constant dense<[16598013818565186068, 100540593900791911]> : tensor<2xui64>
%output_state, %output = stablehlo.rng_bit_generator %c_3, algorithm = DEFAULT : (tensor<2xui64>) -> (tensor<2xui64>, tensor<5x1xui64>)
%0 = stablehlo.shift_right_logical %output, %c_2 : tensor<5x1xui64>
%1 = stablehlo.or %0, %c : tensor<5x1xui64>
%2 = stablehlo.bitcast_convert %1 : (tensor<5x1xui64>) -> tensor<5x1xf64>
%3 = stablehlo.reshape %2 : (tensor<5x1xf64>) -> tensor<1x5xf64>
%4 = stablehlo.subtract %3, %cst_0 : tensor<1x5xf64>
%5 = stablehlo.multiply %4, %cst : tensor<1x5xf64>
%6 = stablehlo.subtract %5, %cst_0 : tensor<1x5xf64>
%7 = chlo.erf_inv %6 : tensor<1x5xf64> -> tensor<1x5xf64>
%8 = stablehlo.multiply %7, %cst_1 : tensor<1x5xf64>
return %8 : tensor<1x5xf64>
}
}
Metadata
Metadata
Assignees
Labels
No labels