Skip to content

Commit 780f484

Browse files
authored
feat(query): support imports and packages in python udf scripts (#18187)
* feat(query): support pep723 scripts in python udf scripts * feat(query): support pep723 scripts in python udf scripts * feat(query): support pep723 scripts in python udf scripts * chore(query): fix * update * update * add safe_codes in builder * add safe_codes in builder * fix binder
1 parent 50f3c98 commit 780f484

File tree

30 files changed

+607
-55
lines changed

30 files changed

+607
-55
lines changed

Cargo.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -643,7 +643,7 @@ overflow-checks = true
643643
rpath = false
644644

645645
[patch.crates-io]
646-
arrow-udf-runtime = { git = "https://github.com/datafuse-extras/arrow-udf.git", rev = "92eeb3b" }
646+
arrow-udf-runtime = { git = "https://github.com/datafuse-extras/arrow-udf.git", rev = "a442343" }
647647
async-backtrace = { git = "https://github.com/datafuse-extras/async-backtrace.git", rev = "dea4553" }
648648
async-recursion = { git = "https://github.com/datafuse-extras/async-recursion.git", rev = "a353334" }
649649
backtrace = { git = "https://github.com/rust-lang/backtrace-rs.git", rev = "72265be" }

src/meta/app/src/principal/user_defined_function.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ pub struct UDFServer {
4040
#[derive(Clone, Debug, Eq, PartialEq)]
4141
pub struct UDFScript {
4242
pub code: String,
43+
pub imports: Vec<String>,
44+
pub packages: Vec<String>,
4345
pub handler: String,
4446
pub language: String,
4547
pub arg_types: Vec<DataType>,
@@ -50,6 +52,8 @@ pub struct UDFScript {
5052
#[derive(Clone, Debug, Eq, PartialEq)]
5153
pub struct UDAFScript {
5254
pub code: String,
55+
pub imports: Vec<String>,
56+
pub packages: Vec<String>,
5357
pub language: String,
5458
// aggregate function input types
5559
pub arg_types: Vec<DataType>,
@@ -167,6 +171,8 @@ impl UserDefinedFunction {
167171
arg_types,
168172
return_type,
169173
runtime_version: runtime_version.to_string(),
174+
imports: vec![],
175+
packages: vec![],
170176
}),
171177
created_on: Utc::now(),
172178
}
@@ -226,6 +232,8 @@ impl Display for UDFDefinition {
226232
handler,
227233
language,
228234
runtime_version,
235+
imports,
236+
packages,
229237
}) => {
230238
for (i, item) in arg_types.iter().enumerate() {
231239
if i > 0 {
@@ -235,7 +243,7 @@ impl Display for UDFDefinition {
235243
}
236244
write!(
237245
f,
238-
") RETURNS {return_type} LANGUAGE {language} RUNTIME_VERSION = {runtime_version} HANDLER = {handler} AS $${code}$$"
246+
") RETURNS {return_type} LANGUAGE {language} IMPORTS = {imports:?} PACKAGES = {packages:?} RUNTIME_VERSION = {runtime_version} HANDLER = {handler} AS $${code}$$"
239247
)?;
240248
}
241249
UDFDefinition::UDAFScript(UDAFScript {
@@ -245,6 +253,8 @@ impl Display for UDFDefinition {
245253
return_type,
246254
language,
247255
runtime_version,
256+
imports,
257+
packages,
248258
}) => {
249259
for (i, item) in arg_types.iter().enumerate() {
250260
if i > 0 {
@@ -259,7 +269,7 @@ impl Display for UDFDefinition {
259269
}
260270
write!(f, "{} {}", item.name(), item.data_type())?;
261271
}
262-
write!(f, " }} RETURNS {return_type} LANGUAGE {language} RUNTIME_VERSION = {runtime_version} AS $${code}$$")?;
272+
write!(f, " }} RETURNS {return_type} LANGUAGE {language} IMPORTS = {imports:?} PACKAGES = {packages:?} RUNTIME_VERSION = {runtime_version} AS $${code}$$")?;
263273
}
264274
}
265275
Ok(())

src/meta/proto-conv/src/udf_from_to_protobuf_impl.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,8 @@ impl FromToProto for mt::UDFScript {
137137
handler: p.handler,
138138
language: p.language,
139139
runtime_version: p.runtime_version,
140+
imports: p.imports,
141+
packages: p.packages,
140142
})
141143
}
142144

@@ -171,6 +173,8 @@ impl FromToProto for mt::UDFScript {
171173
arg_types,
172174
return_type: Some(return_type),
173175
runtime_version: self.runtime_version.clone(),
176+
imports: self.imports.clone(),
177+
packages: self.packages.clone(),
174178
})
175179
}
176180
}
@@ -206,6 +210,8 @@ impl FromToProto for mt::UDAFScript {
206210
return_type,
207211
language: p.language,
208212
runtime_version: p.runtime_version,
213+
imports: p.imports,
214+
packages: p.packages,
209215
state_fields,
210216
})
211217
}
@@ -259,6 +265,8 @@ impl FromToProto for mt::UDAFScript {
259265
arg_types,
260266
state_fields,
261267
return_type: Some(return_type),
268+
imports: self.imports.clone(),
269+
packages: self.packages.clone(),
262270
})
263271
}
264272
}

