Skip to content

Commit 4e017d0

Browse files
committed
fix up copy
1 parent 6c7ffa3 commit 4e017d0

File tree

2 files changed

+10
-6
lines changed

2 files changed

+10
-6
lines changed

src/ProbProg.jl

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,13 @@ function addSampleToTraceLowered(
4747
trace[symbol] = unsafe_load(reinterpret(Ptr{element_type}, sample_ptr))
4848
else
4949
trace[symbol] = Base.deepcopy(
50-
unsafe_wrap(
51-
Array{element_type}, reinterpret(Ptr{element_type}, sample_ptr), prod(shape)
50+
reshape(
51+
unsafe_wrap(
52+
Array{element_type},
53+
reinterpret(Ptr{element_type}, sample_ptr),
54+
prod(shape),
55+
),
56+
shape,
5257
),
5358
)
5459
end

test/probprog/simulate.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,14 @@ function simulate_model(seed, μ, σ, shape)
88
function model(seed, μ, σ, shape)
99
rng = Random.default_rng()
1010
Random.seed!(rng, seed)
11-
s = ProbProg.sample!(normal, rng, μ, σ, shape; symbol = :s)
12-
t = ProbProg.sample!(normal, rng, s, σ, shape; symbol = :t)
11+
s = ProbProg.sample!(normal, rng, μ, σ, shape; symbol=:s)
12+
t = ProbProg.sample!(normal, rng, s, σ, shape; symbol=:t)
1313
return t
1414
end
1515

1616
return ProbProg.simulate!(model, seed, μ, σ, shape)
1717
end
1818

19-
2019
@testset "Simulate" begin
2120
@testset "normal_hlo" begin
2221
shape = (10000,)
@@ -36,7 +35,7 @@ end
3635
end
3736

3837
@testset "normal_simulate" begin
39-
shape = (10,)
38+
shape = (3, 3, 3)
4039
seed = Reactant.to_rarray(UInt64[1, 4])
4140
μ = Reactant.ConcreteRArray(0.0)
4241
σ = Reactant.ConcreteRArray(1.0)

0 commit comments

Comments
 (0)