Skip to content

Commit 7cf358f

Browse files
jumerckxwsmoses
authored andcommitted
TracedToTypes fixes
1 parent b25d1d3 commit 7cf358f

File tree

1 file changed

+41
-4
lines changed

1 file changed

+41
-4
lines changed

src/Tracing.jl

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,10 @@ Base.@nospecializeinfer function traced_type_inner(
381381
}
382382
end
383383
error("Unsupported runtime $runtime")
384-
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
384+
elseif mode == TracedTrack ||
385+
mode == NoStopTracedTrack ||
386+
mode == TracedSetPath ||
387+
mode == TracedToTypes
385388
return T
386389
else
387390
throw("Abstract RArray cannot be made concrete in mode $mode")
@@ -427,7 +430,10 @@ Base.@nospecializeinfer function traced_type_inner(
427430
}
428431
end
429432
error("Unsupported runtime $runtime")
430-
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
433+
elseif mode == TracedTrack ||
434+
mode == NoStopTracedTrack ||
435+
mode == TracedSetPath ||
436+
mode == TracedToTypes
431437
return T
432438
else
433439
throw("Abstract RNumber cannot be made concrete in mode $mode")
@@ -1148,6 +1154,23 @@ function make_tracer(
11481154
)
11491155
end
11501156

1157+
@static if VERSION >= v"1.11.0"
1158+
Base.@nospecializeinfer function make_tracer(
1159+
seen,
1160+
@nospecialize(prev::Memory),
1161+
@nospecialize(path),
1162+
mode;
1163+
@nospecialize(sharding = Sharding.NoSharding()),
1164+
kwargs...,
1165+
)
1166+
if mode == TracedToTypes
1167+
return nothing
1168+
end
1169+
# TODO: does anything more need to be done here?
1170+
return prev
1171+
end
1172+
end
1173+
11511174
Base.@nospecializeinfer function make_tracer(
11521175
seen,
11531176
@nospecialize(prev::ConcretePJRTArray{T,N}),
@@ -1239,7 +1262,14 @@ Base.@nospecializeinfer function make_tracer(
12391262
throw("Cannot trace existing trace type")
12401263
end
12411264
if mode == TracedToTypes
1242-
push!(path, MLIR.IR.type(prev.mlir_data))
1265+
# for TracedRArrays, we check for objectid equality because make_mlir_fn gets rid of duplicate TracedRArrays.
1266+
# i.e. (a, a) should hash differently than (a, b) when a and b are different TracedRArrays.
1267+
if haskey(seen, objectid(prev))
1268+
push!(path, seen[objectid(prev)])
1269+
else
1270+
push!(path, MLIR.IR.type(prev.mlir_data))
1271+
seen[objectid(prev)] = VisitedObject(length(seen) + 1)
1272+
end
12431273
return nothing
12441274
end
12451275
if mode == TracedTrack
@@ -1317,7 +1347,14 @@ Base.@nospecializeinfer function make_tracer(
13171347
throw("Cannot trace existing trace type")
13181348
end
13191349
if mode == TracedToTypes
1320-
push!(path, MLIR.IR.type(prev.mlir_data))
1350+
# for TracedRArrays, we check for objectid equality because make_mlir_fn gets rid of duplicate TracedRArrays.
1351+
# i.e. (a, a) should hash differently than (a, b) when a and b are different TracedRArrays.
1352+
if haskey(seen, objectid(prev))
1353+
push!(path, seen[objectid(prev)])
1354+
else
1355+
push!(path, MLIR.IR.type(prev.mlir_data))
1356+
seen[objectid(prev)] = VisitedObject(length(seen) + 1)
1357+
end
13211358
return nothing
13221359
end
13231360
if mode == TracedTrack

0 commit comments

Comments
 (0)