Skip to content

Commit bcf2835

Browse files
authored
Merge pull request #591 from mspruc/main
Better support for TPCH-like queries
2 parents 90095c3 + 55223bf commit bcf2835

File tree

8 files changed

+243
-165
lines changed

8 files changed

+243
-165
lines changed

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangFilterVisitor.java

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
package org.apache.wayang.api.sql.calcite.converter;
2020

21-
import org.apache.calcite.rel.core.Filter;
2221
import org.apache.calcite.rex.RexNode;
2322
import org.apache.calcite.sql.SqlKind;
2423

@@ -50,9 +49,18 @@ Operator visit(final WayangFilter wayangRelNode) {
5049
}
5150

5251
/** for quick sanity check **/
53-
public static final EnumSet<SqlKind> SUPPORTED_OPS = EnumSet.of(SqlKind.AND, SqlKind.OR, SqlKind.NOT,
54-
SqlKind.EQUALS, SqlKind.NOT_EQUALS,
55-
SqlKind.LESS_THAN, SqlKind.GREATER_THAN,
56-
SqlKind.GREATER_THAN_OR_EQUAL, SqlKind.LESS_THAN_OR_EQUAL, SqlKind.LIKE, SqlKind.IS_NOT_NULL, SqlKind.IS_NULL);
52+
protected static final EnumSet<SqlKind> SUPPORTED_OPS = EnumSet.of(
53+
SqlKind.AND,
54+
SqlKind.OR,
55+
SqlKind.NOT,
56+
SqlKind.EQUALS,
57+
SqlKind.NOT_EQUALS,
58+
SqlKind.LESS_THAN,
59+
SqlKind.GREATER_THAN,
60+
SqlKind.GREATER_THAN_OR_EQUAL,
61+
SqlKind.LESS_THAN_OR_EQUAL,
62+
SqlKind.LIKE,
63+
SqlKind.IS_NOT_NULL,
64+
SqlKind.IS_NULL);
5765

5866
}

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/WayangJoinVisitor.java

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
package org.apache.wayang.api.sql.calcite.converter;
2020

2121
import java.util.List;
22-
import java.util.stream.Collectors;
2322

