Skip to content

Commit 2a06d44

Browse files
committed
feat: Replace DataType with FieldRef in Cast
1 parent 337378a commit 2a06d44

File tree

31 files changed

+500
-385
lines changed

31 files changed

+500
-385
lines changed

datafusion/core/benches/sql_planner.rs

Lines changed: 65 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -103,15 +103,28 @@ fn build_test_data_frame(ctx: &SessionContext, rt: &Runtime) -> DataFrame {
103103
// the actual ops here are largely unimportant as they are just a sample
104104
// of ops that could occur on a dataframe
105105
df = df
106-
.with_column(&c_name, cast(c.clone(), DataType::Utf8))
106+
.with_column(
107+
&c_name,
108+
cast(
109+
c.clone(),
110+
Arc::new(Field::new(&c_name, DataType::Int32, true)),
111+
),
112+
)
107113
.unwrap()
108114
.with_column(
109115
&c_name,
110116
when(
111-
cast(c.clone(), DataType::Int32).gt(lit(135)),
112117
cast(
113-
cast(c.clone(), DataType::Int32) - lit(i + 3),
114-
DataType::Utf8,
118+
c.clone(),
119+
Arc::new(Field::new(&c_name, DataType::Int32, true)),
120+
)
121+
.gt(lit(135)),
122+
cast(
123+
cast(
124+
c.clone(),
125+
Arc::new(Field::new(&c_name, DataType::Int32, true)),
126+
) - lit(i + 3),
127+
Arc::new(Field::new(&c_name, DataType::Int32, true)),
115128
),
116129
)
117130
.otherwise(c.clone())
@@ -122,15 +135,25 @@ fn build_test_data_frame(ctx: &SessionContext, rt: &Runtime) -> DataFrame {
122135
&c_name,
123136
when(
124137
c.clone().is_not_null().and(
125-
cast(c.clone(), DataType::Int32)
126-
.between(lit(120), lit(130)),
138+
cast(
139+
c.clone(),
140+
Arc::new(Field::new(&c_name, DataType::Int32, true)),
141+
)
142+
.between(lit(120), lit(130)),
127143
),
128144
Literal(ScalarValue::Utf8(None), None),
129145
)
130146
.otherwise(
131147
when(
132148
c.clone().is_not_null().and(regexp_like(
133-
cast(c.clone(), DataType::Utf8View),
149+
cast(
150+
c.clone(),
151+
Arc::new(Field::new(
152+
&c_name,
153+
DataType::Utf8View,
154+
true,
155+
)),
156+
),
134157
lit("[0-9]*"),
135158
None,
136159
)),
@@ -146,10 +169,16 @@ fn build_test_data_frame(ctx: &SessionContext, rt: &Runtime) -> DataFrame {
146169
&c_name,
147170
when(
148171
c.clone().is_not_null().and(
149-
cast(c.clone(), DataType::Int32)
150-
.between(lit(90), lit(100)),
172+
cast(
173+
c.clone(),
174+
Arc::new(Field::new(&c_name, DataType::Int32, true)),
175+
)
176+
.between(lit(90), lit(100)),
177+
),
178+
cast(
179+
c.clone(),
180+
Arc::new(Field::new(&c_name, DataType::Int32, true)),
151181
),
152-
cast(c.clone(), DataType::Utf8View),
153182
)
154183
.otherwise(Literal(ScalarValue::Date32(None), None))
155184
.unwrap(),
@@ -159,10 +188,22 @@ fn build_test_data_frame(ctx: &SessionContext, rt: &Runtime) -> DataFrame {
159188
&c_name,
160189
when(
161190
c.clone().is_not_null().and(
162-
cast(c.clone(), DataType::Int32).rem(lit(10)).gt(lit(7)),
191+
cast(
192+
c.clone(),
193+
Arc::new(Field::new(&c_name, DataType::Int32, true)),
194+
)
195+
.rem(lit(10))
196+
.gt(lit(7)),
163197
),
164198
regexp_replace(
165-
cast(c.clone(), DataType::Utf8View),
199+
cast(
200+
c.clone(),
201+
Arc::new(Field::new(
202+
&c_name,
203+
DataType::Utf8View,
204+
true,
205+
)),
206+
),
166207
lit("1"),
167208
lit("a"),
168209
None,
@@ -179,11 +220,21 @@ fn build_test_data_frame(ctx: &SessionContext, rt: &Runtime) -> DataFrame {
179220
&c_name,
180221
try_cast(
181222
to_timestamp(vec![c.clone(), lit("%Y-%m-%d %H:%M:%S")]),
182-
DataType::Timestamp(Nanosecond, Some("UTC".into())),
223+
Arc::new(Field::new(
224+
&c_name,
225+
DataType::Timestamp(Nanosecond, None),
226+
true,
227+
)),
183228
),
184229
)
185230
.unwrap()
186-
.with_column(&c_name, try_cast(c.clone(), DataType::Date32))
231+
.with_column(
232+
&c_name,
233+
try_cast(
234+
c.clone(),
235+
Arc::new(Field::new(&c_name, DataType::Date32, true)),
236+
),
237+
)
187238
.unwrap()
188239
}
189240

datafusion/core/tests/custom_sources_cases/provider_filter_pushdown.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ impl TableProvider for CustomProvider {
183183
Expr::Literal(ScalarValue::Int16(Some(i)), _) => *i as i64,
184184
Expr::Literal(ScalarValue::Int32(Some(i)), _) => *i as i64,
185185
Expr::Literal(ScalarValue::Int64(Some(i)), _) => *i,
186-
Expr::Cast(Cast { expr, data_type: _ }) => match expr.deref() {
186+
Expr::Cast(Cast { expr, field: _ }) => match expr.deref() {
187187
Expr::Literal(lit_value, _) => match lit_value {
188188
ScalarValue::Int8(Some(v)) => *v as i64,
189189
ScalarValue::Int16(Some(v)) => *v as i64,

datafusion/core/tests/dataframe/dataframe_functions.rs

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -416,7 +416,10 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
416416

417417
// the arg2 parameter is a complex expr, but it can be evaluated to the literal value
418418
let alias_expr = Expr::Alias(Alias::new(
419-
cast(lit(0.5), DataType::Float32),
419+
cast(
420+
lit(0.5),
421+
Arc::new(Field::new("arg_2", DataType::Float32, true)),
422+
),
420423
None::<&str>,
421424
"arg_2".to_string(),
422425
));
@@ -436,7 +439,10 @@ async fn test_fn_approx_percentile_cont() -> Result<()> {
436439
);
437440

438441
let alias_expr = Expr::Alias(Alias::new(
439-
cast(lit(0.1), DataType::Float32),
442+
cast(
443+
lit(0.1),
444+
Arc::new(Field::new("arg_2", DataType::Float32, true)),
445+
),
440446
None::<&str>,
441447
"arg_2".to_string(),
442448
));
@@ -1102,7 +1108,7 @@ async fn test_fn_substr() -> Result<()> {
11021108

11031109
#[tokio::test]
11041110
async fn test_cast() -> Result<()> {
1105-
let expr = cast(col("b"), DataType::Float64);
1111+
let expr = cast(col("b"), Arc::new(Field::new("b", DataType::Float64, true)));
11061112
let batches = get_batches(expr).await?;
11071113

11081114
assert_snapshot!(

datafusion/core/tests/dataframe/mod.rs

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2072,7 +2072,13 @@ async fn cast_expr_test() -> Result<()> {
20722072
.await?
20732073
.select_columns(&["c2", "c3"])?
20742074
.limit(0, Some(1))?
2075-
.with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?;
2075+
.with_column(
2076+
"sum",
2077+
cast(
2078+
col("c2") + col("c3"),
2079+
Arc::new(Field::new("sum", DataType::Int64, false)),
2080+
),
2081+
)?;
20762082

20772083
let df_results = df.clone().collect().await?;
20782084
df.clone().show().await?;
@@ -2174,7 +2180,13 @@ async fn cache_test() -> Result<()> {
21742180
.await?
21752181
.select_columns(&["c2", "c3"])?
21762182
.limit(0, Some(1))?
2177-
.with_column("sum", cast(col("c2") + col("c3"), DataType::Int64))?;
2183+
.with_column(
2184+
"sum",
2185+
cast(
2186+
col("c2") + col("c3"),
2187+
Arc::new(Field::new("sum", DataType::Int64, false)),
2188+
),
2189+
)?;
21782190

21792191
let cached_df = df.clone().cache().await?;
21802192

@@ -2672,8 +2684,13 @@ async fn write_table_with_order() -> Result<()> {
26722684
.unwrap();
26732685

26742686
// Ensure the column type matches the target table
2675-
write_df =
2676-
write_df.with_column("tablecol1", cast(col("tablecol1"), DataType::Utf8View))?;
2687+
write_df = write_df.with_column(
2688+
"tablecol1",
2689+
cast(
2690+
col("tablecol1"),
2691+
Arc::new(Field::new("tablecol1", DataType::Utf8, false)),
2692+
),
2693+
)?;
26772694

26782695
let sql_str =
26792696
"create external table data(tablecol1 varchar) stored as parquet location '"
@@ -4688,7 +4705,10 @@ async fn consecutive_projection_same_schema() -> Result<()> {
46884705
let df = df
46894706
.with_column(
46904707
"t",
4691-
cast(Expr::Literal(ScalarValue::Null, None), DataType::Int32),
4708+
cast(
4709+
Expr::Literal(ScalarValue::Null, None),
4710+
Arc::new(Field::new("t", DataType::Int32, true)),
4711+
),
46924712
)
46934713
.unwrap();
46944714
df.clone().show().await.unwrap();

datafusion/core/tests/expr_api/simplification.rs

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,10 @@ fn make_udf_add(volatility: Volatility) -> Arc<ScalarUDF> {
192192
}
193193

194194
fn cast_to_int64_expr(expr: Expr) -> Expr {
195-
Expr::Cast(Cast::new(expr.into(), DataType::Int64))
195+
Expr::Cast(Cast::new(
196+
expr.into(),
197+
Arc::new(Field::new("cast_to_i64", DataType::Int64, true)),
198+
))
196199
}
197200

198201
fn to_timestamp_expr(arg: impl Into<String>) -> Expr {
@@ -748,8 +751,14 @@ fn test_simplify_concat() -> Result<()> {
748751
#[test]
749752
fn test_simplify_cycles() {
750753
// cast(now() as int64) < cast(to_timestamp(0) as int64) + i64::MAX
751-
let expr = cast(now(), DataType::Int64)
752-
.lt(cast(to_timestamp(vec![lit(0)]), DataType::Int64) + lit(i64::MAX));
754+
let expr = cast(
755+
now(),
756+
Arc::new(Field::new("cast_to_i64", DataType::Int64, true)),
757+
)
758+
.lt(cast(
759+
to_timestamp(vec![lit(0)]),
760+
Arc::new(Field::new("cast_to_i64", DataType::Int64, true)),
761+
) + lit(i64::MAX));
753762
let expected = lit(true);
754763
test_simplify_with_cycle_count(expr, expected, 3);
755764
}

datafusion/core/tests/user_defined/user_defined_scalar_functions.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -713,8 +713,12 @@ impl ScalarUDFImpl for CastToI64UDF {
713713
} else {
714714
// need to use an actual cast to get the correct type
715715
Expr::Cast(datafusion_expr::Cast {
716-
expr: Box::new(arg),
717-
data_type: DataType::Int64,
716+
expr: Box::new(arg.clone()),
717+
field: Arc::new(Field::new(
718+
"cast_to_i64",
719+
DataType::Int64,
720+
info.nullable(&arg)?,
721+
)),
718722
})
719723
};
720724
// return the newly written argument to DataFusion

datafusion/expr/src/expr.rs

Lines changed: 27 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1012,13 +1012,14 @@ pub struct Cast {
10121012
/// The expression being cast
10131013
pub expr: Box<Expr>,
10141014
/// The `DataType` the expression will yield
1015-
pub data_type: DataType,
1015+
// pub data_type: DataType,
1016+
pub field: FieldRef,
10161017
}
10171018

10181019
impl Cast {
10191020
/// Create a new Cast expression
1020-
pub fn new(expr: Box<Expr>, data_type: DataType) -> Self {
1021-
Self { expr, data_type }
1021+
pub fn new(expr: Box<Expr>, field: FieldRef) -> Self {
1022+
Self { expr, field }
10221023
}
10231024
}
10241025

@@ -1028,13 +1029,13 @@ pub struct TryCast {
10281029
/// The expression being cast
10291030
pub expr: Box<Expr>,
10301031
/// The `DataType` the expression will yield
1031-
pub data_type: DataType,
1032+
pub field: FieldRef,
10321033
}
10331034

10341035
impl TryCast {
10351036
/// Create a new TryCast expression
1036-
pub fn new(expr: Box<Expr>, data_type: DataType) -> Self {
1037-
Self { expr, data_type }
1037+
pub fn new(expr: Box<Expr>, field: FieldRef) -> Self {
1038+
Self { expr, field }
10381039
}
10391040
}
10401041

@@ -2460,23 +2461,26 @@ impl NormalizeEq for Expr {
24602461
(
24612462
Expr::Cast(Cast {
24622463
expr: self_expr,
2463-
data_type: self_data_type,
2464+
field: self_field,
24642465
}),
24652466
Expr::Cast(Cast {
24662467
expr: other_expr,
2467-
data_type: other_data_type,
2468+
field: other_field,
24682469
}),
24692470
)
24702471
| (
24712472
Expr::TryCast(TryCast {
24722473
expr: self_expr,
2473-
data_type: self_data_type,
2474+
field: self_field,
24742475
}),
24752476
Expr::TryCast(TryCast {
24762477
expr: other_expr,
2477-
data_type: other_data_type,
2478+
field: other_field,
24782479
}),
2479-
) => self_data_type == other_data_type && self_expr.normalize_eq(other_expr),
2480+
) => {
2481+
self_field.data_type() == other_field.data_type()
2482+
&& self_expr.normalize_eq(other_expr)
2483+
}
24802484
(
24812485
Expr::ScalarFunction(ScalarFunction {
24822486
func: self_func,
@@ -2792,15 +2796,9 @@ impl HashNode for Expr {
27922796
when_then_expr: _when_then_expr,
27932797
else_expr: _else_expr,
27942798
}) => {}
2795-
Expr::Cast(Cast {
2796-
expr: _expr,
2797-
data_type,
2798-
})
2799-
| Expr::TryCast(TryCast {
2800-
expr: _expr,
2801-
data_type,
2802-
}) => {
2803-
data_type.hash(state);
2799+
Expr::Cast(Cast { expr: _expr, field })
2800+
| Expr::TryCast(TryCast { expr: _expr, field }) => {
2801+
field.data_type().hash(state);
28042802
}
28052803
Expr::ScalarFunction(ScalarFunction { func, args: _args }) => {
28062804
func.hash(state);
@@ -3487,11 +3485,11 @@ impl Display for Expr {
34873485
}
34883486
write!(f, "END")
34893487
}
3490-
Expr::Cast(Cast { expr, data_type }) => {
3491-
write!(f, "CAST({expr} AS {data_type})")
3488+
Expr::Cast(Cast { expr, field }) => {
3489+
write!(f, "CAST({expr} AS {})", field.data_type())
34923490
}
3493-
Expr::TryCast(TryCast { expr, data_type }) => {
3494-
write!(f, "TRY_CAST({expr} AS {data_type})")
3491+
Expr::TryCast(TryCast { expr, field }) => {
3492+
write!(f, "TRY_CAST({expr} AS {})", field.data_type())
34953493
}
34963494
Expr::Not(expr) => write!(f, "NOT {expr}"),
34973495
Expr::Negative(expr) => write!(f, "(- {expr})"),
@@ -3844,7 +3842,7 @@ mod test {
38443842
fn format_cast() -> Result<()> {
38453843
let expr = Expr::Cast(Cast {
38463844
expr: Box::new(Expr::Literal(ScalarValue::Float32(Some(1.23)), None)),
3847-
data_type: DataType::Utf8,
3845+
field: Arc::new(Field::new("cast", DataType::Utf8, false)),
38483846
});
38493847
let expected_canonical = "CAST(Float32(1.23) AS Utf8)";
38503848
assert_eq!(expected_canonical, format!("{expr}"));
@@ -3871,7 +3869,10 @@ mod test {
38713869
fn test_collect_expr() -> Result<()> {
38723870
// single column
38733871
{
3874-
let expr = &Expr::Cast(Cast::new(Box::new(col("a")), DataType::Float64));
3872+
let expr = &Expr::Cast(Cast::new(
3873+
Box::new(col("a")),
3874+
Arc::new(Field::new("cast", DataType::Float64, false)),
3875+
));
38753876
let columns = expr.column_refs();
38763877
assert_eq!(1, columns.len());
38773878
assert!(columns.contains(&Column::from_name("a")));

0 commit comments

Comments
 (0)