@@ -33,30 +33,31 @@ def calculate_linear_classifier_output_shapes(operator):
3333 N = operator .inputs [0 ].type .shape [0 ]
3434
3535 class_labels = operator .raw_operator .classes_
36+ number_of_classes = len (class_labels )
3637 if all (isinstance (i , np .ndarray ) for i in class_labels ):
3738 class_labels = np .concatenate (class_labels )
3839 if all (isinstance (i , (six .string_types , six .text_type )) for i in class_labels ):
3940 operator .outputs [0 ].type = StringTensorType (shape = [N ])
40- if len ( class_labels ) > 2 or operator .type != 'SklearnLinearSVC' :
41+ if number_of_classes > 2 or operator .type != 'SklearnLinearSVC' :
4142 # For multi-class classifier, we produce a map for encoding the probabilities of all classes
4243 if operator .target_opset < 7 :
4344 operator .outputs [1 ].type = DictionaryType (StringTensorType ([1 ]), FloatTensorType ([1 ]))
4445 else :
4546 operator .outputs [1 ].type = SequenceType (DictionaryType (StringTensorType ([]), FloatTensorType ([])), N )
4647 else :
47- # For binary classifier , we produce the probability of the positive class
48- operator .outputs [1 ].type = FloatTensorType (shape = [N , 1 ])
48+ # For binary LinearSVC , we produce the probability tensor
49+ operator .outputs [1 ].type = FloatTensorType (shape = [N , number_of_classes ])
4950 elif all (isinstance (i , (numbers .Real , bool , np .bool_ )) for i in class_labels ):
5051 operator .outputs [0 ].type = Int64TensorType (shape = [N ])
51- if len ( class_labels ) > 2 or operator .type != 'SklearnLinearSVC' :
52+ if number_of_classes > 2 or operator .type != 'SklearnLinearSVC' :
5253 # For multi-class classifier, we produce a map for encoding the probabilities of all classes
5354 if operator .target_opset < 7 :
5455 operator .outputs [1 ].type = DictionaryType (Int64TensorType ([1 ]), FloatTensorType ([1 ]))
5556 else :
5657 operator .outputs [1 ].type = SequenceType (DictionaryType (Int64TensorType ([]), FloatTensorType ([])), N )
5758 else :
58- # For binary classifier , we produce the probability of the positive class
59- operator .outputs [1 ].type = FloatTensorType (shape = [N , 1 ])
59+ # For binary LinearSVC , we produce the probability tensor
60+ operator .outputs [1 ].type = FloatTensorType (shape = [N , number_of_classes ])
6061 else :
6162 raise ValueError ('Unsupported or mixed label types' )
6263
0 commit comments