Skip to content

Commit d1237d7

Browse files
committed
#1 Structure for the prediction aggregation
1 parent 250c913 commit d1237d7

14 files changed

+670
-125
lines changed

src/main/java/org/scaleborn/elasticsearch/linreg/LinearRegressionPlugin.java

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,28 @@
1616

1717
package org.scaleborn.elasticsearch.linreg;
1818

19-
import java.util.Collections;
19+
import java.util.ArrayList;
2020
import java.util.List;
2121
import org.elasticsearch.plugins.Plugin;
2222
import org.elasticsearch.plugins.SearchPlugin;
23+
import org.scaleborn.elasticsearch.linreg.aggregation.predict.InternalPrediction;
24+
import org.scaleborn.elasticsearch.linreg.aggregation.predict.PredictionAggregationBuilder;
25+
import org.scaleborn.elasticsearch.linreg.aggregation.predict.PredictionAggregationParser;
2326
import org.scaleborn.elasticsearch.linreg.aggregation.stats.InternalStats;
2427
import org.scaleborn.elasticsearch.linreg.aggregation.stats.StatsAggregationBuilder;
25-
import org.scaleborn.elasticsearch.linreg.aggregation.stats.StatsParser;
28+
import org.scaleborn.elasticsearch.linreg.aggregation.stats.StatsAggregationParser;
2629

