Skip to content

Commit 6b340d3

Browse files
authored
feat: Support variadic function in CometFuzz (#2682)
* feat: Support variadic function in CometFuzz * Check input type array has single element
1 parent e4e1148 commit 6b340d3

File tree

2 files changed

+39
-11
lines changed

2 files changed

+39
-11
lines changed

fuzz-testing/src/main/scala/org/apache/comet/fuzz/Meta.scala

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ case class SparkMapType(keyType: SparkType, valueType: SparkType) extends SparkT
4444
case class SparkStructType(fields: Seq[SparkType]) extends SparkType
4545
case object SparkAnyType extends SparkType
4646

47-
case class FunctionSignature(inputTypes: Seq[SparkType])
47+
case class FunctionSignature(inputTypes: Seq[SparkType], varArgs: Boolean = false)
4848

4949
case class Function(name: String, signatures: Seq[FunctionSignature])
5050

@@ -65,11 +65,20 @@ object Meta {
6565
(DataTypes.StringType, 0.2),
6666
(DataTypes.BinaryType, 0.1))
6767

68-
private def createFunctionWithInputTypes(name: String, inputs: Seq[SparkType]): Function = {
69-
Function(name, Seq(FunctionSignature(inputs)))
68+
private def createFunctionWithInputTypes(
69+
name: String,
70+
inputs: Seq[SparkType],
71+
varArgs: Boolean = false): Function = {
72+
Function(name, Seq(FunctionSignature(inputs, varArgs)))
73+
createFunctions(name, Seq(FunctionSignature(inputs, varArgs)))
7074
}
7175

7276
private def createFunctions(name: String, signatures: Seq[FunctionSignature]): Function = {
77+
signatures.foreach { s =>
78+
assert(
79+
!s.varArgs || s.inputTypes.length == 1,
80+
s"Variadic function $s must have exactly one input type")
81+
}
7382
Function(name, signatures)
7483
}
7584

@@ -126,13 +135,11 @@ object Meta {
126135
SparkTypeOneOf(
127136
Seq(
128137
SparkStringType,
129-
SparkArrayType(
130-
SparkTypeOneOf(Seq(SparkStringType, SparkNumericType, SparkBinaryType))))),
131-
SparkTypeOneOf(
132-
Seq(
133-
SparkStringType,
134-
SparkArrayType(
135-
SparkTypeOneOf(Seq(SparkStringType, SparkNumericType, SparkBinaryType))))))),
138+
SparkBinaryType,
139+
SparkArrayType(SparkStringType),
140+
SparkArrayType(SparkNumericType),
141+
SparkArrayType(SparkBinaryType)))),
142+
varArgs = true),
136143
createFunctionWithInputTypes("concat_ws", Seq(SparkStringType, SparkStringType)),
137144
createFunctionWithInputTypes("contains", Seq(SparkStringType, SparkStringType)),
138145
createFunctionWithInputTypes("ends_with", Seq(SparkStringType, SparkStringType)),

fuzz-testing/src/main/scala/org/apache/comet/fuzz/QueryGen.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ package org.apache.comet.fuzz
2121

2222
import java.io.{BufferedWriter, FileWriter}
2323

24+
import scala.annotation.tailrec
2425
import scala.collection.mutable
2526
import scala.util.Random
2627

@@ -103,7 +104,12 @@ object QueryGen {
103104
val func = Utils.randomChoice(Meta.scalarFunc, r)
104105
try {
105106
val signature = Utils.randomChoice(func.signatures, r)
106-
val args = signature.inputTypes.map(x => pickRandomColumn(r, table, x))
107+
val args =
108+
if (signature.varArgs) {
109+
pickRandomColumns(r, table, signature.inputTypes.head)
110+
} else {
111+
signature.inputTypes.map(x => pickRandomColumn(r, table, x))
112+
}
107113

108114
// Example SELECT c0, log(c0) as x FROM test0
109115
s"SELECT ${args.mkString(", ")}, ${func.name}(${args.mkString(", ")}) AS x " +
@@ -117,6 +123,21 @@ object QueryGen {
117123
}
118124
}
119125

126+
@tailrec
127+
private def pickRandomColumns(r: Random, df: DataFrame, targetType: SparkType): Seq[String] = {
128+
targetType match {
129+
case SparkTypeOneOf(choices) =>
130+
val chosenType = Utils.randomChoice(choices, r)
131+
pickRandomColumns(r, df, chosenType)
132+
case _ =>
133+
var columns = Set.empty[String]
134+
for (_ <- 0 to r.nextInt(df.columns.length)) {
135+
columns += pickRandomColumn(r, df, targetType)
136+
}
137+
columns.toSeq
138+
}
139+
}
140+
120141
private def pickRandomColumn(r: Random, df: DataFrame, targetType: SparkType): String = {
121142
targetType match {
122143
case SparkAnyType =>

0 commit comments

Comments
 (0)