@@ -38,30 +38,35 @@ namespace validation_layer
3838 }
3939 }
4040
41- // The format of this table is such that each row accounts for a
42- // specific type of objects, and all elements in the raw except the last
43- // one are allocating objects of that type, while the last element is known
44- // to deallocate objects of that type.
45- //
46- static std::vector<std::vector<std::string> > createDestroySet () {
41+ struct ctorsAndDtors {
42+ std::vector<std::string> ctors;
43+ std::vector<std::string> dtors;
44+ };
45+
46+ static std::vector<ctorsAndDtors > createDestroySet () {
4747 return {
48- {" zeContextCreate" , " zeContextDestroy" },
49- {" zeCommandQueueCreate" , " zeCommandQueueDestroy" },
50- {" zeModuleCreate" , " zeModuleDestroy" },
51- {" zeKernelCreate" , " zeKernelDestroy" },
52- {" zeEventPoolCreate" , " zeEventPoolDestroy" },
53- {" zeCommandListCreateImmediate" , " zeCommandListCreate" , " zeCommandListDestroy" },
54- {" zeEventCreate" , " zeEventDestroy" },
55- {" zeFenceCreate" , " zeFenceDestroy" },
56- {" zeImageCreate" , " zeImageViewCreateExt" , " zeImageDestroy" },
57- {" zeSamplerCreate" , " zeSamplerDestroy" },
58- {" zeMemAllocDevice" , " zeMemAllocHost" , " zeMemAllocShared" , " zeMemFree" }};
48+ {{" zeContextCreate" }, {" zeContextDestroy" }},
49+ {{" zeCommandQueueCreate" }, {" zeCommandQueueDestroy" }},
50+ {{" zeModuleCreate" }, {" zeModuleDestroy" }},
51+ {{" zeKernelCreate" }, {" zeKernelDestroy" }},
52+ {{" zeEventPoolCreate" }, {" zeEventPoolDestroy" }},
53+ {{" zeCommandListCreateImmediate" , " zeCommandListCreate" }, {" zeCommandListDestroy" }},
54+ {{" zeEventCreate" }, {" zeEventDestroy" }},
55+ {{" zeFenceCreate" }, {" zeFenceDestroy" }},
56+ {{" zeImageCreate" , " zeImageViewCreateExt" }, {" zeImageDestroy" }},
57+ {{" zeSamplerCreate" }, {" zeSamplerDestroy" }},
58+ {{" zeMemAllocDevice" , " zeMemAllocHost" , " zeMemAllocShared" }, {" zeMemFree" , " zeMemFreeExt" }}
59+ };
5960 }
6061
6162 basic_leakChecker::ZEbasic_leakChecker::ZEbasic_leakChecker () {
6263 // initialize counts for all functions that should be tracked
63- for (const auto &row : createDestroySet ()) {
64- for (const auto &name : row) {
64+ auto set = createDestroySet ();
65+ for (const auto &s : set) {
66+ for (auto &name : s.ctors ) {
67+ counts[name] = 0 ;
68+ }
69+ for (auto &name : s.dtors ) {
6570 counts[name] = 0 ;
6671 }
6772 }
@@ -249,6 +254,13 @@ namespace validation_layer
249254 return result;
250255 }
251256
257+ ze_result_t basic_leakChecker::ZEbasic_leakChecker::zeMemFreeExtEpilogue (ze_context_handle_t , const ze_memory_free_ext_desc_t *, void *, ze_result_t result) {
258+ if (result == ZE_RESULT_SUCCESS) {
259+ countFunctionCall (" zeMemFreeExt" );
260+ }
261+ return result;
262+ }
263+
252264 void basic_leakChecker::ZEbasic_leakChecker::countFunctionCall (const std::string &functionName)
253265 {
254266 auto it = counts.find (functionName);
@@ -265,40 +277,47 @@ namespace validation_layer
265277 basic_leakChecker::ZEbasic_leakChecker::~ZEbasic_leakChecker () {
266278 std::cerr << " Check balance of create/destroy calls\n " ;
267279 std::cerr << " ----------------------------------------------------------\n " ;
268- std::stringstream ss;
269- for (const auto &Row : createDestroySet ()) {
280+ auto set = createDestroySet ();
281+ for (const auto &s : set) {
282+ auto &ctors = s.ctors ;
283+ auto &dtors = s.dtors ;
270284 int64_t diff = 0 ;
271- for (auto I = Row.begin (); I != Row.end ();) {
272- const char *ZeName = (*I).c_str ();
273- const auto &ZeCount = (counts)[*I];
274-
275- bool First = (I == Row.begin ());
276- bool Last = (++I == Row.end ());
277-
278- if (Last) {
279- ss << " \\ --->" ;
280- diff -= ZeCount;
281- } else {
282- diff += ZeCount;
283- if (!First) {
284- ss << " | " ;
285- std::cerr << ss.str () << " \n " ;
286- ss.str (" " );
287- ss.clear ();
288- }
285+ for (size_t i = 0 ; i < ctors.size (); i++) {
286+ auto name = ctors[i];
287+ auto zeCount = counts[name].load ();
288+ diff += zeCount;
289+
290+ if (i > 0 ) {
291+ std::cerr << " |\n " ;
292+ }
293+
294+ std::cerr << std::setw (30 ) << std::right << name;
295+ std::cerr << " = " ;
296+ std::cerr << std::setw (5 ) << std::left << zeCount;
297+ }
298+
299+ std::cerr << " \\ --->" ;
300+
301+ for (size_t i = 0 ; i < dtors.size (); i++) {
302+ auto name = dtors[i];
303+ auto zeCount = counts[name].load ();
304+ diff -= zeCount;
305+
306+ if (i > 0 ) {
307+ std::cerr << " \n " ;
308+ std::cerr << std::setw (44 ) << std::right << " \\ --->" ;
289309 }
290- ss << std::setw (30 ) << std::right << ZeName;
291- ss << " = " ;
292- ss << std::setw (5 ) << std::left << ZeCount;
310+
311+ std::cerr << std::setw (30 ) << std::right << name;
312+ std::cerr << " = " ;
313+ std::cerr << std::setw (5 ) << std::left << zeCount;
293314 }
294315
295316 if (diff) {
296- ss << " ---> LEAK = " << diff;
317+ std::cerr << " ---> LEAK = " << diff;
297318 }
298319
299- std::cerr << ss.str () << ' \n ' ;
300- ss.str (" " );
301- ss.clear ();
320+ std::cerr << std::endl;
302321 }
303322 }
304323}
0 commit comments