@@ -139,6 +139,25 @@ function sample(
139
139
140
140
symbol_addr = reinterpret (UInt64, pointer_from_objref (symbol))
141
141
142
+ # (out_idx1, in_idx1, out_idx2, in_idx2, ...)
143
+ alias_pairs = Int64[]
144
+ for (out_idx, res) in enumerate (linear_results)
145
+ if TracedUtils. has_idx (res, argprefix)
146
+ in_idx = nothing
147
+ for (i, arg) in enumerate (linear_args)
148
+ if TracedUtils. has_idx (arg, argprefix) &&
149
+ TracedUtils. get_idx (arg, argprefix) == TracedUtils. get_idx (res, argprefix)
150
+ in_idx = i - 1
151
+ break
152
+ end
153
+ end
154
+ @assert in_idx != = nothing " Unable to find operand for aliased result"
155
+ push! (alias_pairs, out_idx - 1 )
156
+ push! (alias_pairs, in_idx)
157
+ end
158
+ end
159
+ alias_attr = MLIR. IR. DenseArrayAttribute (alias_pairs)
160
+
142
161
# Construct MLIR attribute if Julia logpdf function is provided.
143
162
logpdf_attr = nothing
144
163
if logpdf != = nothing
@@ -175,6 +194,8 @@ function sample(
175
194
symbol= symbol_addr,
176
195
traced_input_indices= traced_input_indices,
177
196
traced_output_indices= traced_output_indices,
197
+ alias_map= alias_attr,
198
+ name= Base. String (symbol),
178
199
)
179
200
180
201
for (i, res) in enumerate (linear_results)
0 commit comments