44
55from m2cgen import ast
66from m2cgen .interpreters .interpreter import BaseToCodeInterpreter
7+ from m2cgen .interpreters .utils import chunks
78
89
910class BinExpressionDepthTrackingMixin (BaseToCodeInterpreter ):
@@ -90,7 +91,7 @@ def interpret_bin_vector_num_expr(self, expr, extra_func_args=(),
9091 * extra_func_args )
9192
9293
93- Subroutine = namedtuple ('Subroutine' , ['name' , 'expr' ])
94+ Subroutine = namedtuple ('Subroutine' , ['name' , 'idx' , ' expr' ])
9495
9596
9697class SubroutinesMixin (BaseToCodeInterpreter ):
@@ -103,6 +104,8 @@ class SubroutinesMixin(BaseToCodeInterpreter):
103104 Their code generators should implement 3 methods:
104105 - function_definition;
105106 - function_invocation;
107+ - module_definition;
108+ - module_function_invocation;
106109 - add_return_statement.
107110
108111 Interpreter should prepare at least one subroutine using method
@@ -113,6 +116,7 @@ class SubroutinesMixin(BaseToCodeInterpreter):
113116 # disabled by default
114117 ast_size_check_frequency = sys .maxsize
115118 ast_size_per_subroutine_threshold = sys .maxsize
119+ subroutine_per_group_threshold = sys .maxsize
116120
117121 def __init__ (self , * args , ** kwargs ):
118122 self ._subroutine_idx = 0
@@ -125,15 +129,33 @@ def process_subroutine_queue(self, top_code_generator):
125129 subroutine queue.
126130 """
127131 self ._subroutine_idx = 0
132+ subroutines = []
128133
129- while len ( self .subroutine_expr_queue ) :
134+ while self .subroutine_expr_queue :
130135 self ._reset_reused_expr_cache ()
131136 subroutine = self .subroutine_expr_queue .pop (0 )
132137 subroutine_code = self ._process_subroutine (subroutine )
138+ subroutines .append ((subroutine , subroutine_code ))
139+
140+ subroutines .sort (key = lambda subroutine : subroutine [0 ].idx )
141+
142+ groups = chunks (subroutines , self .subroutine_per_group_threshold )
143+ for _ , subroutine_code in next (groups ):
133144 top_code_generator .add_code_lines (subroutine_code )
134145
135- def enqueue_subroutine (self , name , expr ):
136- self .subroutine_expr_queue .append (Subroutine (name , expr ))
146+ for index , subroutine_group in enumerate (groups ):
147+ cg = self .create_code_generator ()
148+
149+ with cg .module_definition (
150+ module_name = self ._format_group_name (index + 1 )):
151+ for _ , subroutine_code in subroutine_group :
152+ cg .add_code_lines (subroutine_code )
153+
154+ top_code_generator .add_code_lines (
155+ cg .finalize_and_get_generated_code ())
156+
157+ def enqueue_subroutine (self , name , idx , expr ):
158+ self .subroutine_expr_queue .append (Subroutine (name , idx , expr ))
137159
138160 def _pre_interpret_hook (self , expr , ast_size_check_counter = 0 , ** kwargs ):
139161 if isinstance (expr , ast .BinExpr ) and not expr .to_reuse :
@@ -147,7 +169,18 @@ def _pre_interpret_hook(self, expr, ast_size_check_counter=0, **kwargs):
147169 ast_size = ast .count_exprs (expr )
148170 if ast_size > self .ast_size_per_subroutine_threshold :
149171 function_name = self ._get_subroutine_name ()
150- self .enqueue_subroutine (function_name , expr )
172+
173+ self .enqueue_subroutine (
174+ function_name , self ._subroutine_idx , expr )
175+
176+ group_idx = (self ._subroutine_idx //
177+ self .subroutine_per_group_threshold )
178+ if group_idx != 0 :
179+ return self ._cg .module_function_invocation (
180+ self ._format_group_name (group_idx ),
181+ function_name ,
182+ self ._feature_array_name ), kwargs
183+
151184 return self ._cg .function_invocation (
152185 function_name , self ._feature_array_name ), kwargs
153186
@@ -194,6 +227,10 @@ def _get_subroutine_name(self):
194227 self ._subroutine_idx += 1
195228 return subroutine_name
196229
230+ @staticmethod
231+ def _format_group_name (group_idx ):
232+ return f"SubroutineGroup{ group_idx } "
233+
197234 # Methods to be implemented by subclasses.
198235
199236 def create_code_generator (self ):
0 commit comments