Skip to content

Commit 9d24b81

Browse files
committed
[UR] Refactor validation layer refcount feature
The leak checking feature has been expanded to include handles obtained via the ...Get() functions.
1 parent 32e2533 commit 9d24b81

File tree

3 files changed

+126
-83
lines changed

3 files changed

+126
-83
lines changed

scripts/templates/helper.py

Lines changed: 58 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,45 +1486,79 @@ def get_loader_epilogue(specs, namespace, tags, obj, meta):
14861486

14871487
return epilogue
14881488

1489+
1490+
def get_event_wait_list_functions(specs, namespace, tags):
1491+
funcs = []
1492+
for s in specs:
1493+
for obj in s['objects']:
1494+
if re.match(r"function", obj['type']):
1495+
if any(x['name'] == 'phEventWaitList' for x in obj['params']) and any(
1496+
x['name'] == 'numEventsInWaitList' for x in obj['params']):
1497+
funcs.append(make_func_name(namespace, tags, obj))
1498+
return funcs
1499+
1500+
14891501
"""
1490-
Public:
1491-
returns a dictionary with lists of create, retain and release functions
1502+
Private:
1503+
returns a dictionary with lists of create, get, retain and release functions
14921504
"""
1493-
def get_create_retain_release_functions(specs, namespace, tags):
1505+
def _get_create_get_retain_release_functions(specs, namespace, tags):
14941506
funcs = []
14951507
for s in specs:
14961508
for obj in s['objects']:
14971509
if re.match(r"function", obj['type']):
14981510
funcs.append(make_func_name(namespace, tags, obj))
14991511

1500-
create_suffixes = r"(Create[A-Za-z]*){1}"
1501-
retain_suffixes = r"(Retain){1}"
1502-
release_suffixes = r"(Release){1}"
1512+
create_suffixes = r"(Create[A-Za-z]*){1}$"
1513+
get_suffixes = r"(Get){1}$"
1514+
retain_suffixes = r"(Retain){1}$"
1515+
release_suffixes = r"(Release){1}$"
1516+
common_prefix = r"^" + namespace
15031517

1504-
create_exp = namespace + r"([A-Za-z]+)" + create_suffixes
1505-
retain_exp = namespace + r"([A-Za-z]+)" + retain_suffixes
1506-
release_exp = namespace + r"([A-Za-z]+)" + release_suffixes
1518+
create_exp = common_prefix + r"[A-Za-z]+" + create_suffixes
1519+
get_exp = common_prefix + r"[A-Za-z]+" + get_suffixes
1520+
retain_exp = common_prefix + r"[A-Za-z]+" + retain_suffixes
1521+
release_exp = common_prefix + r"[A-Za-z]+" + release_suffixes
15071522

