Skip to content

Commit 2351850

Browse files
authored
Implement option split to reduce discrepancies for lightgbm regressors (#496)
* Implement option split to reduce discrepancies for lightgbm regressors * disable unit test on older version of onnxruntime * update CI * move hummingbirdml tests in a separate folder * remove h2o from CI on Windows
1 parent 1767f05 commit 2351850

16 files changed

+678
-323
lines changed

.azure-pipelines/linux-CI-nightly.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ jobs:
4747
python -m pip install $(ONNX_PATH)
4848
python -m pip install hummingbird-ml --no-deps
4949
python -m pip install -r requirements.txt
50-
pip install torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
5150
python -m pip install -r requirements-dev.txt
5251
python -m pip install $(ORT_PATH)
5352
python -m pip install pytest

.azure-pipelines/linux-conda-CI.yml

Lines changed: 72 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ trigger:
99
jobs:
1010

1111
- job: 'Test'
12+
timeoutInMinutes: 25
1213
pool:
1314
vmImage: 'ubuntu-latest'
1415
strategy:
@@ -70,31 +71,28 @@ jobs:
7071
maxParallel: 3
7172

7273
steps:
73-
- script: sudo install -d -m 0777 /home/vsts/.conda/envs
74-
displayName: Fix Conda permissions
75-
76-
- task: CondaEnvironment@1
74+
- task: UsePythonVersion@0
7775
inputs:
78-
createCustomEnvironment: true
79-
environmentName: 'py$(python.version)'
80-
packageSpecs: 'python=$(python.version)'
76+
versionSpec: '$(python.version)'
77+
architecture: 'x64'
8178

8279
- script: |
8380
python -m pip install --upgrade pip
84-
conda config --set always_yes yes --set changeps1 no
85-
conda install -c conda-forge protobuf
86-
conda install -c conda-forge numpy
87-
conda install -c conda-forge cmake
88-
pip install $(COREML_PATH)
89-
pip install $(ONNX_PATH)
90-
pip install hummingbird-ml --no-deps
81+
pip install $(ONNX_PATH) $(ONNXRT_PATH) cython
9182
pip install -r requirements.txt
92-
pip install torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
83+
displayName: 'Install dependencies'
84+
85+
- script: |
9386
pip install -r requirements-dev.txt
87+
displayName: 'Install dependencies-dev'
88+
89+
- script: |
90+
python -m pip install --upgrade pip
9491
pip install xgboost$(xgboost.version)
92+
pip install $(ONNX_PATH)
9593
pip install $(ONNXRT_PATH)
96-
pip install pytest
97-
displayName: 'Install dependencies'
94+
pip install $(COREML_PATH)
95+
displayName: 'Install xgboost, onnxruntime'
9896
9997
- script: |
10098
pip install flake8
@@ -109,8 +107,63 @@ jobs:
109107
export PYTHONPATH=.
110108
python -c "import onnxconverter_common;print(onnxconverter_common.__version__)"
111109
python -c "import onnxruntime;print(onnxruntime.__version__)"
112-
pytest tests --doctest-modules --junitxml=junit/test-results.xml
113-
displayName: 'pytest - onnxmltools'
110+
displayName: 'version'
111+
112+
- script: |
113+
export PYTHONPATH=.
114+
pytest tests/baseline --durations=0
115+
displayName: 'pytest - baseline'
116+
117+
- script: |
118+
export PYTHONPATH=.
119+
pytest tests/catboost --durations=0
120+
displayName: 'pytest - catboost'
121+
122+
- script: |
123+
export PYTHONPATH=.
124+
pytest tests/coreml --durations=0
125+
displayName: 'pytest - coreml'
126+
127+
- script: |
128+
export PYTHONPATH=.
129+
pytest tests/lightgbm --durations=0
130+
displayName: 'pytest - lightgbm'
131+
132+
- script: |
133+
export PYTHONPATH=.
134+
pytest tests/sparkml --durations=0
135+
displayName: 'pytest - sparkml'
136+
137+
- script: |
138+
export PYTHONPATH=.
139+
pytest tests/svmlib --durations=0
140+
displayName: 'pytest - svmlib'
141+
142+
- script: |
143+
export PYTHONPATH=.
144+
pytest tests/utils --durations=0
145+
displayName: 'pytest - utils'
146+
147+
- script: |
148+
export PYTHONPATH=.
149+
pytest tests/xgboost --durations=0
150+
displayName: 'pytest - xgboost'
151+
152+
- script: |
153+
export PYTHONPATH=.
154+
pip install h2o
155+
pytest tests/h2o --durations=0
156+
displayName: 'pytest - h2o'
157+
158+
- script: |
159+
pip install torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio==0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
160+
pip install hummingbird-ml --no-deps
161+
displayName: 'Install hummingbird-ml'
162+
163+
- script: |
164+
export PYTHONPATH=.
165+
pytest tests/hummingbirdml --durations=0
166+
displayName: 'pytest - hummingbirdml'
114167
115168
- task: PublishTestResults@2
116169
inputs:

.azure-pipelines/win32-CI-nightly.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ jobs:
4545
pip install %COREML_PATH% %ONNX_PATH%
4646
pip install humming-bird-ml --no-deps
4747
pip install -r requirements.txt
48-
python -m pip install torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio===0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
4948
pip install -r requirements-dev.txt
5049
pip install %ONNXRT_PATH%
5150
displayName: 'Install dependencies'

.azure-pipelines/win32-conda-CI.yml

Lines changed: 65 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ trigger:
99
jobs:
1010

1111
- job: 'Test'
12+
timeoutInMinutes: 30
1213
pool:
1314
vmImage: 'windows-latest'
1415
strategy:
@@ -18,79 +19,63 @@ jobs:
1819
ONNX_PATH: 'onnx==1.10.1' # '-i https://test.pypi.org/simple/ onnx==1.9.101'
1920
ONNXRT_PATH: onnxruntime==1.8.1
2021
COREML_PATH: git+https://github.com/apple/[email protected]
21-
sklearn.version: ''
2222

2323
Python39-190-RT181:
2424
python.version: '3.9'
2525
ONNX_PATH: 'onnx==1.9.0'
2626
ONNXRT_PATH: onnxruntime==1.8.1
2727
COREML_PATH: git+https://github.com/apple/[email protected]
28-
sklearn.version: ''
2928

3029
Python39-190-RT180:
3130
python.version: '3.9'
3231
ONNX_PATH: onnx==1.9.0
3332
ONNXRT_PATH: onnxruntime==1.8.0
3433
COREML_PATH: git+https://github.com/apple/[email protected]
35-
sklearn.version: ''
3634

3735
Python38-181-RT170:
3836
python.version: '3.8'
3937
ONNX_PATH: onnx==1.8.1
4038
ONNXRT_PATH: onnxruntime==1.7.0
4139
COREML_PATH: git+https://github.com/apple/[email protected]
42-
sklearn.version: ''
4340

4441
Python37-180-RT160:
4542
python.version: '3.7'
4643
ONNX_PATH: onnx==1.8.0
4744
ONNXRT_PATH: onnxruntime==1.6.0
4845
COREML_PATH: git+https://github.com/apple/[email protected]
49-
sklearn.version: ''
5046

5147
Python37-160-RT111:
5248
python.version: '3.7'
5349
ONNX_PATH: onnx==1.6.0
5450
ONNXRT_PATH: onnxruntime==1.1.1
5551
COREML_PATH: git+https://github.com/apple/[email protected]
56-
sklearn.version: ''
5752

5853
Python37-170-RT130:
5954
python.version: '3.7'
6055
ONNX_PATH: onnx==1.7.0
6156
ONNXRT_PATH: onnxruntime==1.3.0
6257
COREML_PATH: git+https://github.com/apple/[email protected]
63-
sklearn.version: ''
6458

6559
maxParallel: 3
6660

6761
steps:
68-
- task: UsePythonVersion@0
69-
inputs:
70-
versionSpec: '$(python.version)'
71-
architecture: 'x64'
72-
7362
- powershell: Write-Host "##vso[task.prependpath]$env:CONDA\Scripts"
7463
displayName: Add conda to PATH
7564

76-
- script: conda create --yes --quiet --name py$(python.version) -c conda-forge python=$(python.version) numpy protobuf
65+
- script: conda create --yes --quiet --name py$(python.version) -c conda-forge python=$(python.version) numpy protobuf scikit-learn scipy cython
7766
displayName: Create Anaconda environment
7867

7968
- script: |
8069
call activate py$(python.version)
8170
python -m pip install --upgrade pip numpy
8271
echo Test numpy installation... && python -c "import numpy"
83-
python -m pip install scikit-learn
84-
python -m pip install %ONNX_PATH%
85-
python -m pip install humming-bird-ml --no-deps
8672
python -m pip install -r requirements.txt
87-
python -m pip install torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio===0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
8873
displayName: 'Install dependencies (1)'
8974
9075
- script: |
9176
call activate py$(python.version)
9277
python -m pip install -r requirements-dev.txt
93-
displayName: 'Install dependencies (2)'
78+
displayName: 'Install dependencies-dev'
9479
9580
- script: |
9681
call activate py$(python.version)
@@ -99,14 +84,10 @@ jobs:
9984
10085
- script: |
10186
call activate py$(python.version)
87+
python -m pip install %ONNX_PATH%
10288
python -m pip install %ONNXRT_PATH%
10389
displayName: 'Install onnxruntime'
10490
105-
- script: |
106-
call activate py$(python.version)
107-
python -m pip install scikit-learn$(sklearn.version)
108-
displayName: 'Install scikit-learn'
109-
11091
- script: |
11192
call activate py$(python.version)
11293
python -m flake8 ./onnxmltools
@@ -118,8 +99,67 @@ jobs:
11899
export PYTHONPATH=.
119100
python -c "import onnxconverter_common;print(onnxconverter_common.__version__)"
120101
python -c "import onnxruntime;print(onnxruntime.__version__)"
121-
python -m pytest tests --doctest-modules --junitxml=junit/test-results.xml
122-
displayName: 'pytest - onnxmltools'
102+
displayName: 'version'
103+
104+
- script: |
105+
call activate py$(python.version)
106+
export PYTHONPATH=.
107+
python -m pytest tests/baseline --durations=0
108+
displayName: 'pytest baseline'
109+
110+
- script: |
111+
call activate py$(python.version)
112+
export PYTHONPATH=.
113+
python -m pytest tests/catboost --durations=0
114+
displayName: 'pytest catboost'
115+
116+
- script: |
117+
call activate py$(python.version)
118+
export PYTHONPATH=.
119+
python -m pytest tests/coreml --durations=0
120+
displayName: 'pytest coreml'
121+
122+
- script: |
123+
call activate py$(python.version)
124+
export PYTHONPATH=.
125+
python -m pytest tests/lightgbm --durations=0
126+
displayName: 'pytest lightgbm'
127+
128+
- script: |
129+
call activate py$(python.version)
130+
export PYTHONPATH=.
131+
python -m pytest tests/sparkml --durations=0
132+
displayName: 'pytest sparkml'
133+
134+
- script: |
135+
call activate py$(python.version)
136+
export PYTHONPATH=.
137+
python -m pytest tests/svmlib --durations=0
138+
displayName: 'pytest svmlib'
139+
140+
- script: |
141+
call activate py$(python.version)
142+
export PYTHONPATH=.
143+
python -m pytest tests/utils --durations=0
144+
displayName: 'pytest utils'
145+
146+
- script: |
147+
call activate py$(python.version)
148+
export PYTHONPATH=.
149+
python -m pytest tests/xgboost --durations=0
150+
displayName: 'pytest xgboost'
151+
152+
- script: |
153+
call activate py$(python.version)
154+
python -m pip install torch==1.8.1+cpu torchvision==0.9.1+cpu torchaudio===0.8.1 -f https://download.pytorch.org/whl/torch_stable.html
155+
python -m pip install hummingbird-ml --no-deps
156+
displayName: 'Install hummingbird-ml'
157+
158+
- script: |
159+
call activate py$(python.version)
160+
export PYTHONPATH=.
161+
python -m pytest tests/hummingbirdml --durations=0
162+
displayName: 'pytest hummingbirdml'
123163
124164
- task: PublishTestResults@2
125165
inputs:

onnxmltools/convert/lightgbm/_parse.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -81,17 +81,20 @@ def _get_lightgbm_operator_name(model):
8181
return lightgbm_operator_name_map[model_type]
8282

8383

84-
def _parse_lightgbm_simple_model(scope, model, inputs):
84+
def _parse_lightgbm_simple_model(scope, model, inputs, split=None):
8585
'''
8686
This function handles all non-pipeline models.
8787
8888
:param scope: Scope object
8989
:param model: A lightgbm object
9090
:param inputs: A list of variables
91+
:param split: split TreeEnsembleRegressor into multiple node to reduce
92+
discrepancies
9193
:return: A list of output variables which will be passed to next stage
9294
'''
9395
operator_name = _get_lightgbm_operator_name(model)
9496
this_operator = scope.declare_local_operator(operator_name, model)
97+
this_operator.split = split
9598
this_operator.inputs = inputs
9699

97100
if operator_name == 'LgbmClassifier':
@@ -151,27 +154,29 @@ def _parse_sklearn_classifier(scope, model, inputs, zipmap=True):
151154
return this_operator.outputs
152155

153156

154-
def _parse_lightgbm(scope, model, inputs, zipmap=True):
157+
def _parse_lightgbm(scope, model, inputs, zipmap=True, split=None):
155158
'''
156159
This is a delegate function. It doesn't nothing but invoke the correct parsing function according to the input
157160
model's type.
158161
:param scope: Scope object
159162
:param model: A lightgbm object
160163
:param inputs: A list of variables
161164
:param zipmap: add operator ZipMap after operator TreeEnsembleClassifier
165+
:param split: split TreeEnsembleRegressor into multiple node to reduce
166+
discrepancies
162167
:return: The output variables produced by the input model
163168
'''
164169
if isinstance(model, LGBMClassifier):
165170
return _parse_sklearn_classifier(scope, model, inputs, zipmap=zipmap)
166171
if (isinstance(model, WrappedBooster) and
167172
model.operator_name == 'LgbmClassifier'):
168173
return _parse_sklearn_classifier(scope, model, inputs, zipmap=zipmap)
169-
return _parse_lightgbm_simple_model(scope, model, inputs)
174+
return _parse_lightgbm_simple_model(scope, model, inputs, split=split)
170175

171176

172177
def parse_lightgbm(model, initial_types=None, target_opset=None,
173178
custom_conversion_functions=None, custom_shape_calculators=None,
174-
zipmap=True):
179+
zipmap=True, split=None):
175180
raw_model_container = LightGbmModelContainer(model)
176181
topology = Topology(raw_model_container, default_batch_size='None',
177182
initial_types=initial_types, target_opset=target_opset,
@@ -186,7 +191,7 @@ def parse_lightgbm(model, initial_types=None, target_opset=None,
186191
for variable in inputs:
187192
raw_model_container.add_input(variable)
188193

189-
outputs = _parse_lightgbm(scope, model, inputs, zipmap=zipmap)
194+
outputs = _parse_lightgbm(scope, model, inputs, zipmap=zipmap, split=split)
190195

191196
for variable in outputs:
192197
raw_model_container.add_output(variable)

0 commit comments

Comments
 (0)