Skip to content

Commit 6e4dc0c

Browse files
committed
bug fix for alising outputs
1 parent b13f8bf commit 6e4dc0c

File tree

1 file changed

+21
-0
lines changed

1 file changed

+21
-0
lines changed

src/ProbProg.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,25 @@ function sample(
139139

140140
symbol_addr = reinterpret(UInt64, pointer_from_objref(symbol))
141141

142+
# (out_idx1, in_idx1, out_idx2, in_idx2, ...)
143+
alias_pairs = Int64[]
144+
for (out_idx, res) in enumerate(linear_results)
145+
if TracedUtils.has_idx(res, argprefix)
146+
in_idx = nothing
147+
for (i, arg) in enumerate(linear_args)
148+
if TracedUtils.has_idx(arg, argprefix) &&
149+
TracedUtils.get_idx(arg, argprefix) == TracedUtils.get_idx(res, argprefix)
150+
in_idx = i - 1
151+
break
152+
end
153+
end
154+
@assert in_idx !== nothing "Unable to find operand for aliased result"
155+
push!(alias_pairs, out_idx - 1)
156+
push!(alias_pairs, in_idx)
157+
end
158+
end
159+
alias_attr = MLIR.IR.DenseArrayAttribute(alias_pairs)
160+
142161
# Construct MLIR attribute if Julia logpdf function is provided.
143162
logpdf_attr = nothing
144163
if logpdf !== nothing
@@ -175,6 +194,8 @@ function sample(
175194
symbol=symbol_addr,
176195
traced_input_indices=traced_input_indices,
177196
traced_output_indices=traced_output_indices,
197+
alias_map=alias_attr,
198+
name=Base.String(symbol),
178199
)
179200

180201
for (i, res) in enumerate(linear_results)

0 commit comments

Comments
 (0)