@@ -468,29 +468,28 @@ public Object[][] generatePredictions( RowMetaInterface inputMeta, RowMetaInterf
468468
469469 double [] prediction = preds [i ];
470470
471- if ( prediction .length == 1 || !outputProbs ) {
472- if ( supervised ) {
473- if ( classAtt .isNumeric () ) {
474- resultRow [index ++] = prediction [0 ];
475- } else {
476- int maxProb = Utils .maxIndex ( prediction );
477- if ( prediction [maxProb ] > 0 ) {
478- resultRow [index ++] = classAtt .value ( maxProb );
479- } else {
480- resultRow [index ++] =
481- BaseMessages .getString ( PMIScoringMeta .PKG , "PMIScoringData.Message.UnableToPredict" );
482- }
483- }
471+ if ( supervised ) {
472+ if ( classAtt .isNumeric () ) {
473+ resultRow [index ++] = prediction [0 ];
484474 } else {
485475 int maxProb = Utils .maxIndex ( prediction );
486476 if ( prediction [maxProb ] > 0 ) {
487- resultRow [index ++] = maxProb ;
477+ resultRow [index ++] = classAtt . value ( maxProb ) ;
488478 } else {
489- resultRow [index ++] =
490- BaseMessages .getString ( PMIScoringMeta .PKG , "PMIScoringData.Message.UnableToPredictCluster" );
479+ resultRow [index ++] = BaseMessages .getString ( PMIScoringMeta .PKG , "PMIScoringData.Message.UnableToPredict" );
491480 }
492481 }
493482 } else {
483+ int maxProb = Utils .maxIndex ( prediction );
484+ if ( prediction [maxProb ] > 0 ) {
485+ resultRow [index ++] = maxProb ;
486+ } else {
487+ resultRow [index ++] =
488+ BaseMessages .getString ( PMIScoringMeta .PKG , "PMIScoringData.Message.UnableToPredictCluster" );
489+ }
490+ }
491+
492+ if ( ( outputProbs && classAtt == null ) || ( outputProbs && !classAtt .isNumeric () ) ) {
494493 // output probability distribution
495494 for ( double j : prediction ) {
496495 resultRow [index ++] = j ;
@@ -544,29 +543,28 @@ public Object[] generatePrediction( RowMetaInterface inputMeta, RowMetaInterface
544543 int index = inputMeta .size ();
545544
546545 // output for numeric class or discrete class value
547- if ( prediction .length == 1 || !outputProbs ) {
548- if ( supervised ) {
549- if ( classAtt .isNumeric () ) {
550- resultRow [index ++] = prediction [0 ];
551- } else {
552- int maxProb = Utils .maxIndex ( prediction );
553- if ( prediction [maxProb ] > 0 ) {
554- resultRow [index ++] = classAtt .value ( maxProb );
555- } else {
556- resultRow [index ++] =
557- BaseMessages .getString ( PMIScoringMeta .PKG , "WekaScoringData.Message.UnableToPredict" );
558- }
559- }
546+ if ( supervised ) {
547+ if ( classAtt .isNumeric () ) {
548+ resultRow [index ++] = prediction [0 ];
560549 } else {
561550 int maxProb = Utils .maxIndex ( prediction );
562551 if ( prediction [maxProb ] > 0 ) {
563- resultRow [index ++] = maxProb ;
552+ resultRow [index ++] = classAtt . value ( maxProb ) ;
564553 } else {
565- String newVal = BaseMessages .getString ( PMIScoringMeta .PKG , "PMIScoringData.Message.UnableToPredictCluster" );
566- resultRow [index ++] = newVal ;
554+ resultRow [index ++] = BaseMessages .getString ( PMIScoringMeta .PKG , "WekaScoringData.Message.UnableToPredict" );
567555 }
568556 }
569557 } else {
558+ int maxProb = Utils .maxIndex ( prediction );
559+ if ( prediction [maxProb ] > 0 ) {
560+ resultRow [index ++] = maxProb ;
561+ } else {
562+ String newVal = BaseMessages .getString ( PMIScoringMeta .PKG , "PMIScoringData.Message.UnableToPredictCluster" );
563+ resultRow [index ++] = newVal ;
564+ }
565+ }
566+
567+ if ( ( outputProbs && classAtt == null ) || ( outputProbs && !classAtt .isNumeric () ) ) {
570568 // output probability distribution
571569 for ( int i = 0 ; i < prediction .length ; i ++ ) {
572570 Double newVal = prediction [i ];
0 commit comments