Skip to content

Commit f56006a

Browse files
authored
docs: Update contributor guide for adding a new expression (#2704)
1 parent dcf9f09 commit f56006a

File tree

1 file changed

+210
-57
lines changed

1 file changed

+210
-57
lines changed

docs/source/contributor-guide/adding_a_new_expression.md

Lines changed: 210 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -41,26 +41,172 @@ Once you know what you want to add, you'll need to update the query planner to r
4141

4242
### Adding the Expression in Scala
4343

44-
The `QueryPlanSerde` object has a method `exprToProto`, which is responsible for converting a Spark expression to a protobuf expression. Within that method is an `exprToProtoInternal` method that contains a large match statement for each expression type. You'll need to add a new case to this match statement for your new expression.
44+
DataFusion Comet uses a framework based on the `CometExpressionSerde` trait for converting Spark expressions to protobuf. Instead of a large match statement, each expression type has its own serialization handler. For aggregate expressions, use the `CometAggregateExpressionSerde` trait instead.
45+
46+
#### Creating a CometExpressionSerde Implementation
47+
48+
First, create an object that extends `CometExpressionSerde[T]` where `T` is the Spark expression type. This is typically added to one of the serde files in `spark/src/main/scala/org/apache/comet/serde/` (e.g., `math.scala`, `strings.scala`, `arrays.scala`, etc.).
4549

4650
For example, the `unhex` function looks like this:
4751

4852
```scala
49-
case e: Unhex =>
50-
val unHex = unhexSerde(e)
53+
object CometUnhex extends CometExpressionSerde[Unhex] {
54+
override def convert(
55+
expr: Unhex,
56+
inputs: Seq[Attribute],
57+
binding: Boolean): Option[ExprOuterClass.Expr] = {
58+
val childExpr = exprToProtoInternal(expr.child, inputs, binding)
59+
val failOnErrorExpr = exprToProtoInternal(Literal(expr.failOnError), inputs, binding)
60+
61+
val optExpr =
62+
scalarFunctionExprToProtoWithReturnType(
63+
"unhex",
64+
expr.dataType,
65+
false,
66+
childExpr,
67+
failOnErrorExpr)
68+
optExprWithInfo(optExpr, expr, expr.child)
69+
}
70+
}
71+
```
72+
73+
The `CometExpressionSerde` trait provides three methods you can override:
74+
75+
* `convert(expr: T, inputs: Seq[Attribute], binding: Boolean): Option[Expr]` - **Required**. Converts the Spark expression to protobuf. Return `None` if the expression cannot be converted.
76+
* `getSupportLevel(expr: T): SupportLevel` - Optional. Returns the level of support for the expression. See "Using getSupportLevel" section below for details.
77+
* `getExprConfigName(expr: T): String` - Optional. Returns a short name for configuration keys. Defaults to the Spark class name.
78+
79+
For simple scalar functions that map directly to a DataFusion function, you can use the built-in `CometScalarFunction` implementation:
80+
81+
```scala
82+
classOf[Cos] -> CometScalarFunction("cos")
83+
```
84+
85+
#### Registering the Expression Handler
86+
87+
Once you've created your `CometExpressionSerde` implementation, register it in `QueryPlanSerde.scala` by adding it to the appropriate expression map (e.g., `mathExpressions`, `stringExpressions`, `predicateExpressions`, etc.):
88+
89+
```scala
90+
private val mathExpressions: Map[Class[_ <: Expression], CometExpressionSerde[_]] = Map(
91+
// ... other expressions ...
92+
classOf[Unhex] -> CometUnhex,
93+
classOf[Hex] -> CometHex)
94+
```
95+
96+
The `exprToProtoInternal` method will automatically use this mapping to find and invoke your handler when it encounters the corresponding Spark expression type.
97+
98+
A few things to note:
99+
100+
* The `convert` method is recursively called on child expressions using `exprToProtoInternal`, so you'll need to make sure that the child expressions are also converted to protobuf.
101+
* `scalarFunctionExprToProtoWithReturnType` is for scalar functions that need to return type information. Your expression may use a different method depending on the type of expression.
102+
* Use helper methods like `createBinaryExpr` and `createUnaryExpr` from `QueryPlanSerde` for common expression patterns.
103+
104+
#### Using getSupportLevel
105+
106+
The `getSupportLevel` method allows you to control whether an expression should be executed by Comet based on various conditions such as data types, parameter values, or other expression-specific constraints. This is particularly useful when:
107+
108+
1. Your expression only supports specific data types
109+
2. Your expression has known incompatibilities with Spark's behavior
110+
3. Your expression has edge cases that aren't yet supported
111+
112+
The method returns one of three `SupportLevel` values:
113+
114+
* **`Compatible(notes: Option[String] = None)`** - Comet supports this expression with full compatibility with Spark, or may have known differences in specific edge cases that are unlikely to be an issue for most users. This is the default if you don't override `getSupportLevel`.
115+
* **`Incompatible(notes: Option[String] = None)`** - Comet supports this expression but results can be different from Spark. The expression will only be used if `spark.comet.expr.allowIncompatible=true` or the expression-specific config `spark.comet.expr.<exprName>.allowIncompatible=true` is set.
116+
* **`Unsupported(notes: Option[String] = None)`** - Comet does not support this expression under the current conditions. The expression will not be used and Spark will fall back to its native execution.
117+
118+
All three support levels accept an optional `notes` parameter to provide additional context about the support level.
119+
120+
##### Examples
121+
122+
**Example 1: Restricting to specific data types**
123+
124+
The `Abs` expression only supports numeric types:
125+
126+
```scala
127+
object CometAbs extends CometExpressionSerde[Abs] {
128+
override def getSupportLevel(expr: Abs): SupportLevel = {
129+
expr.child.dataType match {
130+
case _: NumericType =>
131+
Compatible()
132+
case _ =>
133+
// Spark supports NumericType, DayTimeIntervalType, and YearMonthIntervalType
134+
Unsupported(Some("Only integral, floating-point, and decimal types are supported"))
135+
}
136+
}
137+
138+
override def convert(
139+
expr: Abs,
140+
inputs: Seq[Attribute],
141+
binding: Boolean): Option[ExprOuterClass.Expr] = {
142+
// ... conversion logic ...
143+
}
144+
}
145+
```
146+
147+
**Example 2: Validating parameter values**
148+
149+
The `TruncDate` expression only supports specific format strings:
150+
151+
```scala
152+
object CometTruncDate extends CometExpressionSerde[TruncDate] {
153+
val supportedFormats: Seq[String] =
154+
Seq("year", "yyyy", "yy", "quarter", "mon", "month", "mm", "week")
155+
156+
override def getSupportLevel(expr: TruncDate): SupportLevel = {
157+
expr.format match {
158+
case Literal(fmt: UTF8String, _) =>
159+
if (supportedFormats.contains(fmt.toString.toLowerCase(Locale.ROOT))) {
160+
Compatible()
161+
} else {
162+
Unsupported(Some(s"Format $fmt is not supported"))
163+
}
164+
case _ =>
165+
Incompatible(
166+
Some("Invalid format strings will throw an exception instead of returning NULL"))
167+
}
168+
}
169+
170+
override def convert(
171+
expr: TruncDate,
172+
inputs: Seq[Attribute],
173+
binding: Boolean): Option[ExprOuterClass.Expr] = {
174+
// ... conversion logic ...
175+
}
176+
}
177+
```
178+
179+
**Example 3: Marking known incompatibilities**
51180

52-
val childExpr = exprToProtoInternal(unHex._1, inputs)
53-
val failOnErrorExpr = exprToProtoInternal(unHex._2, inputs)
181+
The `ArrayAppend` expression has known behavioral differences from Spark:
54182

55-
val optExpr =
56-
scalarExprToProtoWithReturnType("unhex", e.dataType, childExpr, failOnErrorExpr)
57-
optExprWithInfo(optExpr, expr, unHex._1)
183+
```scala
184+
object CometArrayAppend extends CometExpressionSerde[ArrayAppend] {
185+
override def getSupportLevel(expr: ArrayAppend): SupportLevel = Incompatible(None)
186+
187+
override def convert(
188+
expr: ArrayAppend,
189+
inputs: Seq[Attribute],
190+
binding: Boolean): Option[ExprOuterClass.Expr] = {
191+
// ... conversion logic ...
192+
}
193+
}
58194
```
59195

60-
A few things to note here:
196+
This expression will only be used when users explicitly enable incompatible expressions via configuration.
197+
198+
##### How getSupportLevel Affects Execution
199+
200+
When the query planner encounters an expression:
61201

62-
* The function is recursively called on child expressions, so you'll need to make sure that the child expressions are also converted to protobuf.
63-
* `scalarExprToProtoWithReturnType` is for scalar functions that need return type information. Your expression may use a different method depending on the type of expression.
202+
1. It first checks if the expression is explicitly disabled via `spark.comet.expr.<exprName>.enabled=false`
203+
2. It then calls `getSupportLevel` on the expression handler
204+
3. Based on the result:
205+
- `Compatible()`: Expression proceeds to conversion
206+
- `Incompatible()`: Expression is skipped unless `spark.comet.expr.allowIncompatible=true` or expression-specific allow config is set
207+
- `Unsupported()`: Expression is skipped and a fallback to Spark occurs
208+
209+
Any notes provided will be logged to help with debugging and understanding why an expression was not used.
64210

65211
#### Adding Spark-side Tests for the New Expression
66212

@@ -92,9 +238,9 @@ test("unhex") {
92238

93239
### Adding the Expression To the Protobuf Definition
94240

95-
Once you have the expression implemented in Scala, you might need to update the protobuf definition to include the new expression. You may not need to do this if the expression is already covered by the existing protobuf definition (e.g. you're adding a new scalar function).
241+
Once you have the expression implemented in Scala, you might need to update the protobuf definition to include the new expression. You may not need to do this if the expression is already covered by the existing protobuf definition (e.g. you're adding a new scalar function that uses the `ScalarFunc` message).
96242

97-
You can find the protobuf definition in `expr.proto`, and in particular the `Expr` or potentially the `AggExpr`. These are similar in theory to the large case statement in `QueryPlanSerde`, but in protobuf format. So if you were to add a new expression called `Add2`, you would add a new case to the `Expr` message like so:
243+
You can find the protobuf definition in `native/proto/src/proto/expr.proto`, and in particular the `Expr` or potentially the `AggExpr` messages. If you were to add a new expression called `Add2`, you would add a new case to the `Expr` message like so:
98244

99245
```proto
100246
message Expr {
@@ -118,51 +264,58 @@ message Add2 {
118264

119265
With the serialization complete, the next step is to implement the expression in Rust and ensure that the incoming plan can make use of it.
120266

121-
How this works, is somewhat dependent on the type of expression you're adding, so see the `core/src/execution/datafusion/expressions` directory for examples of how to implement different types of expressions.
267+
How this works is somewhat dependent on the type of expression you're adding. Expression implementations live in the `native/spark-expr/src/` directory, organized by category (e.g., `math_funcs/`, `string_funcs/`, `array_funcs/`).
122268

123269
#### Generally Adding a New Expression
124270

125-
If you're adding a new expression, you'll need to review `create_plan` and `create_expr`. `create_plan` is responsible for translating the incoming plan into a DataFusion plan, and may delegate to `create_expr` to create the physical expressions for the plan.
271+
If you're adding a new expression that requires custom protobuf serialization, you may need to:
126272

127-
If you added a new message to the protobuf definition, you'll add a new match case to the `create_expr` method to handle the new expression. For example, if you added an `Add2` expression, you would add a new case like so:
273+
1. Add a new message to the protobuf definition in `native/proto/src/proto/expr.proto`
274+
2. Update the Rust deserialization code to handle the new protobuf message type
128275

129-
```rust
130-
match spark_expr.expr_struct.as_ref().unwrap() {
131-
...
132-
ExprStruct::Add2(add2) => self.create_binary_expr(...)
133-
}
134-
```
135-
136-
`self.create_binary_expr` is for a binary expression, but if something out of the box is needed, you can create a new `PhysicalExpr` implementation. For example, see `if_expr.rs` for an example of an implementation that doesn't fit the `create_binary_expr` mold.
276+
For most expressions, you can skip this step if you're using the existing scalar function infrastructure.
137277

138278
#### Adding a New Scalar Function Expression
139279

140-
For a new scalar function, you can reuse a lot of code by updating the `create_comet_physical_fun` method to match on the function name and make the scalar UDF to be called. For example, the diff to add the `unhex` function is:
141-
142-
```diff
143-
macro_rules! make_comet_scalar_udf {
144-
($name:expr, $func:ident, $data_type:ident) => {{
145-
146-
+ "unhex" => {
147-
+ let func = Arc::new(spark_unhex);
148-
+ make_comet_scalar_udf!("unhex", func, without data_type)
149-
+ }
280+
For a new scalar function, you can reuse a lot of code by updating the `create_comet_physical_fun` method in `native/spark-expr/src/comet_scalar_funcs.rs`. Add a match case for your function name:
150281

151-
}}
282+
```rust
283+
match fun_name {
284+
// ... other functions ...
285+
"unhex" => {
286+
let func = Arc::new(spark_unhex);
287+
make_comet_scalar_udf!("unhex", func, without data_type)
288+
}
289+
// ... more functions ...
152290
}
153291
```
154292

155-
With that addition, you can now implement the spark function in Rust. This function will look very similar to DataFusion code. For examples, see the `core/src/execution/datafusion/expressions/scalar_funcs` directory.
293+
The `make_comet_scalar_udf!` macro has several variants depending on whether your function needs:
294+
- A data type parameter: `make_comet_scalar_udf!("ceil", spark_ceil, data_type)`
295+
- No data type parameter: `make_comet_scalar_udf!("unhex", func, without data_type)`
296+
- An eval mode: `make_comet_scalar_udf!("decimal_div", spark_decimal_div, data_type, eval_mode)`
297+
- A fail_on_error flag: `make_comet_scalar_udf!("spark_modulo", func, without data_type, fail_on_error)`
156298

157-
Without getting into the internals, the function signature will look like:
299+
#### Implementing the Function
300+
301+
Then implement your function in an appropriate module under `native/spark-expr/src/`. The function signature will look like:
158302

159303
```rust
160-
pub(super) fn spark_unhex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
304+
pub fn spark_unhex(args: &[ColumnarValue]) -> Result<ColumnarValue, DataFusionError> {
161305
// Do the work here
162306
}
163307
```
164308

165-
> **_NOTE:_** If you call the `make_comet_scalar_udf` macro with the data type, the function signature will look include the data type as a second argument.
309+
If your function uses the data type or eval mode, the signature will include those as additional parameters:
310+
311+
```rust
312+
pub fn spark_ceil(
313+
args: &[ColumnarValue],
314+
data_type: &DataType
315+
) -> Result<ColumnarValue, DataFusionError> {
316+
// Implementation
317+
}
318+
```
166319

167320
### API Differences Between Spark Versions
168321

@@ -173,33 +326,33 @@ If the expression you're adding has different behavior across different Spark ve
173326

174327
## Shimming to Support Different Spark Versions
175328

176-
By adding shims for each Spark version, you can provide a consistent interface for the expression across different Spark versions. For example, `unhex` added a new optional parameter is Spark 3.4, for if it should `failOnError` or not. So for version 3.3, the shim is:
329+
If the expression you're adding has different behavior across different Spark versions, you can use the shim system located in `spark/src/main/spark-$SPARK_VERSION/org/apache/comet/shims/CometExprShim.scala` for each Spark version.
177330

178-
```scala
179-
trait CometExprShim {
180-
/**
181-
* Returns a tuple of expressions for the `unhex` function.
182-
*/
183-
def unhexSerde(unhex: Unhex): (Expression, Expression) = {
184-
(unhex.child, Literal(false))
185-
}
186-
}
187-
```
331+
The `CometExprShim` trait provides several mechanisms for handling version differences:
332+
333+
1. **Version-specific methods** - Override methods in the trait to provide version-specific behavior
334+
2. **Version-specific expression handling** - Use `versionSpecificExprToProtoInternal` to handle expressions that only exist in certain Spark versions
188335

189-
And for version 3.4, the shim is:
336+
For example, the `StringDecode` expression only exists in certain Spark versions. The shim handles this:
190337

191338
```scala
192339
trait CometExprShim {
193-
/**
194-
* Returns a tuple of expressions for the `unhex` function.
195-
*/
196-
def unhexSerde(unhex: Unhex): (Expression, Expression) = {
197-
(unhex.child, unhex.failOnError)
340+
def versionSpecificExprToProtoInternal(
341+
expr: Expression,
342+
inputs: Seq[Attribute],
343+
binding: Boolean): Option[Expr] = {
344+
expr match {
345+
case s: StringDecode =>
346+
stringDecode(expr, s.charset, s.bin, inputs, binding)
347+
case _ => None
198348
}
349+
}
199350
}
200351
```
201352

202-
Then when `unhexSerde` is called in the `QueryPlanSerde` object, it will use the correct shim for the Spark version.
353+
The `QueryPlanSerde.exprToProtoInternal` method calls `versionSpecificExprToProtoInternal` first, allowing shims to intercept and handle version-specific expressions before falling back to the standard expression maps.
354+
355+
Your `CometExpressionSerde` implementation can also access shim methods by mixing in the `CometExprShim` trait, though in most cases you can directly access the expression properties if they're available across all supported Spark versions.
203356

204357
## Resources
205358

0 commit comments

Comments
 (0)