Skip to content

Commit 87990e9

Browse files
committed
#5 First version for R² as additional value of the statistics aggregation response
1 parent e967241 commit 87990e9

File tree

6 files changed

+42
-8
lines changed

6 files changed

+42
-8
lines changed

gradle.properties

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ wagon-ssh-external.version=2.10
44
commons-math3.version=3.6.1
55
group=org.scaleborn.elasticsearch.plugin
66
name=elasticsearch-linear-regression
7-
version=5.5.1.1
7+
version=5.5.1.2-SNAPSHOT

src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/stats/InternalStats.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,13 +84,23 @@ public double getMse() {
8484
return this.results.statistics.getMse();
8585
}
8686

87+
@Override
88+
public double getR2() {
89+
if (this.results == null) {
90+
return Double.NaN;
91+
}
92+
return this.results.statistics.getR2();
93+
}
94+
8795
@Override
8896
public Object getDoProperty(final String element) {
8997
switch (element) {
9098
case "rss":
9199
return getRss();
92100
case "mse":
93101
return getMse();
102+
case "r2":
103+
return getR2();
94104
}
95105
return null;
96106
}

src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/stats/StatsResults.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ static class Fields {
3434

3535
public static final String RSS = "rss";
3636
public static final String MSE = "mse";
37+
public static final String R2 = "r2";
3738
}
3839

3940
final Statistics statistics;
@@ -46,14 +47,15 @@ public StatsResults(final SlopeCoefficients slopeCoefficients, final double inte
4647

4748
public StatsResults(final StreamInput in) throws IOException {
4849
super(in);
49-
this.statistics = new DefaultStatistics(in.readDouble(), in.readDouble());
50+
this.statistics = new DefaultStatistics(in.readDouble(), in.readDouble(), in.readDouble());
5051
}
5152

5253
@Override
5354
public void writeTo(final StreamOutput out) throws IOException {
5455
super.writeTo(out);
5556
out.writeDouble(this.statistics.getRss());
5657
out.writeDouble(this.statistics.getMse());
58+
out.writeDouble(this.statistics.getR2());
5759
}
5860

5961
@Override
@@ -63,6 +65,8 @@ public XContentBuilder toXContent(final XContentBuilder builder, final Params pa
6365
builder.field(Fields.RSS, this.statistics.getRss());
6466
// MSE
6567
builder.field(Fields.MSE, this.statistics.getMse());
68+
// R2
69+
builder.field(Fields.R2, this.statistics.getR2());
6670
return super.toXContent(builder, params);
6771
}
6872

src/main/java/org/scaleborn/linereg/calculation/statistics/Statistics.java

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,24 @@ public interface Statistics {
3333
*/
3434
double getMse();
3535

36+
/**
37+
* @return R², coefficient of determination
38+
*/
39+
double getR2();
40+
3641
/**
3742
* Default statistics bean.
3843
*/
3944
public class DefaultStatistics implements Statistics {
4045

4146
private final double rss;
4247
private final double mse;
48+
private final double r2;
4349

44-
public DefaultStatistics(final double rss, final double mse) {
50+
public DefaultStatistics(final double rss, final double mse, final double r2) {
4551
this.rss = rss;
4652
this.mse = mse;
53+
this.r2 = r2;
4754
}
4855

4956
@Override
@@ -56,11 +63,17 @@ public double getMse() {
5663
return this.mse;
5764
}
5865

66+
@Override
67+
public double getR2() {
68+
return this.r2;
69+
}
70+
5971
@Override
6072
public String toString() {
6173
return "DefaultStatistics{" +
6274
"rss=" + this.rss +
6375
", mse=" + this.mse +
76+
", r2=" + this.r2 +
6477
'}';
6578
}
6679
}

src/main/java/org/scaleborn/linereg/calculation/statistics/StatsCalculator.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ public Statistics calculate(final StatsModel model) {
2929
.getCovarianceLowerTriangularMatrix();
3030
final double[] slopeCoefficients = model.getSlopeCoefficients().getCoefficients();
3131

32-
double squaredError = model.getStatsSampling().getResponseVariance();
32+
final double responseVariance = model.getStatsSampling().getResponseVariance();
33+
double squaredError = responseVariance;
3334

3435
for (int i = 0; i < featuresCount; i++) {
3536
final double c = slopeCoefficients[i];
@@ -59,6 +60,11 @@ public double getRss() {
5960
public double getMse() {
6061
return rss / model.getStatsSampling().getCount();
6162
}
63+
64+
@Override
65+
public double getR2() {
66+
return 1 - (rss / responseVariance);
67+
}
6268
};
6369
}
6470
}

src/test/java/org/scaleborn/linereg/TestModels.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,7 @@ public void assertCoefficients(final double[] givenCoefficients, final double de
124124
public void assertStatistics(final Statistics statistics) {
125125
assertEquals("RSS not equal", this.expectedStatistics.getRss(), statistics.getRss(), 0.0001d);
126126
assertEquals("MSE not equal", this.expectedStatistics.getMse(), statistics.getMse(), 0.0001d);
127+
assertEquals("R² not equal", this.expectedStatistics.getR2(), statistics.getR2(), 0.0001d);
127128
}
128129
}
129130

@@ -141,7 +142,7 @@ public void assertStatistics(final Statistics statistics) {
141142
* Residual Sum of Squares: rss = 5.459016393
142143
* Coefficient of Determination: R2 = 0.787704918
143144
*/
144-
new DefaultStatistics(5.459016393, 5.459016393 / 7));
145+
new DefaultStatistics(5.459016393, 5.459016393 / 7, 0.787704918));
145146

146147
public static TestModel MULTI_FEATURES_2_MODEL_1 = new TestModel(3, 2, new double[][]{
147148
new double[]{-2, 3, 5},
@@ -157,7 +158,7 @@ public void assertStatistics(final Statistics statistics) {
157158
* Coefficient of Determination: R2 = 8.835223808·10-1
158159
*/
159160
}, new double[]{-0.5496314882d, 0.3070409283d},
160-
new DefaultStatistics(2.99513878d, 2.99513878d / 7));
161+
new DefaultStatistics(2.99513878d, 2.99513878d / 7, 0.8835223808));
161162

162163
public static TestModel MULTI_FEATURES_3_MODEL_1 = new TestModel(4, 3, new double[][]{
163164
new double[]{4, -2, 3, 5},
@@ -173,7 +174,7 @@ public void assertStatistics(final Statistics statistics) {
173174
* Coefficient of Determination: R2 = 8.947697777·10-1
174175
*/
175176
}, new double[]{-0.03116979852d, -0.6272993725d, 0.3079647314d},
176-
new DefaultStatistics(2.705920002d, 2.705920002d / 7));
177+
new DefaultStatistics(2.705920002d, 2.705920002d / 7, 0.8947697777));
177178

178179
/**
179180
* Reference data set Longley: from http://www.itl.nist.gov/div898/strd/lls/data/Longley.shtml
@@ -211,6 +212,6 @@ public void assertStatistics(final Statistics statistics) {
211212
*/
212213
, new double[]{15.0618722713733d, -0.358191792925910E-01d, -2.02022980381683d,
213214
-1.03322686717359d, -0.511041056535807E-01d, 1829.15146461355d},
214-
new DefaultStatistics(836424.055505915d, 836424.055505915d / 16));
215+
new DefaultStatistics(836424.055505915d, 836424.055505915d / 16, 0.995479004577296));
215216

216217
}

0 commit comments

Comments
 (0)