2730
public class LinearRegressionPlugin extends Plugin implements SearchPlugin {
2831

2932
@Override
3033
public List<AggregationSpec> getAggregations() {
31-
return Collections.singletonList(
32-
new AggregationSpec(StatsAggregationBuilder.NAME, StatsAggregationBuilder::new,
33-
new StatsParser()).addResultReader(InternalStats::new));
34+
final List<AggregationSpec> aggregations = new ArrayList<>();
35+
aggregations.add(new AggregationSpec(StatsAggregationBuilder.NAME, StatsAggregationBuilder::new,
36+
new StatsAggregationParser()).addResultReader(InternalStats::new));
37+
aggregations.add(
38+
new AggregationSpec(PredictionAggregationBuilder.NAME, PredictionAggregationBuilder::new,
39+
new PredictionAggregationParser()).addResultReader(InternalPrediction::new));
40+
return aggregations;
3441
}
3542

3643
}
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
/*
2+
* Copyright (c) 2017 Scaleborn UG, www.scaleborn.com
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.scaleborn.elasticsearch.linreg.aggregation.predict;
18+
19+
import java.io.IOException;
20+
import java.util.List;
21+
import java.util.Map;
22+
import org.elasticsearch.common.io.stream.StreamInput;
23+
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
24+
import org.scaleborn.elasticsearch.linreg.aggregation.support.BaseInternalAggregation;
25+
import org.scaleborn.linereg.evaluation.SlopeCoefficients;
26+
27+
/**
28+
* Created by mbok on 11.04.17.
29+
*/
30+
public class InternalPrediction extends
31+
BaseInternalAggregation<PredictionSampling, PredictionResults, InternalPrediction> implements
32+
Prediction {
33+
34+
protected InternalPrediction(final String name, final int featuresCount,
35+
final PredictionSampling sampling,
36+
final PredictionResults results,
37+
final List<PipelineAggregator> pipelineAggregators,
38+
final Map<String, Object> metaData) {
39+
super(name, featuresCount, sampling, results, pipelineAggregators, metaData);
40+
}
41+
42+
public InternalPrediction(final StreamInput in) throws IOException {
43+
super(in, PredictionResults::new);
44+
}
45+
46+
@Override
47+
public double getValue() {
48+
if (this.results == null) {
49+
return Double.NaN;
50+
}
51+
return this.results.getPredictedValue();
52+
}
53+
54+
@Override
55+
protected PredictionSampling buildSampling(final int featuresCount) {
56+
return PredictionAggregationBuilder.buildSampling(featuresCount);
57+
}
58+
59+
@Override
60+
protected Object getDoProperty(final String path) {
61+
if ("value".equals(path)) {
62+
return getValue();
63+
}
64+
return null;
65+
}
66+
67+
@Override
68+
protected InternalPrediction buildInternalAggregation(final String name, final int featuresCount,
69+
final PredictionSampling linRegSampling, final PredictionResults results,
70+
final List<PipelineAggregator> pipelineAggregators, final Map<String, Object> metaData) {
71+
return new InternalPrediction(name, featuresCount, linRegSampling, results, pipelineAggregators,
72+
metaData);
73+
}
74+
75+
@Override
76+
protected PredictionResults buildResults(final PredictionSampling composedSampling,
77+
final SlopeCoefficients slopeCoefficients) {
78+
// TODO calculate predicated value
79+
return new PredictionResults(2, slopeCoefficients);
80+
}
81+
82+
@Override
83+
public String getWriteableName() {
84+
return PredictionAggregationBuilder.NAME;
85+
}
86+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
/*
2+
* Copyright (c) 2017 Scaleborn UG, www.scaleborn.com
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.scaleborn.elasticsearch.linreg.aggregation.predict;
18+
19+
/**
20+
* An aggregation that computes the predicted response value
21+
* for the given input data regarding the linear model evaluated for the current bucket.
22+
* Created by mbok on 11.04.17.
23+
*/
24+
public interface Prediction {
25+
26+
/**
27+
* @return the predicted value
28+
*/
29+
double getValue();
30+
}
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
/*
2+
* Copyright (c) 2017 Scaleborn UG, www.scaleborn.com
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.scaleborn.elasticsearch.linreg.aggregation.predict;
18+
19+
import java.io.IOException;
20+
import java.util.List;
21+
import org.elasticsearch.common.io.stream.StreamInput;
22+
import org.elasticsearch.search.MultiValueMode;
23+
import org.elasticsearch.search.aggregations.AggregatorFactories.Builder;
24+
import org.elasticsearch.search.aggregations.AggregatorFactory;
25+
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory;
26+
import org.elasticsearch.search.aggregations.support.NamedValuesSourceConfigSpec;
27+
import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric;
28+
import org.elasticsearch.search.internal.SearchContext;
29+
import org.scaleborn.elasticsearch.linreg.aggregation.support.BaseAggregationBuilder;
30+
import org.scaleborn.linereg.sampling.exact.ExactModelSamplingFactory;
31+
import org.scaleborn.linereg.sampling.exact.ExactSamplingContext;
32+
33+
/**
34+
* Created by mbok on 11.04.17.
35+
*/
36+
public class PredictionAggregationBuilder extends
37+
BaseAggregationBuilder<PredictionAggregationBuilder> {
38+
39+
public static final String NAME = "linreg_predict";
40+
41+
private static final ExactModelSamplingFactory MODEL_SAMPLING_FACTORY = new ExactModelSamplingFactory();
42+
43+
public PredictionAggregationBuilder(final String name) {
44+
super(name);
45+
}
46+
47+
public PredictionAggregationBuilder(final StreamInput in)
48+
throws IOException {
49+
super(in);
50+
}
51+
52+
53+
@Override
54+
protected MultiValuesSourceAggregatorFactory<Numeric, ?> innerInnerBuild(
55+
final SearchContext context,
56+
final List<NamedValuesSourceConfigSpec<Numeric>> configs, final MultiValueMode multiValueMode,
57+
final AggregatorFactory<?> parent, final Builder subFactoriesBuilder) throws IOException {
58+
return new PredictionAggregatorFactory(this.name, configs, multiValueMode, context, parent,
59+
subFactoriesBuilder, this.metaData);
60+
}
61+
62+
@Override
63+
public String getType() {
64+
return NAME;
65+
}
66+
67+
static PredictionSampling buildSampling(final int featuresCount) {
68+
final ExactSamplingContext context = MODEL_SAMPLING_FACTORY
69+
.createContext(featuresCount);
70+
final PredictionSampling predictionSampling = new PredictionSampling(context,
71+
MODEL_SAMPLING_FACTORY.createCoefficientLinearTermSampling(context),
72+
MODEL_SAMPLING_FACTORY.createCoefficientSquareTermSampling(context),
73+
MODEL_SAMPLING_FACTORY.createInterceptSampling(context));
74+
return predictionSampling;
75+
}
76+
}
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright (c) 2017 Scaleborn UG, www.scaleborn.com
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.scaleborn.elasticsearch.linreg.aggregation.predict;
18+
19+
import java.util.Map;
20+
import org.elasticsearch.common.ParseField;
21+
import org.scaleborn.elasticsearch.linreg.aggregation.support.BaseParser;
22+
23+
/**
24+
* Created by mbok on 11.04.17.
25+
*/
26+
public class PredictionAggregationParser extends BaseParser<PredictionAggregationBuilder> {
27+
28+
@Override
29+
protected PredictionAggregationBuilder createInnerFactory(final String aggregationName,
30+
final Map<ParseField, Object> otherOptions) {
31+
return new PredictionAggregationBuilder(aggregationName);
32+
}
33+
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright (c) 2017 Scaleborn UG, www.scaleborn.com
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.scaleborn.elasticsearch.linreg.aggregation.predict;
18+
19+
import java.io.IOException;
20+
import java.util.List;
21+
import java.util.Map;
22+
import org.elasticsearch.search.MultiValueMode;
23+
import org.elasticsearch.search.aggregations.Aggregator;
24+
import org.elasticsearch.search.aggregations.InternalAggregation;
25+
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
26+
import org.elasticsearch.search.aggregations.support.NamedValuesSourceSpec;
27+
import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric;
28+
import org.elasticsearch.search.internal.SearchContext;
29+
import org.scaleborn.elasticsearch.linreg.aggregation.support.BaseSamplingAggregator;
30+
31+
/**
32+
* Created by mbok on 11.04.17.
33+
*/
34+
public class PredictionAggregator extends BaseSamplingAggregator<PredictionSampling> {
35+
36+
public PredictionAggregator(final String name,
37+
final List<NamedValuesSourceSpec<Numeric>> valuesSources,
38+
final SearchContext context,
39+
final Aggregator parent,
40+
final MultiValueMode multiValueMode,
41+
final List<PipelineAggregator> pipelineAggregators,
42+
final Map<String, Object> metaData) throws IOException {
43+
super(name, valuesSources, context, parent, multiValueMode, pipelineAggregators, metaData);
44+
}
45+
46+
@Override
47+
protected PredictionSampling buildSampling(final int featuresCount) {
48+
return PredictionAggregationBuilder.buildSampling(featuresCount);
49+
}
50+
51+
@Override
52+
protected InternalAggregation doBuildAggregation(final String name, final int featuresCount,
53+
final PredictionSampling predictionSampling,
54+
final List<PipelineAggregator> pipelineAggregators,
55+
final Map<String, Object> stringObjectMap) {
56+
return new InternalPrediction(this.name, this.valuesSources.fieldNames().length - 1,
57+
predictionSampling, null,
58+
pipelineAggregators(), metaData());
59+
}
60+
61+
@Override
62+
public InternalAggregation buildEmptyAggregation() {
63+
return new InternalPrediction(this.name, 0, null, null, pipelineAggregators(), metaData());
64+
}
65+
}
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
/*
2+
* Copyright (c) 2017 Scaleborn UG, www.scaleborn.com
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.scaleborn.elasticsearch.linreg.aggregation.predict;
18+
19+
import java.io.IOException;
20+
import java.util.List;
21+
import java.util.Map;
22+
import org.elasticsearch.search.MultiValueMode;
23+
import org.elasticsearch.search.aggregations.Aggregator;
24+
import org.elasticsearch.search.aggregations.AggregatorFactories;
25+
import org.elasticsearch.search.aggregations.AggregatorFactory;
26+
import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator;
27+
import org.elasticsearch.search.aggregations.support.MultiValuesSourceAggregatorFactory;
28+
import org.elasticsearch.search.aggregations.support.NamedValuesSourceConfigSpec;
29+
import org.elasticsearch.search.aggregations.support.NamedValuesSourceSpec;
30+
import org.elasticsearch.search.aggregations.support.ValuesSource.Numeric;
31+
import org.elasticsearch.search.internal.SearchContext;
32+
33+
/**
34+
* Created by mbok on 11.04.17.
35+
*/
36+
public class PredictionAggregatorFactory extends
37+
MultiValuesSourceAggregatorFactory<Numeric, PredictionAggregatorFactory> {
38+
39+
private final MultiValueMode multiValueMode;
40+
41+
public PredictionAggregatorFactory(final String name,
42+
final List<NamedValuesSourceConfigSpec<Numeric>> configs, final MultiValueMode multiValueMode,
43+
final SearchContext context, final AggregatorFactory<?> parent,
44+
final AggregatorFactories.Builder subFactoriesBuilder,
45+
final Map<String, Object> metaData) throws IOException {
46+
super(name, configs, context, parent, subFactoriesBuilder, metaData);
47+
this.multiValueMode = multiValueMode;
48+
}
49+
50+
@Override
51+
protected Aggregator createUnmapped(final Aggregator parent,
52+
final List<PipelineAggregator> pipelineAggregators, final Map<String, Object> metaData)
53+
throws IOException {
54+
return new PredictionAggregator(this.name, null, this.context, parent, this.multiValueMode,
55+
pipelineAggregators, metaData);
56+
}
57+
58+
@Override
59+
protected Aggregator doCreateInternal(final List<NamedValuesSourceSpec<Numeric>> valuesSources,
60+
final Aggregator parent, final boolean collectsFromSingleBucket,
61+
final List<PipelineAggregator> pipelineAggregators, final Map<String, Object> metaData)
62+
throws IOException {
63+
return new PredictionAggregator(this.name, valuesSources, this.context, parent,
64+
this.multiValueMode,
65+
pipelineAggregators, metaData);
66+
}
67+
}

0 commit comments

Comments
 (0)