Skip to content

Commit 4b0c0f1

Browse files
prabhat00155wenbingl
authored andcommitted
Fixed Linear classifier converter to resolve RS5 test error in LinearSVC (#194)
1 parent 89deca3 commit 4b0c0f1

File tree

2 files changed

+17
-8
lines changed

2 files changed

+17
-8
lines changed

onnxmltools/convert/common/shape_calculator.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,31 +33,30 @@ def calculate_linear_classifier_output_shapes(operator):
3333
N = operator.inputs[0].type.shape[0]
3434

3535
class_labels = operator.raw_operator.classes_
36-
number_of_classes = len(class_labels)
3736
if all(isinstance(i, np.ndarray) for i in class_labels):
3837
class_labels = np.concatenate(class_labels)
3938
if all(isinstance(i, (six.string_types, six.text_type)) for i in class_labels):
4039
operator.outputs[0].type = StringTensorType(shape=[N])
41-
if number_of_classes > 2 or operator.type != 'SklearnLinearSVC':
40+
if len(class_labels) > 2 or operator.type != 'SklearnLinearSVC':
4241
# For multi-class classifier, we produce a map for encoding the probabilities of all classes
4342
if operator.target_opset < 7:
4443
operator.outputs[1].type = DictionaryType(StringTensorType([1]), FloatTensorType([1]))
4544
else:
4645
operator.outputs[1].type = SequenceType(DictionaryType(StringTensorType([]), FloatTensorType([])), N)
4746
else:
48-
# For binary LinearSVC, we produce the probability tensor
49-
operator.outputs[1].type = FloatTensorType(shape=[N, number_of_classes])
47+
# For binary LinearSVC, we produce probability of the positive class
48+
operator.outputs[1].type = FloatTensorType(shape=[N, 1])
5049
elif all(isinstance(i, (numbers.Real, bool, np.bool_)) for i in class_labels):
5150
operator.outputs[0].type = Int64TensorType(shape=[N])
52-
if number_of_classes > 2 or operator.type != 'SklearnLinearSVC':
51+
if len(class_labels) > 2 or operator.type != 'SklearnLinearSVC':
5352
# For multi-class classifier, we produce a map for encoding the probabilities of all classes
5453
if operator.target_opset < 7:
5554
operator.outputs[1].type = DictionaryType(Int64TensorType([1]), FloatTensorType([1]))
5655
else:
5756
operator.outputs[1].type = SequenceType(DictionaryType(Int64TensorType([]), FloatTensorType([])), N)
5857
else:
59-
# For binary LinearSVC, we produce the probability tensor
60-
operator.outputs[1].type = FloatTensorType(shape=[N, number_of_classes])
58+
# For binary LinearSVC, we produce probability of the positive class
59+
operator.outputs[1].type = FloatTensorType(shape=[N, 1])
6160
else:
6261
raise ValueError('Unsupported or mixed label types')
6362

onnxmltools/convert/sklearn/operator_converters/LinearClassifier.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import numpy as np
88
import six, numbers
99
from ...common._registration import register_converter
10+
from ....proto import onnx_proto
1011

1112

1213
def convert_sklearn_linear_classifier(scope, operator, container):
@@ -67,9 +68,18 @@ def convert_sklearn_linear_classifier(scope, operator, container):
6768
probability_tensor_name = scope.get_unique_variable_name('probability_tensor')
6869

6970
if op.__class__.__name__ == 'LinearSVC' and op.classes_.shape[0] <= 2:
71+
raw_scores_tensor_name = scope.get_unique_variable_name('raw_scores_tensor')
72+
positive_class_index_name = scope.get_unique_variable_name('positive_class_index')
73+
74+
container.add_initializer(positive_class_index_name, onnx_proto.TensorProto.INT64,
75+
[], [1])
76+
7077
container.add_node(classifier_type, operator.inputs[0].full_name,
71-
[label_name, operator.outputs[1].full_name],
78+
[label_name, raw_scores_tensor_name],
7279
op_domain='ai.onnx.ml', **classifier_attrs)
80+
container.add_node('ArrayFeatureExtractor', [raw_scores_tensor_name, positive_class_index_name],
81+
operator.outputs[1].full_name, name=scope.get_unique_operator_name('ArrayFeatureExtractor'),
82+
op_domain='ai.onnx.ml')
7383
else:
7484
container.add_node(classifier_type, operator.inputs[0].full_name,
7585
[label_name, probability_tensor_name],

0 commit comments

Comments
 (0)