@@ -398,7 +398,10 @@ Base.@nospecializeinfer function traced_type_inner(
398
398
}
399
399
end
400
400
error (" Unsupported runtime $runtime " )
401
- elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
401
+ elseif mode == TracedTrack ||
402
+ mode == NoStopTracedTrack ||
403
+ mode == TracedSetPath ||
404
+ mode == TracedToTypes
402
405
return T
403
406
else
404
407
throw (" Abstract RArray cannot be made concrete in mode $mode " )
@@ -444,7 +447,10 @@ Base.@nospecializeinfer function traced_type_inner(
444
447
}
445
448
end
446
449
error (" Unsupported runtime $runtime " )
447
- elseif mode == TracedTrack || mode == NoStopTracedTrack || mode == TracedSetPath
450
+ elseif mode == TracedTrack ||
451
+ mode == NoStopTracedTrack ||
452
+ mode == TracedSetPath ||
453
+ mode == TracedToTypes
448
454
return T
449
455
else
450
456
throw (" Abstract RNumber cannot be made concrete in mode $mode " )
@@ -1188,6 +1194,23 @@ function make_tracer(
1188
1194
)
1189
1195
end
1190
1196
1197
+ @static if VERSION >= v " 1.11.0"
1198
+ Base. @nospecializeinfer function make_tracer (
1199
+ seen,
1200
+ @nospecialize (prev:: Memory ),
1201
+ @nospecialize (path),
1202
+ mode;
1203
+ @nospecialize (sharding = Sharding. NoSharding ()),
1204
+ kwargs... ,
1205
+ )
1206
+ if mode == TracedToTypes
1207
+ return nothing
1208
+ end
1209
+ # TODO : does anything more need to be done here?
1210
+ return prev
1211
+ end
1212
+ end
1213
+
1191
1214
Base. @nospecializeinfer function make_tracer (
1192
1215
seen,
1193
1216
@nospecialize (prev:: ConcretePJRTArray{T,N} ),
@@ -1279,7 +1302,14 @@ Base.@nospecializeinfer function make_tracer(
1279
1302
throw (" Cannot trace existing trace type" )
1280
1303
end
1281
1304
if mode == TracedToTypes
1282
- push! (path, MLIR. IR. type (prev. mlir_data))
1305
+ # for TracedRArrays, we check for objectid equality because make_mlir_fn gets rid of duplicate TracedRArrays.
1306
+ # i.e. (a, a) should hash differently than (a, b) when a and b are different TracedRArrays.
1307
+ if haskey (seen, objectid (prev))
1308
+ push! (path, seen[objectid (prev)])
1309
+ else
1310
+ push! (path, MLIR. IR. type (prev. mlir_data))
1311
+ seen[objectid (prev)] = VisitedObject (length (seen) + 1 )
1312
+ end
1283
1313
return nothing
1284
1314
end
1285
1315
if mode == TracedTrack
@@ -1357,7 +1387,14 @@ Base.@nospecializeinfer function make_tracer(
1357
1387
throw (" Cannot trace existing trace type" )
1358
1388
end
1359
1389
if mode == TracedToTypes
1360
- push! (path, MLIR. IR. type (prev. mlir_data))
1390
+ # for TracedRArrays, we check for objectid equality because make_mlir_fn gets rid of duplicate TracedRArrays.
1391
+ # i.e. (a, a) should hash differently than (a, b) when a and b are different TracedRArrays.
1392
+ if haskey (seen, objectid (prev))
1393
+ push! (path, seen[objectid (prev)])
1394
+ else
1395
+ push! (path, MLIR. IR. type (prev. mlir_data))
1396
+ seen[objectid (prev)] = VisitedObject (length (seen) + 1 )
1397
+ end
1361
1398
return nothing
1362
1399
end
1363
1400
if mode == TracedTrack
0 commit comments