Skip to content

Commit e420f32

Browse files
authored
[fix](variant) function element at compute signature (#59083)
The `ELEMENTAT` function should override the `computeSignature` method to determine the return value type for variant.
1 parent dbc4805 commit e420f32

File tree

7 files changed

+154
-253
lines changed

7 files changed

+154
-253
lines changed

fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignature.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,6 @@ default FunctionSignature computeSignature(FunctionSignature signature) {
114114
.then(ComputeSignatureHelper::implementFollowToArgumentReturnType)
115115
.then(ComputeSignatureHelper::normalizeDecimalV2)
116116
.then(ComputeSignatureHelper::ensureNestedNullableOfArray)
117-
.then(ComputeSignatureHelper::dynamicComputeVariantArgs)
118117
.get();
119118
}
120119

fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelper.java

Lines changed: 0 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
import org.apache.doris.nereids.types.NullType;
3434
import org.apache.doris.nereids.types.StructType;
3535
import org.apache.doris.nereids.types.TimeV2Type;
36-
import org.apache.doris.nereids.types.VariantType;
3736
import org.apache.doris.nereids.types.coercion.AnyDataType;
3837
import org.apache.doris.nereids.types.coercion.ComplexDataType;
3938
import org.apache.doris.nereids.types.coercion.FollowToAnyDataType;
@@ -559,58 +558,6 @@ private static FunctionSignature defaultTimePrecisionPromotion(FunctionSignature
559558
return signature;
560559
}
561560

562-
/**
563-
* Dynamically compute function signature for variant type arguments.
564-
* This method handles cases where the function signature contains variant types
565-
* and needs to be adjusted based on the actual argument types.
566-
*
567-
* @param signature Original function signature
568-
* @param arguments List of actual arguments passed to the function
569-
* @return Updated function signature with resolved variant types
570-
*/
571-
public static FunctionSignature dynamicComputeVariantArgs(
572-
FunctionSignature signature, List<Expression> arguments) {
573-
574-
List<DataType> newArgTypes = Lists.newArrayListWithCapacity(arguments.size());
575-
boolean findVariantType = false;
576-
577-
for (int i = 0; i < arguments.size(); i++) {
578-
// Get signature type for current argument position
579-
DataType sigType;
580-
if (i >= signature.argumentsTypes.size()) {
581-
sigType = signature.getVarArgType().orElseThrow(
582-
() -> new AnalysisException("function arity not match with signature"));
583-
} else {
584-
sigType = signature.argumentsTypes.get(i);
585-
}
586-
587-
// Get actual type of the argument expression
588-
DataType expressionType = arguments.get(i).getDataType();
589-
590-
// If both signature type and expression type are variant,
591-
// use expression type and update return type
592-
if (sigType instanceof VariantType && expressionType instanceof VariantType) {
593-
// return type is variant, update return type to expression type
594-
if (signature.returnType instanceof VariantType) {
595-
signature = signature.withReturnType(expressionType);
596-
if (findVariantType) {
597-
throw new AnalysisException("variant type is not supported in multiple arguments");
598-
} else {
599-
findVariantType = true;
600-
}
601-
}
602-
newArgTypes.add(expressionType);
603-
} else {
604-
// Otherwise keep original signature type
605-
newArgTypes.add(sigType);
606-
}
607-
}
608-
609-
// Update signature with new argument types
610-
signature = signature.withArgumentTypes(signature.hasVarArgs, newArgTypes);
611-
return signature;
612-
}
613-
614561
private static FunctionSignature defaultDecimalV3PrecisionPromotion(
615562
FunctionSignature signature, List<Expression> arguments) {
616563
DecimalV3Type finalType = null;

fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Crc32Internal.java

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ public FunctionSignature computeSignature(FunctionSignature signature) {
9494
sig = ComputeSignatureHelper.implementFollowToArgumentReturnType(sig, getArguments());
9595
sig = ComputeSignatureHelper.normalizeDecimalV2(sig, getArguments());
9696
sig = ComputeSignatureHelper.ensureNestedNullableOfArray(sig, getArguments());
97-
sig = ComputeSignatureHelper.dynamicComputeVariantArgs(sig, getArguments());
9897
return sig;
9998
}
10099

fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ElementAt.java

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
2828
import org.apache.doris.nereids.types.ArrayType;
2929
import org.apache.doris.nereids.types.BigIntType;
30+
import org.apache.doris.nereids.types.DataType;
3031
import org.apache.doris.nereids.types.MapType;
3132
import org.apache.doris.nereids.types.StructType;
3233
import org.apache.doris.nereids.types.VarcharType;
@@ -96,4 +97,18 @@ public Expression rewriteWhenAnalyze() {
9697
}
9798
return this;
9899
}
100+
101+
@Override
102+
public FunctionSignature computeSignature(FunctionSignature signature) {
103+
List<Expression> arguments = getArguments();
104+
DataType expressionType = arguments.get(0).getDataType();
105+
DataType sigType = signature.argumentsTypes.get(0);
106+
if (expressionType instanceof VariantType && sigType instanceof VariantType) {
107+
// only keep the variant max subcolumns count
108+
VariantType variantType = new VariantType(((VariantType) expressionType).getVariantMaxSubcolumnsCount());
109+
signature = signature.withArgumentType(0, variantType);
110+
signature = signature.withReturnType(variantType);
111+
}
112+
return super.computeSignature(signature);
113+
}
99114
}

fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/functions/ComputeSignatureHelperTest.java

Lines changed: 0 additions & 198 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,22 @@
1818
package org.apache.doris.nereids.trees.expressions.functions;
1919

2020
import org.apache.doris.catalog.FunctionSignature;
21-
import org.apache.doris.nereids.exceptions.AnalysisException;
2221
import org.apache.doris.nereids.trees.expressions.Expression;
2322
import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral;
2423
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
2524
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
2625
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
2726
import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal;
2827
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
29-
import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
3028
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
3129
import org.apache.doris.nereids.trees.expressions.literal.Literal;
3230
import org.apache.doris.nereids.trees.expressions.literal.MapLiteral;
3331
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
3432
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
3533
import org.apache.doris.nereids.trees.expressions.literal.TimeV2Literal;
36-
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
3734
import org.apache.doris.nereids.types.ArrayType;
3835
import org.apache.doris.nereids.types.BigIntType;
3936
import org.apache.doris.nereids.types.BooleanType;
40-
import org.apache.doris.nereids.types.DataType;
4137
import org.apache.doris.nereids.types.DateTimeType;
4238
import org.apache.doris.nereids.types.DateTimeV2Type;
4339
import org.apache.doris.nereids.types.DateType;
@@ -50,7 +46,6 @@
5046
import org.apache.doris.nereids.types.NullType;
5147
import org.apache.doris.nereids.types.SmallIntType;
5248
import org.apache.doris.nereids.types.TimeV2Type;
53-
import org.apache.doris.nereids.types.VariantType;
5449
import org.apache.doris.nereids.types.coercion.AnyDataType;
5550
import org.apache.doris.nereids.types.coercion.FollowToAnyDataType;
5651
import org.apache.doris.nereids.types.coercion.FollowToArgumentType;
@@ -537,199 +532,6 @@ void testComplexNestedMixedTimePrecisionPromotion() {
537532
((ArrayType) ((MapType) signature.returnType).getValueType()).getItemType());
538533
}
539534

