Skip to content

Commit af2f391

Browse files
authored
Merge pull request #549 from mspruc/main
Serialisability for sql-api projections
2 parents 06db2b1 + ea66604 commit af2f391

File tree

2 files changed

+89
-21
lines changed

2 files changed

+89
-21
lines changed

wayang-api/wayang-api-sql/src/main/java/org/apache/wayang/api/sql/calcite/converter/functions/ProjectMapFuncImpl.java

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import java.util.ArrayList;
2222
import java.util.List;
2323
import java.util.function.BinaryOperator;
24+
import java.util.stream.Collectors;
2425

2526
import org.apache.calcite.rex.RexCall;
2627
import org.apache.calcite.rex.RexInputRef;
@@ -30,36 +31,30 @@
3031
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
3132

3233
import org.apache.wayang.core.function.FunctionDescriptor;
34+
import org.apache.wayang.core.function.FunctionDescriptor.SerializableFunction;
3335
import org.apache.wayang.basic.data.Record;
3436

35-
3637
public class ProjectMapFuncImpl implements
3738
FunctionDescriptor.SerializableFunction<Record, Record> {
38-
private final List<RexNode> projects;
39+
final List<SerializableFunction<Record, Object>> projections;
3940

4041
public ProjectMapFuncImpl(final List<RexNode> projects) {
41-
this.projects = projects;
42-
}
43-
44-
@Override
45-
public Record apply(final Record record) {
46-
47-
final List<Object> projectedRecord = new ArrayList<>();
48-
for (int i = 0; i < projects.size(); i++) {
49-
final RexNode exp = projects.get(i);
42+
this.projections = projects.stream().map(exp -> {
5043
if (exp instanceof RexInputRef) {
51-
projectedRecord.add(record.getField(((RexInputRef) exp).getIndex()));
44+
final int key = ((RexInputRef) exp).getIndex();
45+
return (SerializableFunction<Record, Object>) record -> record.getField(key);
5246
} else if (exp instanceof RexLiteral) {
53-
final RexLiteral literal = (RexLiteral) exp;
54-
projectedRecord.add(literal.getValue());
47+
final Object literalValue = ((RexLiteral) exp).getValue();
48+
return (SerializableFunction<Record, Object>) record -> literalValue;
5549
} else if (exp instanceof RexCall) {
56-
projectedRecord.add(evaluateRexCall(record, (RexCall) exp));
50+
return (SerializableFunction<Record, Object>) record -> evaluateRexCall(record, (RexCall) exp);
51+
} else {
52+
throw new UnsupportedOperationException("Could not resolve record for exp: " + exp);
5753
}
58-
}
59-
return new Record(projectedRecord.toArray(new Object[0]));
54+
}).collect(Collectors.toList());
6055
}
6156

62-
public static Object evaluateRexCall(final Record record, final RexCall rexCall) {
57+
public Object evaluateRexCall(final Record record, final RexCall rexCall) {
6358
if (rexCall == null) {
6459
return null;
6560
}
@@ -70,7 +65,7 @@ public static Object evaluateRexCall(final Record record, final RexCall rexCall)
7065

7166
if (operator == SqlStdOperatorTable.PLUS) {
7267
// Handle addition
73-
return evaluateNaryOperation(record, operands, Double::sum);
68+
return evaluateNaryOperation(record, operands, (a, b) -> a + b);
7469
} else if (operator == SqlStdOperatorTable.MINUS) {
7570
// Handle subtraction
7671
return evaluateNaryOperation(record, operands, (a, b) -> a - b);
@@ -85,7 +80,7 @@ public static Object evaluateRexCall(final Record record, final RexCall rexCall)
8580
}
8681
}
8782

88-
public static Object evaluateNaryOperation(final Record record, final List<RexNode> operands,
83+
public Object evaluateNaryOperation(final Record record, final List<RexNode> operands,
8984
final BinaryOperator<Double> operation) {
9085
if (operands.isEmpty()) {
9186
return null;
@@ -110,7 +105,7 @@ public static Object evaluateNaryOperation(final Record record, final List<RexNo
110105
return result;
111106
}
112107

113-
public static Object evaluateRexNode(final Record record, final RexNode rexNode) {
108+
public Object evaluateRexNode(final Record record, final RexNode rexNode) {
114109
if (rexNode instanceof RexCall) {
115110
// Recursively evaluate a RexCall
116111
return evaluateRexCall(record, (RexCall) rexNode);
@@ -124,4 +119,9 @@ public static Object evaluateRexNode(final Record record, final RexNode rexNode)
124119
return null; // Unsupported or unknown expression
125120
}
126121
}
122+
123+
@Override
124+
public Record apply(final Record record) {
125+
return new Record(projections.stream().map(func -> func.apply(record)).toArray());
126+
}
127127
}

wayang-api/wayang-api-sql/src/test/java/org/apache/wayang/api/sql/SqlToWayangRelTest.java

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,24 @@
1717

1818
package org.apache.wayang.api.sql;
1919

20+
import org.apache.calcite.jdbc.CalciteSchema;
2021
import org.apache.calcite.jdbc.JavaTypeFactoryImpl;
2122
import org.apache.calcite.rel.RelNode;
2223
import org.apache.calcite.rel.externalize.RelWriterImpl;
2324
import org.apache.calcite.rel.rules.CoreRules;
2425
import org.apache.calcite.rel.type.RelDataTypeFactory;
26+
import org.apache.calcite.rex.RexBuilder;
27+
import org.apache.calcite.rex.RexNode;
2528
import org.apache.calcite.sql.SqlExplainLevel;
2629
import org.apache.calcite.sql.SqlNode;
30+
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
2731
import org.apache.calcite.sql.parser.SqlParseException;
2832
import org.apache.calcite.sql.type.SqlTypeName;
2933
import org.apache.calcite.tools.RuleSet;
3034
import org.apache.calcite.tools.RuleSets;
3135

3236
import org.apache.wayang.api.sql.calcite.convention.WayangConvention;
37+
import org.apache.wayang.api.sql.calcite.converter.functions.FilterPredicateImpl;
3338
import org.apache.wayang.api.sql.calcite.optimizer.Optimizer;
3439
import org.apache.wayang.api.sql.calcite.rules.WayangRules;
3540
import org.apache.wayang.api.sql.calcite.schema.SchemaUtils;
@@ -41,18 +46,25 @@
4146
import org.apache.wayang.api.sql.context.SqlContext;
4247
import org.apache.wayang.basic.data.Tuple2;
4348
import org.apache.wayang.core.api.Configuration;
49+
import org.apache.wayang.core.function.FunctionDescriptor.SerializableFunction;
50+
import org.apache.wayang.core.function.FunctionDescriptor.SerializablePredicate;
4451
import org.apache.wayang.core.plan.wayangplan.Operator;
4552
import org.apache.wayang.core.plan.wayangplan.PlanTraversal;
4653
import org.apache.wayang.core.plan.wayangplan.WayangPlan;
4754
import org.apache.wayang.java.Java;
55+
import org.apache.wayang.spark.Spark;
4856
import org.apache.wayang.basic.data.Record;
4957
import org.json.simple.JSONObject;
5058
import org.json.simple.parser.JSONParser;
5159
import org.json.simple.parser.ParseException;
5260

5361
import org.junit.Test;
5462

63+
import java.io.ByteArrayInputStream;
64+
import java.io.ByteArrayOutputStream;
5565
import java.io.IOException;
66+
import java.io.ObjectInputStream;
67+
import java.io.ObjectOutputStream;
5668
import java.io.PrintWriter;
5769
import java.io.StringWriter;
5870
import java.sql.SQLException;
@@ -371,6 +383,62 @@ public void joinWithLargeLeftTableIndexMirrorAlias() throws Exception {
371383
assert (resultTally.equals(shouldBeTally));
372384
}
373385

386+
// tests sql-apis ability to serialize projections and joins
387+
@Test
388+
public void sparkInnerJoin() throws Exception {
389+
final SqlContext sqlContext = createSqlContext("/data/largeLeftTableIndex.csv");
390+
391+
final Tuple2<Collection<Record>, WayangPlan> t = this.buildCollectorAndWayangPlan(sqlContext,
392+
"SELECT * FROM fs.largeLeftTableIndex AS na INNER JOIN fs.largeLeftTableIndex AS nb ON nb.NAMEB = na.NAMEA " //
393+
);
394+
395+
final Collection<Record> result = t.field0;
396+
final WayangPlan wayangPlan = t.field1;
397+
398+
PlanTraversal.upstream().traverse(wayangPlan.getSinks()).getTraversedNodes().forEach(node -> {
399+
node.addTargetPlatform(Spark.platform());
400+
});
401+
402+
sqlContext.execute(wayangPlan);
403+
404+
final List<Record> shouldBe = List.of(
405+
new Record("test1", "test1", "test2", "test1", "test1", "test2"),
406+
new Record("test2", "" , "test2", "" , "test2", "test2"),
407+
new Record("" , "test2", "test2", "test2", "" , "test2")
408+
);
409+
410+
final Map<Record, Integer> resultTally = result.stream()
411+
.collect(Collectors.toMap(rec -> rec, rec -> 1, Integer::sum));
412+
final Map<Record, Integer> shouldBeTally = shouldBe.stream()
413+
.collect(Collectors.toMap(rec -> rec, rec -> 1, Integer::sum));
414+
415+
assert (resultTally.equals(shouldBeTally));
416+
}
417+
418+
//@Test
419+
public void rexSerializationTest() throws Exception {
420+
// create filterPredicateImpl for serialisation
421+
final RelDataTypeFactory typeFactory = new JavaTypeFactoryImpl();
422+
final RexBuilder rb = new RexBuilder(typeFactory);
423+
final RexNode leftOperand = rb.makeInputRef(typeFactory.createSqlType(SqlTypeName.VARCHAR), 0);
424+
final RexNode rightOperand = rb.makeLiteral("test");
425+
final RexNode cond = rb.makeCall(SqlStdOperatorTable.EQUALS, leftOperand, rightOperand);
426+
final SerializablePredicate<?> fpImpl = new FilterPredicateImpl(cond);
427+
428+
final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
429+
final ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteArrayOutputStream);
430+
objectOutputStream.writeObject(fpImpl);
431+
objectOutputStream.close();
432+
433+
final ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(
434+
byteArrayOutputStream.toByteArray());
435+
final ObjectInputStream objectInputStream = new ObjectInputStream(byteArrayInputStream);
436+
final Object deserializedObject = objectInputStream.readObject();
437+
objectInputStream.close();
438+
439+
assert (((FilterPredicateImpl) deserializedObject).test(new Record("test")));
440+
}
441+
374442
@Test
375443
public void exampleFilterTableRefToTableRef() throws Exception {
376444
final SqlContext sqlContext = createSqlContext("/data/exampleRefToRef.csv");

0 commit comments

Comments
 (0)