Skip to content

Commit 1ab2a3c

Browse files
authored
feat(parkservices): preserve last descriptor fields; mark NaN feature indices in processSequentially (#414)
This PR - `processSequentially`: before constructing each `AnomalyDescriptor`, scan the point for `Double.NaN` values and set `missingValues` to the indices of the missing features. This enables `processSequentially` to handle missing values. - Preserve `currentInput` and `missingValues` during `ThresholdedRandomCutForest` serialization so the last processed descriptor survives save/restore. Enables downstream consumers to surface/impute per-feature missingness. Testing - Unit tests for TRCF round-trip serialization and NaN handling. - Manual verification on sample streams. Signed-off-by: kaituo <[email protected]>
1 parent 1cb5fd8 commit 1ab2a3c

File tree

15 files changed

+160
-16
lines changed

15 files changed

+160
-16
lines changed

Java/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ vector data point, scores the data point, and then updates the model with this
157157
point. The program output appends a column of anomaly scores to the input:
158158

159159
```text
160-
$ java -cp core/target/randomcutforest-core-4.3.0.jar com.amazon.randomcutforest.runner.AnomalyScoreRunner < ../example-data/rcf-paper.csv > example_output.csv
160+
$ java -cp core/target/randomcutforest-core-4.4.0.jar com.amazon.randomcutforest.runner.AnomalyScoreRunner < ../example-data/rcf-paper.csv > example_output.csv
161161
$ tail example_output.csv
162162
-5.0029,0.0170,-0.0057,0.8129401629464965
163163
-4.9975,-0.0102,-0.0065,0.6591046054520615
@@ -176,8 +176,8 @@ read additional usage instructions, including options for setting model
176176
hyperparameters, using the `--help` flag:
177177

178178
```text
179-
$ java -cp core/target/randomcutforest-core-4.3.0.jar com.amazon.randomcutforest.runner.AnomalyScoreRunner --help
180-
Usage: java -cp target/random-cut-forest-4.3.0.jar com.amazon.randomcutforest.runner.AnomalyScoreRunner [options] < input_file > output_file
179+
$ java -cp core/target/randomcutforest-core-4.4.0.jar com.amazon.randomcutforest.runner.AnomalyScoreRunner --help
180+
Usage: java -cp target/random-cut-forest-4.4.0.jar com.amazon.randomcutforest.runner.AnomalyScoreRunner [options] < input_file > output_file
181181
182182
Compute scalar anomaly scores from the input rows and append them to the output rows.
183183
@@ -239,14 +239,14 @@ framework. Build an executable jar containing the benchmark code by running
239239
To invoke the full benchmark suite:
240240

241241
```text
242-
% java -jar benchmark/target/randomcutforest-benchmark-4.3.0-jar-with-dependencies.jar
242+
% java -jar benchmark/target/randomcutforest-benchmark-4.4.0-jar-with-dependencies.jar
243243
```
244244

245245
The full benchmark suite takes a long time to run. You can also pass a regex at the command-line, then only matching
246246
benchmark methods will be executed.
247247

248248
```text
249-
% java -jar benchmark/target/randomcutforest-benchmark-4.3.0-jar-with-dependencies.jar RandomCutForestBenchmark\.updateAndGetAnomalyScore
249+
% java -jar benchmark/target/randomcutforest-benchmark-4.4.0-jar-with-dependencies.jar RandomCutForestBenchmark\.updateAndGetAnomalyScore
250250
```
251251

252252
[rcf-paper]: http://proceedings.mlr.press/v48/guha16.pdf

Java/benchmark/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<parent>
77
<groupId>software.amazon.randomcutforest</groupId>
88
<artifactId>randomcutforest-parent</artifactId>
9-
<version>4.3.0</version>
9+
<version>4.4.0</version>
1010
</parent>
1111

1212
<artifactId>randomcutforest-benchmark</artifactId>

Java/core/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<parent>
77
<groupId>software.amazon.randomcutforest</groupId>
88
<artifactId>randomcutforest-parent</artifactId>
9-
<version>4.3.0</version>
9+
<version>4.4.0</version>
1010
</parent>
1111

1212
<artifactId>randomcutforest-core</artifactId>

Java/examples/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
<parent>
88
<groupId>software.amazon.randomcutforest</groupId>
99
<artifactId>randomcutforest-parent</artifactId>
10-
<version>4.3.0</version>
10+
<version>4.4.0</version>
1111
</parent>
1212

1313
<artifactId>randomcutforest-examples</artifactId>

Java/parkservices/pom.xml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
<parent>
77
<groupId>software.amazon.randomcutforest</groupId>
88
<artifactId>randomcutforest-parent</artifactId>
9-
<version>4.3.0</version>
9+
<version>4.4.0</version>
1010
</parent>
1111

1212
<artifactId>randomcutforest-parkservices</artifactId>

Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/PredictorCorrector.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,7 @@ protected void saveScores(ScoringStrategy strategy, int choice, double[] scoreVe
730730
protected <P extends AnomalyDescriptor> P detect(P result, RCFComputeDescriptor lastSignificantDescriptor,
731731
RandomCutForest forest) {
732732
if (result.getRCFPoint() == null) {
733+
lastDescriptor = result.copyOf();
733734
return result;
734735
}
735736
float[] point = result.getRCFPoint();
@@ -747,6 +748,7 @@ protected <P extends AnomalyDescriptor> P detect(P result, RCFComputeDescriptor
747748

748749
// we will not have zero scores affect any thresholding
749750
if (score == 0) {
751+
lastDescriptor = result.copyOf();
750752
return result;
751753
}
752754

Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/ThresholdedRandomCutForest.java

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,8 +318,11 @@ public List<AnomalyDescriptor> processSequentially(double[][] data, Function<Ano
318318
* can have additional benefits. At the moment the operation does not support
319319
* external timestamps.
320320
*
321+
*
322+
*
321323
* @param data a vectors of vectors (each of which has to have the same
322-
* inputLength)
324+
* inputLength). Mising values are represented by Double.NaN
325+
* in a vector.
323326
* @param timestamps a vector of timestamps (in the same order as the data, has
324327
* to be same length as data, and ascending)
325328
* @param filter a condition to drop desriptor (recommended filter:
@@ -368,6 +371,11 @@ public List<AnomalyDescriptor> processSequentially(double[][] data, long[] times
368371
checkArgument(point != null, " data should not be null ");
369372
checkArgument(point.length == length, " nonuniform lengths ");
370373
AnomalyDescriptor description = new AnomalyDescriptor(point, timestamp);
374+
// check missing values in point.
375+
int[] missingValues = generateMissingIndicesArray(point);
376+
if (missingValues != null) {
377+
description.setMissingValues(missingValues);
378+
}
371379
augment(description);
372380
if (saveDescriptor(description)) {
373381
lastAnomalyDescriptor = description.copyOf();
@@ -390,6 +398,20 @@ public List<AnomalyDescriptor> processSequentially(double[][] data) {
390398
return processSequentially(data, x -> x.getAnomalyGrade() > 0);
391399
}
392400

401+
private int[] generateMissingIndicesArray(double[] point) {
402+
List<Integer> intArray = new ArrayList<>();
403+
for (int i = 0; i < point.length; i++) {
404+
if (Double.isNaN(point[i])) {
405+
intArray.add(i);
406+
}
407+
}
408+
// Return null if the array is empty
409+
if (intArray.size() == 0) {
410+
return null;
411+
}
412+
return intArray.stream().mapToInt(Integer::intValue).toArray();
413+
}
414+
393415
/**
394416
* a function that extrapolates the data seen by the ThresholdedRCF model, and
395417
* uses the transformations allowed (as opposed to just using RCFs). The

Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/state/returntypes/ComputeDescriptorMapper.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ public class ComputeDescriptorMapper implements IStateMapper<RCFComputeDescripto
3434
@Override
3535
public RCFComputeDescriptor toModel(ComputeDescriptorState state, long seed) {
3636

37-
RCFComputeDescriptor descriptor = new RCFComputeDescriptor(null, 0L);
37+
RCFComputeDescriptor descriptor = new RCFComputeDescriptor(state.getCurrentInput(), state.getInputTimeStamp());
3838
descriptor.setRCFScore(state.getScore());
3939
descriptor.setInternalTimeStamp(state.getInternalTimeStamp());
4040
descriptor.setAttribution(new DiVectorMapper().toModel(state.getAttribution()));
@@ -72,6 +72,8 @@ public ComputeDescriptorState toState(RCFComputeDescriptor descriptor) {
7272
state.setAnomalyGrade(descriptor.getAnomalyGrade());
7373
state.setThreshold(descriptor.getThreshold());
7474
state.setCorrectionMode(descriptor.getCorrectionMode().name());
75+
state.setInputTimeStamp(descriptor.getInputTimestamp());
76+
state.setCurrentInput(descriptor.getCurrentInput());
7577
return state;
7678
}
7779

Java/parkservices/src/main/java/com/amazon/randomcutforest/parkservices/state/returntypes/ComputeDescriptorState.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
@Data
2525
public class ComputeDescriptorState implements Serializable {
26-
private static final long serialVersionUID = 1L;
26+
private static final long serialVersionUID = 2L;
2727

2828
private long internalTimeStamp;
2929
private double score;
@@ -42,4 +42,6 @@ public class ComputeDescriptorState implements Serializable {
4242
private double threshold;
4343
private double anomalyGrade;
4444
private String correctionMode;
45+
private long inputTimeStamp;
46+
private double[] currentInput;
4547
}

Java/parkservices/src/test/java/com/amazon/randomcutforest/parkservices/ThresholdedRandomCutForestTest.java

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -672,4 +672,38 @@ void psqEmptyDataReturnsEmpty() {
672672
assertTrue(out2.isEmpty());
673673
}
674674

675+
@Test
676+
void testProcessSequentiallyWithMissingValues() {
677+
ThresholdedRandomCutForest f = ThresholdedRandomCutForest.builder().dimensions(2).shingleSize(1).build();
678+
679+
double[][] data = { { 1.0, Double.NaN }, { 3.0, 4.0 }, { Double.NaN, 5.0 }, { Double.NaN, Double.NaN } };
680+
long[] stamps = { 10L, 20L, 30L, 40L };
681+
682+
List<AnomalyDescriptor> descriptors = f.processSequentially(data, stamps, d -> true);
683+
684+
assertEquals(4, descriptors.size());
685+
686+
AnomalyDescriptor first = descriptors.get(0);
687+
assertEquals(10L, first.getInputTimestamp());
688+
assertArrayEquals(new double[] { 1.0, Double.NaN }, first.getCurrentInput());
689+
assertNotNull(first.getMissingValues());
690+
assertArrayEquals(new int[] { 1 }, first.getMissingValues());
691+
692+
AnomalyDescriptor second = descriptors.get(1);
693+
assertEquals(20L, second.getInputTimestamp());
694+
assertArrayEquals(new double[] { 3.0, 4.0 }, second.getCurrentInput());
695+
assertNull(second.getMissingValues());
696+
697+
AnomalyDescriptor third = descriptors.get(2);
698+
assertEquals(30L, third.getInputTimestamp());
699+
assertArrayEquals(new double[] { Double.NaN, 5.0 }, third.getCurrentInput());
700+
assertNotNull(third.getMissingValues());
701+
assertArrayEquals(new int[] { 0 }, third.getMissingValues());
702+
703+
AnomalyDescriptor fourth = descriptors.get(3);
704+
assertEquals(40L, fourth.getInputTimestamp());
705+
assertArrayEquals(new double[] { Double.NaN, Double.NaN }, fourth.getCurrentInput());
706+
assertNotNull(fourth.getMissingValues());
707+
assertArrayEquals(new int[] { 0, 1 }, fourth.getMissingValues());
708+
}
675709
}

0 commit comments

Comments
 (0)