Skip to content

Commit 73cb667

Browse files
authored
Adding the coefficient of determination and adjusted coefficient of determination to Basic and Window Regression Performance Evaluators. (#281)
1 parent f37dffd commit 73cb667

File tree

2 files changed

+128
-45
lines changed

2 files changed

+128
-45
lines changed

moa/src/main/java/moa/evaluation/BasicRegressionPerformanceEvaluator.java

Lines changed: 47 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
*
1616
* You should have received a copy of the GNU General Public License
1717
* along with this program. If not, see <http://www.gnu.org/licenses/>.
18-
*
18+
*
1919
*/
2020
package moa.evaluation;
2121

@@ -24,7 +24,6 @@
2424
import moa.core.Measurement;
2525

2626
import com.yahoo.labs.samoa.instances.Instance;
27-
import com.yahoo.labs.samoa.instances.InstanceData;
2827
import com.yahoo.labs.samoa.instances.Prediction;
2928

3029
/**
@@ -45,11 +44,15 @@ public class BasicRegressionPerformanceEvaluator extends AbstractMOAObject
4544
protected double averageError;
4645

4746
protected double sumTarget;
48-
47+
4948
protected double squareTargetError;
50-
49+
5150
protected double averageTargetError;
5251

52+
protected double totalSumSquares;
53+
54+
protected double numAttributes;
55+
5356
@Override
5457
public void reset() {
5558
this.weightObserved = 0.0;
@@ -58,47 +61,67 @@ public void reset() {
5861
this.sumTarget = 0.0;
5962
this.averageTargetError = 0.0;
6063
this.squareTargetError = 0.0;
61-
64+
this.totalSumSquares = 0.0;
65+
this.numAttributes = 0.0;
6266
}
6367

6468
@Override
6569
public void addResult(Example<Instance> example, double[] prediction) {
66-
Instance inst = example.getData();
70+
Instance inst = example.getData();
6771
if (inst.weight() > 0.0) {
6872
if (prediction.length > 0) {
69-
double meanTarget = this.weightObserved != 0 ?
70-
this.sumTarget / this.weightObserved : 0.0;
73+
double meanTarget = this.weightObserved != 0 ?
74+
this.sumTarget / this.weightObserved : 0.0;
7175
this.squareError += (inst.classValue() - prediction[0]) * (inst.classValue() - prediction[0]);
7276
this.averageError += Math.abs(inst.classValue() - prediction[0]);
7377
this.squareTargetError += (inst.classValue() - meanTarget) * (inst.classValue() - meanTarget);
7478
this.averageTargetError += Math.abs(inst.classValue() - meanTarget);
7579
this.sumTarget += inst.classValue();
7680
this.weightObserved += inst.weight();
81+
this.numAttributes = inst.numAttributes();
7782
}
78-
//System.out.println(inst.classValue()+", "+prediction[0]);
7983
}
8084
}
8185

8286
@Override
8387
public Measurement[] getPerformanceMeasurements() {
8488
return new Measurement[]{
85-
new Measurement("classified instances",
86-
getTotalWeightObserved()),
87-
new Measurement("mean absolute error",
88-
getMeanError()),
89-
new Measurement("root mean squared error",
90-
getSquareError()),
91-
new Measurement("relative mean absolute error",
92-
getRelativeMeanError()),
93-
new Measurement("relative root mean squared error",
94-
getRelativeSquareError())
89+
new Measurement("classified instances",
90+
getTotalWeightObserved()),
91+
new Measurement("mean absolute error",
92+
getMeanError()),
93+
new Measurement("root mean squared error",
94+
getSquareError()),
95+
new Measurement("relative mean absolute error",
96+
getRelativeMeanError()),
97+
new Measurement("relative root mean squared error",
98+
getRelativeSquareError()),
99+
new Measurement("coefficient of determination",
100+
getCoefficientOfDetermination()),
101+
new Measurement("adjusted coefficient of determination",
102+
getAdjustedCoefficientOfDetermination())
95103
};
96104
}
97105

98106
public double getTotalWeightObserved() {
99107
return this.weightObserved;
100108
}
101109

110+
public double getCoefficientOfDetermination() {
111+
if(weightObserved > 0.0) {
112+
double SSres = squareError;
113+
double SStot = squareTargetError;
114+
115+
return 1 - (SSres / SStot);
116+
}
117+
return 0.0;
118+
}
119+
120+
public double getAdjustedCoefficientOfDetermination() {
121+
return 1 - ((1-getCoefficientOfDetermination())*(getTotalWeightObserved() - 1)) /
122+
(getTotalWeightObserved() - numAttributes - 1);
123+
}
124+
102125
public double getMeanError() {
103126
return this.weightObserved > 0.0 ? this.averageError
104127
/ this.weightObserved : 0.0;
@@ -130,18 +153,18 @@ private double getRelativeMeanError() {
130153
//return targetMeanError > 0 ? getMeanError()/targetMeanError : 0.0;
131154
return this.averageTargetError> 0 ?
132155
this.averageError/this.averageTargetError : 0.0;
133-
}
156+
}
134157

135158
private double getRelativeSquareError() {
136159
//double targetSquareError = getTargetSquareError();
137160
//return targetSquareError > 0 ? getSquareError()/targetSquareError : 0.0;
138-
return Math.sqrt(this.squareTargetError> 0 ?
161+
return Math.sqrt(this.squareTargetError> 0 ?
139162
this.squareError/this.squareTargetError : 0.0);
140163
}
141-
164+
142165
@Override
143166
public void addResult(Example<Instance> example, Prediction prediction) {
144-
if(prediction!=null)
145-
addResult(example,prediction.getVotes(0));
167+
if(prediction!=null)
168+
addResult(example,prediction.getVotes(0));
146169
}
147170
}

moa/src/main/java/moa/evaluation/WindowRegressionPerformanceEvaluator.java

Lines changed: 81 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
*
1616
* You should have received a copy of the GNU General Public License
1717
* along with this program. If not, see <http://www.gnu.org/licenses/>.
18-
*
18+
*
1919
*/
2020
package moa.evaluation;
2121

@@ -29,7 +29,6 @@
2929
import moa.tasks.TaskMonitor;
3030

3131
import com.yahoo.labs.samoa.instances.Instance;
32-
import com.yahoo.labs.samoa.instances.InstanceData;
3332
import com.yahoo.labs.samoa.instances.Prediction;
3433

3534
/**
@@ -52,6 +51,14 @@ public class WindowRegressionPerformanceEvaluator extends AbstractOptionHandler
5251

5352
protected Estimator squareError;
5453

54+
protected Estimator squareTargetError;
55+
56+
protected Estimator sumTarget;
57+
58+
protected double numAttributes;
59+
60+
protected Estimator averageTargetError;
61+
5562
protected Estimator averageError;
5663

5764
protected int numClasses;
@@ -103,6 +110,9 @@ public void reset(int numClasses) {
103110
this.weightObserved = new Estimator(this.widthOption.getValue());
104111
this.squareError = new Estimator(this.widthOption.getValue());
105112
this.averageError = new Estimator(this.widthOption.getValue());
113+
this.squareTargetError = new Estimator(this.widthOption.getValue());
114+
this.sumTarget = new Estimator(this.widthOption.getValue());
115+
this.averageTargetError = new Estimator(this.widthOption.getValue());
106116
this.TotalweightObserved = 0;
107117
}
108118

@@ -118,7 +128,17 @@ public void addResult(Example<Instance> example, double[] prediction) {
118128
this.weightObserved.add(weight);
119129

120130
if (prediction.length > 0) {
131+
double meanTarget = this.weightObserved.total() != 0 ?
132+
this.sumTarget.total() / this.weightObserved.total() : 0.0;
133+
121134
this.squareError.add((inst.classValue() - prediction[0]) * (inst.classValue() - prediction[0]));
135+
136+
this.squareTargetError.add((inst.classValue() - meanTarget) * (inst.classValue() - meanTarget));
137+
this.sumTarget.add(inst.classValue());
138+
this.numAttributes = inst.numAttributes();
139+
140+
this.averageTargetError.add(Math.abs(inst.classValue() - meanTarget));
141+
122142
this.averageError.add(Math.abs(inst.classValue() - prediction[0]));
123143
}
124144
//System.out.println(inst.classValue()+", "+prediction[0]);
@@ -128,12 +148,52 @@ public void addResult(Example<Instance> example, double[] prediction) {
128148
@Override
129149
public Measurement[] getPerformanceMeasurements() {
130150
return new Measurement[]{
131-
new Measurement("classified instances",
132-
getTotalWeightObserved()),
133-
new Measurement("mean absolute error",
134-
getMeanError()),
135-
new Measurement("root mean squared error",
136-
getSquareError())};
151+
new Measurement("classified instances",
152+
getTotalWeightObserved()),
153+
new Measurement("mean absolute error",
154+
getMeanError()),
155+
new Measurement("root mean squared error",
156+
getSquareError()),
157+
new Measurement("relative mean absolute error",
158+
getRelativeMeanError()),
159+
new Measurement("relative root mean squared error",
160+
getRelativeSquareError()),
161+
new Measurement("coefficient of determination",
162+
getCoefficientOfDetermination()),
163+
new Measurement("adjusted coefficient of determination",
164+
getAdjustedCoefficientOfDetermination())
165+
};
166+
}
167+
168+
public double getCoefficientOfDetermination() {
169+
if(weightObserved.total() > 0.0) {
170+
double SSres = squareError.total();
171+
double SStot = squareTargetError.total();
172+
173+
return 1 - (SSres / SStot);
174+
}
175+
return 0.0;
176+
}
177+
178+
public double getAdjustedCoefficientOfDetermination() {
179+
return 1 - ((1-getCoefficientOfDetermination())*(getTotalWeightObserved() - 1)) /
180+
(getTotalWeightObserved() - numAttributes - 1);
181+
}
182+
183+
private double getRelativeMeanError() {
184+
//double targetMeanError = getTargetMeanError();
185+
//return targetMeanError > 0 ? getMeanError()/targetMeanError : 0.0;
186+
return this.averageTargetError.total() > 0 ?
187+
this.averageError.total() / this.averageTargetError.total() : 0.0;
188+
// //TODO: implement!
189+
// return -1.0;
190+
}
191+
192+
private double getRelativeSquareError() {
193+
//double targetSquareError = getTargetSquareError();
194+
//return targetSquareError > 0 ? getSquareError()/targetSquareError : 0.0;
195+
return Math.sqrt(this.squareTargetError.total() > 0 ?
196+
this.squareError.total() / this.squareTargetError.total() : 0.0);
137197
}
138198

139199
public double getTotalWeightObserved() {
@@ -158,18 +218,18 @@ public void getDescription(StringBuilder sb, int indent) {
158218

159219
@Override
160220
public void prepareForUseImpl(TaskMonitor monitor,
161-
ObjectRepository repository) {
221+
ObjectRepository repository) {
222+
}
223+
224+
225+
@Override
226+
public void addResult(Example<Instance> testInst, Prediction prediction) {
227+
double votes[];
228+
if(prediction==null)
229+
votes = new double[0];
230+
else
231+
votes=prediction.getVotes();
232+
addResult(testInst, votes);
233+
162234
}
163-
164-
165-
@Override
166-
public void addResult(Example<Instance> testInst, Prediction prediction) {
167-
double votes[];
168-
if(prediction==null)
169-
votes = new double[0];
170-
else
171-
votes=prediction.getVotes();
172-
addResult(testInst, votes);
173-
174-
}
175235
}

0 commit comments

Comments
 (0)