1508-
create_funcs, retain_funcs, release_funcs = (
1523+
create_funcs, get_funcs, retain_funcs, release_funcs = (
15091524
list(filter(lambda f: re.match(create_exp, f), funcs)),
1525+
list(filter(lambda f: re.match(get_exp, f), funcs)),
15101526
list(filter(lambda f: re.match(retain_exp, f), funcs)),
15111527
list(filter(lambda f: re.match(release_exp, f), funcs)),
15121528
)
15131529

1514-
create_funcs, retain_funcs = (
1515-
list(filter(lambda f: re.sub(create_suffixes, "Release", f) in release_funcs, create_funcs)),
1516-
list(filter(lambda f: re.sub(retain_suffixes, "Release", f) in release_funcs, retain_funcs)),
1517-
)
1530+
return {"create": create_funcs, "get": get_funcs, "retain": retain_funcs, "release": release_funcs}
15181531

1519-
return {"create": create_funcs, "retain": retain_funcs, "release": release_funcs}
15201532

1533+
"""
1534+
Public:
1535+
returns a list of dictionaries containing handle types and the corresponding create, get, retain and release functions
1536+
"""
1537+
def get_handle_create_get_retain_release_functions(specs, namespace, tags):
1538+
# Handles without release function
1539+
excluded_handles = ["$x_platform_handle_t", "$x_native_handle_t"]
1540+
# Handles from experimental features
1541+
exp_prefix = "$x_exp"
1542+
1543+
funcs = _get_create_get_retain_release_functions(specs, namespace, tags)
1544+
records = []
1545+
for h in get_adapter_handles(specs):
1546+
if h['name'] in excluded_handles or h['name'].startswith(exp_prefix):
1547+
continue
15211548

1522-
def get_event_wait_list_functions(specs, namespace, tags):
1523-
funcs = []
1524-
for s in specs:
1525-
for obj in s['objects']:
1526-
if re.match(r"function", obj['type']):
1527-
if any(x['name'] == 'phEventWaitList' for x in obj['params']) and any(
1528-
x['name'] == 'numEventsInWaitList' for x in obj['params']):
1529-
funcs.append(make_func_name(namespace, tags, obj))
1530-
return funcs
1549+
class_type = subt(namespace, tags, h['class'])
1550+
create_funcs = list(filter(lambda f: class_type in f, funcs['create']))
1551+
get_funcs = list(filter(lambda f: class_type in f, funcs['get']))
1552+
retain_funcs = list(filter(lambda f: class_type in f, funcs['retain']))
1553+
release_funcs = list(filter(lambda f: class_type in f, funcs['release']))
1554+
1555+
record = {}
1556+
record['handle'] = subt(namespace, tags, h['name'])
1557+
record['create'] = create_funcs
1558+
record['get'] = get_funcs
1559+
record['retain'] = retain_funcs
1560+
record['release'] = release_funcs
1561+
1562+
records.append(record)
1563+
1564+
return records

scripts/templates/valddi.cpp.mako

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ from templates import helper as th
77
88
x=tags['$x']
99
X=x.upper()
10-
create_retain_release_funcs=th.get_create_retain_release_functions(specs, n, tags)
10+
11+
handle_create_get_retain_release_funcs=th.get_handle_create_get_retain_release_functions(specs, n, tags)
1112
%>/*
1213
*
1314
* Copyright (C) 2023 Intel Corporation
@@ -27,11 +28,12 @@ namespace ur_validation_layer
2728
%for obj in th.get_adapter_functions(specs):
2829
<%
2930
func_name=th.make_func_name(n, tags, obj)
30-
object_param=th.make_param_lines(n, tags, obj, format=["name"])[-1]
31-
object_param_type=th.make_param_lines(n, tags, obj, format=["type"])[-1]
31+
3232
param_checks=th.make_param_checks(n, tags, obj, meta=meta).items()
3333
first_errors = [X + "_RESULT_ERROR_INVALID_NULL_POINTER", X + "_RESULT_ERROR_INVALID_NULL_HANDLE"]
3434
sorted_param_checks = sorted(param_checks, key=lambda pair: False if pair[0] in first_errors else True)
35+
36+
tracked_params = list(filter(lambda p: any(th.subt(n, tags, p['type']) in [hf['handle'], hf['handle'] + "*"] for hf in handle_create_get_retain_release_funcs), obj['params']))
3537
%>
3638
///////////////////////////////////////////////////////////////////////////////
3739
/// @brief Intercept function for ${th.make_func_name(n, tags, obj)}
@@ -74,37 +76,35 @@ namespace ur_validation_layer
7476

7577
${x}_result_t result = ${th.make_pfn_name(n, tags, obj)}( ${", ".join(th.make_param_lines(n, tags, obj, format=["name"]))} );
7678

77-
%if func_name == n + "AdapterRelease":
78-
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
79-
{
80-
refCountContext.decrementRefCount(${object_param}, true);
81-
}
82-
%elif func_name == n + "AdapterRetain":
79+
%for tp in tracked_params:
80+
<%
81+
tp_handle_funcs = next((hf for hf in handle_create_get_retain_release_funcs if th.subt(n, tags, tp['type']) in [hf['handle'], hf['handle'] + "*"]), None)
82+
is_handle_to_adapter = ("_adapter_handle_t" in tp['type'])
83+
%>
84+
%if func_name in tp_handle_funcs['create']:
8385
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
8486
{
85-
refCountContext.incrementRefCount(${object_param}, true);
87+
refCountContext.createRefCount(*${tp['name']});
8688
}
87-
%elif func_name == n + "AdapterGet":
88-
if( context.enableLeakChecking && phAdapters && result == UR_RESULT_SUCCESS )
89+
%elif func_name in tp_handle_funcs['get']:
90+
if( context.enableLeakChecking && ${tp['name']} && result == UR_RESULT_SUCCESS )
8991
{
90-
refCountContext.createOrIncrementRefCount(*phAdapters, true);
91-
}
92-
%elif func_name in create_retain_release_funcs["create"]:
93-
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
94-
{
95-
refCountContext.createRefCount(*${object_param});
92+
for (uint32_t i = ${th.param_traits.range_start(tp)}; i < ${th.param_traits.range_end(tp)}; i++) {
93+
refCountContext.createOrIncrementRefCount(${tp['name']}[i], ${str(is_handle_to_adapter).lower()});
94+
}
9695
}
97-
%elif func_name in create_retain_release_funcs["retain"]:
96+
%elif func_name in tp_handle_funcs['retain']:
9897
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
9998
{
100-
refCountContext.incrementRefCount(${object_param});
99+
refCountContext.incrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
101100
}
102-
%elif func_name in create_retain_release_funcs["release"]:
101+
%elif func_name in tp_handle_funcs['release']:
103102
if( context.enableLeakChecking && result == UR_RESULT_SUCCESS )
104103
{
105-
refCountContext.decrementRefCount(${object_param});
104+
refCountContext.decrementRefCount(${tp['name']}, ${str(is_handle_to_adapter).lower()});
106105
}
107106
%endif
107+
%endfor
108108

109109
return result;
110110
}

0 commit comments

Comments
 (0)