Skip to content

Commit 3b33a46

Browse files
committed
make skl2onnx optional
The support for scikit-learn pipelines is optional, so we raise an error only if one tries to use it while skl2onnx is not installed. Fixes #21
1 parent 4b94d41 commit 3b33a46

File tree

3 files changed

+142
-166
lines changed

3 files changed

+142
-166
lines changed

ebm2onnx/sklearn.py

Lines changed: 40 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,54 @@
1-
from skl2onnx.common.data_types import Int64TensorType, FloatTensorType, StringTensorType
21
from . import context
32
from . import convert
43

54
import onnx
65

76

8-
def ebm_output_shape_calculator(operator):
9-
op = operator.raw_operator
7+
try:
8+
from skl2onnx.common.data_types import Int64TensorType, FloatTensorType, StringTensorType
109

11-
operator.outputs[0].type = Int64TensorType([None]) # label
12-
operator.outputs[1].type = FloatTensorType([None, len(op.classes_)]) # probabilities
1310

11+
def ebm_output_shape_calculator(operator):
12+
op = operator.raw_operator
1413

15-
def convert_ebm_classifier(scope, operator, container):
16-
"""Converts an EBM model to ONNX with sklearn-onnx
17-
"""
18-
op = operator.raw_operator
14+
operator.outputs[0].type = Int64TensorType([None]) # label
15+
operator.outputs[1].type = FloatTensorType([None, len(op.classes_)]) # probabilities
1916

20-
input_name = operator.inputs[0].onnx_name
21-
ctx = context.create(
22-
generate_variable_name=scope.get_unique_variable_name,
23-
generate_operator_name=scope.get_unique_operator_name,
24-
)
2517

26-
g = convert.to_graph(
27-
op, dtype=(input_name, 'float'),
28-
name="ebm",
29-
predict_proba=True,
30-
prediction_name="label",
31-
probabilities_name="probabilities",
32-
context=ctx
33-
)
18+
def convert_ebm_classifier(scope, operator, container):
19+
"""Converts an EBM model to ONNX with sklearn-onnx
20+
"""
21+
op = operator.raw_operator
3422

35-
for node in g.nodes:
36-
v = container._get_op_version(node.domain, node.op_type)
37-
container.node_domain_version_pair_sets.add((node.domain, v))
23+
input_name = operator.inputs[0].onnx_name
24+
ctx = context.create(
25+
generate_variable_name=scope.get_unique_variable_name,
26+
generate_operator_name=scope.get_unique_operator_name,
27+
)
3828

39-
container.nodes.extend(g.nodes)
29+
g = convert.to_graph(
30+
op, dtype=(input_name, 'float'),
31+
name="ebm",
32+
predict_proba=True,
33+
prediction_name="label",
34+
probabilities_name="probabilities",
35+
context=ctx
36+
)
4037

41-
for i in g.initializers:
42-
content = i.SerializeToString()
43-
container.initializers_strings[content] = i.name
44-
container.initializers.append(i)
38+
for node in g.nodes:
39+
v = container._get_op_version(node.domain, node.op_type)
40+
container.node_domain_version_pair_sets.add((node.domain, v))
41+
42+
container.nodes.extend(g.nodes)
43+
44+
for i in g.initializers:
45+
content = i.SerializeToString()
46+
container.initializers_strings[content] = i.name
47+
container.initializers.append(i)
48+
49+
except Exception:
50+
def ebm_output_shape_calculator(operator):
51+
raise ImportError('skl2onnx not found. Please install it to use serialize a model via scikit-learn')
52+
53+
def convert_ebm_classifier(scope, operator, container):
54+
raise ImportError('skl2onnx not found. Please install it to use serialize a model via scikit-learn')

examples/convert.ipynb

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,23 @@
11
{
22
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": null,
6+
"id": "fd9863cf-daee-4d80-8b1c-2b6fe086b81d",
7+
"metadata": {},
8+
"outputs": [],
9+
"source": [
10+
"\"\"\" Uncomment this if your environment does not have all runtime dependencies installed\n",
11+
"import pip\n",
12+
"pip.main([\n",
13+
" 'install',\n",
14+
" 'interpret-core[notebook,dash,ploly]',\n",
15+
" 'onnxruntime',\n",
16+
" 'matplotlib',\n",
17+
"])\n",
18+
"\"\"\""
19+
]
20+
},
321
{
422
"cell_type": "code",
523
"execution_count": null,
@@ -15,6 +33,19 @@
1533
"from sklearn.metrics import precision_recall_fscore_support"
1634
]
1735
},
36+
{
37+
"cell_type": "code",
38+
"execution_count": null,
39+
"id": "4441327e-f892-4e52-aa1c-c92e43a45696",
40+
"metadata": {},
41+
"outputs": [],
42+
"source": [
43+
"from interpret.provider import InlineProvider\n",
44+
"from interpret import set_visualize_provider\n",
45+
"\n",
46+
"set_visualize_provider(InlineProvider())"
47+
]
48+
},
1849
{
1950
"cell_type": "code",
2051
"execution_count": null,
@@ -38,6 +69,27 @@
3869
"# Binary classification"
3970
]
4071
},
72+
{
73+
"cell_type": "code",
74+
"execution_count": null,
75+
"id": "01be0d4d-c47a-47a9-b953-b5b7c6a9a67a",
76+
"metadata": {},
77+
"outputs": [],
78+
"source": [
79+
"from onnx import defs"
80+
]
81+
},
82+
{
83+
"cell_type": "code",
84+
"execution_count": null,
85+
"id": "020a4e17-cbaf-4159-b1e5-11e42b33f49f",
86+
"metadata": {},
87+
"outputs": [],
88+
"source": [
89+
"#rt.get_available_providers() \n",
90+
"onnx.__version__\n"
91+
]
92+
},
4193
{
4294
"cell_type": "markdown",
4395
"id": "powerful-desktop",
@@ -92,9 +144,7 @@
92144
"cell_type": "code",
93145
"execution_count": null,
94146
"id": "canadian-telephone",
95-
"metadata": {
96-
"scrolled": false
97-
},
147+
"metadata": {},
98148
"outputs": [],
99149
"source": [
100150
"# A lookup at the generated model\n",
@@ -121,9 +171,30 @@
121171
" model=ebm,\n",
122172
" dtype=ebm2onnx.get_dtype_from_pandas(x_train),\n",
123173
" name=\"ebm\",\n",
174+
" #target_opset=10,\n",
124175
")"
125176
]
126177
},
178+
{
179+
"cell_type": "code",
180+
"execution_count": null,
181+
"id": "8e40b3d0-c8c7-4178-8c10-826a3d1363ca",
182+
"metadata": {},
183+
"outputs": [],
184+
"source": [
185+
"onnx_model.ir_version = 10"
186+
]
187+
},
188+
{
189+
"cell_type": "code",
190+
"execution_count": null,
191+
"id": "bb80a006-8542-485e-a0fa-3473ac214018",
192+
"metadata": {},
193+
"outputs": [],
194+
"source": [
195+
"onnx_model"
196+
]
197+
},
127198
{
128199
"cell_type": "markdown",
129200
"id": "italic-authorization",
@@ -190,7 +261,7 @@
190261
"name": "python",
191262
"nbconvert_exporter": "python",
192263
"pygments_lexer": "ipython3",
193-
"version": "3.10.10"
264+
"version": "3.13.8"
194265
}
195266
},
196267
"nbformat": 4,

0 commit comments

Comments
 (0)