Skip to content

Commit 89deca3

Browse files
authored
Fixed LinearSVC output shape bug (#193)
1 parent c10d361 commit 89deca3

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

onnxmltools/convert/common/shape_calculator.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,30 +33,31 @@ def calculate_linear_classifier_output_shapes(operator):
3333
N = operator.inputs[0].type.shape[0]
3434

3535
class_labels = operator.raw_operator.classes_
36+
number_of_classes = len(class_labels)
3637
if all(isinstance(i, np.ndarray) for i in class_labels):
3738
class_labels = np.concatenate(class_labels)
3839
if all(isinstance(i, (six.string_types, six.text_type)) for i in class_labels):
3940
operator.outputs[0].type = StringTensorType(shape=[N])
40-
if len(class_labels) > 2 or operator.type != 'SklearnLinearSVC':
41+
if number_of_classes > 2 or operator.type != 'SklearnLinearSVC':
4142
# For multi-class classifier, we produce a map for encoding the probabilities of all classes
4243
if operator.target_opset < 7:
4344
operator.outputs[1].type = DictionaryType(StringTensorType([1]), FloatTensorType([1]))
4445
else:
4546
operator.outputs[1].type = SequenceType(DictionaryType(StringTensorType([]), FloatTensorType([])), N)
4647
else:
47-
# For binary classifier, we produce the probability of the positive class
48-
operator.outputs[1].type = FloatTensorType(shape=[N, 1])
48+
# For binary LinearSVC, we produce the probability tensor
49+
operator.outputs[1].type = FloatTensorType(shape=[N, number_of_classes])
4950
elif all(isinstance(i, (numbers.Real, bool, np.bool_)) for i in class_labels):
5051
operator.outputs[0].type = Int64TensorType(shape=[N])
51-
if len(class_labels) > 2 or operator.type != 'SklearnLinearSVC':
52+
if number_of_classes > 2 or operator.type != 'SklearnLinearSVC':
5253
# For multi-class classifier, we produce a map for encoding the probabilities of all classes
5354
if operator.target_opset < 7:
5455
operator.outputs[1].type = DictionaryType(Int64TensorType([1]), FloatTensorType([1]))
5556
else:
5657
operator.outputs[1].type = SequenceType(DictionaryType(Int64TensorType([]), FloatTensorType([])), N)
5758
else:
58-
# For binary classifier, we produce the probability of the positive class
59-
operator.outputs[1].type = FloatTensorType(shape=[N, 1])
59+
# For binary LinearSVC, we produce the probability tensor
60+
operator.outputs[1].type = FloatTensorType(shape=[N, number_of_classes])
6061
else:
6162
raise ValueError('Unsupported or mixed label types')
6263

0 commit comments

Comments
 (0)