3939def count_temps_in_call (call_node , local_sym_tab ):
4040 """Count the number of temporary variables needed for a function call."""
4141
42- count = 0
42+ count = {}
4343 is_helper = False
4444
4545 # NOTE: We exclude print calls for now
@@ -49,21 +49,28 @@ def count_temps_in_call(call_node, local_sym_tab):
4949 and call_node .func .id != "print"
5050 ):
5151 is_helper = True
52+ func_name = call_node .func .id
5253 elif isinstance (call_node .func , ast .Attribute ):
5354 if HelperHandlerRegistry .has_handler (call_node .func .attr ):
5455 is_helper = True
56+ func_name = call_node .func .attr
5557
5658 if not is_helper :
57- return 0
59+ return {} # No temps needed
5860
59- for arg in call_node .args :
61+ for arg_idx in range ( len ( call_node .args )) :
6062 # NOTE: Count all non-name arguments
6163 # For struct fields, if it is being passed as an argument,
6264 # The struct object should already exist in the local_sym_tab
63- if not isinstance (arg , ast .Name ) and not (
65+ arg = call_node .args [arg_idx ]
66+ if isinstance (arg , ast .Name ) or (
6467 isinstance (arg , ast .Attribute ) and arg .value .id in local_sym_tab
6568 ):
66- count += 1
69+ continue
70+ param_type = HelperHandlerRegistry .get_param_type (func_name , arg_idx )
71+ if isinstance (param_type , ir .PointerType ):
72+ pointee_type = param_type .pointee
73+ count [pointee_type ] = count .get (pointee_type , 0 ) + 1
6774
6875 return count
6976
@@ -99,11 +106,15 @@ def handle_if_allocation(
99106def allocate_mem (
100107 module , builder , body , func , ret_type , map_sym_tab , local_sym_tab , structs_sym_tab
101108):
102- max_temps_needed = 0
109+ max_temps_needed = {}
110+
111+ def merge_type_counts (count_dict ):
112+ nonlocal max_temps_needed
113+ for typ , cnt in count_dict .items ():
114+ max_temps_needed [typ ] = max (max_temps_needed .get (typ , 0 ), cnt )
103115
104116 def update_max_temps_for_stmt (stmt ):
105117 nonlocal max_temps_needed
106- temps_needed = 0
107118
108119 if isinstance (stmt , ast .If ):
109120 for s in stmt .body :
@@ -112,10 +123,13 @@ def update_max_temps_for_stmt(stmt):
112123 update_max_temps_for_stmt (s )
113124 return
114125
126+ stmt_temps = {}
115127 for node in ast .walk (stmt ):
116128 if isinstance (node , ast .Call ):
117- temps_needed += count_temps_in_call (node , local_sym_tab )
118- max_temps_needed = max (max_temps_needed , temps_needed )
129+ call_temps = count_temps_in_call (node , local_sym_tab )
130+ for typ , cnt in call_temps .items ():
131+ stmt_temps [typ ] = stmt_temps .get (typ , 0 ) + cnt
132+ merge_type_counts (stmt_temps )
119133
120134 for stmt in body :
121135 update_max_temps_for_stmt (stmt )
0 commit comments