Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,8 @@ public RelNode visitJoin(Join node, CalcitePlanContext context) {
context.relBuilder.peek(),
context.relBuilder.literal(context.sysLimit.joinSubsearchLimit())));
}
List<String> leftAllFields = context.fieldBuilder.getAllFieldNames(1);
List<String> rightAllFields = context.fieldBuilder.getAllFieldNames(0);
if (node.getJoinCondition().isEmpty()) {
// join-with-field-list grammar
List<String> leftColumns = context.fieldBuilder.getStaticFieldNames(1);
Expand Down Expand Up @@ -1255,12 +1257,12 @@ public RelNode visitJoin(Join node, CalcitePlanContext context) {
|| (node.getArgumentMap().get("overwrite").equals(Literal.TRUE))) {
toBeRemovedFields =
duplicatedFieldNames.stream()
.map(field -> JoinAndLookupUtils.analyzeFieldsForLookUp(field, true, context))
.map(field -> JoinAndLookupUtils.analyzeFieldsInLeft(field, context))
.toList();
} else {
toBeRemovedFields =
duplicatedFieldNames.stream()
.map(field -> JoinAndLookupUtils.analyzeFieldsForLookUp(field, false, context))
.map(field -> JoinAndLookupUtils.analyzeFieldsInRight(field, context))
.toList();
}
Literal max = node.getArgumentMap().get("max");
Expand All @@ -1285,13 +1287,16 @@ public RelNode visitJoin(Join node, CalcitePlanContext context) {
if (!toBeRemovedFields.isEmpty()) {
context.relBuilder.projectExcept(toBeRemovedFields);
}
context.fieldBuilder.reorganizeDynamicFields(leftAllFields, rightAllFields);

return context.relBuilder.peek();
}
// The join-with-criteria grammar doesn't allow empty join condition
RexNode joinCondition =
node.getJoinCondition()
.map(c -> rexVisitor.analyzeJoinCondition(c, context))
.orElse(context.relBuilder.literal(true));
joinCondition = context.rexBuilder.castAnyToAlignTypes(joinCondition, context);
if (node.getJoinType() == SEMI || node.getJoinType() == ANTI) {
// semi and anti join only return left table outputs
context.relBuilder.join(
Expand All @@ -1302,7 +1307,7 @@ public RelNode visitJoin(Join node, CalcitePlanContext context) {
// when a new project add to stack. To avoid `id0`, we will rename the `id0` to `alias.id`
// or `tableIdentifier.id`:
List<String> leftColumns = context.fieldBuilder.getStaticFieldNames(1);
List<String> rightColumns = context.fieldBuilder.getStaticFieldNames();
List<String> rightColumns = context.fieldBuilder.getStaticFieldNames(0);
List<String> rightTableName =
PlanUtils.findTable(context.relBuilder.peek()).getQualifiedName();
// Using `table.column` instead of `catalog.database.table.column` as column prefix because
Expand Down Expand Up @@ -1337,6 +1342,8 @@ public RelNode visitJoin(Join node, CalcitePlanContext context) {
}
context.relBuilder.join(
JoinAndLookupUtils.translateJoinType(node.getJoinType()), joinCondition);

context.fieldBuilder.reorganizeDynamicFields(leftAllFields, rightAllFields);
JoinAndLookupUtils.renameToExpectedFields(
rightColumnsWithAliasIfConflict, leftColumns.size(), context);
}
Expand Down Expand Up @@ -1369,9 +1376,9 @@ public Void visitInputRef(RexInputRef inputRef) {

private static RexNode buildJoinConditionByFieldName(
CalcitePlanContext context, String fieldName) {
RexNode lookupKey = JoinAndLookupUtils.analyzeFieldsForLookUp(fieldName, false, context);
RexNode sourceKey = JoinAndLookupUtils.analyzeFieldsForLookUp(fieldName, true, context);
return context.rexBuilder.equals(sourceKey, lookupKey);
RexNode sourceKey = JoinAndLookupUtils.analyzeFieldsInLeft(fieldName, context);
RexNode lookupKey = JoinAndLookupUtils.analyzeFieldsInRight(fieldName, context);
return context.rexBuilder.equalsWithCastAsNeeded(sourceKey, lookupKey);
}

@Override
Expand All @@ -1397,6 +1404,10 @@ public RelNode visitLookup(Lookup node, CalcitePlanContext context) {
// Get lookupColumns from top of stack (after above potential projection).
List<String> lookupTableFieldNames = context.fieldBuilder.getStaticFieldNames();

// For merging with dynamic fields later
List<String> leftAllFields = context.fieldBuilder.getAllFieldNames(1);
List<String> rightAllFields = context.fieldBuilder.getAllFieldNames(0);

// 3. Find fields which should be removed in lookup-table.
// For lookup table, the mapping fields should be dropped after join
// unless they are explicitly put in the output fields
Expand All @@ -1410,6 +1421,7 @@ public RelNode visitLookup(Lookup node, CalcitePlanContext context) {
.toList();
List<RexNode> toBeRemovedLookupFields =
toBeRemovedLookupFieldNames.stream()
.filter(d -> lookupTableFieldNames.contains(d))
.map(d -> (RexNode) context.fieldBuilder.staticField(2, 1, d))
.toList();
List<RexNode> toBeRemovedFields = new ArrayList<>(toBeRemovedLookupFields);
Expand All @@ -1421,7 +1433,7 @@ public RelNode visitLookup(Lookup node, CalcitePlanContext context) {

List<RexNode> duplicatedSourceFields =
duplicatedFieldNamesMap.keySet().stream()
.map(field -> JoinAndLookupUtils.analyzeFieldsForLookUp(field, true, context))
.map(field -> JoinAndLookupUtils.analyzeFieldsInLeft(field, context))
.toList();
// Duplicated fields in source-field should always be removed.
toBeRemovedFields.addAll(duplicatedSourceFields);
Expand All @@ -1433,7 +1445,7 @@ public RelNode visitLookup(Lookup node, CalcitePlanContext context) {
if (!duplicatedFieldNamesMap.isEmpty() && node.getOutputStrategy() == OutputStrategy.APPEND) {
List<RexNode> duplicatedProvidedFields =
duplicatedFieldNamesMap.values().stream()
.map(field -> JoinAndLookupUtils.analyzeFieldsForLookUp(field, false, context))
.map(field -> JoinAndLookupUtils.analyzeFieldsInRight(field, context))
.toList();
for (int i = 0; i < duplicatedProvidedFields.size(); ++i) {
newCoalesceList.add(
Expand Down Expand Up @@ -1470,7 +1482,7 @@ public RelNode visitLookup(Lookup node, CalcitePlanContext context) {
context.relBuilder.projectExcept(toBeRemovedFields);
}

// TODO: dedupe dynamic fields
context.fieldBuilder.reorganizeDynamicFields(leftAllFields, rightAllFields);

// 7. Rename the fields to the expected names.
JoinAndLookupUtils.renameToExpectedFields(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,19 @@
import org.apache.calcite.avatica.util.TimeUnit;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlIntervalQualifier;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.parser.SqlParserPos;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.opensearch.sql.ast.expression.SpanUnit;
import org.opensearch.sql.calcite.type.AbstractExprRelDataType;
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
import org.opensearch.sql.calcite.utils.RexConverter;
import org.opensearch.sql.data.type.ExprCoreType;
import org.opensearch.sql.exception.ExpressionEvaluationException;
import org.opensearch.sql.exception.SemanticCheckException;
Expand All @@ -41,6 +44,20 @@ public RexNode equals(RexNode n1, RexNode n2) {
return this.makeCall(SqlStdOperatorTable.EQUALS, n1, n2);
}

/** Make equals call with adding cast in case the node type is ANY. */
public RexNode equalsWithCastAsNeeded(RexNode n1, RexNode n2) {
if (isAnyType(n1) && isAnyType(n2)) {
n1 = castToString(n1);
n2 = castToString(n2);
} else if (isAnyType(n1)) {
n1 = castToTargetType(n1, n2);
} else if (isAnyType(n2)) {
n2 = castToTargetType(n2, n1);
}

return equals(n1, n2);
}

public RexNode and(RexNode left, RexNode right) {
final RelDataType booleanType = this.getTypeFactory().createSqlType(SqlTypeName.BOOLEAN);
return this.makeCall(booleanType, SqlStdOperatorTable.AND, List.of(left, right));
Expand Down Expand Up @@ -163,4 +180,36 @@ else if ((SqlTypeUtil.isApproximateNumeric(sourceType) || SqlTypeUtil.isDecimal(
}
return super.makeCast(pos, type, exp, matchNullability, safe, format);
}

public boolean isAnyType(RexNode node) {
return node.getType().getSqlTypeName().equals(SqlTypeName.ANY);
}

public RexNode castToString(RexNode node) {
RelDataType stringType = getTypeFactory().createSqlType(SqlTypeName.VARCHAR);
RelDataType nullableStringType = getTypeFactory().createTypeWithNullability(stringType, true);
return makeCast(nullableStringType, node, true, true);
}

/** cast node to the same type as target */
public RexNode castToTargetType(RexNode node, RexNode target) {
return makeCast(target.getType(), node, true, true);
}

/** Utility to cast ANY to specific types to avoid compare issue */
RexNode castAnyToAlignTypes(RexNode rexNode, CalcitePlanContext context) {
return rexNode.accept(
new RexConverter() {
@Override
public RexNode visitCall(RexCall call) {
if (call.getKind() == SqlKind.EQUALS) {
RexNode n0 = call.operands.get(0);
RexNode n1 = call.operands.get(1);
return super.visitCall((RexCall) equalsWithCastAsNeeded(n0, n1));
} else {
return super.visitCall(call);
}
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.logical.LogicalProject;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexLiteral;
Expand All @@ -34,6 +33,7 @@
import org.apache.calcite.util.mapping.Mappings;
import org.apache.commons.lang3.tuple.Pair;
import org.immutables.value.Value;
import org.opensearch.sql.calcite.ExtendedRexBuilder;
import org.opensearch.sql.calcite.rel.RelBuilderWrapper;
import org.opensearch.sql.calcite.rel.RelFieldBuilder;

Expand Down Expand Up @@ -78,7 +78,8 @@ public void apply(RelOptRuleCall call, LogicalAggregate aggregate, LogicalProjec

final RelBuilder rawRelBuilder = call.builder();
final RelBuilderWrapper relBuilder = new RelBuilderWrapper(rawRelBuilder);
final RexBuilder rexBuilder = aggregate.getCluster().getRexBuilder();
final ExtendedRexBuilder rexBuilder =
new ExtendedRexBuilder(aggregate.getCluster().getRexBuilder());
final RelFieldBuilder fieldBuilder = new RelFieldBuilder(rawRelBuilder, rexBuilder);
relBuilder.push(project.getInput());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,18 +38,29 @@ public static Optional<RexNode> resolveField(
if (inputFieldNames.contains(fieldName)) {
return Optional.of(context.fieldBuilder.staticField(inputCount, inputOrdinal, fieldName));
} else if (context.fieldBuilder.isDynamicFieldsExist()) {
return Optional.of(context.fieldBuilder.dynamicField(fieldName));
return Optional.of(context.fieldBuilder.dynamicField(inputCount, inputOrdinal, fieldName));
}
return Optional.empty();
}

/** Resolve field in the top of the stack */
public static Optional<RexNode> resolveField(String fieldName, CalcitePlanContext context) {
return resolveField(1, 0, fieldName, context);
}

/** Resolve field in the specified input. Throw exception if not found. */
public static RexNode resolveFieldOrThrow(
int inputCount, int inputOrdinal, String fieldName, CalcitePlanContext context) {
return resolveField(inputCount, inputOrdinal, fieldName, context)
.orElseThrow(
() -> new IllegalArgumentException(String.format("Field [%s] not found.", fieldName)));
}

/** Resolve field in the top of the stack. Throw exception if not found. */
public static RexNode resolveFieldOrThrow(String fieldName, CalcitePlanContext context) {
return resolveFieldOrThrow(1, 0, fieldName, context);
}

/**
* Resolves a qualified name to a RexNode based on the current context.
*
Expand Down Expand Up @@ -78,6 +89,8 @@ private static RexNode resolveInJoinCondition(

return resolveFieldWithAlias(nameNode, context, 2)
.or(() -> resolveFieldWithoutAlias(nameNode, context, 2))
.or(() -> resolveDynamicFieldsWithAlias(nameNode, context, 2))
.or(() -> resolveDynamicFields(nameNode, context, 2))
.orElseThrow(() -> getNotFoundException(nameNode));
}

Expand Down Expand Up @@ -139,6 +152,28 @@ private static Optional<RexNode> resolveFieldWithAlias(
return Optional.empty();
}

private static Optional<RexNode> resolveDynamicFieldsWithAlias(
QualifiedName nameNode, CalcitePlanContext context, int inputCount) {
List<String> parts = nameNode.getParts();
log.debug(
"resolveDynamicFieldsWithAlias() called with nameNode={}, parts={}, inputCount={}",
nameNode,
parts,
inputCount);

if (parts.size() >= 2) {
// Consider first part as table alias
String alias = parts.get(0);

String fieldName = String.join(".", parts.subList(1, parts.size()));
Optional<RexNode> dynamicField =
tryToResolveField(alias, DYNAMIC_FIELDS_MAP, context, inputCount);
return dynamicField.map(field -> createItemAccess(field, fieldName, context));
}

return Optional.empty();
}

private static Optional<RexNode> resolveDynamicFields(
QualifiedName nameNode, CalcitePlanContext context, int inputCount) {
List<String> parts = nameNode.getParts();
Expand Down
Loading
Loading