99#include " ur_validation_layer.hpp"
1010
1111#include < mutex>
12+ #include < typeindex>
1213#include < unordered_map>
1314#include < utility>
1415
@@ -20,7 +21,12 @@ struct RefCountContext {
2021 private:
2122 struct RefRuntimeInfo {
2223 int64_t refCount;
24+ std::type_index type;
2325 std::vector<BacktraceLine> backtrace;
26+
27+ RefRuntimeInfo (int64_t refCount, std::type_index type,
28+ std::vector<BacktraceLine> backtrace)
29+ : refCount(refCount), type(type), backtrace(backtrace) {}
2430 };
2531
2632 enum RefCountUpdateType {
@@ -34,26 +40,32 @@ struct RefCountContext {
3440 std::unordered_map<void *, struct RefRuntimeInfo > counts;
3541 int64_t adapterCount = 0 ;
3642
37- void updateRefCount (void *ptr, enum RefCountUpdateType type,
43+ template <typename T>
44+ void updateRefCount (T handle, enum RefCountUpdateType type,
3845 bool isAdapterHandle = false ) {
3946 std::unique_lock<std::mutex> ulock (mutex);
4047
48+ void *ptr = static_cast <void *>(handle);
4149 auto it = counts.find (ptr);
4250
4351 switch (type) {
4452 case REFCOUNT_CREATE_OR_INCREASE:
4553 if (it == counts.end ()) {
46- counts[ptr] = {1 , getCurrentBacktrace ()};
54+ std::tie (it, std::ignore) = counts.emplace (
55+ ptr, RefRuntimeInfo{1 , std::type_index (typeid (handle)),
56+ getCurrentBacktrace ()});
4757 if (isAdapterHandle) {
4858 adapterCount++;
4959 }
5060 } else {
51- counts[ptr] .refCount ++;
61+ it-> second .refCount ++;
5262 }
5363 break ;
5464 case REFCOUNT_CREATE:
5565 if (it == counts.end ()) {
56- counts[ptr] = {1 , getCurrentBacktrace ()};
66+ std::tie (it, std::ignore) = counts.emplace (
67+ ptr, RefRuntimeInfo{1 , std::type_index (typeid (handle)),
68+ getCurrentBacktrace ()});
5769 } else {
5870 context.logger .error (" Handle {} already exists" , ptr);
5971 return ;
@@ -65,29 +77,31 @@ struct RefCountContext {
6577 " Attempting to retain nonexistent handle {}" , ptr);
6678 return ;
6779 } else {
68- counts[ptr] .refCount ++;
80+ it-> second .refCount ++;
6981 }
7082 break ;
7183 case REFCOUNT_DECREASE:
7284 if (it == counts.end ()) {
73- counts[ptr] = {-1 , getCurrentBacktrace ()};
85+ std::tie (it, std::ignore) = counts.emplace (
86+ ptr, RefRuntimeInfo{-1 , std::type_index (typeid (handle)),
87+ getCurrentBacktrace ()});
7488 } else {
75- counts[ptr] .refCount --;
89+ it-> second .refCount --;
7690 }
7791
78- if (counts[ptr] .refCount < 0 ) {
92+ if (it-> second .refCount < 0 ) {
7993 context.logger .error (
8094 " Attempting to release nonexistent handle {}" , ptr);
81- } else if (counts[ptr] .refCount == 0 && isAdapterHandle) {
95+ } else if (it-> second .refCount == 0 && isAdapterHandle) {
8296 adapterCount--;
8397 }
8498 break ;
8599 }
86100
87101 context.logger .debug (" Reference count for handle {} changed to {}" , ptr,
88- counts[ptr] .refCount );
102+ it-> second .refCount );
89103
90- if (counts[ptr] .refCount == 0 ) {
104+ if (it-> second .refCount == 0 ) {
91105 counts.erase (ptr);
92106 }
93107
@@ -99,23 +113,35 @@ struct RefCountContext {
99113 }
100114
101115 public:
102- void createRefCount (void *ptr) { updateRefCount (ptr, REFCOUNT_CREATE); }
116+ template <typename T> void createRefCount (T handle) {
117+ updateRefCount<T>(handle, REFCOUNT_CREATE);
118+ }
103119
104- void incrementRefCount (void *ptr, bool isAdapterHandle = false ) {
105- updateRefCount (ptr, REFCOUNT_INCREASE, isAdapterHandle);
120+ template <typename T>
121+ void incrementRefCount (T handle, bool isAdapterHandle = false ) {
122+ updateRefCount (handle, REFCOUNT_INCREASE, isAdapterHandle);
106123 }
107124
108- void decrementRefCount (void *ptr, bool isAdapterHandle = false ) {
109- updateRefCount (ptr, REFCOUNT_DECREASE, isAdapterHandle);
125+ template <typename T>
126+ void decrementRefCount (T handle, bool isAdapterHandle = false ) {
127+ updateRefCount (handle, REFCOUNT_DECREASE, isAdapterHandle);
110128 }
111129
112- void createOrIncrementRefCount (void *ptr, bool isAdapterHandle = false ) {
113- updateRefCount (ptr, REFCOUNT_CREATE_OR_INCREASE, isAdapterHandle);
130+ template <typename T>
131+ void createOrIncrementRefCount (T handle, bool isAdapterHandle = false ) {
132+ updateRefCount (handle, REFCOUNT_CREATE_OR_INCREASE, isAdapterHandle);
114133 }
115134
116135 void clear () { counts.clear (); }
117136
118- bool isReferenceValid (void *ptr) { return counts.count (ptr) > 0 ; }
137+ template <typename T> bool isReferenceValid (T handle) {
138+ auto it = counts.find (static_cast <void *>(handle));
139+ if (it == counts.end () || it->second .refCount < 1 ) {
140+ return false ;
141+ }
142+
143+ return (it->second .type == std::type_index (typeid (handle)));
144+ }
119145
120146 void logInvalidReferences () {
121147 for (auto &[ptr, refRuntimeInfo] : counts) {
0 commit comments