Skip to content

Commit 1cb5fd8

Browse files
authored
Allow processSequentially to accept explicit timestamps etc (#413)
* Allow processSequentially to accept explicit timestamps and fix threshold auto-adjust behavior Updated ThresholdedRandomCutForest.processSequentially to accept explicit timestamps, enabling its use in STREAMING_IMPUTE mode. Additionally, corrected threshold initialization in ThresholdedRandomCutForest constructor: the lower RCF score threshold is now set only when autoAdjust is disabled. Previously, setting this threshold regardless of autoAdjust inadvertently disabled automatic adjustments in BasicThresholder. Testing: * added UT Signed-off-by: Kaituo Li <[email protected]>
1 parent 35f4cf6 commit 1cb5fd8

File tree

13 files changed

+271
-27
lines changed

13 files changed

+271
-27
lines changed

Java/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ vector data point, scores the data point, and then updates the model with this
157157
point. The program output appends a column of anomaly scores to the input:
158158

159159
```text
160-
$ java -cp core/target/randomcutforest-core-4.2.0.jar com.amazon.randomcutforest.runner.AnomalyScoreRunner < ../example-data/rcf-paper.csv > example_output.csv
160+
$ java -cp core/target/randomcutforest-core-4.3.0.jar com.amazon.randomcutforest.runner.AnomalyScoreRunner < ../example-data/rcf-paper.csv > example_output.csv
161161
$ tail example_output.csv
162162
-5.0029,0.0170,-0.0057,0.8129401629464965
163163
-4.9975,-0.0102,-0.0065,0.6591046054520615
@@ -176,8 +176,8 @@ read additional usage instructions, including options for setting model
176176
hyperparameters, using the `--help` flag:
177177

178178
```text
179-
$ java -cp core/target/randomcutforest-core-4.2.0.jar com.amazon.randomcutforest.runner.AnomalyScoreRunner --help
180-
Usage: java -cp target/random-cut-forest-4.2.0.jar com.amazon.randomcutforest.runner.AnomalyScoreRunner [options] < input_file > output_file
179+
$ java -cp core/target/randomcutforest-core-4.3.0.jar com.amazon.randomcutforest.runner.AnomalyScoreRunner --help
180+
Usage: java -cp target/random-cut-forest-4.3.0.jar com.amazon.randomcutforest.runner.AnomalyScoreRunner [options] < input_file > output_file
181181
182182
Compute scalar anomaly scores from the input rows and append them to the output rows.
183183
@@ -239,14 +239,14 @@ framework. Build an executable jar containing the benchmark code by running
239239
To invoke the full benchmark suite:
240240

241241
```text
242-
% java -jar benchmark/target/randomcutforest-benchmark-4.2.0-jar-with-dependencies.jar
242+
% java -jar benchmark/target/randomcutforest-benchmark-4.3.0-jar-with-dependencies.jar
243243
```
244244

245245
The full benchmark suite takes a long time to run. You can also pass a regex at the command-line, then only matching
246246
benchmark methods will be executed.
247247

248248
```text
249-
% java -jar benchmark/target/randomcutforest-benchmark-4.2.0-jar-with-dependencies.jar RandomCutForestBenchmark\.updateAndGetAnomalyScore
249+
% java -jar benchmark/target/randomcutforest-benchmark-4.3.0-jar-with-dependencies.jar RandomCutForestBenchmark\.updateAndGetAnomalyScore
250250
```
251251

252252
[rcf-paper]: http://proceedings.mlr.press/v48/guha16.pdf

Java/benchmark/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<parent>
77
<groupId>software.amazon.randomcutforest</groupId>
88
<artifactId>randomcutforest-parent</artifactId>
9-
<version>4.2.0</version>
9+
<version>4.3.0</version>
1010
</parent>
1111

1212
<artifactId>randomcutforest-benchmark</artifactId>

Java/core/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<parent>
77
<groupId>software.amazon.randomcutforest</groupId>
88
<artifactId>randomcutforest-parent</artifactId>
9-
<version>4.2.0</version>
9+
<version>4.3.0</version>
1010
</parent>
1111

1212
<artifactId>randomcutforest-core</artifactId>

Java/core/src/test/java/com/amazon/randomcutforest/SampleSummaryTest.java

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -345,10 +345,8 @@ public void ParallelTest(BiFunction<float[], float[], Double> distance) {
345345
assertEquals(summary2.summaryPoints.length, summary1.summaryPoints.length,
346346
" incorrect length of typical points");
347347
// due to randomization, they might not equal
348-
assertTrue(
349-
Math.abs(clusters.size() - summary1.summaryPoints.length) <= 1,
350-
"The difference between clusters.size() and summary1.summaryPoints.length should be at most 1"
351-
);
348+
assertTrue(Math.abs(clusters.size() - summary1.summaryPoints.length) <= 1,
349+
"The difference between clusters.size() and summary1.summaryPoints.length should be at most 1");
352350
double total = clusters.stream().map(ICluster::getWeight).reduce(0.0, Double::sum);
353351
assertEquals(total, summary1.weightOfSamples, 1e-3);
354352
// parallelization can produce reordering of merges

Java/examples/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
<parent>
88
<groupId>software.amazon.randomcutforest</groupId>
99
<artifactId>randomcutforest-parent</artifactId>
10-
<version>4.2.0</version>
10+
<version>4.3.0</version>
1111
</parent>
1212

1313
<artifactId>randomcutforest-examples</artifactId>

Java/parkservices/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<parent>
77
<groupId>software.amazon.randomcutforest</groupId>
88
<artifactId>randomcutforest-parent</artifactId>
9-
<version>4.2.0</version>
9+
<version>4.3.0</version>
1010
</parent>
1111

1212
<artifactId>randomcutforest-parkservices</artifactId>

Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/RCFCaster.java

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import java.util.ArrayList;
2222
import java.util.List;
23+
import java.util.Locale;
2324
import java.util.Optional;
2425
import java.util.function.Function;
2526

@@ -246,6 +247,36 @@ public TimedRangeVector extrapolate(Calibration calibration, int horizon, boolea
246247

247248
@Override
248249
public List<AnomalyDescriptor> processSequentially(double[][] data, Function<AnomalyDescriptor, Boolean> filter) {
250+
if (data == null || data.length == 0) {
251+
return new ArrayList<>();
252+
}
253+
254+
long timestamp = preprocessor.getInternalTimeStamp();
255+
long[] timestamps = new long[data.length];
256+
for (int i = 0; i < data.length; i++) {
257+
timestamps[i] = ++timestamp;
258+
}
259+
260+
return processSequentially(data, timestamps, filter);
261+
}
262+
263+
public List<AnomalyDescriptor> processSequentially(double[][] data, long[] timestamps,
264+
Function<AnomalyDescriptor, Boolean> filter) {
265+
// Precondition checks
266+
checkArgument(filter != null, "filter must not be null");
267+
if (data != null && data.length > 0) {
268+
checkArgument(timestamps != null, "timestamps must not be null when data is non-empty");
269+
checkArgument(timestamps.length == data.length, String.format(Locale.ROOT,
270+
"timestamps length (%s) must equal data length (%s)", timestamps.length, data.length));
271+
for (int i = 1; i < timestamps.length; i++) {
272+
checkArgument(timestamps[i] > timestamps[i - 1],
273+
String.format(Locale.ROOT,
274+
"timestamps must be strictly ascending: "
275+
+ "timestamps[%s]=%s is not > timestamps[%s]=%s",
276+
i, timestamps[i], i - 1, timestamps[i - 1]));
277+
}
278+
}
279+
249280
ArrayList<AnomalyDescriptor> answer = new ArrayList<>();
250281
if (data != null) {
251282
if (data.length > 0) {
@@ -254,11 +285,12 @@ public List<AnomalyDescriptor> processSequentially(double[][] data, Function<Ano
254285
if (cacheDisabled) { // turn caching on temporarily
255286
forest.setBoundingBoxCacheFraction(1.0);
256287
}
257-
long timestamp = preprocessor.getInternalTimeStamp();
258288
int length = preprocessor.getInputLength();
259-
for (double[] point : data) {
289+
for (int i = 0; i < data.length; i++) {
290+
double[] point = data[i];
291+
checkArgument(point != null, " data should not be null ");
260292
checkArgument(point.length == length, " nonuniform lengths ");
261-
ForecastDescriptor description = new ForecastDescriptor(point, timestamp++, forecastHorizon);
293+
ForecastDescriptor description = new ForecastDescriptor(point, timestamps[i], forecastHorizon);
262294
augment(description);
263295
if (filter.apply(description)) {
264296
answer.add(description);

Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/ThresholdedRandomCutForest.java

Lines changed: 76 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
import java.util.ArrayList;
4040
import java.util.Arrays;
4141
import java.util.List;
42+
import java.util.Locale;
4243
import java.util.Optional;
4344
import java.util.Random;
4445
import java.util.function.Function;
@@ -138,7 +139,11 @@ public ThresholdedRandomCutForest(Builder<?> builder) {
138139
lastAnomalyDescriptor = new RCFComputeDescriptor(null, 0, builder.forestMode, builder.transformMethod,
139140
builder.imputationMethod);
140141

141-
predictorCorrector.setAbsoluteThreshold(builder.lowerThreshold.orElse(DEFAULT_ABSOLUTE_THRESHOLD));
142+
// when autoAdjust is true, the lowerThreshold is dynamically calculated
143+
if (!builder.autoAdjust) {
144+
predictorCorrector.setAbsoluteThreshold(builder.lowerThreshold.orElse(DEFAULT_ABSOLUTE_THRESHOLD));
145+
}
146+
142147
predictorCorrector.setZfactor(builder.zFactor);
143148

144149
predictorCorrector.setScoreDifferencing(builder.scoreDifferencing.orElse(DEFAULT_SCORE_DIFFERENCING));
@@ -279,8 +284,7 @@ public AnomalyDescriptor process(double[] inputPoint, long timestamp, int[] miss
279284
* of the word batch -- the entire goal of this procedure is to provide
280285
* sequential processing and not standard batch processing). The procedure
281286
* avoids transfer of ephemeral transient objects for non-anomalies and thereby
282-
* can have additional benefits. At the moment the operation does not support
283-
* external timestamps.
287+
* can have additional benefits.
284288
*
285289
* @param data a vectors of vectors (each of which has to have the same
286290
* inputLength)
@@ -289,6 +293,66 @@ public AnomalyDescriptor process(double[] inputPoint, long timestamp, int[] miss
289293
* @return collection of descriptors of the anomalies filtered by the condition
290294
*/
291295
public List<AnomalyDescriptor> processSequentially(double[][] data, Function<AnomalyDescriptor, Boolean> filter) {
296+
if (data == null || data.length == 0) {
297+
return new ArrayList<>();
298+
}
299+
300+
long timestamp = preprocessor.getInternalTimeStamp();
301+
long[] timestamps = new long[data.length];
302+
for (int i = 0; i < data.length; i++) {
303+
timestamps[i] = ++timestamp;
304+
}
305+
306+
return processSequentially(data, timestamps, filter);
307+
}
308+
309+
/**
310+
* the following function processes a list of vectors sequentially; the main
311+
* benefit of this invocation is the caching is persisted from one data point to
312+
* another and thus the execution is efficient. Moreover in many scenarios where
313+
* serialization deserialization is expensive then it may be of benefit of
314+
* invoking sequential process on a contiguous chunk of input (we avoid the use
315+
* of the word batch -- the entire goal of this procedure is to provide
316+
* sequential processing and not standard batch processing). The procedure
317+
* avoids transfer of ephemeral transient objects for non-anomalies and thereby
318+
* can have additional benefits. At the moment the operation does not support
319+
* external timestamps.
320+
*
321+
* @param data a vectors of vectors (each of which has to have the same
322+
* inputLength)
323+
* @param timestamps a vector of timestamps (in the same order as the data, has
324+
* to be same length as data, and ascending)
325+
* @param filter a condition to drop desriptor (recommended filter:
326+
* anomalyGrade positive)
327+
* @return collection of descriptors of the anomalies filtered by the condition
328+
* @throws IllegalArgumentException if
329+
* <ul>
330+
* <li>data is non-null but timestamps is
331+
* null</li>
332+
* <li>timestamps.length != data.length</li>
333+
* <li>timestamps is not strictly
334+
* ascending</li>
335+
* <li>any data[i].length !=
336+
* preprocessor.getInputLength()</li>
337+
* </ul>
338+
*/
339+
public List<AnomalyDescriptor> processSequentially(double[][] data, long[] timestamps,
340+
Function<AnomalyDescriptor, Boolean> filter) {
341+
// Precondition checks
342+
checkArgument(filter != null, "filter must not be null");
343+
if (data != null && data.length > 0) {
344+
checkArgument(timestamps != null, "timestamps must not be null when data is non-empty");
345+
checkArgument(timestamps.length == data.length, String.format(Locale.ROOT,
346+
"timestamps length (%s) must equal data length (%s)", timestamps.length, data.length));
347+
for (int i = 1; i < timestamps.length; i++) {
348+
checkArgument(timestamps[i] > timestamps[i - 1],
349+
String.format(Locale.ROOT,
350+
"timestamps must be strictly ascending: "
351+
+ "timestamps[%s]=%s is not > timestamps[%s]=%s",
352+
i, timestamps[i], i - 1, timestamps[i - 1]));
353+
}
354+
}
355+
292356
ArrayList<AnomalyDescriptor> answer = new ArrayList<>();
293357

294358
if (data != null && data.length > 0) {
@@ -297,11 +361,13 @@ public List<AnomalyDescriptor> processSequentially(double[][] data, Function<Ano
297361
if (cacheDisabled) { // turn caching on temporarily
298362
forest.setBoundingBoxCacheFraction(1.0);
299363
}
300-
long timestamp = preprocessor.getInternalTimeStamp();
301364
int length = preprocessor.getInputLength();
302-
for (double[] point : data) {
365+
for (int i = 0; i < data.length; i++) {
366+
double[] point = data[i];
367+
long timestamp = timestamps[i];
368+
checkArgument(point != null, " data should not be null ");
303369
checkArgument(point.length == length, " nonuniform lengths ");
304-
AnomalyDescriptor description = new AnomalyDescriptor(point, timestamp++);
370+
AnomalyDescriptor description = new AnomalyDescriptor(point, timestamp);
305371
augment(description);
306372
if (saveDescriptor(description)) {
307373
lastAnomalyDescriptor = description.copyOf();
@@ -519,7 +585,11 @@ <P extends AnomalyDescriptor> void postProcess(P result) {
519585
reference = preprocessor.getShingledInput(shingleSize + index);
520586
result.setPastTimeStamp(preprocessor.getTimeStamp(shingleSize + index));
521587
}
588+
589+
// relative index is the source of truth. Past values always have value:
590+
// either current input or previous input.
522591
result.setPastValues(reference);
592+
523593
if (newPoint != null) {
524594
double[] values = preprocessor.getExpectedValue(index, reference, point, newPoint);
525595
if (forestMode == ForestMode.TIME_AUGMENTED) {

0 commit comments

Comments
 (0)