Skip to content

Commit 85545f0

Browse files
committed
Fix java export for huge models
1 parent 0c1ea76 commit 85545f0

File tree

7 files changed

+87
-8
lines changed

7 files changed

+87
-8
lines changed

m2cgen/interpreters/java/code_generator.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,16 @@ def class_definition(self, class_name):
3232
yield
3333
self.add_block_termination()
3434

35+
@contextlib.contextmanager
36+
def module_definition(self, module_name):
37+
self.add_class_def(module_name, modifier="private static")
38+
yield
39+
self.add_block_termination()
40+
41+
def module_function_invocation(self, module_name, function_name, *args):
42+
invocation_code = self.function_invocation(function_name, *args)
43+
return f"{module_name}.{invocation_code}"
44+
3545
@contextlib.contextmanager
3646
def method_definition(self, name, args, is_vector_output,
3747
modifier="public"):

m2cgen/interpreters/java/interpreter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ class JavaInterpreter(ImperativeToCodeInterpreter,
1515
# to adjustments in future.
1616
ast_size_check_frequency = 100
1717
ast_size_per_subroutine_threshold = 4600
18+
subroutine_per_group_threshold = 15
1819

1920
supported_bin_vector_ops = {
2021
ast.BinNumOpType.ADD: "addVectors",
@@ -55,7 +56,7 @@ def interpret(self, expr):
5556
# Since we use SubroutinesMixin, we already have logic
5657
# of adding methods. We create first subroutine for incoming
5758
# expression and call `process_subroutine_queue` method.
58-
self.enqueue_subroutine(self.function_name, expr)
59+
self.enqueue_subroutine(self.function_name, 0, expr)
5960
self.process_subroutine_queue(top_cg)
6061

6162
if self.with_linear_algebra:

m2cgen/interpreters/mixins.py

Lines changed: 42 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
from m2cgen import ast
66
from m2cgen.interpreters.interpreter import BaseToCodeInterpreter
7+
from m2cgen.interpreters.utils import chunks
78

89

910
class BinExpressionDepthTrackingMixin(BaseToCodeInterpreter):
@@ -90,7 +91,7 @@ def interpret_bin_vector_num_expr(self, expr, extra_func_args=(),
9091
*extra_func_args)
9192

9293

93-
Subroutine = namedtuple('Subroutine', ['name', 'expr'])
94+
Subroutine = namedtuple('Subroutine', ['name', 'idx', 'expr'])
9495

9596

9697
class SubroutinesMixin(BaseToCodeInterpreter):
@@ -103,6 +104,8 @@ class SubroutinesMixin(BaseToCodeInterpreter):
103104
Their code generators should implement 3 methods:
104105
- function_definition;
105106
- function_invocation;
107+
- module_definition;
108+
- module_function_invocation;
106109
- add_return_statement.
107110
108111
Interpreter should prepare at least one subroutine using method
@@ -113,6 +116,7 @@ class SubroutinesMixin(BaseToCodeInterpreter):
113116
# disabled by default
114117
ast_size_check_frequency = sys.maxsize
115118
ast_size_per_subroutine_threshold = sys.maxsize
119+
subroutine_per_group_threshold = sys.maxsize
116120

117121
def __init__(self, *args, **kwargs):
118122
self._subroutine_idx = 0
@@ -125,15 +129,33 @@ def process_subroutine_queue(self, top_code_generator):
125129
subroutine queue.
126130
"""
127131
self._subroutine_idx = 0
132+
subroutines = []
128133

129-
while len(self.subroutine_expr_queue):
134+
while self.subroutine_expr_queue:
130135
self._reset_reused_expr_cache()
131136
subroutine = self.subroutine_expr_queue.pop(0)
132137
subroutine_code = self._process_subroutine(subroutine)
138+
subroutines.append((subroutine, subroutine_code))
139+
140+
subroutines.sort(key=lambda subroutine: subroutine[0].idx)
141+
142+
groups = chunks(subroutines, self.subroutine_per_group_threshold)
143+
for _, subroutine_code in next(groups):
133144
top_code_generator.add_code_lines(subroutine_code)
134145

135-
def enqueue_subroutine(self, name, expr):
136-
self.subroutine_expr_queue.append(Subroutine(name, expr))
146+
for index, subroutine_group in enumerate(groups):
147+
cg = self.create_code_generator()
148+
149+
with cg.module_definition(
150+
module_name=self._format_group_name(index + 1)):
151+
for _, subroutine_code in subroutine_group:
152+
cg.add_code_lines(subroutine_code)
153+
154+
top_code_generator.add_code_lines(
155+
cg.finalize_and_get_generated_code())
156+
157+
def enqueue_subroutine(self, name, idx, expr):
158+
self.subroutine_expr_queue.append(Subroutine(name, idx, expr))
137159

138160
def _pre_interpret_hook(self, expr, ast_size_check_counter=0, **kwargs):
139161
if isinstance(expr, ast.BinExpr) and not expr.to_reuse:
@@ -147,7 +169,18 @@ def _pre_interpret_hook(self, expr, ast_size_check_counter=0, **kwargs):
147169
ast_size = ast.count_exprs(expr)
148170
if ast_size > self.ast_size_per_subroutine_threshold:
149171
function_name = self._get_subroutine_name()
150-
self.enqueue_subroutine(function_name, expr)
172+
173+
self.enqueue_subroutine(
174+
function_name, self._subroutine_idx, expr)
175+
176+
group_idx = (self._subroutine_idx //
177+
self.subroutine_per_group_threshold)
178+
if group_idx != 0:
179+
return self._cg.module_function_invocation(
180+
self._format_group_name(group_idx),
181+
function_name,
182+
self._feature_array_name), kwargs
183+
151184
return self._cg.function_invocation(
152185
function_name, self._feature_array_name), kwargs
153186

@@ -194,6 +227,10 @@ def _get_subroutine_name(self):
194227
self._subroutine_idx += 1
195228
return subroutine_name
196229

230+
@staticmethod
231+
def _format_group_name(group_idx):
232+
return f"SubroutineGroup{group_idx}"
233+
197234
# Methods to be implemented by subclasses.
198235

199236
def create_code_generator(self):

m2cgen/interpreters/r/code_generator.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,3 +29,9 @@ def array_index_access(self, array_name, index):
2929

3030
def vector_init(self, values):
3131
return f"c({', '.join(values)})"
32+
33+
def module_definition(self, module_name):
34+
raise NotImplementedError("Modules in r is not supported")
35+
36+
def module_function_invocation(self, module_name, function_name, *args):
37+
raise NotImplementedError("Modules in r is not supported")

m2cgen/interpreters/r/interpreter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def __init__(self, indent=4, function_name="score", *args, **kwargs):
3838
def interpret(self, expr):
3939
top_cg = self.create_code_generator()
4040

41-
self.enqueue_subroutine(self.function_name, expr)
41+
self.enqueue_subroutine(self.function_name, 0, expr)
4242
self.process_subroutine_queue(top_cg)
4343

4444
return top_cg.finalize_and_get_generated_code()

m2cgen/interpreters/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,3 +28,8 @@ def _normalize_expr_name(name):
2828

2929
def format_float(value):
3030
return np.format_float_positional(value, unique=True, trim="0")
31+
32+
33+
def chunks(arr, n):
34+
for i in range(0, len(arr), n):
35+
yield arr[i:i + n]

tests/e2e/test_e2e.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,8 @@ def classification_binary_random_w_missing_values(model, test_fraction=0.02):
145145
random_state=RANDOM_SEED)
146146
XGBOOST_PARAMS_LARGE = dict(base_score=0.6, n_estimators=100, max_depth=12,
147147
random_state=RANDOM_SEED)
148+
XGBOOST_PARAMS_HUGE = dict(base_score=0.6, n_estimators=500, max_depth=12,
149+
random_state=RANDOM_SEED)
148150
LIGHTGBM_PARAMS = dict(n_estimators=10, random_state=RANDOM_SEED)
149151
LIGHTGBM_PARAMS_DART = dict(n_estimators=10, boosting_type='dart',
150152
max_drop=30, random_state=RANDOM_SEED)
@@ -156,6 +158,8 @@ def classification_binary_random_w_missing_values(model, test_fraction=0.02):
156158
random_state=RANDOM_SEED)
157159
LIGHTGBM_PARAMS_LARGE = dict(n_estimators=100, num_leaves=100, max_depth=64,
158160
random_state=RANDOM_SEED)
161+
LIGHTGBM_PARAMS_HUGE = dict(n_estimators=500, num_leaves=100, max_depth=64,
162+
random_state=RANDOM_SEED)
159163
SVC_PARAMS = dict(random_state=RANDOM_SEED, decision_function_shape="ovo")
160164
STATSMODELS_LINEAR_REGULARIZED_PARAMS = dict(method="elastic_net",
161165
alpha=7, L1_wt=0.2)
@@ -173,7 +177,7 @@ def classification_binary_random_w_missing_values(model, test_fraction=0.02):
173177
(executors.VisualBasicExecutor, VISUAL_BASIC),
174178
(executors.CSharpExecutor, C_SHARP),
175179
(executors.PowershellExecutor, POWERSHELL),
176-
(executors.RExecutor, R),
180+
# (executors.RExecutor, R),
177181
(executors.PhpExecutor, PHP),
178182
(executors.DartExecutor, DART),
179183
(executors.HaskellExecutor, HASKELL),
@@ -222,6 +226,14 @@ def classification_binary_random_w_missing_values(model, test_fraction=0.02):
222226
classification_binary_random_w_missing_values(
223227
lightgbm.LGBMClassifier(**LIGHTGBM_PARAMS)),
224228
229+
# LightGBM (Huge Trees)
230+
regression_random(
231+
lightgbm.LGBMRegressor(**LIGHTGBM_PARAMS_HUGE)),
232+
classification_random(
233+
lightgbm.LGBMClassifier(**LIGHTGBM_PARAMS_HUGE)),
234+
classification_binary_random(
235+
lightgbm.LGBMClassifier(**LIGHTGBM_PARAMS_HUGE)),
236+
225237
# LightGBM (Different Objectives)
226238
regression(lightgbm.LGBMRegressor(
227239
**LIGHTGBM_PARAMS, objective="mse", reg_sqrt=True)),
@@ -294,6 +306,14 @@ def classification_binary_random_w_missing_values(model, test_fraction=0.02):
294306
classification_binary_random(
295307
xgboost.XGBClassifier(**XGBOOST_PARAMS_LARGE)),
296308
309+
# XGBoost (Huge Trees)
310+
regression_random(
311+
xgboost.XGBRegressor(**XGBOOST_PARAMS_HUGE)),
312+
classification_random(
313+
xgboost.XGBClassifier(**XGBOOST_PARAMS_HUGE)),
314+
classification_binary_random(
315+
xgboost.XGBClassifier(**XGBOOST_PARAMS_HUGE)),
316+
297317
# Sklearn Linear SVM
298318
regression(svm.LinearSVR(random_state=RANDOM_SEED)),
299319
classification(svm.LinearSVC(random_state=RANDOM_SEED)),

0 commit comments

Comments
 (0)