|
1 | | -from skl2onnx.common.data_types import Int64TensorType, FloatTensorType, StringTensorType |
2 | 1 | from . import context |
3 | 2 | from . import convert |
4 | 3 |
|
5 | 4 | import onnx |
6 | 5 |
|
7 | 6 |
|
8 | | -def ebm_output_shape_calculator(operator): |
9 | | - op = operator.raw_operator |
| 7 | +try: |
| 8 | + from skl2onnx.common.data_types import Int64TensorType, FloatTensorType, StringTensorType |
10 | 9 |
|
11 | | - operator.outputs[0].type = Int64TensorType([None]) # label |
12 | | - operator.outputs[1].type = FloatTensorType([None, len(op.classes_)]) # probabilities |
13 | 10 |
|
| 11 | + def ebm_output_shape_calculator(operator): |
| 12 | + op = operator.raw_operator |
14 | 13 |
|
15 | | -def convert_ebm_classifier(scope, operator, container): |
16 | | - """Converts an EBM model to ONNX with sklearn-onnx |
17 | | - """ |
18 | | - op = operator.raw_operator |
| 14 | + operator.outputs[0].type = Int64TensorType([None]) # label |
| 15 | + operator.outputs[1].type = FloatTensorType([None, len(op.classes_)]) # probabilities |
19 | 16 |
|
20 | | - input_name = operator.inputs[0].onnx_name |
21 | | - ctx = context.create( |
22 | | - generate_variable_name=scope.get_unique_variable_name, |
23 | | - generate_operator_name=scope.get_unique_operator_name, |
24 | | - ) |
25 | 17 |
|
26 | | - g = convert.to_graph( |
27 | | - op, dtype=(input_name, 'float'), |
28 | | - name="ebm", |
29 | | - predict_proba=True, |
30 | | - prediction_name="label", |
31 | | - probabilities_name="probabilities", |
32 | | - context=ctx |
33 | | - ) |
| 18 | + def convert_ebm_classifier(scope, operator, container): |
| 19 | + """Converts an EBM model to ONNX with sklearn-onnx |
| 20 | + """ |
| 21 | + op = operator.raw_operator |
34 | 22 |
|
35 | | - for node in g.nodes: |
36 | | - v = container._get_op_version(node.domain, node.op_type) |
37 | | - container.node_domain_version_pair_sets.add((node.domain, v)) |
| 23 | + input_name = operator.inputs[0].onnx_name |
| 24 | + ctx = context.create( |
| 25 | + generate_variable_name=scope.get_unique_variable_name, |
| 26 | + generate_operator_name=scope.get_unique_operator_name, |
| 27 | + ) |
38 | 28 |
|
39 | | - container.nodes.extend(g.nodes) |
| 29 | + g = convert.to_graph( |
| 30 | + op, dtype=(input_name, 'float'), |
| 31 | + name="ebm", |
| 32 | + predict_proba=True, |
| 33 | + prediction_name="label", |
| 34 | + probabilities_name="probabilities", |
| 35 | + context=ctx |
| 36 | + ) |
40 | 37 |
|
41 | | - for i in g.initializers: |
42 | | - content = i.SerializeToString() |
43 | | - container.initializers_strings[content] = i.name |
44 | | - container.initializers.append(i) |
| 38 | + for node in g.nodes: |
| 39 | + v = container._get_op_version(node.domain, node.op_type) |
| 40 | + container.node_domain_version_pair_sets.add((node.domain, v)) |
| 41 | + |
| 42 | + container.nodes.extend(g.nodes) |
| 43 | + |
| 44 | + for i in g.initializers: |
| 45 | + content = i.SerializeToString() |
| 46 | + container.initializers_strings[content] = i.name |
| 47 | + container.initializers.append(i) |
| 48 | + |
| 49 | +except Exception: |
| 50 | + def ebm_output_shape_calculator(operator): |
| 51 | + raise ImportError('skl2onnx not found. Please install it to use serialize a model via scikit-learn') |
| 52 | + |
| 53 | + def convert_ebm_classifier(scope, operator, container): |
| 54 | + raise ImportError('skl2onnx not found. Please install it to use serialize a model via scikit-learn') |
0 commit comments