Skip to content

Commit 46e0f6b

Browse files
committed
transpose fix up
1 parent ef2e770 commit 46e0f6b

File tree

4 files changed

+63
-23
lines changed

4 files changed

+63
-23
lines changed

src/ProbProg.jl

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ function addSampleToTraceLowered(
2727
elseif datatype_width == 64
2828
Float64
2929
else
30-
error("Unsupported datatype width: $datatype_width")
30+
@ccall printf("Unsupported datatype width: %d\n"::Cstring, datatype_width::Cint)::Cvoid
31+
return nothing
3132
end
3233

3334
typed_ptr = Ptr{julia_type}(sample_ptr)
@@ -65,6 +66,7 @@ end
6566
(),
6667
string(f),
6768
false;
69+
do_transpose=false,
6870
args_in_result=:all,
6971
argprefix,
7072
resprefix,
@@ -97,19 +99,19 @@ end
9799
resv = MLIR.IR.result(gen_op, i)
98100
if TracedUtils.has_idx(res, resprefix)
99101
path = TracedUtils.get_idx(res, resprefix)
100-
TracedUtils.set!(result, path[2:end], TracedUtils.transpose_val(resv))
102+
TracedUtils.set!(result, path[2:end], resv)
101103
elseif TracedUtils.has_idx(res, argprefix)
102104
idx, path = TracedUtils.get_argidx(res, argprefix)
103105
if idx == 1 && fnwrap
104-
TracedUtils.set!(f, path[3:end], TracedUtils.transpose_val(resv))
106+
TracedUtils.set!(f, path[3:end], resv)
105107
else
106108
if fnwrap
107109
idx -= 1
108110
end
109-
TracedUtils.set!(args[idx], path[3:end], TracedUtils.transpose_val(resv))
111+
TracedUtils.set!(args[idx], path[3:end], resv)
110112
end
111113
else
112-
TracedUtils.set!(res, (), TracedUtils.transpose_val(resv))
114+
TracedUtils.set!(res, (), resv)
113115
end
114116
end
115117

@@ -130,8 +132,8 @@ end
130132
(),
131133
string(f),
132134
false;
135+
do_transpose=false,
133136
args_in_result=:all,
134-
do_transpose=false, # TODO: double check transpose
135137
argprefix,
136138
resprefix,
137139
resargprefix,
@@ -177,19 +179,19 @@ end
177179
resv = MLIR.IR.result(sample_op, i)
178180
if TracedUtils.has_idx(res, resprefix)
179181
path = TracedUtils.get_idx(res, resprefix)
180-
TracedUtils.set!(result, path[2:end], TracedUtils.transpose_val(resv))
182+
TracedUtils.set!(result, path[2:end], resv)
181183
elseif TracedUtils.has_idx(res, argprefix)
182184
idx, path = TracedUtils.get_argidx(res, argprefix)
183185
if idx == 1 && fnwrap
184-
TracedUtils.set!(f, path[3:end], TracedUtils.transpose_val(resv))
186+
TracedUtils.set!(f, path[3:end], resv)
185187
else
186188
if fnwrap
187189
idx -= 1
188190
end
189-
TracedUtils.set!(args[idx], path[3:end], TracedUtils.transpose_val(resv))
191+
TracedUtils.set!(args[idx], path[3:end], resv)
190192
end
191193
else
192-
TracedUtils.set!(res, (), TracedUtils.transpose_val(resv))
194+
TracedUtils.set!(res, (), resv)
193195
end
194196
end
195197

@@ -210,6 +212,7 @@ end
210212
(),
211213
string(f),
212214
false;
215+
do_transpose=false,
213216
args_in_result=:all,
214217
argprefix,
215218
resprefix,
@@ -246,19 +249,19 @@ end
246249
resv = MLIR.IR.result(simulate_op, i)
247250
if TracedUtils.has_idx(res, resprefix)
248251
path = TracedUtils.get_idx(res, resprefix)
249-
TracedUtils.set!(result, path[2:end], TracedUtils.transpose_val(resv))
252+
TracedUtils.set!(result, path[2:end], resv)
250253
elseif TracedUtils.has_idx(res, argprefix)
251254
idx, path = TracedUtils.get_argidx(res, argprefix)
252255
if idx == 1 && fnwrap
253-
TracedUtils.set!(f, path[3:end], TracedUtils.transpose_val(resv))
256+
TracedUtils.set!(f, path[3:end], resv)
254257
else
255258
if fnwrap
256259
idx -= 1
257260
end
258-
TracedUtils.set!(args[idx], path[3:end], TracedUtils.transpose_val(resv))
261+
TracedUtils.set!(args[idx], path[3:end], resv)
259262
end
260263
else
261-
TracedUtils.set!(res, (), TracedUtils.transpose_val(resv))
264+
TracedUtils.set!(res, (), resv)
262265
end
263266
end
264267

@@ -271,7 +274,7 @@ function print_trace(trace::Dict{Symbol,Any})
271274
println(" $symbol:")
272275
println(" Sample: $(sample)")
273276
end
274-
println("### End of Trace ###")
277+
return println("### End of Trace ###")
275278
end
276279

277280
end

test/probprog/generate.jl

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ function generate_model(seed, μ, σ, shape)
1616
end
1717

1818
@testset "Generate" begin
19-
@testset "normal_deterministic" begin
19+
@testset "deterministic" begin
2020
shape = (10000,)
2121
seed1 = Reactant.to_rarray(UInt64[1, 4])
2222
seed2 = Reactant.to_rarray(UInt64[1, 4])
@@ -38,7 +38,7 @@ end
3838
Array(model_compiled(seed2, μ2, σ2, shape)),
3939
))
4040
end
41-
@testset "normal_hlo" begin
41+
@testset "hlo" begin
4242
shape = (10000,)
4343
seed = Reactant.to_rarray(UInt64[1, 4])
4444
μ = Reactant.ConcreteRNumber(0.0)
@@ -53,12 +53,28 @@ end
5353
@test !contains(repr(after), "enzyme.sample")
5454
end
5555

