Skip to content

Commit 0f97fe6

Browse files
committed
#1 First working version
1 parent d1237d7 commit 0f97fe6

File tree

11 files changed

+146
-33
lines changed

11 files changed

+146
-33
lines changed

src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/predict/InternalPrediction.java

Lines changed: 26 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,13 @@
1717
package org.scaleborn.elasticsearch.linreg.aggregation.predict;
1818

1919
import java.io.IOException;
20+
import java.util.Arrays;
2021
import java.util.List;
2122
import java.util.Map;
23+
import org.apache.logging.log4j.Logger;
2224
import org.elasticsearch.common.io.stream.StreamInput;
25+
import org.elasticsearch.common.io.stream.StreamOutput;
26+
import org.elasticsearch.common.logging.Loggers;
2327
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
2428
import org.scaleborn.elasticsearch.linreg.aggregation.support.BaseInternalAggregation;
2529
import org.scaleborn.linereg.evaluation.SlopeCoefficients;
@@ -31,16 +35,28 @@ public class InternalPrediction extends
3135
BaseInternalAggregation<PredictionSampling, PredictionResults, InternalPrediction> implements
3236
Prediction {
3337

38+
private static final Logger LOGGER = Loggers.getLogger(InternalPrediction.class);
39+
40+
private final double[] inputs;
41+
3442
protected InternalPrediction(final String name, final int featuresCount,
3543
final PredictionSampling sampling,
3644
final PredictionResults results,
37-
final List<PipelineAggregator> pipelineAggregators,
45+
final double[] inputs, final List<PipelineAggregator> pipelineAggregators,
3846
final Map<String, Object> metaData) {
3947
super(name, featuresCount, sampling, results, pipelineAggregators, metaData);
48+
this.inputs = inputs;
4049
}
4150

4251
public InternalPrediction(final StreamInput in) throws IOException {
4352
super(in, PredictionResults::new);
53+
this.inputs = in.readDoubleArray();
54+
}
55+
56+
@Override
57+
protected void doWriteTo(final StreamOutput out) throws IOException {
58+
super.doWriteTo(out);
59+
out.writeDoubleArray(this.inputs);
4460
}
4561

4662
@Override
@@ -68,15 +84,20 @@ protected Object getDoProperty(final String path) {
6884
protected InternalPrediction buildInternalAggregation(final String name, final int featuresCount,
6985
final PredictionSampling linRegSampling, final PredictionResults results,
7086
final List<PipelineAggregator> pipelineAggregators, final Map<String, Object> metaData) {
71-
return new InternalPrediction(name, featuresCount, linRegSampling, results, pipelineAggregators,
87+
return new InternalPrediction(name, featuresCount, linRegSampling, results, this.inputs,
88+
pipelineAggregators,
7289
metaData);
7390
}
7491

7592
@Override
7693
protected PredictionResults buildResults(final PredictionSampling composedSampling,
77-
final SlopeCoefficients slopeCoefficients) {
78-
// TODO calculate predicated value
79-
return new PredictionResults(2, slopeCoefficients);
94+
final SlopeCoefficients slopeCoefficients, final double intercept) {
95+
double predictedValue = intercept;
96+
LOGGER.info("Predicting values for inputs: {}", Arrays.toString(this.inputs));
97+
for (int i = 0; i < this.featuresCount; i++) {
98+
predictedValue += slopeCoefficients.getCoefficients()[i] * this.inputs[i];
99+
}
100+
return new PredictionResults(predictedValue, slopeCoefficients, intercept);
80101
}
81102

82103
@Override

src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/predict/PredictionAggregationBuilder.java

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818

1919
import java.io.IOException;
2020
import java.util.List;
21+
import java.util.Objects;
2122
import org.elasticsearch.common.io.stream.StreamInput;
23+
import org.elasticsearch.common.io.stream.StreamOutput;
2224
import org.elasticsearch.search.MultiValueMode;
2325
import org.elasticsearch.search.aggregations.AggregatorFactories.Builder;
2426
import org.elasticsearch.search.aggregations.AggregatorFactory;
@@ -39,6 +41,7 @@ public class PredictionAggregationBuilder extends
3941
public static final String NAME = "linreg_predict";
4042

4143
private static final ExactModelSamplingFactory MODEL_SAMPLING_FACTORY = new ExactModelSamplingFactory();
44+
private double[] inputs;
4245

4346
public PredictionAggregationBuilder(final String name) {
4447
super(name);
@@ -47,15 +50,30 @@ public PredictionAggregationBuilder(final String name) {
4750
public PredictionAggregationBuilder(final StreamInput in)
4851
throws IOException {
4952
super(in);
53+
this.inputs = in.readDoubleArray();
5054
}
5155

56+
@Override
57+
protected void innerWriteTo(final StreamOutput out) throws IOException {
58+
super.innerWriteTo(out);
59+
out.writeDoubleArray(this.inputs);
60+
}
5261

5362
@Override
5463
protected MultiValuesSourceAggregatorFactory<Numeric, ?> innerInnerBuild(
5564
final SearchContext context,
5665
final List<NamedValuesSourceConfigSpec<Numeric>> configs, final MultiValueMode multiValueMode,
5766
final AggregatorFactory<?> parent, final Builder subFactoriesBuilder) throws IOException {
58-
return new PredictionAggregatorFactory(this.name, configs, multiValueMode, context, parent,
67+
if (this.inputs == null || this.inputs.length != configs.size() - 1) {
68+
throw new IllegalArgumentException(
69+
"[inputs] must have [" + (configs.size() - 1)
70+
+ "] values as much as the number of feature fields: ["
71+
+ this.name
72+
+ "]");
73+
}
74+
return new PredictionAggregatorFactory(this.name, configs, multiValueMode, this.inputs,
75+
context,
76+
parent,
5977
subFactoriesBuilder, this.metaData);
6078
}
6179

@@ -73,4 +91,23 @@ static PredictionSampling buildSampling(final int featuresCount) {
7391
MODEL_SAMPLING_FACTORY.createInterceptSampling(context));
7492
return predictionSampling;
7593
}
94+
95+
public void inputs(final double[] inputs) {
96+
this.inputs = inputs;
97+
}
98+
99+
public double[] inputs() {
100+
return this.inputs;
101+
}
102+
103+
@Override
104+
protected int innerHashCode() {
105+
return Objects.hash(this.inputs);
106+
}
107+
108+
@Override
109+
protected boolean innerEquals(final Object obj) {
110+
final PredictionAggregationBuilder other = (PredictionAggregationBuilder) obj;
111+
return Objects.equals(this.inputs, other.inputs);
112+
}
76113
}

src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/predict/PredictionAggregationParser.java

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,66 @@
1616

1717
package org.scaleborn.elasticsearch.linreg.aggregation.predict;
1818

19+
import java.io.IOException;
20+
import java.util.ArrayList;
21+
import java.util.List;
1922
import java.util.Map;
23+
import org.apache.logging.log4j.Logger;
2024
import org.elasticsearch.common.ParseField;
25+
import org.elasticsearch.common.ParsingException;
26+
import org.elasticsearch.common.logging.Loggers;
27+
import org.elasticsearch.common.xcontent.XContentParser;
28+
import org.elasticsearch.common.xcontent.XContentParser.Token;
2129
import org.scaleborn.elasticsearch.linreg.aggregation.support.BaseParser;
2230

2331
/**
2432
* Created by mbok on 11.04.17.
2533
*/
2634
public class PredictionAggregationParser extends BaseParser<PredictionAggregationBuilder> {
2735

36+
private static final Logger LOGGER = Loggers.getLogger(PredictionAggregationParser.class);
37+
private static final ParseField INPUTS = new ParseField("inputs");
38+
39+
2840
@Override
2941
protected PredictionAggregationBuilder createInnerFactory(final String aggregationName,
3042
final Map<ParseField, Object> otherOptions) {
31-
return new PredictionAggregationBuilder(aggregationName);
43+
final PredictionAggregationBuilder builder = new PredictionAggregationBuilder(aggregationName);
44+
if (otherOptions.containsKey(INPUTS)) {
45+
final List<Double> inputsList = (List<Double>) otherOptions.get(INPUTS);
46+
final double[] inputs = new double[inputsList.size()];
47+
int i = 0;
48+
for (final Double input : inputsList) {
49+
inputs[i++] = input;
50+
}
51+
builder.inputs(inputs);
52+
}
53+
return builder;
54+
}
55+
56+
@Override
57+
protected boolean token(final String aggregationName, final String currentFieldName,
58+
Token token, final XContentParser parser, final Map<ParseField, Object> otherOptions)
59+
throws IOException {
60+
List<Double> inputFields = null;
61+
if (super.token(aggregationName, currentFieldName, token, parser, otherOptions)) {
62+
return true;
63+
} else if (INPUTS.match(currentFieldName)) {
64+
inputFields = new ArrayList<>();
65+
while ((token = parser.nextToken()) != XContentParser.Token.END_ARRAY) {
66+
if (token == Token.VALUE_NUMBER) {
67+
inputFields.add(parser.numberValue().doubleValue());
68+
} else {
69+
throw new ParsingException(parser.getTokenLocation(),
70+
"Number value expected, but got token " + token + " [" + currentFieldName + "] in ["
71+
+ aggregationName
72+
+ "].");
73+
}
74+
}
75+
otherOptions.put(INPUTS, inputFields);
76+
LOGGER.info("Parsed input fields: {}", inputFields);
77+
return true;
78+
}
79+
return false;
3280
}
3381
}

src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/predict/PredictionAggregator.java

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,17 @@
3333
*/
3434
public class PredictionAggregator extends BaseSamplingAggregator<PredictionSampling> {
3535

36+
private final double[] inputs;
37+
3638
public PredictionAggregator(final String name,
3739
final List<NamedValuesSourceSpec<Numeric>> valuesSources,
3840
final SearchContext context,
3941
final Aggregator parent,
4042
final MultiValueMode multiValueMode,
41-
final List<PipelineAggregator> pipelineAggregators,
43+
final double[] inputs, final List<PipelineAggregator> pipelineAggregators,
4244
final Map<String, Object> metaData) throws IOException {
4345
super(name, valuesSources, context, parent, multiValueMode, pipelineAggregators, metaData);
46+
this.inputs = inputs;
4447
}
4548

4649
@Override
@@ -54,12 +57,13 @@ protected InternalAggregation doBuildAggregation(final String name, final int fe
5457
final List<PipelineAggregator> pipelineAggregators,
5558
final Map<String, Object> stringObjectMap) {
5659
return new InternalPrediction(this.name, this.valuesSources.fieldNames().length - 1,
57-
predictionSampling, null,
60+
predictionSampling, null, this.inputs,
5861
pipelineAggregators(), metaData());
5962
}
6063

6164
@Override
6265
public InternalAggregation buildEmptyAggregation() {
63-
return new InternalPrediction(this.name, 0, null, null, pipelineAggregators(), metaData());
66+
return new InternalPrediction(this.name, 0, null, null, this.inputs, pipelineAggregators(),
67+
metaData());
6468
}
6569
}

src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/predict/PredictionAggregatorFactory.java

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import java.util.Map;
2222
import org.elasticsearch.search.MultiValueMode;
2323
import org.elasticsearch.search.aggregations.Aggregator;
24-
import org.elasticsearch.search.aggregations.AggregatorFactories;
24+
import org.elasticsearch.search.aggregations.AggregatorFactories.Builder;
2525
import org.elasticsearch.search.aggregations.AggregatorFactory;
2626
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
2727
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory;
@@ -37,21 +37,24 @@ public class PredictionAggregatorFactory extends
3737
MultiValuesSourceAggregatorFactory<Numeric, PredictionAggregatorFactory> {
3838

3939
private final MultiValueMode multiValueMode;
40+
private final double[] inputs;
4041

4142
public PredictionAggregatorFactory(final String name,
4243
final List<NamedValuesSourceConfigSpec<Numeric>> configs, final MultiValueMode multiValueMode,
43-
final SearchContext context, final AggregatorFactory<?> parent,
44-
final AggregatorFactories.Builder subFactoriesBuilder,
44+
final double[] inputs, final SearchContext context, final AggregatorFactory<?> parent,
45+
final Builder subFactoriesBuilder,
4546
final Map<String, Object> metaData) throws IOException {
4647
super(name, configs, context, parent, subFactoriesBuilder, metaData);
4748
this.multiValueMode = multiValueMode;
49+
this.inputs = inputs;
4850
}
4951

5052
@Override
5153
protected Aggregator createUnmapped(final Aggregator parent,
5254
final List<PipelineAggregator> pipelineAggregators, final Map<String, Object> metaData)
5355
throws IOException {
5456
return new PredictionAggregator(this.name, null, this.context, parent, this.multiValueMode,
57+
this.inputs,
5558
pipelineAggregators, metaData);
5659
}
5760

@@ -61,7 +64,7 @@ protected Aggregator doCreateInternal(final List<NamedValuesSourceSpec<Numeric>>
6164
final List<PipelineAggregator> pipelineAggregators, final Map<String, Object> metaData)
6265
throws IOException {
6366
return new PredictionAggregator(this.name, valuesSources, this.context, parent,
64-
this.multiValueMode,
67+
this.multiValueMode, this.inputs,
6568
pipelineAggregators, metaData);
6669
}
6770
}

src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/predict/PredictionResults.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ public class PredictionResults extends ModelResults {
3232
private final double predictedValue;
3333

3434
public PredictionResults(final double predictedValue,
35-
final SlopeCoefficients slopeCoefficients) {
36-
super(slopeCoefficients);
35+
final SlopeCoefficients slopeCoefficients, final double intercept) {
36+
super(slopeCoefficients, intercept);
3737
this.predictedValue = predictedValue;
3838
}
3939

src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/stats/InternalStats.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,10 @@ protected InternalStats buildInternalAggregation(final String name, final int fe
105105

106106
@Override
107107
protected StatsResults buildResults(final StatsAggregationSampling composedSampling,
108-
final SlopeCoefficients slopeCoefficients) {
108+
final SlopeCoefficients slopeCoefficients, final double intercept) {
109109
final Statistics stats = statsCalculator
110110
.calculate(new StatsModel(composedSampling, slopeCoefficients));
111-
return new StatsResults(slopeCoefficients, stats);
111+
return new StatsResults(slopeCoefficients, intercept, stats);
112112
}
113113

114114
}

src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/stats/StatsResults.java

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,9 @@ static class Fields {
3838

3939
final Statistics statistics;
4040

41-
public StatsResults(final SlopeCoefficients slopeCoefficients, final Statistics statistics) {
42-
super(slopeCoefficients);
41+
public StatsResults(final SlopeCoefficients slopeCoefficients, final double intercept,
42+
final Statistics statistics) {
43+
super(slopeCoefficients, intercept);
4344
this.statistics = statistics;
4445
}
4546

src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/support/BaseAggregationBuilder.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ public S fields(final List<String> fields) {
100100
}
101101

102102
@Override
103-
protected void innerWriteTo(final StreamOutput out) {
103+
protected void innerWriteTo(final StreamOutput out) throws IOException {
104104
// Do nothing, no extra state to write to stream
105105
}
106106

src/main/java/org/scaleborn/elasticsearch/linreg/aggregation/support/BaseInternalAggregation.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public abstract class BaseInternalAggregation<S extends BaseSampling<S>, M exten
5656
/**
5757
* Features count
5858
*/
59-
private final int featuresCount;
59+
protected final int featuresCount;
6060

6161
protected M results;
6262

@@ -167,7 +167,8 @@ protected abstract A buildInternalAggregation(final String name, final int featu
167167
final M results,
168168
final List<PipelineAggregator> pipelineAggregators, final Map<String, Object> metaData);
169169

170-
protected abstract M buildResults(S composedSampling, SlopeCoefficients slopeCoefficients);
170+
protected abstract M buildResults(S composedSampling, SlopeCoefficients slopeCoefficients,
171+
double intercept);
171172

172173

173174
private M evaluateResults(final S composedSampling) {
@@ -176,8 +177,7 @@ private M evaluateResults(final S composedSampling) {
176177
.buildDerivationEquation(composedSampling);
177178
final SlopeCoefficients slopeCoefficients = derivationEquationSolver
178179
.solveCoefficients(derivationEquation);
179-
final M buildResults = buildResults(composedSampling, slopeCoefficients);
180-
buildResults.setIntercept(
180+
final M buildResults = buildResults(composedSampling, slopeCoefficients,
181181
interceptCalculator.calculate(slopeCoefficients, composedSampling, composedSampling));
182182
return buildResults;
183183
}

0 commit comments

Comments
 (0)