|
17 | 17 |
|
18 | 18 | import com.datasqrl.flinkrunner.stdlib.utils.AutoRegisterSystemFunction; |
19 | 19 | import com.google.auto.service.AutoService; |
20 | | -import java.util.ArrayList; |
21 | | -import lombok.SneakyThrows; |
| 20 | +import java.util.LinkedList; |
| 21 | +import java.util.List; |
| 22 | +import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.JsonNode; |
22 | 23 | import org.apache.flink.shaded.jackson2.com.fasterxml.jackson.databind.ObjectMapper; |
23 | 24 | import org.apache.flink.table.functions.AggregateFunction; |
24 | 25 | import org.apache.flink.util.jackson.JacksonMapperFactory; |
25 | 26 |
|
26 | 27 | /** Aggregation function that aggregates JSON objects into a JSON array. */ |
27 | 28 | @AutoService(AutoRegisterSystemFunction.class) |
28 | | -public class jsonb_array_agg extends AggregateFunction<FlinkJsonType, ArrayAgg> |
| 29 | +public class jsonb_array_agg extends AggregateFunction<FlinkJsonType, ArrayAggAccumulator> |
29 | 30 | implements AutoRegisterSystemFunction { |
30 | 31 |
|
31 | 32 | private static final ObjectMapper mapper = JacksonMapperFactory.createObjectMapper(); |
32 | 33 |
|
33 | 34 | @Override |
34 | | - public ArrayAgg createAccumulator() { |
35 | | - return new ArrayAgg(new ArrayList<>()); |
| 35 | + public ArrayAggAccumulator createAccumulator() { |
| 36 | + return new ArrayAggAccumulator(new LinkedList<>(), new LinkedList<>()); |
36 | 37 | } |
37 | 38 |
|
38 | | - public void accumulate(ArrayAgg accumulator, String value) { |
39 | | - accumulator.add(mapper.getNodeFactory().textNode(value)); |
| 39 | + public void accumulate(ArrayAggAccumulator acc, String value) { |
| 40 | + acc.add(mapper.getNodeFactory().textNode(value)); |
40 | 41 | } |
41 | 42 |
|
42 | | - @SneakyThrows |
43 | | - public void accumulate(ArrayAgg accumulator, FlinkJsonType value) { |
44 | | - if (value != null) { |
45 | | - accumulator.add(value.json); |
46 | | - } else { |
47 | | - accumulator.add(null); |
48 | | - } |
| 43 | + public void accumulate(ArrayAggAccumulator acc, FlinkJsonType value) { |
| 44 | + acc.add(value == null ? null : value.json); |
49 | 45 | } |
50 | 46 |
|
51 | | - public void accumulate(ArrayAgg accumulator, Double value) { |
52 | | - accumulator.add(mapper.getNodeFactory().numberNode(value)); |
| 47 | + public void accumulate(ArrayAggAccumulator acc, Double value) { |
| 48 | + acc.add(mapper.getNodeFactory().numberNode(value)); |
53 | 49 | } |
54 | 50 |
|
55 | | - public void accumulate(ArrayAgg accumulator, Long value) { |
56 | | - accumulator.add(mapper.getNodeFactory().numberNode(value)); |
| 51 | + public void accumulate(ArrayAggAccumulator acc, Long value) { |
| 52 | + acc.add(mapper.getNodeFactory().numberNode(value)); |
57 | 53 | } |
58 | 54 |
|
59 | | - public void accumulate(ArrayAgg accumulator, Integer value) { |
60 | | - accumulator.add(mapper.getNodeFactory().numberNode(value)); |
| 55 | + public void accumulate(ArrayAggAccumulator acc, Integer value) { |
| 56 | + acc.add(mapper.getNodeFactory().numberNode(value)); |
61 | 57 | } |
62 | 58 |
|
63 | | - public void retract(ArrayAgg accumulator, String value) { |
64 | | - accumulator.remove(mapper.getNodeFactory().textNode(value)); |
| 59 | + public void retract(ArrayAggAccumulator acc, String value) { |
| 60 | + var nodeVal = mapper.getNodeFactory().textNode(value); |
| 61 | + if (!acc.remove(nodeVal)) { |
| 62 | + acc.addRetract(nodeVal); |
| 63 | + } |
65 | 64 | } |
66 | 65 |
|
67 | | - @SneakyThrows |
68 | | - public void retract(ArrayAgg accumulator, FlinkJsonType value) { |
69 | | - if (value != null) { |
70 | | - accumulator.remove(value.json); |
71 | | - } else { |
72 | | - accumulator.remove(null); |
| 66 | + public void retract(ArrayAggAccumulator acc, FlinkJsonType value) { |
| 67 | + var finalVal = value == null ? null : value.json; |
| 68 | + if (!acc.remove(finalVal)) { |
| 69 | + acc.addRetract(finalVal); |
73 | 70 | } |
74 | 71 | } |
75 | 72 |
|
76 | | - public void retract(ArrayAgg accumulator, Double value) { |
77 | | - accumulator.remove(mapper.getNodeFactory().numberNode(value)); |
| 73 | + public void retract(ArrayAggAccumulator acc, Double value) { |
| 74 | + var nodeVal = mapper.getNodeFactory().numberNode(value); |
| 75 | + if (!acc.getElements().remove(nodeVal)) { |
| 76 | + acc.addRetract(nodeVal); |
| 77 | + } |
78 | 78 | } |
79 | 79 |
|
80 | | - public void retract(ArrayAgg accumulator, Long value) { |
81 | | - accumulator.remove(mapper.getNodeFactory().numberNode(value)); |
| 80 | + public void retract(ArrayAggAccumulator acc, Long value) { |
| 81 | + var nodeVal = mapper.getNodeFactory().numberNode(value); |
| 82 | + if (!acc.getElements().remove(nodeVal)) { |
| 83 | + acc.addRetract(nodeVal); |
| 84 | + } |
82 | 85 | } |
83 | 86 |
|
84 | | - public void retract(ArrayAgg accumulator, Integer value) { |
85 | | - accumulator.remove(mapper.getNodeFactory().numberNode(value)); |
| 87 | + public void retract(ArrayAggAccumulator acc, Integer value) { |
| 88 | + var nodeVal = mapper.getNodeFactory().numberNode(value); |
| 89 | + if (!acc.getElements().remove(nodeVal)) { |
| 90 | + acc.addRetract(nodeVal); |
| 91 | + } |
86 | 92 | } |
87 | 93 |
|
88 | | - public void merge(ArrayAgg accumulator, java.lang.Iterable<ArrayAgg> iterable) { |
89 | | - iterable.forEach(o -> accumulator.getObjects().addAll(o.getObjects())); |
| 94 | + public void merge(ArrayAggAccumulator acc, Iterable<ArrayAggAccumulator> iterable) { |
| 95 | + for (ArrayAggAccumulator otherAcc : iterable) { |
| 96 | + acc.getElements().addAll(otherAcc.getElements()); |
| 97 | + acc.getRetractElements().addAll(otherAcc.getRetractElements()); |
| 98 | + } |
| 99 | + |
| 100 | + List<JsonNode> newRetractBuffer = new LinkedList<>(); |
| 101 | + for (JsonNode elem : acc.getRetractElements()) { |
| 102 | + if (!acc.remove(elem)) { |
| 103 | + newRetractBuffer.add(elem); |
| 104 | + } |
| 105 | + } |
| 106 | + |
| 107 | + acc.getRetractElements().clear(); |
| 108 | + acc.getRetractElements().addAll(newRetractBuffer); |
90 | 109 | } |
91 | 110 |
|
92 | 111 | @Override |
93 | | - public FlinkJsonType getValue(ArrayAgg accumulator) { |
94 | | - // Replacing var with explicit type declaration for Java 11 compatibility |
| 112 | + public FlinkJsonType getValue(ArrayAggAccumulator acc) { |
95 | 113 | var arrayNode = mapper.createArrayNode(); |
96 | | - for (Object o : accumulator.getObjects()) { |
| 114 | + for (Object o : acc.getElements()) { |
97 | 115 | if (o instanceof FlinkJsonType) { |
98 | 116 | arrayNode.add(((FlinkJsonType) o).json); |
99 | 117 | } else { |
|
0 commit comments