Skip to content

Commit 522403b

Browse files
authored
feat: add fp16 support to Substrait (#18086)
## Which issue does this PR close? - Closes #16298 ## Rationale for this change Float16 is an Arrow type. Substrait serialization for the type is defined in https://github.com/apache/arrow/blame/main/format/substrait/extension_types.yaml as part of Arrow. We should support it. This picks up where #16793 leaves off. ## What changes are included in this PR? Support for converting DataType::Float16 to/from Substrait. Support for converting ScalarValue::Float16 to/from Substrait. ## Are these changes tested? Yes ## Are there any user-facing changes? Yes. The `SubstraitProducer` trait received a new method (`register_type`) which downstream implementors will need to provide an implementation for. The example custom producer has been updated with a default implementation. One public method that changed is [`datafusion_substrait::logical_plan::producer::from_empty_relation`](https://docs.rs/datafusion-substrait/50.2.0/datafusion_substrait/logical_plan/producer/fn.from_empty_relation.html). I'm not sure if that is meant to be part of the public API (for one thing, it is undocumented, though maybe this is because it serves an obvious purpose. It also returns a `Rel` which is a pretty internal structure).
1 parent ec3d20b commit 522403b

File tree

13 files changed

+166
-37
lines changed

13 files changed

+166
-37
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/substrait/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ async-recursion = "1.0"
3535
async-trait = { workspace = true }
3636
chrono = { workspace = true }
3737
datafusion = { workspace = true, features = ["sql"] }
38+
half = { workspace = true }
3839
itertools = { workspace = true }
3940
object_store = { workspace = true }
4041
pbjson-types = { workspace = true }

datafusion/substrait/src/logical_plan/consumer/expr/literal.rs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
use crate::logical_plan::consumer::types::from_substrait_type;
1919
use crate::logical_plan::consumer::utils::{next_struct_field_name, DEFAULT_TIMEZONE};
2020
use crate::logical_plan::consumer::SubstraitConsumer;
21+
use crate::variation_const::FLOAT_16_TYPE_NAME;
2122
#[allow(deprecated)]
2223
use crate::variation_const::{
2324
DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF,
@@ -38,6 +39,7 @@ use datafusion::common::{
3839
not_impl_err, plan_err, substrait_datafusion_err, substrait_err, ScalarValue,
3940
};
4041
use datafusion::logical_expr::Expr;
42+
use prost::Message;
4143
use std::sync::Arc;
4244
use substrait::proto;
4345
use substrait::proto::expression::literal::user_defined::Val;
@@ -440,8 +442,6 @@ pub(crate) fn from_substrait_literal(
440442
return Ok(value);
441443
}
442444

443-
// TODO: remove the code below once the producer has been updated
444-
445445
// Helper function to prevent duplicating this code - can be inlined once the non-extension path is removed
446446
let interval_month_day_nano =
447447
|user_defined: &proto::expression::literal::UserDefined| -> datafusion::common::Result<ScalarValue> {
@@ -474,6 +474,36 @@ pub(crate) fn from_substrait_literal(
474474
.get(&user_defined.type_reference)
475475
{
476476
match name.as_ref() {
477+
FLOAT_16_TYPE_NAME => {
478+
// Rules for encoding fp16 Substrait literals are defined as part of Arrow here:
479+
//
480+
// https://github.com/apache/arrow/blame/bab558061696ddc1841148d6210424b12923d48e/format/substrait/extension_types.yaml#L112
481+
482+
let Some(value) = user_defined.val.as_ref() else {
483+
return substrait_err!("Float16 value is empty");
484+
};
485+
let Val::Value(value_any) = value else {
486+
return substrait_err!(
487+
"Float16 value is not a value type literal"
488+
);
489+
};
490+
if value_any.type_url != "google.protobuf.UInt32Value" {
491+
return substrait_err!(
492+
"Float16 value is not a google.protobuf.UInt32Value"
493+
);
494+
}
495+
let decoded_value =
496+
pbjson_types::UInt32Value::decode(value_any.value.clone())
497+
.map_err(|err| {
498+
substrait_datafusion_err!(
499+
"Failed to decode float16 value: {err}"
500+
)
501+
})?;
502+
let u32_bytes = decoded_value.value.to_le_bytes();
503+
let f16_val =
504+
half::f16::from_le_bytes(u32_bytes[0..2].try_into().unwrap());
505+
return Ok(ScalarValue::Float16(Some(f16_val)));
506+
}
477507
// Kept for backwards compatibility - producers should use IntervalCompound instead
478508
#[allow(deprecated)]
479509
INTERVAL_MONTH_DAY_NANO_TYPE_NAME => {

datafusion/substrait/src/logical_plan/consumer/types.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
use super::utils::{from_substrait_precision, next_struct_field_name, DEFAULT_TIMEZONE};
1919
use super::SubstraitConsumer;
20+
use crate::variation_const::FLOAT_16_TYPE_NAME;
2021
#[allow(deprecated)]
2122
use crate::variation_const::{
2223
DATE_32_TYPE_VARIATION_REF, DATE_64_TYPE_VARIATION_REF,
@@ -251,6 +252,7 @@ pub fn from_substrait_type(
251252
match name.as_ref() {
252253
// Kept for backwards compatibility, producers should use IntervalCompound instead
253254
INTERVAL_MONTH_DAY_NANO_TYPE_NAME => Ok(DataType::Interval(IntervalUnit::MonthDayNano)),
255+
FLOAT_16_TYPE_NAME => Ok(DataType::Float16),
254256
_ => not_impl_err!(
255257
"Unsupported Substrait user defined type with ref {} and variation {}",
256258
u.type_reference,
@@ -304,15 +306,15 @@ pub fn from_substrait_named_struct(
304306
})?,
305307
&base_schema.names,
306308
&mut name_idx,
307-
);
309+
)?;
308310
if name_idx != base_schema.names.len() {
309311
return substrait_err!(
310312
"Names list must match exactly to nested schema, but found {} uses for {} names",
311313
name_idx,
312314
base_schema.names.len()
313315
);
314316
}
315-
DFSchema::try_from(Schema::new(fields?))
317+
DFSchema::try_from(Schema::new(fields))
316318
}
317319

318320
fn from_substrait_struct_type(

datafusion/substrait/src/logical_plan/producer/expr/cast.rs

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ pub fn from_cast(
4040
nullable: true,
4141
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
4242
literal_type: Some(LiteralType::Null(to_substrait_type(
43-
data_type, true,
43+
producer, data_type, true,
4444
)?)),
4545
};
4646
return Ok(Expression {
@@ -51,7 +51,7 @@ pub fn from_cast(
5151
Ok(Expression {
5252
rex_type: Some(RexType::Cast(Box::new(
5353
substrait::proto::expression::Cast {
54-
r#type: Some(to_substrait_type(data_type, true)?),
54+
r#type: Some(to_substrait_type(producer, data_type, true)?),
5555
input: Some(Box::new(producer.handle_expr(expr, schema)?)),
5656
failure_behavior: FailureBehavior::ThrowException.into(),
5757
},
@@ -68,7 +68,7 @@ pub fn from_try_cast(
6868
Ok(Expression {
6969
rex_type: Some(RexType::Cast(Box::new(
7070
substrait::proto::expression::Cast {
71-
r#type: Some(to_substrait_type(data_type, true)?),
71+
r#type: Some(to_substrait_type(producer, data_type, true)?),
7272
input: Some(Box::new(producer.handle_expr(expr, schema)?)),
7373
failure_behavior: FailureBehavior::ReturnNull.into(),
7474
},
@@ -79,7 +79,9 @@ pub fn from_try_cast(
7979
#[cfg(test)]
8080
mod tests {
8181
use super::*;
82-
use crate::logical_plan::producer::to_substrait_extended_expr;
82+
use crate::logical_plan::producer::{
83+
to_substrait_extended_expr, DefaultSubstraitProducer,
84+
};
8385
use datafusion::arrow::datatypes::{DataType, Field};
8486
use datafusion::common::DFSchema;
8587
use datafusion::execution::SessionStateBuilder;
@@ -92,6 +94,8 @@ mod tests {
9294
let empty_schema = DFSchemaRef::new(DFSchema::empty());
9395
let field = Field::new("out", DataType::Int32, false);
9496

97+
let mut producer = DefaultSubstraitProducer::new(&state);
98+
9599
let expr = Expr::Literal(ScalarValue::Null, None)
96100
.cast_to(&DataType::Int32, &empty_schema)
97101
.unwrap();
@@ -107,7 +111,7 @@ mod tests {
107111
nullable: true,
108112
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
109113
literal_type: Some(LiteralType::Null(
110-
to_substrait_type(&DataType::Int32, true).unwrap(),
114+
to_substrait_type(&mut producer, &DataType::Int32, true).unwrap(),
111115
)),
112116
};
113117
let expected = Expression {
@@ -131,13 +135,16 @@ mod tests {
131135
typed_null.referred_expr[0].expr_type.as_ref().unwrap()
132136
{
133137
let cast_expr = substrait::proto::expression::Cast {
134-
r#type: Some(to_substrait_type(&DataType::Int32, true).unwrap()),
138+
r#type: Some(
139+
to_substrait_type(&mut producer, &DataType::Int32, true).unwrap(),
140+
),
135141
input: Some(Box::new(Expression {
136142
rex_type: Some(RexType::Literal(Literal {
137143
nullable: true,
138144
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
139145
literal_type: Some(LiteralType::Null(
140-
to_substrait_type(&DataType::Int64, true).unwrap(),
146+
to_substrait_type(&mut producer, &DataType::Int64, true)
147+
.unwrap(),
141148
)),
142149
})),
143150
})),

datafusion/substrait/src/logical_plan/producer/expr/literal.rs

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
use crate::logical_plan::producer::{to_substrait_type, SubstraitProducer};
1919
use crate::variation_const::{
2020
DATE_32_TYPE_VARIATION_REF, DECIMAL_128_TYPE_VARIATION_REF,
21-
DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF,
21+
DEFAULT_CONTAINER_TYPE_VARIATION_REF, DEFAULT_TYPE_VARIATION_REF, FLOAT_16_TYPE_NAME,
2222
LARGE_CONTAINER_TYPE_VARIATION_REF, TIME_32_TYPE_VARIATION_REF,
2323
TIME_64_TYPE_VARIATION_REF, UNSIGNED_INTEGER_TYPE_VARIATION_REF,
2424
VIEW_CONTAINER_TYPE_VARIATION_REF,
@@ -61,6 +61,7 @@ pub(crate) fn to_substrait_literal(
6161
nullable: true,
6262
type_variation_reference: DEFAULT_TYPE_VARIATION_REF,
6363
literal_type: Some(LiteralType::Null(to_substrait_type(
64+
producer,
6465
&value.data_type(),
6566
true,
6667
)?)),
@@ -94,6 +95,41 @@ pub(crate) fn to_substrait_literal(
9495
LiteralType::I64(*n as i64),
9596
UNSIGNED_INTEGER_TYPE_VARIATION_REF,
9697
),
98+
ScalarValue::Float16(Some(f)) => {
99+
// Rules for encoding fp16 Substrait literals are defined as part of Arrow here:
100+
//
101+
// https://github.com/apache/arrow/blame/bab558061696ddc1841148d6210424b12923d48e/format/substrait/extension_types.yaml#L112
102+
//
103+
// fp16 literals are encoded as user defined literals with
104+
// a google.protobuf.UInt32Value message where the lower 16 bits are
105+
// the fp16 value.
106+
let type_anchor = producer.register_type(FLOAT_16_TYPE_NAME.to_string());
107+
108+
// The spec says "lower 16 bits" but neglects to mention the endianness.
109+
// Let's just use little-endian for now.
110+
//
111+
// See https://github.com/apache/arrow/issues/47846
112+
let f_bytes = f.to_le_bytes();
113+
let value = u32::from_le_bytes([f_bytes[0], f_bytes[1], 0, 0]);
114+
115+
let value = pbjson_types::UInt32Value { value };
116+
let encoded_value = prost::Message::encode_to_vec(&value);
117+
(
118+
LiteralType::UserDefined(
119+
substrait::proto::expression::literal::UserDefined {
120+
type_reference: type_anchor,
121+
type_parameters: vec![],
122+
val: Some(substrait::proto::expression::literal::user_defined::Val::Value(
123+
pbjson_types::Any {
124+
type_url: "google.protobuf.UInt32Value".to_string(),
125+
value: encoded_value.into(),
126+
},
127+
)),
128+
},
129+
),
130+
DEFAULT_TYPE_VARIATION_REF,
131+
)
132+
}
97133
ScalarValue::Float32(Some(f)) => {
98134
(LiteralType::Fp32(*f), DEFAULT_TYPE_VARIATION_REF)
99135
}
@@ -241,7 +277,7 @@ pub(crate) fn to_substrait_literal(
241277
),
242278
ScalarValue::Map(m) => {
243279
let map = if m.is_empty() || m.value(0).is_empty() {
244-
let mt = to_substrait_type(m.data_type(), m.is_nullable())?;
280+
let mt = to_substrait_type(producer, m.data_type(), m.is_nullable())?;
245281
let mt = match mt {
246282
substrait::proto::Type {
247283
kind: Some(r#type::Kind::Map(mt)),
@@ -354,12 +390,13 @@ fn convert_array_to_literal_list<T: OffsetSizeTrait>(
354390
.collect::<datafusion::common::Result<Vec<_>>>()?;
355391

356392
if values.is_empty() {
357-
let lt = match to_substrait_type(array.data_type(), array.is_nullable())? {
358-
substrait::proto::Type {
359-
kind: Some(r#type::Kind::List(lt)),
360-
} => lt.as_ref().to_owned(),
361-
_ => unreachable!(),
362-
};
393+
let lt =
394+
match to_substrait_type(producer, array.data_type(), array.is_nullable())? {
395+
substrait::proto::Type {
396+
kind: Some(r#type::Kind::List(lt)),
397+
} => lt.as_ref().to_owned(),
398+
_ => unreachable!(),
399+
};
363400
Ok(LiteralType::EmptyList(lt))
364401
} else {
365402
Ok(LiteralType::List(List { values }))

datafusion/substrait/src/logical_plan/producer/expr/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ pub fn to_substrait_extended_expr(
7878
})
7979
})
8080
.collect::<datafusion::common::Result<Vec<_>>>()?;
81-
let substrait_schema = to_substrait_named_struct(schema)?;
81+
let substrait_schema = to_substrait_named_struct(&mut producer, schema)?;
8282

8383
let extensions = producer.get_extensions();
8484
Ok(Box::new(ExtendedExpression {

datafusion/substrait/src/logical_plan/producer/plan.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ pub fn to_substrait_plan(
3636
let plan_rels = vec![PlanRel {
3737
rel_type: Some(plan_rel::RelType::Root(RelRoot {
3838
input: Some(*producer.handle_plan(plan)?),
39-
names: to_substrait_named_struct(plan.schema())?.names,
39+
names: to_substrait_named_struct(&mut producer, plan.schema())?.names,
4040
})),
4141
}];
4242

datafusion/substrait/src/logical_plan/producer/rel/read_rel.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ pub fn from_table_scan(
4848
});
4949

5050
let table_schema = scan.source.schema().to_dfschema_ref()?;
51-
let base_schema = to_substrait_named_struct(&table_schema)?;
51+
let base_schema = to_substrait_named_struct(producer, &table_schema)?;
5252

5353
let filter_option = if scan.filters.is_empty() {
5454
None
@@ -83,15 +83,18 @@ pub fn from_table_scan(
8383
}))
8484
}
8585

86-
pub fn from_empty_relation(e: &EmptyRelation) -> datafusion::common::Result<Box<Rel>> {
86+
pub fn from_empty_relation(
87+
producer: &mut impl SubstraitProducer,
88+
e: &EmptyRelation,
89+
) -> datafusion::common::Result<Box<Rel>> {
8790
if e.produce_one_row {
8891
return not_impl_err!("Producing a row from empty relation is unsupported");
8992
}
9093
#[allow(deprecated)]
9194
Ok(Box::new(Rel {
9295
rel_type: Some(RelType::Read(Box::new(ReadRel {
9396
common: None,
94-
base_schema: Some(to_substrait_named_struct(&e.schema)?),
97+
base_schema: Some(to_substrait_named_struct(producer, &e.schema)?),
9598
filter: None,
9699
best_effort_filter: None,
97100
projection: None,
@@ -135,7 +138,7 @@ pub fn from_values(
135138
Ok(Box::new(Rel {
136139
rel_type: Some(RelType::Read(Box::new(ReadRel {
137140
common: None,
138-
base_schema: Some(to_substrait_named_struct(&v.schema)?),
141+
base_schema: Some(to_substrait_named_struct(producer, &v.schema)?),
139142
filter: None,
140143
best_effort_filter: None,
141144
projection: None,

datafusion/substrait/src/logical_plan/producer/substrait_producer.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,10 @@ use substrait::proto::{
7070
/// self.extensions.register_function(signature)
7171
/// }
7272
///
73+
/// fn register_type(&mut self, type_name: String) -> u32 {
74+
/// self.extensions.register_type(type_name)
75+
/// }
76+
///
7377
/// fn get_extensions(self) -> Extensions {
7478
/// self.extensions
7579
/// }
@@ -114,6 +118,15 @@ pub trait SubstraitProducer: Send + Sync + Sized {
114118
/// there is one. Otherwise, it should generate a new anchor.
115119
fn register_function(&mut self, signature: String) -> u32;
116120

121+
/// Within a Substrait plan, user defined types are referenced using type anchors that are stored at
122+
/// the top level of the [Plan](substrait::proto::Plan) within
123+
/// [ExtensionType](substrait::proto::extensions::simple_extension_declaration::ExtensionType)
124+
/// messages.
125+
///
126+
/// When given a type name, this method should return the existing anchor for it if
127+
/// there is one. Otherwise, it should generate a new anchor.
128+
fn register_type(&mut self, name: String) -> u32;
129+
117130
/// Consume the producer to generate the [Extensions] for the Substrait plan based on the
118131
/// functions that have been registered
119132
fn get_extensions(self) -> Extensions;
@@ -182,7 +195,7 @@ pub trait SubstraitProducer: Send + Sync + Sized {
182195
&mut self,
183196
plan: &EmptyRelation,
184197
) -> datafusion::common::Result<Box<Rel>> {
185-
from_empty_relation(plan)
198+
from_empty_relation(self, plan)
186199
}
187200

188201
fn handle_subquery_alias(
@@ -367,6 +380,10 @@ impl SubstraitProducer for DefaultSubstraitProducer<'_> {
367380
self.extensions.register_function(fn_name)
368381
}
369382

383+
fn register_type(&mut self, type_name: String) -> u32 {
384+
self.extensions.register_type(type_name)
385+
}
386+
370387
fn get_extensions(self) -> Extensions {
371388
self.extensions
372389
}

0 commit comments

Comments
 (0)