|
16 | 16 |
|
17 | 17 | package org.scaleborn.elasticsearch.linreg.aggregation.stats; |
18 | 18 |
|
19 | | -import static java.util.Collections.emptyMap; |
20 | | - |
21 | 19 | import java.io.IOException; |
22 | | -import java.util.ArrayList; |
23 | 20 | import java.util.List; |
24 | 21 | import java.util.Map; |
25 | 22 | import org.apache.logging.log4j.Logger; |
26 | 23 | import org.elasticsearch.common.io.stream.StreamInput; |
27 | | -import org.elasticsearch.common.io.stream.StreamOutput; |
28 | 24 | import org.elasticsearch.common.logging.Loggers; |
29 | | -import org.elasticsearch.common.xcontent.XContentBuilder; |
30 | | -import org.elasticsearch.search.aggregations.InternalAggregation; |
31 | 25 | import org.elasticsearch.search.aggregations.pipeline.PipelineAggregator; |
32 | | -import org.scaleborn.elasticsearch.linreg.aggregation.support.StateInputStreamAdapter; |
33 | | -import org.scaleborn.elasticsearch.linreg.aggregation.support.StateOutputStreamAdapter; |
34 | | -import org.scaleborn.linereg.evaluation.DerivationEquation; |
35 | | -import org.scaleborn.linereg.evaluation.DerivationEquationBuilder; |
| 26 | +import org.scaleborn.elasticsearch.linreg.aggregation.support.BaseInternalAggregation; |
| 27 | +import org.scaleborn.linereg.calculation.statistics.Statistics; |
| 28 | +import org.scaleborn.linereg.calculation.statistics.StatsCalculator; |
| 29 | +import org.scaleborn.linereg.calculation.statistics.StatsModel; |
36 | 30 | import org.scaleborn.linereg.evaluation.SlopeCoefficients; |
37 | | -import org.scaleborn.linereg.evaluation.commons.CommonsMathSolver; |
38 | | -import org.scaleborn.linereg.statistics.Statistics; |
39 | | -import org.scaleborn.linereg.statistics.StatsBuilder; |
40 | | -import org.scaleborn.linereg.statistics.StatsModel; |
41 | | -import org.scaleborn.linereg.statistics.StatsSampling; |
42 | 31 |
|
43 | 32 | /** |
44 | 33 | * Created by mbok on 21.03.17. |
45 | 34 | */ |
46 | | -public class InternalStats extends InternalAggregation implements Stats { |
| 35 | +public class InternalStats extends |
| 36 | + BaseInternalAggregation<StatsAggregationSampling, StatsResults, InternalStats> implements |
| 37 | + Stats { |
47 | 38 |
|
48 | 39 | private static final Logger LOGGER = Loggers.getLogger(InternalStats.class); |
49 | | - /** |
50 | | - * per shard sampling needed to compute stats |
51 | | - */ |
52 | | - private StatsSampling<?> sampling; |
53 | | - /** |
54 | | - * final result |
55 | | - */ |
56 | | - private Statistics results; |
57 | 40 |
|
58 | | - /** |
59 | | - * Features count |
60 | | - */ |
61 | | - private int featuresCount; |
| 41 | + private static final StatsCalculator statsCalculator = new StatsCalculator(); |
62 | 42 |
|
63 | 43 | /** |
64 | 44 | * per shard ctor |
65 | 45 | */ |
66 | | - protected InternalStats(String name, int featuresCount, StatsSampling<?> linRegSampling, |
67 | | - Statistics results, |
68 | | - List<PipelineAggregator> pipelineAggregators, Map<String, Object> metaData) { |
69 | | - super(name, pipelineAggregators, metaData); |
70 | | - this.featuresCount = featuresCount; |
71 | | - this.sampling = linRegSampling; |
72 | | - this.results = results; |
| 46 | + protected InternalStats(final String name, final int featuresCount, |
| 47 | + final StatsAggregationSampling linRegSampling, |
| 48 | + final StatsResults results, |
| 49 | + final List<PipelineAggregator> pipelineAggregators, final Map<String, Object> metaData) { |
| 50 | + super(name, featuresCount, linRegSampling, results, pipelineAggregators, metaData); |
73 | 51 | } |
74 | 52 |
|
75 | 53 | /** |
76 | 54 | * Read from a stream. |
77 | 55 | */ |
78 | | - public InternalStats(StreamInput in) throws IOException { |
79 | | - super(in); |
80 | | - this.featuresCount = in.readInt(); |
81 | | - if (in.readBoolean()) { |
82 | | - this.sampling = StatsAggregationBuilder.buildSampling(this.featuresCount); |
83 | | - StateInputStreamAdapter streamAdapter = new StateInputStreamAdapter(in); |
84 | | - this.sampling.loadState(streamAdapter); |
85 | | - } |
86 | | - if (in.readBoolean()) { |
87 | | - this.results = new DefaultStatistics(in.readDouble(), in.readDouble()); |
88 | | - } |
| 56 | + public InternalStats(final StreamInput in) throws IOException { |
| 57 | + super(in, StatsResults::new); |
89 | 58 | } |
90 | 59 |
|
91 | 60 | @Override |
92 | | - protected void doWriteTo(StreamOutput out) throws IOException { |
93 | | - out.writeInt(this.featuresCount); |
94 | | - out.writeBoolean(this.sampling != null); |
95 | | - if (this.sampling != null) { |
96 | | - StateOutputStreamAdapter outAdapter = new StateOutputStreamAdapter(out); |
97 | | - this.sampling.saveState(outAdapter); |
98 | | - } |
99 | | - out.writeBoolean(this.results != null); |
100 | | - if (this.results != null) { |
101 | | - out.writeDouble(this.results.getRss()); |
102 | | - out.writeDouble(this.results.getMse()); |
103 | | - } |
| 61 | + protected StatsAggregationSampling buildSampling(final int featuresCount) { |
| 62 | + return StatsAggregationBuilder.buildSampling(featuresCount); |
104 | 63 | } |
105 | 64 |
|
| 65 | + |
106 | 66 | @Override |
107 | 67 | public String getWriteableName() { |
108 | 68 | return StatsAggregationBuilder.NAME; |
109 | 69 | } |
110 | 70 |
|
111 | | - static class Fields { |
112 | | - |
113 | | - public static final String RSS = "rss"; |
114 | | - public static final String MSE = "mse"; |
115 | | - } |
116 | | - |
117 | 71 | @Override |
118 | 72 | public double getRss() { |
119 | | - if (results == null) { |
| 73 | + if (this.results == null) { |
120 | 74 | return Double.NaN; |
121 | 75 | } |
122 | | - return results.getRss(); |
| 76 | + return this.results.statistics.getRss(); |
123 | 77 | } |
124 | 78 |
|
125 | 79 | @Override |
126 | 80 | public double getMse() { |
127 | | - if (results == null) { |
| 81 | + if (this.results == null) { |
128 | 82 | return Double.NaN; |
129 | 83 | } |
130 | | - return results.getMse(); |
| 84 | + return this.results.statistics.getMse(); |
131 | 85 | } |
132 | 86 |
|
133 | 87 | @Override |
134 | | - public XContentBuilder doXContentBody(XContentBuilder builder, Params params) throws IOException { |
135 | | - if (results != null) { |
136 | | - // RSS |
137 | | - builder.field(Fields.RSS, results.getRss()); |
138 | | - // MSE |
139 | | - builder.field(Fields.MSE, results.getMse()); |
| 88 | + public Object getDoProperty(final String element) { |
| 89 | + switch (element) { |
| 90 | + case "rss": |
| 91 | + return getRss(); |
| 92 | + case "mse": |
| 93 | + return getMse(); |
140 | 94 | } |
141 | | - return builder; |
| 95 | + return null; |
142 | 96 | } |
143 | 97 |
|
144 | 98 | @Override |
145 | | - public Object getProperty(List<String> path) { |
146 | | - if (path.isEmpty()) { |
147 | | - return this; |
148 | | - } else if (path.size() == 1) { |
149 | | - String element = path.get(0); |
150 | | - if (results == null) { |
151 | | - return emptyMap(); |
152 | | - } |
153 | | - switch (element) { |
154 | | - case "rss": |
155 | | - return results.getRss(); |
156 | | - case "mse": |
157 | | - return results.getMse(); |
158 | | - default: |
159 | | - throw new IllegalArgumentException( |
160 | | - "Found unknown path element [" + element + "] in [" + getName() + "]"); |
161 | | - } |
162 | | - } else { |
163 | | - throw new IllegalArgumentException("path not supported for [" + getName() + "]: " + path); |
164 | | - } |
| 99 | + protected InternalStats buildInternalAggregation(final String name, final int featuresCount, |
| 100 | + final StatsAggregationSampling linRegSampling, final StatsResults results, |
| 101 | + final List<PipelineAggregator> pipelineAggregators, final Map<String, Object> metaData) { |
| 102 | + return new InternalStats(name, featuresCount, linRegSampling, results, pipelineAggregators, |
| 103 | + metaData); |
165 | 104 | } |
166 | 105 |
|
167 | | - @SuppressWarnings("unchecked") |
168 | 106 | @Override |
169 | | - public InternalAggregation doReduce(List<InternalAggregation> aggregations, |
170 | | - ReduceContext reduceContext) { |
171 | | - // merge samples across all shards |
172 | | - List<InternalAggregation> aggs = new ArrayList<>(aggregations); |
173 | | - aggs.removeIf(p -> ((InternalStats) p).sampling == null); |
174 | | - |
175 | | - // return empty result iff all samples are null |
176 | | - if (aggs.isEmpty()) { |
177 | | - return new InternalStats(name, featuresCount, null, new DefaultStatistics(0, 0), |
178 | | - pipelineAggregators(), |
179 | | - getMetaData()); |
180 | | - } |
181 | | - |
182 | | - StatsSampling composedSampling = StatsAggregationBuilder |
183 | | - .buildSampling(featuresCount); |
184 | | - for (int i = 0; i < aggs.size(); ++i) { |
185 | | - LOGGER.info("Merging sampling={}: {}", i, ((InternalStats) aggs.get(i)).sampling); |
186 | | - composedSampling.merge(((InternalStats) aggs.get(i)).sampling); |
187 | | - } |
188 | | - |
189 | | - // Linear regression evaluation |
190 | | - DerivationEquation derivationEquation = new DerivationEquationBuilder() |
191 | | - .buildDerivationEquation(composedSampling); |
192 | | - CommonsMathSolver commonsMathSolver = new CommonsMathSolver(); |
193 | | - SlopeCoefficients slopeCoefficients = commonsMathSolver.solveCoefficients(derivationEquation); |
194 | | - StatsBuilder statsBuilder = new StatsBuilder(); |
195 | | - Statistics statsResult = statsBuilder |
196 | | - .buildStats(new StatsModel(composedSampling, slopeCoefficients)); |
197 | | - LOGGER.info("Evaluated linear with {} and stats {}", slopeCoefficients, statsResult); |
198 | | - return new InternalStats(name, featuresCount, composedSampling, statsResult, |
199 | | - pipelineAggregators(), getMetaData()); |
| 107 | + protected StatsResults buildResults(final StatsAggregationSampling composedSampling, |
| 108 | + final SlopeCoefficients slopeCoefficients) { |
| 109 | + final Statistics stats = statsCalculator |
| 110 | + .calculate(new StatsModel(composedSampling, slopeCoefficients)); |
| 111 | + return new StatsResults(slopeCoefficients, stats); |
200 | 112 | } |
| 113 | + |
201 | 114 | } |
0 commit comments