src/meta/proto-conv/src/util.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ const META_CHANGE_LOG: &[(u64, &str)] = &[
159159
(127, "2025-05-18: Add: UserOption::workload_group"),
160160
(128, "2025-05-22: Add: Storage Network config"),
161161
(129, "2025-05-30: Add: New DataType Vector"),
162+
(130, "2025-06-19: Add: New UDF imports and packages in udf definition"),
162163
// Dear developer:
163164
// If you're gonna add a new metadata version, you'll have to add a test for it.
164165
// You could just copy an existing test file(e.g., `../tests/it/v024_table_meta.rs`)

src/meta/proto-conv/tests/it/main.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,4 +120,4 @@ mod v125_table_index;
120120
mod v126_iceberg_storage_catalog_option;
121121
mod v127_user_option_workload_group;
122122
mod v128_storage_network_config;
123-
mod v129_vector_datatype;
123+
mod v130_udf_imports_packages;

src/meta/proto-conv/tests/it/v081_udf_script.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ fn test_decode_udf_script() -> anyhow::Result<()> {
116116
language: "python".to_string(),
117117
arg_types: vec![DataType::Number(NumberDataType::Int32)],
118118
return_type: DataType::Number(NumberDataType::Float32),
119+
imports: vec![],
120+
packages: vec![],
119121
runtime_version: "3.12.2".to_string(),
120122
}),
121123
created_on: DateTime::<Utc>::default(),

src/meta/proto-conv/tests/it/v115_add_udaf_script.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,8 @@ fn test_decode_v115_add_udaf_script() -> anyhow::Result<()> {
6161
)],
6262
return_type: DataType::Number(NumberDataType::Float32),
6363
runtime_version: "".to_string(),
64+
imports: vec![],
65+
packages: vec![],
6466
}),
6567
created_on: DateTime::<Utc>::default(),
6668
};

src/meta/proto-conv/tests/it/v129_vector_datatype.rs renamed to src/meta/proto-conv/tests/it/v129_vector_datatype copy.rs

File renamed without changes.
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
// Copyright 2023 Datafuse Labs.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use chrono::DateTime;
16+
use chrono::Utc;
17+
use databend_common_expression::types::DataType;
18+
use databend_common_expression::types::NumberDataType;
19+
use databend_common_meta_app::principal::UDFDefinition;
20+
use databend_common_meta_app::principal::UDFScript;
21+
use databend_common_meta_app::principal::UserDefinedFunction;
22+
use fastrace::func_name;
23+
24+
use crate::common;
25+
26+
// These bytes are built when a new version in introduced,
27+
// and are kept for backward compatibility test.
28+
//
29+
// *************************************************************
30+
// * These messages should never be updated, *
31+
// * only be added when a new version is added, *
32+
// * or be removed when an old version is no longer supported. *
33+
// *************************************************************
34+
//
35+
// The message bytes are built from the output of `test_pb_from_to()`
36+
#[test]
37+
fn test_decode_v130_udf_script() -> anyhow::Result<()> {
38+
let bytes = vec![
39+
10, 5, 109, 121, 95, 102, 110, 18, 21, 84, 104, 105, 115, 32, 105, 115, 32, 97, 32, 100,
40+
101, 115, 99, 114, 105, 112, 116, 105, 111, 110, 50, 119, 10, 9, 115, 111, 109, 101, 32,
41+
99, 111, 100, 101, 18, 5, 109, 121, 95, 102, 110, 26, 6, 112, 121, 116, 104, 111, 110, 34,
42+
19, 154, 2, 9, 58, 0, 160, 6, 130, 1, 168, 6, 24, 160, 6, 130, 1, 168, 6, 24, 42, 19, 154,
43+
2, 9, 74, 0, 160, 6, 130, 1, 168, 6, 24, 160, 6, 130, 1, 168, 6, 24, 50, 6, 51, 46, 49, 50,
44+
46, 50, 58, 9, 64, 115, 49, 47, 97, 46, 122, 105, 112, 58, 8, 64, 115, 50, 47, 98, 46, 112,
45+
121, 66, 5, 110, 117, 109, 112, 121, 66, 6, 112, 97, 110, 100, 97, 115, 160, 6, 130, 1,
46+
168, 6, 24, 42, 23, 49, 57, 55, 48, 45, 48, 49, 45, 48, 49, 32, 48, 48, 58, 48, 48, 58, 48,
47+
48, 32, 85, 84, 67, 160, 6, 130, 1, 168, 6, 24,
48+
];
49+
50+
let want = || UserDefinedFunction {
51+
name: "my_fn".to_string(),
52+
description: "This is a description".to_string(),
53+
definition: UDFDefinition::UDFScript(UDFScript {
54+
code: "some code".to_string(),
55+
handler: "my_fn".to_string(),
56+
language: "python".to_string(),
57+
arg_types: vec![DataType::Number(NumberDataType::Int32)],
58+
return_type: DataType::Number(NumberDataType::Float32),
59+
imports: vec!["@s1/a.zip".to_string(), "@s2/b.py".to_string()],
60+
packages: vec!["numpy".to_string(), "pandas".to_string()],
61+
runtime_version: "3.12.2".to_string(),
62+
}),
63+
created_on: DateTime::<Utc>::default(),
64+
};
65+
66+
common::test_pb_from_to(func_name!(), want())?;
67+
common::test_load_old(func_name!(), bytes.as_slice(), 130, want())
68+
}

0 commit comments

Comments
 (0)