@@ -381,7 +381,10 @@ Base.@nospecializeinfer function traced_type_inner(
381
381
}
382
382
end
383
383
error (" Unsupported runtime $runtime " )
384
- elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
384
+ elseif mode == TracedTrack ||
385
+ mode == NoStopTracedTrack ||
386
+ mode == TracedSetPath ||
387
+ mode == TracedToTypes
385
388
return T
386
389
else
387
390
throw (" Abstract RArray cannot be made concrete in mode $mode " )
@@ -427,7 +430,10 @@ Base.@nospecializeinfer function traced_type_inner(
427
430
}
428
431
end
429
432
error (" Unsupported runtime $runtime " )
430
- elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
433
+ elseif mode == TracedTrack ||
434
+ mode == NoStopTracedTrack ||
435
+ mode == TracedSetPath ||
436
+ mode == TracedToTypes
431
437
return T
432
438
else
433
439
throw (" Abstract RNumber cannot be made concrete in mode $mode " )
@@ -1148,6 +1154,23 @@ function make_tracer(
1148
1154
)
1149
1155
end
1150
1156
1157
+ @static if VERSION >= v " 1.11.0"
1158
+ Base. @nospecializeinfer function make_tracer (
1159
+ seen,
1160
+ @nospecialize (prev:: Memory ),
1161
+ @nospecialize (path),
1162
+ mode;
1163
+ @nospecialize (sharding = Sharding. NoSharding ()),
1164
+ kwargs... ,
1165
+ )
1166
+ if mode == TracedToTypes
1167
+ return nothing
1168
+ end
1169
+ # TODO : does anything more need to be done here?
1170
+ return prev
1171
+ end
1172
+ end
1173
+
1151
1174
Base. @nospecializeinfer function make_tracer (
1152
1175
seen,
1153
1176
@nospecialize (prev:: ConcretePJRTArray{T,N} ),
@@ -1239,7 +1262,14 @@ Base.@nospecializeinfer function make_tracer(
1239
1262
throw (" Cannot trace existing trace type" )
1240
1263
end
1241
1264
if mode == TracedToTypes
1242
- push! (path, MLIR. IR. type (prev. mlir_data))
1265
+ # for TracedRArrays, we check for objectid equality because make_mlir_fn gets rid of duplicate TracedRArrays.
1266
+ # i.e. (a, a) should hash differently than (a, b) when a and b are different TracedRArrays.
1267
+ if haskey (seen, objectid (prev))
1268
+ push! (path, seen[objectid (prev)])
1269
+ else
1270
+ push! (path, MLIR. IR. type (prev. mlir_data))
1271
+ seen[objectid (prev)] = VisitedObject (length (seen) + 1 )
1272
+ end
1243
1273
return nothing
1244
1274
end
1245
1275
if mode == TracedTrack
@@ -1317,7 +1347,14 @@ Base.@nospecializeinfer function make_tracer(
1317
1347
throw (" Cannot trace existing trace type" )
1318
1348
end
1319
1349
if mode == TracedToTypes
1320
- push! (path, MLIR. IR. type (prev. mlir_data))
1350
+ # for TracedRArrays, we check for objectid equality because make_mlir_fn gets rid of duplicate TracedRArrays.
1351
+ # i.e. (a, a) should hash differently than (a, b) when a and b are different TracedRArrays.
1352
+ if haskey (seen, objectid (prev))
1353
+ push! (path, seen[objectid (prev)])
1354
+ else
1355
+ push! (path, MLIR. IR. type (prev. mlir_data))
1356
+ seen[objectid (prev)] = VisitedObject (length (seen) + 1 )
1357
+ end
1321
1358
return nothing
1322
1359
end
1323
1360
if mode == TracedTrack
0 commit comments