Skip to content

Commit 0859252

Browse files
authored
Merge pull request #405 from kaituo/missing
Fix confidence adjustment when all input values are missing
2 parents 7158799 + 7eadb49 commit 0859252

File tree

4 files changed

+148
-2
lines changed

4 files changed

+148
-2
lines changed

Java/core/src/main/java/com/amazon/randomcutforest/preprocessor/ImputePreprocessor.java

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,31 @@ public float[] getScaledShingledInput(double[] inputPoint, long timestamp, int[]
8888
*/
8989
@Override
9090
protected void updateTimestamps(long timestamp) {
91-
if (previousTimeStamps[0] == previousTimeStamps[1]) {
91+
/*
92+
* For imputations done on timestamps other than the current one (specified by
93+
* the timestamp parameter), the timestamp of the imputed tuple matches that of
94+
* the input tuple, and we increment numberOfImputed. For imputations done at
95+
* the current timestamp (if all input values are missing), the timestamp of the
96+
* imputed tuple is the current timestamp, and we increment numberOfImputed.
97+
*
98+
* To check if imputed values are still present in the shingle, we use the first
99+
* condition (previousTimeStamps[0] == previousTimeStamps[1]). This works
100+
* because previousTimeStamps has a size equal to the shingle size and is filled
101+
* with the current timestamp. However, there are scenarios where we might miss
102+
* decrementing numberOfImputed:
103+
*
104+
* 1. Not all values in the shingle are imputed. 2. We accumulated
105+
* numberOfImputed when the current timestamp had missing values.
106+
*
107+
* As a result, this could cause the data quality measure to decrease
108+
* continuously since we are always counting missing values that should
109+
* eventually be reset to zero. The second condition <pre> timestamp >
110+
* previousTimeStamps[previousTimeStamps.length-1] && numberOfImputed > 0 </pre>
111+
* will decrement numberOfImputed when we move to a new timestamp, provided
112+
* numberOfImputed is greater than zero.
113+
*/
114+
if (previousTimeStamps[0] == previousTimeStamps[1]
115+
|| (timestamp > previousTimeStamps[previousTimeStamps.length - 1] && numberOfImputed > 0)) {
92116
numberOfImputed = numberOfImputed - 1;
93117
}
94118
super.updateTimestamps(timestamp);
@@ -333,7 +357,10 @@ protected float[] generateShingle(double[] inputTuple, long timestamp, int[] mis
333357
}
334358
}
335359

336-
updateForest(changeForest, input, timestamp, forest, false);
360+
// last parameter isFullyImputed = if we miss everything in inputTuple?
361+
// This would ensure dataQuality is decreasing if we impute whenever
362+
updateForest(changeForest, input, timestamp, forest,
363+
missingValues != null ? missingValues.length == inputTuple.length : false);
337364
if (changeForest) {
338365
updateTimeStampDeviations(timestamp, lastInputTimeStamp);
339366
transformer.updateDeviation(input, savedInput, missingValues);

Java/core/src/main/java/com/amazon/randomcutforest/state/preprocessor/PreprocessorMapper.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ public Preprocessor toModel(PreprocessorState state, long seed) {
6262
preprocessor.setPreviousTimeStamps(state.getPreviousTimeStamps());
6363
preprocessor.setNormalizeTime(state.isNormalizeTime());
6464
preprocessor.setFastForward(state.isFastForward());
65+
preprocessor.setNumberOfImputed(state.getNumberOfImputed());
6566
return preprocessor;
6667
}
6768

@@ -94,6 +95,7 @@ public PreprocessorState toState(Preprocessor model) {
9495
state.setTimeStampDeviationStates(getStates(model.getTimeStampDeviations(), deviationMapper));
9596
state.setDataQualityStates(getStates(model.getDataQuality(), deviationMapper));
9697
state.setFastForward(model.isFastForward());
98+
state.setNumberOfImputed(model.getNumberOfImputed());
9799
return state;
98100
}
99101

Java/core/src/main/java/com/amazon/randomcutforest/state/preprocessor/PreprocessorState.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,5 @@ public class PreprocessorState implements Serializable {
5959
private DeviationState[] dataQualityStates;
6060
private DeviationState[] timeStampDeviationStates;
6161
private boolean fastForward;
62+
private int numberOfImputed;
6263
}
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
/*
2+
* Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* A copy of the License is located at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* or in the "license" file accompanying this file. This file is distributed
11+
* on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
12+
* express or implied. See the License for the specific language governing
13+
* permissions and limitations under the License.
14+
*/
15+
16+
package com.amazon.randomcutforest.parkservices;
17+
18+
import static org.junit.jupiter.api.Assertions.assertEquals;
19+
import static org.junit.jupiter.api.Assertions.assertTrue;
20+
21+
import java.util.ArrayList;
22+
import java.util.List;
23+
import java.util.Random;
24+
25+
import org.junit.jupiter.params.ParameterizedTest;
26+
import org.junit.jupiter.params.provider.EnumSource;
27+
28+
import com.amazon.randomcutforest.config.ForestMode;
29+
import com.amazon.randomcutforest.config.ImputationMethod;
30+
import com.amazon.randomcutforest.config.Precision;
31+
import com.amazon.randomcutforest.config.TransformMethod;
32+
33+
public class MissingValueTest {
34+
@ParameterizedTest
35+
@EnumSource(ImputationMethod.class)
36+
public void testConfidence(ImputationMethod method) {
37+
// Create and populate a random cut forest
38+
39+
int shingleSize = 4;
40+
int numberOfTrees = 50;
41+
int sampleSize = 256;
42+
Precision precision = Precision.FLOAT_32;
43+
int baseDimensions = 1;
44+
45+
long count = 0;
46+
47+
int dimensions = baseDimensions * shingleSize;
48+
ThresholdedRandomCutForest forest = new ThresholdedRandomCutForest.Builder<>().compact(true)
49+
.dimensions(dimensions).randomSeed(0).numberOfTrees(numberOfTrees).shingleSize(shingleSize)
50+
.sampleSize(sampleSize).precision(precision).anomalyRate(0.01).imputationMethod(method)
51+
.fillValues(new double[] { 3 }).forestMode(ForestMode.STREAMING_IMPUTE)
52+
.transformMethod(TransformMethod.NORMALIZE).autoAdjust(true).build();
53+
54+
// Define the size and range
55+
int size = 400;
56+
double min = 200.0;
57+
double max = 240.0;
58+
59+
// Generate the list of doubles
60+
List<Double> randomDoubles = generateUniformRandomDoubles(size, min, max);
61+
62+
double lastConfidence = 0;
63+
for (double val : randomDoubles) {
64+
double[] point = new double[] { val };
65+
long newStamp = 100 * count;
66+
if (count >= 300 && count < 325) {
67+
// drop observations
68+
AnomalyDescriptor result = forest.process(new double[] { Double.NaN }, newStamp,
69+
generateIntArray(point.length));
70+
if (count > 300) {
71+
// confidence start decreasing after 1 missing point
72+
assertTrue(result.getDataConfidence() < lastConfidence, "count " + count);
73+
}
74+
lastConfidence = result.getDataConfidence();
75+
float[] rcfPoint = result.getRCFPoint();
76+
double scale = result.getScale()[0];
77+
double shift = result.getShift()[0];
78+
double[] actual = new double[] { (rcfPoint[3] * scale) + shift };
79+
if (method == ImputationMethod.ZERO) {
80+
assertEquals(0, actual[0], 0.001d);
81+
} else if (method == ImputationMethod.FIXED_VALUES) {
82+
assertEquals(3.0d, actual[0], 0.001d);
83+
}
84+
} else {
85+
AnomalyDescriptor result = forest.process(point, newStamp);
86+
if ((count > 100 && count < 300) || count >= 326) {
87+
// The first 65+ observations gives 0 confidence.
88+
// Confidence start increasing after 1 observed point
89+
assertTrue(result.getDataConfidence() > lastConfidence);
90+
}
91+
lastConfidence = result.getDataConfidence();
92+
}
93+
++count;
94+
}
95+
}
96+
97+
public static int[] generateIntArray(int size) {
98+
int[] intArray = new int[size];
99+
for (int i = 0; i < size; i++) {
100+
intArray[i] = i;
101+
}
102+
return intArray;
103+
}
104+
105+
public static List<Double> generateUniformRandomDoubles(int size, double min, double max) {
106+
List<Double> randomDoubles = new ArrayList<>(size);
107+
Random random = new Random(0);
108+
109+
for (int i = 0; i < size; i++) {
110+
double randomValue = min + (max - min) * random.nextDouble();
111+
randomDoubles.add(randomValue);
112+
}
113+
114+
return randomDoubles;
115+
}
116+
}

0 commit comments

Comments
 (0)