Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

package org.apache.wayang.api.sql.calcite.converter;

import java.util.ArrayList;
import java.util.List;
import java.util.Set;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import java.io.Serializable;

import org.apache.wayang.api.sql.calcite.converter.functions.FlattenJoinResult;
import org.apache.wayang.api.sql.calcite.converter.functions.JoinFlattenResult;
import org.apache.wayang.api.sql.calcite.rel.WayangJoin;
import org.apache.wayang.basic.data.Record;
import org.apache.wayang.basic.data.Tuple2;
Expand Down Expand Up @@ -48,7 +48,7 @@ Operator visit(final WayangJoin wayangRelNode) {
childOpLeft.connectTo(0, join, 0);
childOpRight.connectTo(0, join, 1);

final SerializableFunction<Tuple2<Record, Record>, Record> mp = new FlattenJoinResult();
final SerializableFunction<Tuple2<Record, Record>, Record> mp = new JoinFlattenResult();

final MapOperator<Tuple2<Record, Record>, Record> mapOperator = new MapOperator<Tuple2<Record, Record>, Record>(
mp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,13 @@
package org.apache.wayang.api.sql.calcite.converter;

import org.apache.calcite.rel.core.Filter;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.runtime.SqlFunctions;
import org.apache.calcite.sql.SqlKind;

import org.apache.wayang.api.sql.calcite.converter.functions.FilterPredicateImpl;
import org.apache.wayang.api.sql.calcite.rel.WayangFilter;
import org.apache.wayang.basic.data.Record;
import org.apache.wayang.basic.operators.FilterOperator;
import org.apache.wayang.core.function.FunctionDescriptor;
import org.apache.wayang.core.plan.wayangplan.Operator;

import java.util.EnumSet;
Expand All @@ -42,7 +37,6 @@ public class WayangFilterVisitor extends WayangRelNodeVisitor<WayangFilter> {

@Override
Operator visit(final WayangFilter wayangRelNode) {

final Operator childOp = wayangRelConverter.convert(wayangRelNode.getInput(0));

final RexNode condition = ((Filter) wayangRelNode).getCondition();
Expand All @@ -56,106 +50,8 @@ Operator visit(final WayangFilter wayangRelNode) {
return filter;
}

private class FilterPredicateImpl implements FunctionDescriptor.SerializablePredicate<Record> {

private final RexNode condition;

private FilterPredicateImpl(final RexNode condition) {
this.condition = condition;
}

@Override
public boolean test(final Record record) {
return condition.accept(new EvaluateFilterCondition(true, record));
}
}

private class EvaluateFilterCondition extends RexVisitorImpl<Boolean> {

final Record record;

protected EvaluateFilterCondition(final boolean deep, final Record record) {
super(deep);
this.record = record;
}

@Override
public Boolean visitCall(final RexCall call) {
final SqlKind kind = call.getKind();

if (!kind.belongsTo(WayangFilterVisitor.SUPPORTED_OPS))
throw new IllegalStateException(
"Cannot handle this filter predicate yet: " + kind + " during RexCall: " + call);

switch (kind) {
// Since NOT captures only one operand we just get
// the first
case NOT:
assert (call.getOperands().size() == 1) : "SqlKind.NOT should only have 1 operand in call got: " + call.getOperands().size() + ", call: " + call;
return !(call.getOperands().get(0).accept(this));
case AND:
return call.getOperands().stream().allMatch(operator -> operator.accept(this));
case OR:
return call.getOperands().stream().anyMatch(operator -> operator.accept(this));
default:
assert (call.getOperands().size() == 2);
return eval(record, kind, call.getOperands().get(0), call.getOperands().get(1));
}
}

public boolean eval(final Record record, final SqlKind kind, final RexNode leftOperand,
final RexNode rightOperand) {

if (leftOperand instanceof RexInputRef && rightOperand instanceof RexLiteral) {
final RexInputRef rexInputRef = (RexInputRef) leftOperand;
final int index = rexInputRef.getIndex();
final Object field = record.getField(index);
final RexLiteral rexLiteral = (RexLiteral) rightOperand;
switch (kind) {
case LIKE:
return SqlFunctions.like(field.toString(), rexLiteral.toString().replace("'", ""));
case GREATER_THAN:
return isGreaterThan(field, rexLiteral);
case LESS_THAN:
return isLessThan(field, rexLiteral);
case EQUALS:
return isEqualTo(field, rexLiteral);
case GREATER_THAN_OR_EQUAL:
return isGreaterThan(field, rexLiteral) || isEqualTo(field, rexLiteral);
case LESS_THAN_OR_EQUAL:
return isLessThan(field, rexLiteral) || isEqualTo(field, rexLiteral);
default:
throw new IllegalStateException("Predicate not supported yet");

}

} else {
throw new IllegalStateException("Predicate not supported yet");
}

}

private boolean isGreaterThan(final Object o, final RexLiteral rexLiteral) {
// return rexLiteral.getValue().compareTo(o)< 0;
return ((Comparable) o).compareTo(rexLiteral.getValueAs(o.getClass())) > 0;

}

private boolean isLessThan(final Object o, final RexLiteral rexLiteral) {
return ((Comparable) o).compareTo(rexLiteral.getValueAs(o.getClass())) < 0;
}

private boolean isEqualTo(final Object o, final RexLiteral rexLiteral) {
try {
return ((Comparable) o).compareTo(rexLiteral.getValueAs(o.getClass())) == 0;
} catch (final Exception e) {
throw new IllegalStateException("Predicate not supported yet");
}
}
}

/** for quick sanity check **/
private static final EnumSet<SqlKind> SUPPORTED_OPS = EnumSet.of(SqlKind.AND, SqlKind.OR, SqlKind.NOT,
public static final EnumSet<SqlKind> SUPPORTED_OPS = EnumSet.of(SqlKind.AND, SqlKind.OR, SqlKind.NOT,
SqlKind.EQUALS, SqlKind.NOT_EQUALS,
SqlKind.LESS_THAN, SqlKind.GREATER_THAN,
SqlKind.GREATER_THAN_OR_EQUAL, SqlKind.LESS_THAN_OR_EQUAL, SqlKind.LIKE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,131 +18,74 @@

package org.apache.wayang.api.sql.calcite.converter;

import java.util.List;
import java.util.stream.Collectors;

import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexVisitorImpl;
import org.apache.calcite.sql.SqlKind;

import org.apache.wayang.api.sql.calcite.converter.functions.JoinFlattenResult;
import org.apache.wayang.api.sql.calcite.converter.functions.JoinKeyExtractor;
import org.apache.wayang.api.sql.calcite.rel.WayangJoin;

import org.apache.wayang.basic.data.Record;
import org.apache.wayang.basic.data.Tuple2;
import org.apache.wayang.basic.operators.JoinOperator;
import org.apache.wayang.basic.operators.MapOperator;
import org.apache.wayang.core.function.FunctionDescriptor;
import org.apache.wayang.core.function.TransformationDescriptor;
import org.apache.wayang.core.plan.wayangplan.Operator;
import org.apache.wayang.core.util.ReflectionUtils;

public class WayangJoinVisitor extends WayangRelNodeVisitor<WayangJoin> {

WayangJoinVisitor(WayangRelConverter wayangRelConverter) {
WayangJoinVisitor(final WayangRelConverter wayangRelConverter) {
super(wayangRelConverter);
}

@Override
Operator visit(WayangJoin wayangRelNode) {
Operator childOpLeft = wayangRelConverter.convert(wayangRelNode.getInput(0));
Operator childOpRight = wayangRelConverter.convert(wayangRelNode.getInput(1));
Operator visit(final WayangJoin wayangRelNode) {
final Operator childOpLeft = wayangRelConverter.convert(wayangRelNode.getInput(0));
final Operator childOpRight = wayangRelConverter.convert(wayangRelNode.getInput(1));

final RexNode condition = ((Join) wayangRelNode).getCondition();
final RexCall call = (RexCall) condition;

RexNode condition = ((Join) wayangRelNode).getCondition();
final List<Integer> keys = call.getOperands().stream()
.map(RexInputRef.class::cast)
.map(RexInputRef::getIndex)
.collect(Collectors.toList());

assert (keys.size() == 2) : "Amount of keys found in join was not 2, got: " + keys.size();

if (!condition.isA(SqlKind.EQUALS)) {
throw new UnsupportedOperationException("Only equality joins supported");
}

//offset of the index in the right child
int offset = wayangRelNode.getInput(0).getRowType().getFieldCount();
// offset of the index in the right child
final int offset = wayangRelNode.getInput(0).getRowType().getFieldCount();

int leftKeyIndex = condition.accept(new KeyIndex(false, Child.LEFT));
int rightKeyIndex = condition.accept(new KeyIndex(false, Child.RIGHT)) - offset;
final int leftKeyIndex = keys.get(0);
final int rightKeyIndex = keys.get(1) - offset;

JoinOperator<Record, Record, Object> join = new JoinOperator<>(
new TransformationDescriptor<>(new KeyExtractor(leftKeyIndex), Record.class, Object.class),
new TransformationDescriptor<>(new KeyExtractor(rightKeyIndex), Record.class, Object.class)
);
final JoinOperator<Record, Record, Object> join = new JoinOperator<>(
new TransformationDescriptor<>(new JoinKeyExtractor(leftKeyIndex), Record.class, Object.class),
new TransformationDescriptor<>(new JoinKeyExtractor(rightKeyIndex), Record.class, Object.class));

//call connectTo on both operators (left and right)
// call connectTo on both operators (left and right)
childOpLeft.connectTo(0, join, 0);
childOpRight.connectTo(0, join, 1);

// Join returns Tuple2 - map to a Record
MapOperator<Tuple2, Record> mapOperator = new MapOperator(
new MapFunctionImpl(),
Tuple2.class,
Record.class
);
final MapOperator<Tuple2<Record, Record>, Record> mapOperator = new MapOperator<Tuple2<Record, Record>, Record>(
new JoinFlattenResult(),
ReflectionUtils.specify(Tuple2.class),
Record.class);
join.connectTo(0, mapOperator, 0);

return mapOperator;
}

/**
* Extracts key index from the call
*/
private class KeyIndex extends RexVisitorImpl<Integer> {
final Child child;

protected KeyIndex(boolean deep, Child child) {
super(deep);
this.child = child;
}

@Override
public Integer visitCall(RexCall call) {
RexNode operand = call.getOperands().get(child.ordinal());
if (!(operand instanceof RexInputRef)) {
throw new UnsupportedOperationException("Unsupported operation");
}
RexInputRef rexInputRef = (RexInputRef) operand;
return rexInputRef.getIndex();
}
}

/**
* Extracts the key
*/
private class KeyExtractor implements FunctionDescriptor.SerializableFunction<Record, Object> {
private final int index;

public KeyExtractor(int index) {
this.index = index;
}

public Object apply(final Record record) {
return record.getField(index);
}
}

/**
* Flattens Tuple2<Record, Record> to Record
*/
private class MapFunctionImpl implements FunctionDescriptor.SerializableFunction<Tuple2<Record, Record>, Record> {
public MapFunctionImpl() {
super();
}

@Override
public Record apply(final Tuple2<Record, Record> tuple2) {
int length1 = tuple2.getField0().size();
int length2 = tuple2.getField1().size();

int totalLength = length1 + length2;

Object[] fields = new Object[totalLength];

for (int i = 0; i < length1; i++) {
fields[i] = tuple2.getField0().getField(i);
}
for (int j = length1; j < totalLength; j++) {
fields[j] = tuple2.getField1().getField(j - length1);
}
return new Record(fields);

}
}

// Helpers
private enum Child {
LEFT, RIGHT
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
import java.io.Serializable;
import java.util.List;
import java.util.stream.Collectors;

import org.apache.calcite.rel.core.Join;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.wayang.api.sql.calcite.converter.functions.MultiConditionJoinFuncImpl;

import org.apache.wayang.api.sql.calcite.converter.functions.JoinFlattenResult;
import org.apache.wayang.api.sql.calcite.converter.functions.MultiConditionJoinKeyExtractor;
import org.apache.wayang.api.sql.calcite.rel.WayangJoin;
import org.apache.wayang.basic.data.Record;
Expand Down Expand Up @@ -119,7 +121,7 @@ Operator visit(WayangJoin wayangRelNode) {
childOpRight.connectTo(0, join, 1);

// Join returns Tuple2 - map to a Record
final SerializableFunction<Tuple2<Record, Record>, Record> mp = new MultiConditionJoinFuncImpl();
final SerializableFunction<Tuple2<Record, Record>, Record> mp = new JoinFlattenResult();

final MapOperator<Tuple2<Record, Record>, Record> mapOperator = new MapOperator<Tuple2<Record, Record>, Record>(
mp,
Expand Down
Loading
Loading