Skip to content

Commit d0a770a

Browse files
committed
TracedToTypes fixes
1 parent 1113a92 commit d0a770a

File tree

1 file changed

+41
-4
lines changed

1 file changed

+41
-4
lines changed

src/Tracing.jl

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -398,7 +398,10 @@ Base.@nospecializeinfer function traced_type_inner(
398398
}
399399
end
400400
error("Unsupported runtime $runtime")
401-
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
401+
elseif mode == TracedTrack ||
402+
mode == NoStopTracedTrack ||
403+
mode == TracedSetPath ||
404+
mode == TracedToTypes
402405
return T
403406
else
404407
throw("Abstract RArray cannot be made concrete in mode $mode")
@@ -444,7 +447,10 @@ Base.@nospecializeinfer function traced_type_inner(
444447
}
445448
end
446449
error("Unsupported runtime $runtime")
447-
elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
450+
elseif mode == TracedTrack ||
451+
mode == NoStopTracedTrack ||
452+
mode == TracedSetPath ||
453+
mode == TracedToTypes
448454
return T
449455
else
450456
throw("Abstract RNumber cannot be made concrete in mode $mode")
@@ -1188,6 +1194,23 @@ function make_tracer(
11881194
)
11891195
end
11901196

1197+
@static if VERSION >= v"1.11.0"
1198+
Base.@nospecializeinfer function make_tracer(
1199+
seen,
1200+
@nospecialize(prev::Memory),
1201+
@nospecialize(path),
1202+
mode;
1203+
@nospecialize(sharding = Sharding.NoSharding()),
1204+
kwargs...,
1205+
)
1206+
if mode == TracedToTypes
1207+
return nothing
1208+
end
1209+
# TODO: does anything more need to be done here?
1210+
return prev
1211+
end
1212+
end
1213+
11911214
Base.@nospecializeinfer function make_tracer(
11921215
seen,
11931216
@nospecialize(prev::ConcretePJRTArray{T,N}),
@@ -1279,7 +1302,14 @@ Base.@nospecializeinfer function make_tracer(
12791302
throw("Cannot trace existing trace type")
12801303
end
12811304
if mode == TracedToTypes
1282-
push!(path, MLIR.IR.type(prev.mlir_data))
1305+
# for TracedRArrays, we check for objectid equality because make_mlir_fn gets rid of duplicate TracedRArrays.
1306+
# i.e. (a, a) should hash differently than (a, b) when a and b are different TracedRArrays.
1307+
if haskey(seen, objectid(prev))
1308+
push!(path, seen[objectid(prev)])
1309+
else
1310+
push!(path, MLIR.IR.type(prev.mlir_data))
1311+
seen[objectid(prev)] = VisitedObject(length(seen) + 1)
1312+
end
12831313
return nothing
12841314
end
12851315
if mode == TracedTrack
@@ -1357,7 +1387,14 @@ Base.@nospecializeinfer function make_tracer(
13571387
throw("Cannot trace existing trace type")
13581388
end
13591389
if mode == TracedToTypes
1360-
push!(path, MLIR.IR.type(prev.mlir_data))
1390+
# for TracedRArrays, we check for objectid equality because make_mlir_fn gets rid of duplicate TracedRArrays.
1391+
# i.e. (a, a) should hash differently than (a, b) when a and b are different TracedRArrays.
1392+
if haskey(seen, objectid(prev))
1393+
push!(path, seen[objectid(prev)])
1394+
else
1395+
push!(path, MLIR.IR.type(prev.mlir_data))
1396+
seen[objectid(prev)] = VisitedObject(length(seen) + 1)
1397+
end
13611398
return nothing
13621399
end
13631400
if mode == TracedTrack

0 commit comments

Comments
 (0)