2423
import org.apache.calcite.rex.RexCall;
2524
import org.apache.calcite.rex.RexInputRef;
@@ -55,7 +54,7 @@ Operator visit(final WayangJoin wayangRelNode) {
5554
final List<Integer> keys = call.getOperands().stream()
5655
.map(RexInputRef.class::cast)
5756
.map(RexInputRef::getIndex)
58-
.collect(Collectors.toList());
57+
.toList();
5958

6059
assert (keys.size() == 2) : "Amount of keys found in join was not 2, got: " + keys.size();
6160

@@ -78,7 +77,7 @@ Operator visit(final WayangJoin wayangRelNode) {
7877
childOpRight.connectTo(0, join, 1);
7978

8079
// Join returns Tuple2 - map to a Record
81-
final MapOperator<Tuple2<Record, Record>, Record> mapOperator = new MapOperator<Tuple2<Record, Record>, Record>(
80+
final MapOperator<Tuple2<Record, Record>, Record> mapOperator = new MapOperator<>(
8281
new JoinFlattenResult(),
8382
ReflectionUtils.specify(Tuple2.class),
8483
Record.class);

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/AggregateFunction.java

Lines changed: 21 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@
1717
*/
1818
package org.apache.wayang.api.sql.calcite.converter.functions;
1919

20-
import java.util.Arrays;
20+
import java.math.BigDecimal;
2121
import java.util.List;
2222
import java.util.Optional;
2323
import java.util.function.BiFunction;
24-
import java.util.stream.Collectors;
2524

2625
import org.apache.calcite.rel.core.AggregateCall;
2726
import org.apache.calcite.runtime.SqlFunctions;
@@ -31,12 +30,12 @@
3130

3231
public class AggregateFunction
3332
implements FunctionDescriptor.SerializableBinaryOperator<Record> {
34-
final List<SqlKind> aggregateKinds;
33+
private final List<SqlKind> aggregateKinds;
3534

3635
public AggregateFunction(final List<AggregateCall> aggregateCalls) {
3736
this.aggregateKinds = aggregateCalls.stream()
38-
.map(call -> call.getAggregation().getKind())
39-
.collect(Collectors.toList());
37+
.map(call -> call.getAggregation().getKind())
38+
.toList();
4039
}
4140

4241
@Override
@@ -56,15 +55,15 @@ public Record apply(final Record record1, final Record record2) {
5655

5756
switch (kind) {
5857
case SUM:
59-
resValues[counter] = this.castAndMap(field1, field2, null, Long::sum, Integer::sum, Double::sum);
58+
resValues[counter] = this.castAndMap(field1, field2, null, Long::sum, Integer::sum, Double::sum, BigDecimal::add);
6059
break;
6160
case MIN:
6261
resValues[counter] = this.castAndMap(field1, field2, SqlFunctions::least, SqlFunctions::least,
63-
SqlFunctions::least, SqlFunctions::least);
62+
SqlFunctions::least, SqlFunctions::least, SqlFunctions::least);
6463
break;
6564
case MAX:
6665
resValues[counter] = this.castAndMap(field1, field2, SqlFunctions::greatest, SqlFunctions::greatest,
67-
SqlFunctions::greatest, SqlFunctions::greatest);
66+
SqlFunctions::greatest, SqlFunctions::greatest, SqlFunctions::greatest);
6867
break;
6968
case COUNT:
7069
// since aggregates inject an extra column for counting before,
@@ -76,9 +75,7 @@ public Record apply(final Record record1, final Record record2) {
7675
resValues[counter] = count;
7776
break;
7877
case AVG:
79-
assert (field1 instanceof Integer && field2 instanceof Integer)
80-
: "Expected to find integers for count but found: " + field1 + " and " + field2;
81-
final Object avg = Integer.class.cast(field1) + Integer.class.cast(field2);
78+
final Object avg = this.castAndMap(field1, field2, null, Long::sum, Integer::sum, Double::sum, BigDecimal::add);
8279

8380
resValues[counter] = avg;
8481

@@ -95,6 +92,7 @@ public Record apply(final Record record1, final Record record2) {
9592
return new Record(resValues);
9693
}
9794

95+
9896
/**
9997
* Handles casts for the record class for each interior type.
10098
*
@@ -110,7 +108,8 @@ private Object castAndMap(final Object a, final Object b,
110108
final BiFunction<String, String, Object> stringMap,
111109
final BiFunction<Long, Long, Object> longMap,
112110
final BiFunction<Integer, Integer, Object> integerMap,
113-
final BiFunction<Double, Double, Object> doubleMap) {
111+
final BiFunction<Double, Double, Object> doubleMap,
112+
final BiFunction<BigDecimal, BigDecimal, Object> bigDecimalMap) {
114113
// support operations between null and any
115114
// class
116115
if ((a == null || b == null) || (a.getClass() == b.getClass())) {
@@ -122,19 +121,16 @@ private Object castAndMap(final Object a, final Object b,
122121
// force .getClass() to be safe so
123122
// we can pass null objects to
124123
// .apply methods.
125-
switch (aWrapped.orElse(bWrapped.orElse("")).getClass().getSimpleName()) {
126-
case "String":
127-
return stringMap.apply((String) a, (String) b);
128-
case "Long":
129-
return longMap.apply((Long) a, (Long) b);
130-
case "Integer":
131-
return integerMap.apply((Integer) a, (Integer) b);
132-
case "Double":
133-
return doubleMap.apply((Double) a, (Double) b);
134-
default:
135-
throw new IllegalStateException("Unsupported operation between: " + aWrapped.getClass().toString()
136-
+ " and: " + bWrapped.getClass().toString());
137-
}
124+
return switch (aWrapped.orElse(bWrapped.orElse("")).getClass().getSimpleName()) {
125+
case "String" -> stringMap.apply((String) a, (String) b);
126+
case "Long" -> longMap.apply((Long) a, (Long) b);
127+
case "Integer" -> integerMap.apply((Integer) a, (Integer) b);
128+
case "Double" -> doubleMap.apply((Double) a, (Double) b);
129+
case "BigDecimal" -> bigDecimalMap.apply((BigDecimal) a, (BigDecimal) b);
130+
default -> throw new IllegalStateException("Unsupported operation between: "
131+
+ aWrapped.getClass().toString()
132+
+ " and: " + bWrapped.getClass().toString());
133+
};
138134
}
139135
throw new IllegalStateException("Unsupported operation between: " + a.getClass().getSimpleName() + " and: "
140136
+ b.getClass().getSimpleName());

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/CallTreeFactory.java

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,30 +18,33 @@
1818
package org.apache.wayang.api.sql.calcite.converter.functions;
1919

2020
import java.io.Serializable;
21+
import java.math.BigDecimal;
22+
import java.util.Calendar;
2123
import java.util.List;
22-
import java.util.stream.Collectors;
2324

2425
import org.apache.calcite.rex.RexCall;
2526
import org.apache.calcite.rex.RexInputRef;
2627
import org.apache.calcite.rex.RexLiteral;
2728
import org.apache.calcite.rex.RexNode;
2829
import org.apache.calcite.sql.SqlKind;
29-
30+
import org.apache.calcite.util.Sarg;
3031
import org.apache.wayang.basic.data.Record;
3132
import org.apache.wayang.core.function.FunctionDescriptor.SerializableFunction;
3233

34+
import com.google.common.collect.ImmutableRangeSet;
35+
3336
/**
3437
* AST of the {@link RexCall} arithmetic, composed into serializable nodes;
3538
* {@link Call}, {@link InputRef}, {@link Literal}
3639
*/
37-
interface CallTreeFactory<Input, Output> extends Serializable {
38-
public default Node<Output> fromRexNode(final RexNode node) {
39-
if (node instanceof RexCall) {
40-
return new Call<>((RexCall) node, this);
41-
} else if (node instanceof RexInputRef) {
42-
return new InputRef<>((RexInputRef) node);
43-
} else if (node instanceof RexLiteral) {
44-
return new Literal<>((RexLiteral) node);
40+
interface CallTreeFactory extends Serializable {
41+
public default Node fromRexNode(final RexNode node) {
42+
if (node instanceof final RexCall call) {
43+
return new Call(call, this);
44+
} else if (node instanceof final RexInputRef inputRef) {
45+
return new InputRef(inputRef);
46+
} else if (node instanceof final RexLiteral literal) {
47+
return new Literal(literal);
4548
} else {
4649
throw new UnsupportedOperationException("Unsupported RexNode in filter condition: " + node);
4750
}
@@ -55,50 +58,66 @@ public default Node<Output> fromRexNode(final RexNode node) {
5558
* @return a serializable function of +, -, * or /
5659
* @throws UnsupportedOperationException on unrecognized {@link SqlKind}
5760
*/
58-
public SerializableFunction<List<Output>, Output> deriveOperation(SqlKind kind);
61+
public SerializableFunction<List<Object>, Object> deriveOperation(SqlKind kind);
5962
}
6063

61-
interface Node<Output> extends Serializable {
62-
public Output evaluate(final Record record);
64+
interface Node extends Serializable {
65+
public Object evaluate(final Record rec);
6366
}
6467

65-
class Call<Input, Output> implements Node<Output> {
66-
final List<Node<Output>> operands;
67-
final SerializableFunction<List<Output>, Output> operation;
68+
class Call implements Node {
69+
private final List<Node> operands;
70+
final SerializableFunction<List<Object>, Object> operation;
6871

69-
protected Call(final RexCall call, final CallTreeFactory<Input, Output> tree) {
70-
operands = call.getOperands().stream().map(tree::fromRexNode).collect(Collectors.toList());
72+
protected Call(final RexCall call, final CallTreeFactory tree) {
73+
operands = call.getOperands().stream().map(tree::fromRexNode).toList();
7174
operation = tree.deriveOperation(call.getKind());
7275
}
7376

7477
@Override
75-
public Output evaluate(final Record record) {
76-
return operation.apply(operands.stream().map(op -> op.evaluate(record)).collect(Collectors.toList()));
78+
public Object evaluate(final Record rec) {
79+
return operation.apply(
80+
operands.stream()
81+
.map(op -> op.evaluate(rec))
82+
.toList());
7783
}
7884
}
7985

80-
class Literal<Output> implements Node<Output> {
81-
final Output value;
86+
class Literal implements Node {
87+
final Serializable value;
8288

8389
Literal(final RexLiteral literal) {
84-
value = (Output) literal.getValue2();
90+
value = switch (literal.getTypeName()) {
91+
case DATE -> literal.getValueAs(Calendar.class);
92+
case INTEGER -> literal.getValueAs(Double.class);
93+
case INTERVAL_DAY -> literal.getValueAs(BigDecimal.class).doubleValue();
94+
case DECIMAL -> literal.getValueAs(BigDecimal.class).doubleValue();
95+
case CHAR -> literal.getValueAs(String.class);
96+
case SARG -> {
97+
final Sarg<?> sarg = literal.getValueAs(Sarg.class);
98+
assert sarg.rangeSet instanceof Serializable : "Sarg RangeSet was not serializable.";
99+
yield (ImmutableRangeSet<?>) sarg.rangeSet;
100+
}
101+
default -> throw new UnsupportedOperationException(
102+
"Literal conversion to Java not implemented, type: " + literal.getTypeName());
103+
};
85104
}
86105

87106
@Override
88-
public Output evaluate(final Record record) {
107+
public Object evaluate(final Record rec) {
89108
return value;
90109
}
91110
}
92111

93-
class InputRef<Output> implements Node<Output> {
112+
class InputRef implements Node {
94113
private final int key;
95114

96115
InputRef(final RexInputRef inputRef) {
97116
this.key = inputRef.getIndex();
98117
}
99118

100119
@Override
101-
public Output evaluate(final Record record) {
102-
return (Output) record.getField(key);
120+
public Object evaluate(final Record rec) {
121+
return rec.getField(key);
103122
}
104123
}

0 commit comments

Comments
 (0)