540-
@Test
541-
void testNoDynamicComputeVariantArgs() {
542-
FunctionSignature signature = FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE);
543-
signature = ComputeSignatureHelper.dynamicComputeVariantArgs(signature, Collections.emptyList());
544-
Assertions.assertTrue(signature.returnType instanceof DoubleType);
545-
}
546-
547-
@Test
548-
void testDynamicComputeVariantArgsSingleVariant() {
549-
VariantType variantType = new VariantType(100);
550-
FunctionSignature signature = FunctionSignature.ret(VariantType.INSTANCE)
551-
.args(VariantType.INSTANCE, IntegerType.INSTANCE);
552-
553-
List<Expression> arguments = Lists.newArrayList(
554-
new MockVariantExpression(variantType),
555-
new IntegerLiteral(42));
556-
557-
signature = ComputeSignatureHelper.dynamicComputeVariantArgs(signature, arguments);
558-
559-
Assertions.assertTrue(signature.returnType instanceof VariantType);
560-
Assertions.assertEquals(100, ((VariantType) signature.returnType).getVariantMaxSubcolumnsCount());
561-
Assertions.assertEquals(10000, ((VariantType) signature.returnType).getVariantMaxSparseColumnStatisticsSize());
562-
563-
Assertions.assertTrue(signature.getArgType(0) instanceof VariantType);
564-
Assertions.assertEquals(100, ((VariantType) signature.getArgType(0)).getVariantMaxSubcolumnsCount());
565-
Assertions.assertEquals(10000, ((VariantType) signature.getArgType(0)).getVariantMaxSparseColumnStatisticsSize());
566-
567-
Assertions.assertTrue(signature.getArgType(1) instanceof IntegerType);
568-
}
569-
570-
@Test
571-
void testDynamicComputeVariantArgsMultipleVariants() {
572-
VariantType variantType1 = new VariantType(150);
573-
VariantType variantType2 = new VariantType(250);
574-
FunctionSignature signature = FunctionSignature.ret(IntegerType.INSTANCE)
575-
.args(VariantType.INSTANCE, VariantType.INSTANCE);
576-
577-
List<Expression> arguments = Lists.newArrayList(
578-
new MockVariantExpression(variantType1),
579-
new MockVariantExpression(variantType2));
580-
581-
signature = ComputeSignatureHelper.dynamicComputeVariantArgs(signature, arguments);
582-
583-
Assertions.assertTrue(signature.getArgType(0) instanceof VariantType);
584-
Assertions.assertEquals(150, ((VariantType) signature.getArgType(0)).getVariantMaxSubcolumnsCount());
585-
Assertions.assertEquals(10000, ((VariantType) signature.getArgType(0)).getVariantMaxSparseColumnStatisticsSize());
586-
Assertions.assertTrue(signature.getArgType(1) instanceof VariantType);
587-
Assertions.assertEquals(250, ((VariantType) signature.getArgType(1)).getVariantMaxSubcolumnsCount());
588-
Assertions.assertEquals(10000, ((VariantType) signature.getArgType(1)).getVariantMaxSparseColumnStatisticsSize());
589-
Assertions.assertTrue(signature.returnType instanceof IntegerType);
590-
}
591-
592-
@Test
593-
void testDynamicComputeVariantArgsMixedTypesWithSingleVariant() {
594-
VariantType variantType = new VariantType(75);
595-
FunctionSignature signature = FunctionSignature.ret(BooleanType.INSTANCE)
596-
.args(VariantType.INSTANCE, IntegerType.INSTANCE, DoubleType.INSTANCE);
597-
598-
List<Expression> arguments = Lists.newArrayList(
599-
new MockVariantExpression(variantType),
600-
new IntegerLiteral(10),
601-
new DoubleLiteral(3.14));
602-
603-
signature = ComputeSignatureHelper.dynamicComputeVariantArgs(signature, arguments);
604-
605-
Assertions.assertTrue(signature.getArgType(0) instanceof VariantType);
606-
Assertions.assertEquals(75, ((VariantType) signature.getArgType(0)).getVariantMaxSubcolumnsCount());
607-
Assertions.assertEquals(10000, ((VariantType) signature.getArgType(0)).getVariantMaxSparseColumnStatisticsSize());
608-
Assertions.assertTrue(signature.getArgType(1) instanceof IntegerType);
609-
Assertions.assertTrue(signature.getArgType(2) instanceof DoubleType);
610-
611-
Assertions.assertTrue(signature.returnType instanceof BooleanType);
612-
}
613-
614-
@Test
615-
void testDynamicComputeVariantArgsWithNullLiteral() {
616-
FunctionSignature signature = FunctionSignature.ret(BooleanType.INSTANCE)
617-
.args(VariantType.INSTANCE, IntegerType.INSTANCE);
618-
619-
List<Expression> arguments = Lists.newArrayList(
620-
new NullLiteral(),
621-
new IntegerLiteral(10));
622-
623-
signature = ComputeSignatureHelper.dynamicComputeVariantArgs(signature, arguments);
624-
625-
Assertions.assertTrue(signature.getArgType(0) instanceof VariantType);
626-
Assertions.assertEquals(0, ((VariantType) signature.getArgType(0)).getVariantMaxSubcolumnsCount());
627-
Assertions.assertEquals(10000, ((VariantType) signature.getArgType(0)).getVariantMaxSparseColumnStatisticsSize());
628-
Assertions.assertTrue(signature.getArgType(1) instanceof IntegerType);
629-
}
630-
631-
@Test
632-
void testDynamicComputeVariantArgsNoVariantReturnType() {
633-
VariantType variantType = new VariantType(300);
634-
FunctionSignature signature = FunctionSignature.ret(IntegerType.INSTANCE)
635-
.args(VariantType.INSTANCE);
636-
637-
List<Expression> arguments = Lists.newArrayList(
638-
new MockVariantExpression(variantType));
639-
640-
signature = ComputeSignatureHelper.dynamicComputeVariantArgs(signature, arguments);
641-
642-
Assertions.assertTrue(signature.returnType instanceof IntegerType);
643-
644-
Assertions.assertTrue(signature.getArgType(0) instanceof VariantType);
645-
Assertions.assertEquals(300, ((VariantType) signature.getArgType(0)).getVariantMaxSubcolumnsCount());
646-
Assertions.assertEquals(10000, ((VariantType) signature.getArgType(0)).getVariantMaxSparseColumnStatisticsSize());
647-
}
648-
649-
@Test
650-
void testDynamicComputeVariantArgsWithVarArgsThrowsException() {
651-
VariantType variantType1 = new VariantType(150);
652-
VariantType variantType2 = new VariantType(250);
653-
FunctionSignature signature = FunctionSignature.ret(VariantType.INSTANCE)
654-
.args(VariantType.INSTANCE, VariantType.INSTANCE);
655-
656-
List<Expression> arguments = Lists.newArrayList(
657-
new MockVariantExpression(variantType1),
658-
new MockVariantExpression(variantType2));
659-
660-
AnalysisException exception = Assertions.assertThrows(AnalysisException.class, () -> {
661-
ComputeSignatureHelper.dynamicComputeVariantArgs(signature, arguments);
662-
});
663-
664-
Assertions.assertEquals("variant type is not supported in multiple arguments", exception.getMessage());
665-
}
666-
667-
@Test
668-
void testDynamicComputeVariantArgsWithComputeSignature() {
669-
VariantType variantType = new VariantType(200);
670-
FunctionSignature signature = FunctionSignature.ret(VariantType.INSTANCE)
671-
.args(VariantType.INSTANCE);
672-
673-
List<Expression> arguments = Lists.newArrayList(
674-
new MockVariantExpression(variantType));
675-
676-
signature = ComputeSignatureHelper.dynamicComputeVariantArgs(signature, arguments);
677-
678-
Assertions.assertTrue(signature.returnType instanceof VariantType);
679-
Assertions.assertEquals(200, ((VariantType) signature.returnType).getVariantMaxSubcolumnsCount());
680-
Assertions.assertEquals(10000, ((VariantType) signature.returnType).getVariantMaxSparseColumnStatisticsSize());
681-
Assertions.assertTrue(signature.getArgType(0) instanceof VariantType);
682-
Assertions.assertEquals(200, ((VariantType) signature.getArgType(0)).getVariantMaxSubcolumnsCount());
683-
Assertions.assertEquals(10000, ((VariantType) signature.getArgType(0)).getVariantMaxSparseColumnStatisticsSize());
684-
}
685-
686-
/**
687-
* Mock Expression class for testing VariantType
688-
*/
689-
private static class MockVariantExpression extends Expression {
690-
private final VariantType variantType;
691-
692-
public MockVariantExpression(VariantType variantType) {
693-
super(Collections.emptyList());
694-
this.variantType = variantType;
695-
}
696-
697-
@Override
698-
public DataType getDataType() {
699-
return variantType;
700-
}
701-
702-
@Override
703-
public boolean nullable() {
704-
return true;
705-
}
706-
707-
@Override
708-
public Expression withChildren(List<Expression> children) {
709-
return this;
710-
}
711-
712-
@Override
713-
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
714-
return visitor.visit(this, context);
715-
}
716-
717-
@Override
718-
public int arity() {
719-
return 0;
720-
}
721-
722-
@Override
723-
public Expression child(int index) {
724-
throw new IndexOutOfBoundsException("MockVariantExpression has no children");
725-
}
726-
727-
@Override
728-
public List<Expression> children() {
729-
return Collections.emptyList();
730-
}
731-
}
732-
733535
@Test
734536
void testDateV1AndDateTimeV1TypeConversion() {
735537
// Test DateType -> DateV2Type conversion with implementAnyDataTypeWithOutIndex

0 commit comments

Comments
 (0)