Skip to content

Commit 3c26020

Browse files
committed
#3 correct variance and covariance usage in preparation for a replacement by another algorithms
1 parent 87990e9 commit 3c26020

File tree

3 files changed

+28
-27
lines changed

3 files changed

+28
-27
lines changed

src/main/java/org/scaleborn/linereg/calculation/statistics/StatsCalculator.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,40 +30,40 @@ public Statistics calculate(final StatsModel model) {
3030
final double[] slopeCoefficients = model.getSlopeCoefficients().getCoefficients();
3131

3232
final double responseVariance = model.getStatsSampling().getResponseVariance();
33-
double squaredError = responseVariance;
33+
double squaredMeanError = responseVariance;
3434

3535
for (int i = 0; i < featuresCount; i++) {
3636
final double c = slopeCoefficients[i];
3737
final double c2 = c * c;
3838
// Minus double of feature response coefficient
39-
squaredError -= 2 * featuresResponseCovariance[i] * c;
39+
squaredMeanError -= 2 * featuresResponseCovariance[i] * c;
4040

4141
// Add values from covariance matrix of the derivation matrix
4242
for (int j = 0; j <= i; j++) {
4343
if (i == j) {
4444
// Variance term
45-
squaredError += c2 * covarianceLowerTriangularMatrix[i][j];
45+
squaredMeanError += c2 * covarianceLowerTriangularMatrix[i][j];
4646
} else {
4747
// Covariance term
48-
squaredError += 2 * c * slopeCoefficients[j] * covarianceLowerTriangularMatrix[i][j];
48+
squaredMeanError += 2 * c * slopeCoefficients[j] * covarianceLowerTriangularMatrix[i][j];
4949
}
5050
}
5151
}
52-
final double rss = squaredError;
52+
final double mse = squaredMeanError;
5353
return new Statistics() {
5454
@Override
5555
public double getRss() {
56-
return rss;
56+
return mse * model.getStatsSampling().getCount();
5757
}
5858

5959
@Override
6060
public double getMse() {
61-
return rss / model.getStatsSampling().getCount();
61+
return mse;
6262
}
6363

6464
@Override
6565
public double getR2() {
66-
return 1 - (rss / responseVariance);
66+
return 1 - (getRss() / (responseVariance * model.getStatsSampling().getCount()));
6767
}
6868
};
6969
}

src/main/java/org/scaleborn/linereg/sampling/exact/ExactCoefficientSquareTermSampling.java

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ public class ExactCoefficientSquareTermSampling implements
3636
* instability as well as to arithmetic overflow.
3737
*/
3838
private double[][] featuresProductSums;
39-
private ExactSamplingContext context;
39+
private final ExactSamplingContext context;
4040

41-
public ExactCoefficientSquareTermSampling(ExactSamplingContext context) {
41+
public ExactCoefficientSquareTermSampling(final ExactSamplingContext context) {
4242
this.context = context;
43-
int featuresCount = context.getFeaturesCount();
43+
final int featuresCount = context.getFeaturesCount();
4444
this.featuresProductSums = new double[featuresCount][];
4545
for (int i = 0; i < featuresCount; i++) {
4646
this.featuresProductSums[i] = new double[featuresCount];
@@ -49,31 +49,31 @@ public ExactCoefficientSquareTermSampling(ExactSamplingContext context) {
4949

5050
@Override
5151
public double[][] getCovarianceLowerTriangularMatrix() {
52-
int featuresCount = this.context.getFeaturesCount();
53-
long count = this.context.getCount();
54-
double[][] covMatrix = new double[featuresCount][];
55-
double[] averages = this.context.getFeaturesMean();
56-
double[] featureSums = this.context.featureSums;
52+
final int featuresCount = this.context.getFeaturesCount();
53+
final long count = this.context.getCount();
54+
final double[][] covMatrix = new double[featuresCount][];
55+
final double[] averages = this.context.getFeaturesMean();
56+
final double[] featureSums = this.context.featureSums;
5757
for (int i = 0; i < featuresCount; i++) {
58-
double avgI = averages[i];
58+
final double avgI = averages[i];
5959
covMatrix[i] = new double[featuresCount];
6060
// Iterate until "i" due to the covariance matrix is symmetric and
6161
// build only the lower triangle
6262
for (int j = 0; j <= i; j++) {
63-
double avgJ = averages[j];
64-
covMatrix[i][j] =
63+
final double avgJ = averages[j];
64+
covMatrix[i][j] = (
6565
this.featuresProductSums[i][j] - avgI * featureSums[j] - avgJ * featureSums[i]
66-
+ count * avgI * avgJ;
66+
+ count * avgI * avgJ) / count;
6767
}
6868
}
6969
return covMatrix;
7070
}
7171

7272
@Override
7373
public void sample(final double[] featureValues, final double responseValue) {
74-
int featuresCount = this.context.getFeaturesCount();
74+
final int featuresCount = this.context.getFeaturesCount();
7575
for (int i = 0; i < featuresCount; i++) {
76-
double vi = featureValues[i];
76+
final double vi = featureValues[i];
7777
for (int j = 0; j < featuresCount; j++) {
7878
this.featuresProductSums[i][j] += vi * featureValues[j];
7979
}
@@ -82,7 +82,7 @@ public void sample(final double[] featureValues, final double responseValue) {
8282

8383
@Override
8484
public void merge(final ExactCoefficientSquareTermSampling fromSample) {
85-
int featuresCount = this.context.getFeaturesCount();
85+
final int featuresCount = this.context.getFeaturesCount();
8686
for (int i = 0; i < featuresCount; i++) {
8787
for (int j = 0; j < featuresCount; j++) {
8888
this.featuresProductSums[i][j] += fromSample.featuresProductSums[i][j];

src/main/java/org/scaleborn/linereg/sampling/exact/ExactModelSamplingFactory.java

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@ public ExactResponseVarianceTermSampling(
5353

5454
@Override
5555
public double getResponseVariance() {
56-
return this.context.responseSquareSum
57-
- this.context.responseSum / this.context.getCount() * this.context.responseSum;
56+
return (this.context.responseSquareSum
57+
- this.context.responseSum / this.context.getCount() * this.context.responseSum)
58+
/ this.context.getCount();
5859
}
5960

6061
@Override
@@ -106,10 +107,10 @@ public double[] getFeaturesResponseCovariance() {
106107
final double[] featuresMean = this.context.getFeaturesMean();
107108
final double responseMean = this.context.getResponseMean();
108109
for (int i = 0; i < featuresCount; i++) {
109-
covariance[i] =
110+
covariance[i] = (
110111
this.context.featuresResponseProductSum[i] - featuresMean[i] * this.context.responseSum
111112
- responseMean * this.context.featureSums[i]
112-
+ count * featuresMean[i] * responseMean;
113+
+ count * featuresMean[i] * responseMean) / count;
113114
}
114115
return covariance;
115116
}

0 commit comments

Comments
 (0)