56-
@testset "normal_generate" begin
56+
@testset "normal" begin
5757
shape = (10000,)
5858
seed = Reactant.to_rarray(UInt64[1, 4])
5959
μ = Reactant.ConcreteRNumber(0.0)
6060
σ = Reactant.ConcreteRNumber(1.0)
6161
X = Array(@jit optimize = :probprog generate_model(seed, μ, σ, shape))
6262
@test mean(X) 0.0 atol = 0.05 rtol = 0.05
6363
end
64+
65+
@testset "correctness" begin
66+
op(x, y) = x * y'
67+
68+
function fake_model(x, y)
69+
return ProbProg.sample!(op, x, y)
70+
end
71+
72+
x = reshape(collect(Float64, 1:12), (4, 3))
73+
y = reshape(collect(Float64, 1:12), (4, 3))
74+
x_ra = Reactant.to_rarray(x)
75+
y_ra = Reactant.to_rarray(y)
76+
77+
@test Array(@jit optimize = :probprog ProbProg.generate!(fake_model, x_ra, y_ra)) ==
78+
op(x, y)
79+
end
6480
end

test/probprog/sample.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ function sample2(seed, μ, σ, shape)
1818
function model(seed, μ, σ, shape)
1919
rng = Random.default_rng()
2020
Random.seed!(rng, seed)
21-
s = ProbProg.sample!(normal, rng, μ, σ, shape)
21+
_ = ProbProg.sample!(normal, rng, μ, σ, shape)
2222
t = ProbProg.sample!(normal, rng, μ, σ, shape)
2323
return t
2424
end

test/probprog/simulate.jl

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,32 @@ using Reactant: ProbProg
4242

4343
trace = ProbProg.createTrace()
4444

45-
result = Array(
46-
@jit optimize = :probprog sync = true simulate_model(trace, seed, μ, σ, shape)
47-
)
45+
result = Array(@jit optimize = :probprog simulate_model(trace, seed, μ, σ, shape))
4846

49-
ProbProg.print_trace(trace)
5047
@test size(result) == shape
48+
@test haskey(trace, :s)
49+
@test haskey(trace, :t)
50+
@test size(trace[:s]) == shape
51+
@test size(trace[:t]) == shape
52+
end
53+
54+
@testset "correctness" begin
55+
op(x, y) = x * y'
56+
function fake_model(x, y)
57+
return ProbProg.sample!(op, x, y; symbol=:matmul)
58+
end
59+
60+
trace = ProbProg.createTrace()
61+
x = reshape(collect(Float64, 1:12), (4, 3))
62+
y = reshape(collect(Float64, 1:12), (4, 3))
63+
x_ra = Reactant.to_rarray(x)
64+
y_ra = Reactant.to_rarray(y)
65+
66+
@test Array(
67+
@jit optimize = :probprog ProbProg.simulate!(fake_model, x_ra, y_ra; trace)
68+
) == op(x, y)
69+
70+
@test haskey(trace, :matmul)
71+
@test trace[:matmul] == op(x, y)
5172
end
5273
end

0 commit comments

Comments
 (0)