Skip to content

Commit 79d1fe7

Browse files
authored
Merge pull request #588 from mspruc/main
PLUS & MINUS arithmetic for filters in sql-api
2 parents 9c9bf38 + 0ec8d1f commit 79d1fe7

File tree

3 files changed

+48
-31
lines changed

3 files changed

+48
-31
lines changed

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangAggregateVisitor.java

Lines changed: 11 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,15 @@
2020

2121
import java.util.HashSet;
2222
import java.util.List;
23-
import java.util.Set;
2423

25-
import org.apache.calcite.rel.core.Aggregate;
2624
import org.apache.calcite.rel.core.AggregateCall;
2725

2826
import org.apache.wayang.api.sql.calcite.converter.functions.AggregateAddCols;
2927
import org.apache.wayang.api.sql.calcite.converter.functions.AggregateFunction;
3028
import org.apache.wayang.api.sql.calcite.converter.functions.AggregateKeyExtractor;
3129
import org.apache.wayang.api.sql.calcite.converter.functions.AggregateGetResult;
3230
import org.apache.wayang.api.sql.calcite.rel.WayangAggregate;
31+
3332
import org.apache.wayang.basic.data.Record;
3433
import org.apache.wayang.basic.operators.GlobalReduceOperator;
3534
import org.apache.wayang.basic.operators.MapOperator;
@@ -48,10 +47,8 @@ public class WayangAggregateVisitor extends WayangRelNodeVisitor<WayangAggregate
4847
@Override
4948
Operator visit(final WayangAggregate wayangRelNode) {
5049
final Operator childOp = wayangRelConverter.convert(wayangRelNode.getInput(0));
51-
Operator aggregateOperator;
5250

53-
final List<AggregateCall> aggregateCalls = ((Aggregate) wayangRelNode).getAggCallList();
54-
final int groupCount = wayangRelNode.getGroupCount();
51+
final List<AggregateCall> aggregateCalls = wayangRelNode.getAggCallList();
5552
final HashSet<Integer> groupingFields = new HashSet<>(wayangRelNode.getGroupSet().asSet());
5653

5754
final MapOperator<Record, Record> mapOperator = new MapOperator<>(
@@ -60,22 +57,15 @@ Operator visit(final WayangAggregate wayangRelNode) {
6057
Record.class);
6158
childOp.connectTo(0, mapOperator, 0);
6259

63-
if (groupCount > 0) {
64-
ReduceByOperator<Record, Object> reduceByOperator;
65-
reduceByOperator = new ReduceByOperator<>(
66-
new TransformationDescriptor<>(new AggregateKeyExtractor(groupingFields), Record.class, Object.class),
67-
new ReduceDescriptor<>(new AggregateFunction(aggregateCalls),
68-
DataUnitType.createGrouped(Record.class),
69-
DataUnitType.createBasicUnchecked(Record.class)));
70-
aggregateOperator = reduceByOperator;
71-
} else {
72-
GlobalReduceOperator<Record> globalReduceOperator;
73-
globalReduceOperator = new GlobalReduceOperator<>(
74-
new ReduceDescriptor<>(new AggregateFunction(aggregateCalls),
75-
DataUnitType.createGrouped(Record.class),
76-
DataUnitType.createBasicUnchecked(Record.class)));
77-
aggregateOperator = globalReduceOperator;
78-
}
60+
final Operator aggregateOperator = wayangRelNode.getGroupCount() > 0 ? new ReduceByOperator<>(
61+
new TransformationDescriptor<>(new AggregateKeyExtractor(groupingFields), Record.class, Object.class),
62+
new ReduceDescriptor<>(new AggregateFunction(aggregateCalls),
63+
DataUnitType.createGrouped(Record.class),
64+
DataUnitType.createBasicUnchecked(Record.class)))
65+
: new GlobalReduceOperator<>(
66+
new ReduceDescriptor<>(new AggregateFunction(aggregateCalls),
67+
DataUnitType.createGrouped(Record.class),
68+
DataUnitType.createBasicUnchecked(Record.class)));
7969

8070
mapOperator.connectTo(0, aggregateOperator, 0);
8171

@@ -85,6 +75,5 @@ Operator visit(final WayangAggregate wayangRelNode) {
8575
Record.class);
8676
aggregateOperator.connectTo(0, mapOperator2, 0);
8777
return mapOperator2;
88-
8978
}
9079
}

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangFilterVisitor.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,7 @@ public class WayangFilterVisitor extends WayangRelNodeVisitor<WayangFilter> {
3838
@Override
3939
Operator visit(final WayangFilter wayangRelNode) {
4040
final Operator childOp = wayangRelConverter.convert(wayangRelNode.getInput(0));
41-
42-
final RexNode condition = ((Filter) wayangRelNode).getCondition();
41+
final RexNode condition = wayangRelNode.getCondition();
4342

4443
final FilterOperator<Record> filter = new FilterOperator<>(
4544
new FilterPredicateImpl(condition),

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/FilterPredicateImpl.java

Lines changed: 36 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,11 @@
1818

1919
package org.apache.wayang.api.sql.calcite.converter.functions;
2020

21+
import java.util.Date;
2122
import java.util.List;
23+
import java.util.Objects;
24+
import java.util.Optional;
25+
2226
import org.apache.calcite.rex.RexNode;
2327
import org.apache.calcite.runtime.SqlFunctions;
2428
import org.apache.calcite.sql.SqlKind;
@@ -56,31 +60,54 @@ public SerializableFunction<List<Object>, Object> deriveOperation(final SqlKind
5660
isLessThan(input.get(0), input.get(1)) || isEqualTo(input.get(0), input.get(1));
5761
case AND -> input.stream().map(Boolean.class::cast).allMatch(Boolean::booleanValue);
5862
case OR -> input.stream().map(Boolean.class::cast).anyMatch(Boolean::booleanValue);
63+
case MINUS -> widenToDouble.apply(input.get(0)) - widenToDouble.apply(input.get(1));
64+
case PLUS -> widenToDouble.apply(input.get(0)) + widenToDouble.apply(input.get(1));
5965
default -> throw new UnsupportedOperationException("Kind not supported: " + kind);
6066
};
6167
}
6268
}
6369

6470
/**
65-
* Widening conversions
71+
* Widens number types to optional double
72+
* see also {@link #widenToDouble}
73+
* @return Optional.empty if no conversion available, Optional.of(double) otherwise
74+
*/
75+
final SerializableFunction<Object, Optional<Double>> widenToOptionalDouble =
76+
field -> field instanceof final Number number ? Optional.of(number.doubleValue()) :
77+
field instanceof final Date date ? Optional.of((double) date.getTime()) :
78+
Optional.empty();
79+
80+
81+
/**
82+
* Consumes the option from {@link #widenToOptionalDouble()}, and eagerly provides the underlying double.
83+
* @throws UnsupportedOperationException if conversion was not possible
6684
*/
67-
final SerializableFunction<Object, Comparable> ensureComparable = (a) -> a instanceof Integer val ? val.longValue() : (Comparable<?>) a;
85+
final SerializableFunction<Object, Double> widenToDouble = field -> widenToOptionalDouble
86+
.andThen(option -> option.orElseThrow(() -> new UnsupportedOperationException("Could not convert: " + option + " to double.")))
87+
.apply(field);
88+
89+
/**
90+
* Widening conversions, all numbers to double
91+
*/
92+
final SerializableFunction<Object, Comparable> ensureComparable = field ->
93+
field instanceof Number || field instanceof Date ? widenToDouble.apply(field) :
94+
Comparable.class.cast(field);
95+
6896

6997
/**
7098
* Java equivalent of SQL like clauses
99+
*
71100
* @param s1
72101
* @param s2
73102
* @return true if {@code s1} like {@code s2}
74103
*/
75104
private boolean like(final String s1, final String s2) {
76-
final SqlFunctions.LikeFunction likeFunction = new SqlFunctions.LikeFunction();
77-
final boolean isMatch = likeFunction.like(s1, s2);
78-
79-
return isMatch;
105+
return new SqlFunctions.LikeFunction().like(s1, s2);
80106
}
81107

82108
/**
83109
* Java equivalent of sql greater than clauses
110+
*
84111
* @param o1
85112
* @param o2
86113
* @return true if {@code o1 > o2}
@@ -91,6 +118,7 @@ private boolean isGreaterThan(final Object o1, final Object o2) {
91118

92119
/**
93120
* Java equivalent of sql less than clauses
121+
*
94122
* @param o1
95123
* @param o2
96124
* @return true if {@code o1 < o2}
@@ -101,11 +129,12 @@ private boolean isLessThan(final Object o1, final Object o2) {
101129

102130
/**
103131
* Java equivalent of SQL equals clauses
132+
*
104133
* @param o1
105134
* @param o2
106135
* @return true if {@code o1 == o2}
107136
*/
108137
private boolean isEqualTo(final Object o1, final Object o2) {
109-
return ensureComparable.apply(o1).equals(ensureComparable.apply(o2));
138+
return Objects.equals(ensureComparable.apply(o1), ensureComparable.apply(o2));
110139
}
111140
}

0 commit comments

Comments
 (0)