Skip to content

Commit 250c913

Browse files
committed
Generalize aggregation stuff to simplify newer aggregations based in the evaluated linear model
1 parent 3803b2b commit 250c913

File tree

20 files changed

+713
-205
lines changed

20 files changed

+713
-205
lines changed

src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/stats/InternalStats.java

Lines changed: 40 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -16,186 +16,99 @@
1616

1717
package org.scaleborn.elasticsearch.linreg.aggregation.stats;
1818

19-
import static java.util.Collections.emptyMap;
20-
2119
import java.io.IOException;
22-
import java.util.ArrayList;
2320
import java.util.List;
2421
import java.util.Map;
2522
import org.apache.logging.log4j.Logger;
2623
import org.elasticsearch.common.io.stream.StreamInput;
27-
import org.elasticsearch.common.io.stream.StreamOutput;
2824
import org.elasticsearch.common.logging.Loggers;
29-
import org.elasticsearch.common.xcontent.XContentBuilder;
30-
import org.elasticsearch.search.aggregations.InternalAggregation;
3125
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
32-
import org.scaleborn.elasticsearch.linreg.aggregation.support.StateInputStreamAdapter;
33-
import org.scaleborn.elasticsearch.linreg.aggregation.support.StateOutputStreamAdapter;
34-
import org.scaleborn.linereg.evaluation.DerivationEquation;
35-
import org.scaleborn.linereg.evaluation.DerivationEquationBuilder;
26+
import org.scaleborn.elasticsearch.linreg.aggregation.support.BaseInternalAggregation;
27+
import org.scaleborn.linereg.calculation.statistics.Statistics;
28+
import org.scaleborn.linereg.calculation.statistics.StatsCalculator;
29+
import org.scaleborn.linereg.calculation.statistics.StatsModel;
3630
import org.scaleborn.linereg.evaluation.SlopeCoefficients;
37-
import org.scaleborn.linereg.evaluation.commons.CommonsMathSolver;
38-
import org.scaleborn.linereg.statistics.Statistics;
39-
import org.scaleborn.linereg.statistics.StatsBuilder;
40-
import org.scaleborn.linereg.statistics.StatsModel;
41-
import org.scaleborn.linereg.statistics.StatsSampling;
4231

