Skip to content

Commit 127126d

Browse files
committed
adding traced_output_indices attr to simulate op
1 parent 38e33de commit 127126d

File tree

1 file changed

+17
-4
lines changed

1 file changed

+17
-4
lines changed

src/ProbProg.jl

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -422,7 +422,7 @@ function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
422422

423423
out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]
424424
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
425-
fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
425+
fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
426426

427427
batch_inputs = MLIR.IR.Value[]
428428
for a in linear_args
@@ -437,7 +437,7 @@ function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
437437
end
438438
end
439439

440-
call_op = MLIR.Dialects.enzyme.untracedCall(batch_inputs; outputs=out_tys, fn=fname)
440+
call_op = MLIR.Dialects.enzyme.untracedCall(batch_inputs; outputs=out_tys, fn=fn_attr)
441441

442442
for (i, res) in enumerate(linear_results)
443443
resv = MLIR.IR.result(call_op, i)
@@ -504,7 +504,15 @@ function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
504504

505505
out_tys = [MLIR.IR.type(TracedUtils.get_mlir_data(res)) for res in linear_results]
506506
fname = TracedUtils.get_attribute_by_name(func2, "sym_name")
507-
fname = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
507+
fn_attr = MLIR.IR.FlatSymbolRefAttribute(Base.String(fname))
508+
509+
# Specify which outputs to add to the trace.
510+
traced_output_indices = Int[]
511+
for (i, res) in enumerate(linear_results)
512+
if TracedUtils.has_idx(res, resprefix)
513+
push!(traced_output_indices, i - 1)
514+
end
515+
end
508516

509517
batch_inputs = MLIR.IR.Value[]
510518
for a in linear_args
@@ -524,7 +532,12 @@ function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
524532
)::MLIR.IR.Type
525533
weight_ty = MLIR.IR.TensorType(Int64[], MLIR.IR.Type(Float64))
526534
simulate_op = MLIR.Dialects.enzyme.simulate(
527-
batch_inputs; trace=trace_ty, weight=weight_ty, outputs=out_tys, fn=fname
535+
batch_inputs;
536+
trace=trace_ty,
537+
weight=weight_ty,
538+
outputs=out_tys,
539+
fn=fn_attr,
540+
traced_output_indices=traced_output_indices,
528541
)
529542

530543
for (i, res) in enumerate(linear_results)

0 commit comments

Comments
 (0)