diff --git a/src/Tracing.jl b/src/Tracing.jl index aa39f08141..de89fe9dae 100644 --- a/src/Tracing.jl +++ b/src/Tracing.jl @@ -381,7 +381,10 @@ Base.@nospecializeinfer function traced_type_inner( } end error("Unsupported runtime $runtime") - elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath + elseif mode == TracedTrack || + mode == NoStopTracedTrack || + mode == TracedSetPath || + mode == TracedToTypes return T else throw("Abstract RArray cannot be made concrete in mode $mode") @@ -427,7 +430,10 @@ Base.@nospecializeinfer function traced_type_inner( } end error("Unsupported runtime $runtime") - elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath + elseif mode == TracedTrack || + mode == NoStopTracedTrack || + mode == TracedSetPath || + mode == TracedToTypes return T else throw("Abstract RNumber cannot be made concrete in mode $mode") @@ -1239,7 +1245,14 @@ Base.@nospecializeinfer function make_tracer( throw("Cannot trace existing trace type") end if mode == TracedToTypes - push!(path, MLIR.IR.type(prev.mlir_data)) + # for TracedRArrays, we check for objectid equality because make_mlir_fn gets rid of duplicate TracedRArrays. + # i.e. (a, a) should hash differently than (a, b) when a and b are different TracedRArrays. + if haskey(seen, objectid(prev)) + push!(path, seen[objectid(prev)]) + else + push!(path, MLIR.IR.type(prev.mlir_data)) + seen[objectid(prev)] = VisitedObject(length(seen) + 1) + end return nothing end if mode == TracedTrack @@ -1317,7 +1330,14 @@ Base.@nospecializeinfer function make_tracer( throw("Cannot trace existing trace type") end if mode == TracedToTypes - push!(path, MLIR.IR.type(prev.mlir_data)) + # for TracedRArrays, we check for objectid equality because make_mlir_fn gets rid of duplicate TracedRArrays. + # i.e. (a, a) should hash differently than (a, b) when a and b are different TracedRArrays. + if haskey(seen, objectid(prev)) + push!(path, seen[objectid(prev)]) + else + push!(path, MLIR.IR.type(prev.mlir_data)) + seen[objectid(prev)] = VisitedObject(length(seen) + 1) + end return nothing end if mode == TracedTrack