4332
/**
4433
* Created by mbok on 21.03.17.
4534
*/
46-
public class InternalStats extends InternalAggregation implements Stats {
35+
public class InternalStats extends
36+
BaseInternalAggregation<StatsAggregationSampling, StatsResults, InternalStats> implements
37+
Stats {
4738

4839
private static final Logger LOGGER = Loggers.getLogger(InternalStats.class);
49-
/**
50-
* per shard sampling needed to compute stats
51-
*/
52-
private StatsSampling<?> sampling;
53-
/**
54-
* final result
55-
*/
56-
private Statistics results;
5740

58-
/**
59-
* Features count
60-
*/
61-
private int featuresCount;
41+
private static final StatsCalculator statsCalculator = new StatsCalculator();
6242

6343
/**
6444
* per shard ctor
6545
*/
66-
protected InternalStats(String name, int featuresCount, StatsSampling<?> linRegSampling,
67-
Statistics results,
68-
List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) {
69-
super(name, pipelineAggregators, metaData);
70-
this.featuresCount = featuresCount;
71-
this.sampling = linRegSampling;
72-
this.results = results;
46+
protected InternalStats(final String name, final int featuresCount,
47+
final StatsAggregationSampling linRegSampling,
48+
final StatsResults results,
49+
final List<PipelineAggregator> pipelineAggregators, final Map<String, Object> metaData) {
50+
super(name, featuresCount, linRegSampling, results, pipelineAggregators, metaData);
7351
}
7452

7553
/**
7654
* Read from a stream.
7755
*/
78-
public InternalStats(StreamInput in) throws IOException {
79-
super(in);
80-
this.featuresCount = in.readInt();
81-
if (in.readBoolean()) {
82-
this.sampling = StatsAggregationBuilder.buildSampling(this.featuresCount);
83-
StateInputStreamAdapter streamAdapter = new StateInputStreamAdapter(in);
84-
this.sampling.loadState(streamAdapter);
85-
}
86-
if (in.readBoolean()) {
87-
this.results = new DefaultStatistics(in.readDouble(), in.readDouble());
88-
}
56+
public InternalStats(final StreamInput in) throws IOException {
57+
super(in, StatsResults::new);
8958
}
9059

9160
@Override
92-
protected void doWriteTo(StreamOutput out) throws IOException {
93-
out.writeInt(this.featuresCount);
94-
out.writeBoolean(this.sampling != null);
95-
if (this.sampling != null) {
96-
StateOutputStreamAdapter outAdapter = new StateOutputStreamAdapter(out);
97-
this.sampling.saveState(outAdapter);
98-
}
99-
out.writeBoolean(this.results != null);
100-
if (this.results != null) {
101-
out.writeDouble(this.results.getRss());
102-
out.writeDouble(this.results.getMse());
103-
}
61+
protected StatsAggregationSampling buildSampling(final int featuresCount) {
62+
return StatsAggregationBuilder.buildSampling(featuresCount);
10463
}
10564

65+
10666
@Override
10767
public String getWriteableName() {
10868
return StatsAggregationBuilder.NAME;
10969
}
11070

111-
static class Fields {
112-
113-
public static final String RSS = "rss";
114-
public static final String MSE = "mse";
115-
}
116-
11771
@Override
11872
public double getRss() {
119-
if (results == null) {
73+
if (this.results == null) {
12074
return Double.NaN;
12175
}
122-
return results.getRss();
76+
return this.results.statistics.getRss();
12377
}
12478

12579
@Override
12680
public double getMse() {
127-
if (results == null) {
81+
if (this.results == null) {
12882
return Double.NaN;
12983
}
130-
return results.getMse();
84+
return this.results.statistics.getMse();
13185
}
13286

13387
@Override
134-
public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException {
135-
if (results != null) {
136-
// RSS
137-
builder.field(Fields.RSS, results.getRss());
138-
// MSE
139-
builder.field(Fields.MSE, results.getMse());
88+
public Object getDoProperty(final String element) {
89+
switch (element) {
90+
case "rss":
91+
return getRss();
92+
case "mse":
93+
return getMse();
14094
}
141-
return builder;
95+
return null;
14296
}
14397

14498
@Override
145-
public Object getProperty(List<String> path) {
146-
if (path.isEmpty()) {
147-
return this;
148-
} else if (path.size() == 1) {
149-
String element = path.get(0);
150-
if (results == null) {
151-
return emptyMap();
152-
}
153-
switch (element) {
154-
case "rss":
155-
return results.getRss();
156-
case "mse":
157-
return results.getMse();
158-
default:
159-
throw new IllegalArgumentException(
160-
"Found unknown path element [" + element + "] in [" + getName() + "]");
161-
}
162-
} else {
163-
throw new IllegalArgumentException("path not supported for [" + getName() + "]: " + path);
164-
}
99+
protected InternalStats buildInternalAggregation(final String name, final int featuresCount,
100+
final StatsAggregationSampling linRegSampling, final StatsResults results,
101+
final List<PipelineAggregator> pipelineAggregators, final Map<String, Object> metaData) {
102+
return new InternalStats(name, featuresCount, linRegSampling, results, pipelineAggregators,
103+
metaData);
165104
}
166105

167-
@SuppressWarnings("unchecked")
168106
@Override
169-
public InternalAggregation doReduce(List<InternalAggregation> aggregations,
170-
ReduceContext reduceContext) {
171-
// merge samples across all shards
172-
List<InternalAggregation> aggs = new ArrayList<>(aggregations);
173-
aggs.removeIf(p -> ((InternalStats) p).sampling == null);
174-
175-
// return empty result iff all samples are null
176-
if (aggs.isEmpty()) {
177-
return new InternalStats(name, featuresCount, null, new DefaultStatistics(0, 0),
178-
pipelineAggregators(),
179-
getMetaData());
180-
}
181-
182-
StatsSampling composedSampling = StatsAggregationBuilder
183-
.buildSampling(featuresCount);
184-
for (int i = 0; i < aggs.size(); ++i) {
185-
LOGGER.info("Merging sampling={}: {}", i, ((InternalStats) aggs.get(i)).sampling);
186-
composedSampling.merge(((InternalStats) aggs.get(i)).sampling);
187-
}
188-
189-
// Linear regression evaluation
190-
DerivationEquation derivationEquation = new DerivationEquationBuilder()
191-
.buildDerivationEquation(composedSampling);
192-
CommonsMathSolver commonsMathSolver = new CommonsMathSolver();
193-
SlopeCoefficients slopeCoefficients = commonsMathSolver.solveCoefficients(derivationEquation);
194-
StatsBuilder statsBuilder = new StatsBuilder();
195-
Statistics statsResult = statsBuilder
196-
.buildStats(new StatsModel(composedSampling, slopeCoefficients));
197-
LOGGER.info("Evaluated linear with {} and stats {}", slopeCoefficients, statsResult);
198-
return new InternalStats(name, featuresCount, composedSampling, statsResult,
199-
pipelineAggregators(), getMetaData());
107+
protected StatsResults buildResults(final StatsAggregationSampling composedSampling,
108+
final SlopeCoefficients slopeCoefficients) {
109+
final Statistics stats = statsCalculator
110+
.calculate(new StatsModel(composedSampling, slopeCoefficients));
111+
return new StatsResults(slopeCoefficients, stats);
200112
}
113+
201114
}

src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/stats/Stats.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
package org.scaleborn.elasticsearch.linreg.aggregation.stats;
1818

19-
import org.scaleborn.linereg.statistics.Statistics;
19+
import org.scaleborn.linereg.calculation.statistics.Statistics;
2020

2121
/**
2222
* Created by mbok on 21.03.17.

src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/stats/StatsAggregationBuilder.java

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,6 @@
2828
import org.scaleborn.elasticsearch.linreg.aggregation.support.BaseAggregationBuilder;
2929
import org.scaleborn.linereg.sampling.exact.ExactModelSamplingFactory;
3030
import org.scaleborn.linereg.sampling.exact.ExactSamplingContext;
31-
import org.scaleborn.linereg.statistics.StatsSampling;
32-
import org.scaleborn.linereg.statistics.StatsSampling.StatsSamplingProxy;
3331

3432
/**
3533
* Created by mbok on 21.03.17.
@@ -42,21 +40,21 @@ public class StatsAggregationBuilder extends
4240
private static final ExactModelSamplingFactory MODEL_SAMPLING_FACTORY = new ExactModelSamplingFactory();
4341

4442

45-
public StatsAggregationBuilder(String name) {
43+
public StatsAggregationBuilder(final String name) {
4644
super(name);
4745
}
4846

49-
public StatsAggregationBuilder(StreamInput in) throws IOException {
47+
public StatsAggregationBuilder(final StreamInput in) throws IOException {
5048
super(in);
5149
}
5250

5351
@Override
54-
protected StatsAggregatorFactory innerInnerBuild(SearchContext context,
55-
List<NamedValuesSourceConfigSpec<Numeric>> configs, MultiValueMode multiValueMode,
56-
AggregatorFactory<?> parent, AggregatorFactories.Builder subFactoriesBuilder)
52+
protected StatsAggregatorFactory innerInnerBuild(final SearchContext context,
53+
final List<NamedValuesSourceConfigSpec<Numeric>> configs, final MultiValueMode multiValueMode,
54+
final AggregatorFactory<?> parent, final AggregatorFactories.Builder subFactoriesBuilder)
5755
throws IOException {
58-
return new StatsAggregatorFactory(name, configs, multiValueMode, context, parent,
59-
subFactoriesBuilder, metaData);
56+
return new StatsAggregatorFactory(this.name, configs, multiValueMode, context, parent,
57+
subFactoriesBuilder, this.metaData);
6058
}
6159

6260

@@ -65,13 +63,14 @@ public String getType() {
6563
return NAME;
6664
}
6765

68-
static StatsSampling<?> buildSampling(final int featuresCount) {
69-
ExactSamplingContext context = MODEL_SAMPLING_FACTORY
66+
static StatsAggregationSampling buildSampling(final int featuresCount) {
67+
final ExactSamplingContext context = MODEL_SAMPLING_FACTORY
7068
.createContext(featuresCount);
71-
StatsSamplingProxy<?> statsSampling = new StatsSamplingProxy<>(context,
72-
MODEL_SAMPLING_FACTORY.createResponseVarianceTermSampling(context),
69+
final StatsAggregationSampling statsSampling = new StatsAggregationSampling(context,
7370
MODEL_SAMPLING_FACTORY.createCoefficientLinearTermSampling(context),
74-
MODEL_SAMPLING_FACTORY.createCoefficientSquareTermSampling(context));
71+
MODEL_SAMPLING_FACTORY.createCoefficientSquareTermSampling(context),
72+
MODEL_SAMPLING_FACTORY.createInterceptSampling(context),
73+
MODEL_SAMPLING_FACTORY.createResponseVarianceTermSampling(context));
7574
return statsSampling;
7675
}
7776
}
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright (c) 2017 Scaleborn UG, www.scaleborn.com
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.scaleborn.elasticsearch.linreg.aggregation.stats;
18+
19+
import java.io.IOException;
20+
import org.scaleborn.elasticsearch.linreg.aggregation.support.BaseSampling;
21+
import org.scaleborn.linereg.calculation.statistics.StatsSampling;
22+
import org.scaleborn.linereg.sampling.io.StateInputStream;
23+
import org.scaleborn.linereg.sampling.io.StateOutputStream;
24+
25+
/**
26+
* Created by mbok on 08.04.17.
27+
*/
28+
public class StatsAggregationSampling extends BaseSampling<StatsAggregationSampling> implements
29+
StatsSampling<StatsAggregationSampling> {
30+
31+
private final ResponseVarianceTermSampling responseVarianceTermSampling;
32+
33+
public StatsAggregationSampling(
34+
final SamplingContext<?> samplingContext,
35+
final CoefficientLinearTermSampling<?> coefficientLinearTermSampling,
36+
final CoefficientSquareTermSampling<?> coefficientSquareTermSampling,
37+
final InterceptSampling<?> interceptSampling,
38+
final ResponseVarianceTermSampling<?> responseVarianceTermSampling) {
39+
super(samplingContext, coefficientLinearTermSampling, coefficientSquareTermSampling,
40+
interceptSampling);
41+
this.responseVarianceTermSampling = responseVarianceTermSampling;
42+
}
43+
44+
@Override
45+
public void merge(final StatsAggregationSampling fromSample) {
46+
super.merge(fromSample);
47+
this.responseVarianceTermSampling.merge(fromSample.responseVarianceTermSampling);
48+
}
49+
50+
@Override
51+
public void sample(final double[] featureValues, final double responseValue) {
52+
super.sample(featureValues, responseValue);
53+
this.responseVarianceTermSampling.sample(featureValues, responseValue);
54+
}
55+
56+
@Override
57+
public void saveState(final StateOutputStream destination) throws IOException {
58+
super.saveState(destination);
59+
this.responseVarianceTermSampling.saveState(destination);
60+
}
61+
62+
@Override
63+
public void loadState(final StateInputStream source) throws IOException {
64+
super.loadState(source);
65+
this.responseVarianceTermSampling.loadState(source);
66+
}
67+
68+
@Override
69+
public double getResponseVariance() {
70+
return this.responseVarianceTermSampling.getResponseVariance();
71+
}
72+
}

src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/stats/StatsAggregator.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@
3636
import org.elasticsearch.search.aggregations.support.NamedValuesSourceSpec;
3737
import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric;
3838
import org.elasticsearch.search.internal.SearchContext;
39-
import org.scaleborn.linereg.statistics.StatsSampling;
4039

4140
/**
4241
* Created by mbok on 21.03.17.
@@ -50,7 +49,7 @@ public class StatsAggregator extends MetricsAggregator {
5049
*/
5150
final NumericMultiValuesSource valuesSources;
5251

53-
protected ObjectArray<StatsSampling<?>> samplings;
52+
protected ObjectArray<StatsAggregationSampling> samplings;
5453

5554
public StatsAggregator(final String name,
5655
final List<NamedValuesSourceSpec<Numeric>> valuesSources,
@@ -95,7 +94,7 @@ public void collect(final int doc, final long bucket) throws IOException {
9594
if (includeDocument(doc) == true) {
9695
StatsAggregator.this.samplings = bigArrays
9796
.grow(StatsAggregator.this.samplings, bucket + 1);
98-
StatsSampling<?> sampling = StatsAggregator.this.samplings.get(bucket);
97+
StatsAggregationSampling sampling = StatsAggregator.this.samplings.get(bucket);
9998
// add document fields to correlation stats
10099
if (sampling == null) {
101100
sampling = StatsAggregationBuilder.buildSampling(this.fieldNames.length - 1);

0 commit comments

Comments
 (0)