1515 *
1616 * You should have received a copy of the GNU General Public License
1717 * along with this program. If not, see <http://www.gnu.org/licenses/>.
18- *
18+ *
1919 */
2020package moa .evaluation ;
2121
2929import moa .tasks .TaskMonitor ;
3030
3131import com .yahoo .labs .samoa .instances .Instance ;
32- import com .yahoo .labs .samoa .instances .InstanceData ;
3332import com .yahoo .labs .samoa .instances .Prediction ;
3433
3534/**
@@ -52,6 +51,14 @@ public class WindowRegressionPerformanceEvaluator extends AbstractOptionHandler
5251
5352 protected Estimator squareError ;
5453
54+ protected Estimator squareTargetError ;
55+
56+ protected Estimator sumTarget ;
57+
58+ protected double numAttributes ;
59+
60+ protected Estimator averageTargetError ;
61+
5562 protected Estimator averageError ;
5663
5764 protected int numClasses ;
@@ -103,6 +110,9 @@ public void reset(int numClasses) {
103110 this .weightObserved = new Estimator (this .widthOption .getValue ());
104111 this .squareError = new Estimator (this .widthOption .getValue ());
105112 this .averageError = new Estimator (this .widthOption .getValue ());
113+ this .squareTargetError = new Estimator (this .widthOption .getValue ());
114+ this .sumTarget = new Estimator (this .widthOption .getValue ());
115+ this .averageTargetError = new Estimator (this .widthOption .getValue ());
106116 this .TotalweightObserved = 0 ;
107117 }
108118
@@ -118,7 +128,17 @@ public void addResult(Example<Instance> example, double[] prediction) {
118128 this .weightObserved .add (weight );
119129
120130 if (prediction .length > 0 ) {
131+ double meanTarget = this .weightObserved .total () != 0 ?
132+ this .sumTarget .total () / this .weightObserved .total () : 0.0 ;
133+
121134 this .squareError .add ((inst .classValue () - prediction [0 ]) * (inst .classValue () - prediction [0 ]));
135+
136+ this .squareTargetError .add ((inst .classValue () - meanTarget ) * (inst .classValue () - meanTarget ));
137+ this .sumTarget .add (inst .classValue ());
138+ this .numAttributes = inst .numAttributes ();
139+
140+ this .averageTargetError .add (Math .abs (inst .classValue () - meanTarget ));
141+
122142 this .averageError .add (Math .abs (inst .classValue () - prediction [0 ]));
123143 }
124144 //System.out.println(inst.classValue()+", "+prediction[0]);
@@ -128,12 +148,52 @@ public void addResult(Example<Instance> example, double[] prediction) {
128148 @ Override
129149 public Measurement [] getPerformanceMeasurements () {
130150 return new Measurement []{
131- new Measurement ("classified instances" ,
132- getTotalWeightObserved ()),
133- new Measurement ("mean absolute error" ,
134- getMeanError ()),
135- new Measurement ("root mean squared error" ,
136- getSquareError ())};
151+ new Measurement ("classified instances" ,
152+ getTotalWeightObserved ()),
153+ new Measurement ("mean absolute error" ,
154+ getMeanError ()),
155+ new Measurement ("root mean squared error" ,
156+ getSquareError ()),
157+ new Measurement ("relative mean absolute error" ,
158+ getRelativeMeanError ()),
159+ new Measurement ("relative root mean squared error" ,
160+ getRelativeSquareError ()),
161+ new Measurement ("coefficient of determination" ,
162+ getCoefficientOfDetermination ()),
163+ new Measurement ("adjusted coefficient of determination" ,
164+ getAdjustedCoefficientOfDetermination ())
165+ };
166+ }
167+
168+ public double getCoefficientOfDetermination () {
169+ if (weightObserved .total () > 0.0 ) {
170+ double SSres = squareError .total ();
171+ double SStot = squareTargetError .total ();
172+
173+ return 1 - (SSres / SStot );
174+ }
175+ return 0.0 ;
176+ }
177+
178+ public double getAdjustedCoefficientOfDetermination () {
179+ return 1 - ((1 -getCoefficientOfDetermination ())*(getTotalWeightObserved () - 1 )) /
180+ (getTotalWeightObserved () - numAttributes - 1 );
181+ }
182+
183+ private double getRelativeMeanError () {
184+ //double targetMeanError = getTargetMeanError();
185+ //return targetMeanError > 0 ? getMeanError()/targetMeanError : 0.0;
186+ return this .averageTargetError .total () > 0 ?
187+ this .averageError .total () / this .averageTargetError .total () : 0.0 ;
188+ // //TODO: implement!
189+ // return -1.0;
190+ }
191+
192+ private double getRelativeSquareError () {
193+ //double targetSquareError = getTargetSquareError();
194+ //return targetSquareError > 0 ? getSquareError()/targetSquareError : 0.0;
195+ return Math .sqrt (this .squareTargetError .total () > 0 ?
196+ this .squareError .total () / this .squareTargetError .total () : 0.0 );
137197 }
138198
139199 public double getTotalWeightObserved () {
@@ -158,18 +218,18 @@ public void getDescription(StringBuilder sb, int indent) {
158218
159219 @ Override
160220 public void prepareForUseImpl (TaskMonitor monitor ,
161- ObjectRepository repository ) {
221+ ObjectRepository repository ) {
222+ }
223+
224+
225+ @ Override
226+ public void addResult (Example <Instance > testInst , Prediction prediction ) {
227+ double votes [];
228+ if (prediction ==null )
229+ votes = new double [0 ];
230+ else
231+ votes =prediction .getVotes ();
232+ addResult (testInst , votes );
233+
162234 }
163-
164-
165- @ Override
166- public void addResult (Example <Instance > testInst , Prediction prediction ) {
167- double votes [];
168- if (prediction ==null )
169- votes = new double [0 ];
170- else
171- votes =prediction .getVotes ();
172- addResult (testInst , votes );
173-
174- }
175235}
0 commit comments