Skip to content

Commit 000b7d6

Browse files
author
Mark Hall
committed
PMI Scoring now always includes the predicted class/cluster field in the output, regardless of whether probabilities are being output or not.
1 parent ae8bab8 commit 000b7d6

File tree

2 files changed

+49
-52
lines changed

2 files changed

+49
-52
lines changed

src/org/pentaho/di/trans/steps/pmi/PMIScoringData.java

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -468,29 +468,28 @@ public Object[][] generatePredictions( RowMetaInterface inputMeta, RowMetaInterf
468468

469469
double[] prediction = preds[i];
470470

471-
if ( prediction.length == 1 || !outputProbs ) {
472-
if ( supervised ) {
473-
if ( classAtt.isNumeric() ) {
474-
resultRow[index++] = prediction[0];
475-
} else {
476-
int maxProb = Utils.maxIndex( prediction );
477-
if ( prediction[maxProb] > 0 ) {
478-
resultRow[index++] = classAtt.value( maxProb );
479-
} else {
480-
resultRow[index++] =
481-
BaseMessages.getString( PMIScoringMeta.PKG, "PMIScoringData.Message.UnableToPredict" );
482-
}
483-
}
471+
if ( supervised ) {
472+
if ( classAtt.isNumeric() ) {
473+
resultRow[index++] = prediction[0];
484474
} else {
485475
int maxProb = Utils.maxIndex( prediction );
486476
if ( prediction[maxProb] > 0 ) {
487-
resultRow[index++] = maxProb;
477+
resultRow[index++] = classAtt.value( maxProb );
488478
} else {
489-
resultRow[index++] =
490-
BaseMessages.getString( PMIScoringMeta.PKG, "PMIScoringData.Message.UnableToPredictCluster" );
479+
resultRow[index++] = BaseMessages.getString( PMIScoringMeta.PKG, "PMIScoringData.Message.UnableToPredict" );
491480
}
492481
}
493482
} else {
483+
int maxProb = Utils.maxIndex( prediction );
484+
if ( prediction[maxProb] > 0 ) {
485+
resultRow[index++] = maxProb;
486+
} else {
487+
resultRow[index++] =
488+
BaseMessages.getString( PMIScoringMeta.PKG, "PMIScoringData.Message.UnableToPredictCluster" );
489+
}
490+
}
491+
492+
if ( ( outputProbs && classAtt == null ) || ( outputProbs && !classAtt.isNumeric() ) ) {
494493
// output probability distribution
495494
for ( double j : prediction ) {
496495
resultRow[index++] = j;
@@ -544,29 +543,28 @@ public Object[] generatePrediction( RowMetaInterface inputMeta, RowMetaInterface
544543
int index = inputMeta.size();
545544

546545
// output for numeric class or discrete class value
547-
if ( prediction.length == 1 || !outputProbs ) {
548-
if ( supervised ) {
549-
if ( classAtt.isNumeric() ) {
550-
resultRow[index++] = prediction[0];
551-
} else {
552-
int maxProb = Utils.maxIndex( prediction );
553-
if ( prediction[maxProb] > 0 ) {
554-
resultRow[index++] = classAtt.value( maxProb );
555-
} else {
556-
resultRow[index++] =
557-
BaseMessages.getString( PMIScoringMeta.PKG, "WekaScoringData.Message.UnableToPredict" );
558-
}
559-
}
546+
if ( supervised ) {
547+
if ( classAtt.isNumeric() ) {
548+
resultRow[index++] = prediction[0];
560549
} else {
561550
int maxProb = Utils.maxIndex( prediction );
562551
if ( prediction[maxProb] > 0 ) {
563-
resultRow[index++] = maxProb;
552+
resultRow[index++] = classAtt.value( maxProb );
564553
} else {
565-
String newVal = BaseMessages.getString( PMIScoringMeta.PKG, "PMIScoringData.Message.UnableToPredictCluster" );
566-
resultRow[index++] = newVal;
554+
resultRow[index++] = BaseMessages.getString( PMIScoringMeta.PKG, "WekaScoringData.Message.UnableToPredict" );
567555
}
568556
}
569557
} else {
558+
int maxProb = Utils.maxIndex( prediction );
559+
if ( prediction[maxProb] > 0 ) {
560+
resultRow[index++] = maxProb;
561+
} else {
562+
String newVal = BaseMessages.getString( PMIScoringMeta.PKG, "PMIScoringData.Message.UnableToPredictCluster" );
563+
resultRow[index++] = newVal;
564+
}
565+
}
566+
567+
if ( ( outputProbs && classAtt == null ) || ( outputProbs && !classAtt.isNumeric() ) ) {
570568
// output probability distribution
571569
for ( int i = 0; i < prediction.length; i++ ) {
572570
Double newVal = prediction[i];

src/org/pentaho/di/trans/steps/pmi/PMIScoringMeta.java

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -833,33 +833,38 @@ public void saveRep( Repository rep, ObjectId id_transformation, ObjectId id_ste
833833
if ( supervised ) {
834834
classAttName = header.classAttribute().name();
835835

836-
if ( header.classAttribute().isNumeric() || !m_outputProbabilities ) {
837-
int
838-
valueType =
839-
( header.classAttribute().isNumeric() ) ? ValueMetaInterface.TYPE_NUMBER :
840-
ValueMetaInterface.TYPE_STRING;
841-
842-
ValueMetaInterface newVM = ValueMetaFactory.createValueMeta( classAttName + "_predicted", valueType );
843-
newVM.setOrigin( origin );
844-
row.addValueMeta( newVM );
845-
} else {
836+
int
837+
valueType =
838+
( header.classAttribute().isNumeric() ) ? ValueMetaInterface.TYPE_NUMBER : ValueMetaInterface.TYPE_STRING;
839+
840+
ValueMetaInterface newVM = ValueMetaFactory.createValueMeta( classAttName + "_predicted", valueType );
841+
newVM.setOrigin( origin );
842+
row.addValueMeta( newVM );
843+
844+
if ( m_outputProbabilities && !header.classAttribute().isNumeric() ) {
846845
for ( int i = 0; i < header.classAttribute().numValues(); i++ ) {
847846
String classVal = header.classAttribute().value( i );
848-
ValueMetaInterface
849-
newVM =
847+
// ValueMetaInterface
848+
newVM =
850849
ValueMetaFactory.createValueMeta( classAttName + ":" + classVal + "_predicted_prob",
851850
ValueMetaInterface.TYPE_NUMBER );
852851
newVM.setOrigin( origin );
853852
row.addValueMeta( newVM );
854853
}
855854
}
856855
} else {
856+
ValueMetaInterface
857+
newVM =
858+
ValueMetaFactory.createValueMeta( "cluster#_predicted", ValueMetaInterface.TYPE_NUMBER );
859+
newVM.setOrigin( origin );
860+
row.addValueMeta( newVM );
861+
857862
if ( m_outputProbabilities ) {
858863
try {
859864
int numClusters = ( (PMIScoringClusterer) m_model ).numberOfClusters();
860865
for ( int i = 0; i < numClusters; i++ ) {
861-
ValueMetaInterface
862-
newVM =
866+
// ValueMetaInterface
867+
newVM =
863868
ValueMetaFactory
864869
.createValueMeta( "cluster_" + i + "_predicted_prob", ValueMetaInterface.TYPE_NUMBER );
865870
newVM.setOrigin( origin );
@@ -869,12 +874,6 @@ public void saveRep( Repository rep, ObjectId id_transformation, ObjectId id_ste
869874
throw new KettleStepException(
870875
BaseMessages.getString( PKG, "PMIScoringMeta.Error.UnableToGetNumberOfClusters" ), ex );
871876
}
872-
} else {
873-
ValueMetaInterface
874-
newVM =
875-
ValueMetaFactory.createValueMeta( "cluster#_predicted", ValueMetaInterface.TYPE_NUMBER );
876-
newVM.setOrigin( origin );
877-
row.addValueMeta( newVM );
878877
}
879878
}
880879
} catch ( KettlePluginException e ) {

0 commit comments

Comments
 (0)