Skip to content

Commit 752f564

Browse files
committed
Change count_temps_in_call to return hashmap of types
1 parent d8cddb9 commit 752f564

File tree

1 file changed

+9
-4
lines changed

1 file changed

+9
-4
lines changed

pythonbpf/functions/functions_pass.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
def count_temps_in_call(call_node, local_sym_tab):
3434
"""Count the number of temporary variables needed for a function call."""
3535

36-
count = 0
36+
count = {}
3737
is_helper = False
3838

3939
# NOTE: We exclude print calls for now
@@ -43,21 +43,26 @@ def count_temps_in_call(call_node, local_sym_tab):
4343
and call_node.func.id != "print"
4444
):
4545
is_helper = True
46+
func_name = call_node.func.id
4647
elif isinstance(call_node.func, ast.Attribute):
4748
if HelperHandlerRegistry.has_handler(call_node.func.attr):
4849
is_helper = True
50+
func_name = call_node.func.attr
4951

5052
if not is_helper:
5153
return 0
5254

53-
for arg in call_node.args:
55+
for arg_idx in range(len(call_node.args)):
5456
# NOTE: Count all non-name arguments
5557
# For struct fields, if it is being passed as an argument,
5658
# The struct object should already exist in the local_sym_tab
57-
if not isinstance(arg, ast.Name) and not (
59+
arg = call_node.args[arg_idx]
60+
if isinstance(arg, ast.Name) or (
5861
isinstance(arg, ast.Attribute) and arg.value.id in local_sym_tab
5962
):
60-
count += 1
63+
continue
64+
param_type = HelperHandlerRegistry.get_param_type(func_name, arg_idx)
65+
count[param_type] = count.get(param_type, 0) + 1
6166

6267
return count
6368

0 commit comments

Comments
 (0)