@@ -422,7 +422,7 @@ function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
422
422
423
423
out_tys = [MLIR. IR. type (TracedUtils. get_mlir_data (res)) for res in linear_results]
424
424
fname = TracedUtils. get_attribute_by_name (func2, " sym_name" )
425
- fname = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
425
+ fn_attr = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
426
426
427
427
batch_inputs = MLIR. IR. Value[]
428
428
for a in linear_args
@@ -437,7 +437,7 @@ function call_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
437
437
end
438
438
end
439
439
440
- call_op = MLIR. Dialects. enzyme. untracedCall (batch_inputs; outputs= out_tys, fn= fname )
440
+ call_op = MLIR. Dialects. enzyme. untracedCall (batch_inputs; outputs= out_tys, fn= fn_attr )
441
441
442
442
for (i, res) in enumerate (linear_results)
443
443
resv = MLIR. IR. result (call_op, i)
@@ -504,7 +504,15 @@ function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
504
504
505
505
out_tys = [MLIR. IR. type (TracedUtils. get_mlir_data (res)) for res in linear_results]
506
506
fname = TracedUtils. get_attribute_by_name (func2, " sym_name" )
507
- fname = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
507
+ fn_attr = MLIR. IR. FlatSymbolRefAttribute (Base. String (fname))
508
+
509
+ # Specify which outputs to add to the trace.
510
+ traced_output_indices = Int[]
511
+ for (i, res) in enumerate (linear_results)
512
+ if TracedUtils. has_idx (res, resprefix)
513
+ push! (traced_output_indices, i - 1 )
514
+ end
515
+ end
508
516
509
517
batch_inputs = MLIR. IR. Value[]
510
518
for a in linear_args
@@ -524,7 +532,12 @@ function simulate_internal(f::Function, args::Vararg{Any,Nargs}) where {Nargs}
524
532
):: MLIR.IR.Type
525
533
weight_ty = MLIR. IR. TensorType (Int64[], MLIR. IR. Type (Float64))
526
534
simulate_op = MLIR. Dialects. enzyme. simulate (
527
- batch_inputs; trace= trace_ty, weight= weight_ty, outputs= out_tys, fn= fname
535
+ batch_inputs;
536
+ trace= trace_ty,
537
+ weight= weight_ty,
538
+ outputs= out_tys,
539
+ fn= fn_attr,
540
+ traced_output_indices= traced_output_indices,
528
541
)
529
542
530
543
for (i, res) in enumerate (linear_results)
0 commit comments