Skip to content

Commit 20e4079

Browse files
authored
feat: support async function for merge not match clause (#17941)
* add logic test * rm unused code * update * fix * skip empty * modify test * seperate cast schema * modify test * modify test
1 parent 4f707b2 commit 20e4079

File tree

11 files changed

+438
-157
lines changed

11 files changed

+438
-157
lines changed

src/query/pipeline/transforms/src/processors/transforms/transform_pipeline_helper.rs

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,15 @@ pub trait TransformPipelineHelper {
145145
where
146146
F: Fn() -> Result<R>,
147147
R: Transform + 'static;
148+
149+
fn try_create_async_transform_pipeline_builder_with_len<F, R>(
150+
&mut self,
151+
f: F,
152+
transform_len: usize,
153+
) -> Result<TransformPipeBuilder>
154+
where
155+
F: Fn() -> Result<R>,
156+
R: AsyncTransform + 'static;
148157
}
149158

150159
impl TransformPipelineHelper for Pipeline {
@@ -179,4 +188,25 @@ impl TransformPipelineHelper for Pipeline {
179188
transform_len,
180189
)
181190
}
191+
192+
fn try_create_async_transform_pipeline_builder_with_len<F, R>(
193+
&mut self,
194+
f: F,
195+
transform_len: usize,
196+
) -> Result<TransformPipeBuilder>
197+
where
198+
F: Fn() -> Result<R>,
199+
R: AsyncTransform + 'static,
200+
{
201+
self.add_transform_with_specified_len(
202+
|input, output| {
203+
Ok(ProcessorPtr::create(AsyncTransformer::create(
204+
input,
205+
output,
206+
f()?,
207+
)))
208+
},
209+
transform_len,
210+
)
211+
}
182212
}

src/query/service/src/pipelines/builders/builder_mutation.rs

Lines changed: 104 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
use std::collections::HashMap;
1516
use std::sync::Arc;
1617

1718
use databend_common_base::base::tokio::sync::Semaphore;
1819
use databend_common_catalog::table::Table;
20+
use databend_common_exception::ErrorCode;
1921
use databend_common_exception::Result;
2022
use databend_common_expression::BlockThresholds;
2123
use databend_common_expression::DataSchema;
@@ -34,11 +36,18 @@ use databend_common_pipeline_transforms::processors::TransformPipelineHelper;
3436
use databend_common_sql::binder::MutationStrategy;
3537
use databend_common_sql::executor::physical_plans::Mutation;
3638
use databend_common_sql::executor::physical_plans::MutationKind;
39+
use databend_common_sql::DefaultExprBinder;
3740
use databend_common_storages_fuse::operations::TransformSerializeBlock;
3841
use databend_common_storages_fuse::operations::UnMatchedExprs;
3942
use databend_common_storages_fuse::FuseTable;
4043

44+
use crate::pipelines::processors::transforms::build_cast_exprs;
45+
use crate::pipelines::processors::transforms::build_expression_transform;
46+
use crate::pipelines::processors::transforms::AsyncFunctionBranch;
47+
use crate::pipelines::processors::transforms::CastSchemaBranch;
4148
use crate::pipelines::processors::transforms::TransformAddComputedColumns;
49+
use crate::pipelines::processors::transforms::TransformBranchedAsyncFunction;
50+
use crate::pipelines::processors::transforms::TransformBranchedCastSchema;
4251
use crate::pipelines::processors::transforms::TransformResortAddOnWithoutSourceSchema;
4352
use crate::pipelines::PipelineBuilder;
4453

@@ -136,16 +145,110 @@ impl PipelineBuilder {
136145

137146
// fill default columns
138147
let table_default_schema = &table.schema_with_stream().remove_computed_fields();
148+
let default_schema: DataSchemaRef = Arc::new(table_default_schema.into());
149+
150+
let mut expression_transforms = Vec::with_capacity(unmatched.len());
151+
let mut data_schemas = HashMap::with_capacity(unmatched.len());
152+
let mut trigger_non_null_errors = Vec::with_capacity(unmatched.len());
153+
let mut async_function_branches = HashMap::with_capacity(unmatched.len());
154+
let mut cast_schema_branches = HashMap::with_capacity(unmatched.len());
155+
for (idx, item) in unmatched.iter().enumerate() {
156+
let mut input_schema = item.0.clone();
157+
let mut default_expr_binder = DefaultExprBinder::try_new(self.ctx.clone())?;
158+
if let Some((async_funcs, new_default_schema, new_default_schema_no_cast)) =
159+
default_expr_binder
160+
.split_async_default_exprs(input_schema.clone(), default_schema.clone())?
161+
{
162+
async_function_branches.insert(idx, AsyncFunctionBranch {
163+
async_func_descs: async_funcs,
164+
});
165+
166+
if new_default_schema != new_default_schema_no_cast {
167+
cast_schema_branches.insert(idx, CastSchemaBranch {
168+
to_schema: new_default_schema.clone(),
169+
from_schema: new_default_schema_no_cast.clone(),
170+
exprs: build_cast_exprs(
171+
new_default_schema_no_cast.clone(),
172+
new_default_schema.clone(),
173+
)?,
174+
});
175+
}
176+
// update input_schema, which is used in `TransformResortAddOnWithoutSourceSchema`
177+
input_schema = new_default_schema;
178+
}
179+
180+
data_schemas.insert(idx, input_schema.clone());
181+
match build_expression_transform(
182+
input_schema,
183+
default_schema.clone(),
184+
tbl.clone(),
185+
self.ctx.clone(),
186+
) {
187+
Ok(expression_transform) => {
188+
expression_transforms.push(Some(expression_transform));
189+
trigger_non_null_errors.push(None);
190+
}
191+
Err(err) => {
192+
if err.code() != ErrorCode::BAD_ARGUMENTS {
193+
return Err(err);
194+
}
195+
196+
expression_transforms.push(None);
197+
trigger_non_null_errors.push(Some(err));
198+
}
199+
};
200+
}
201+
202+
if !async_function_branches.is_empty() {
203+
let branches = Arc::new(async_function_branches);
204+
let mut builder = self
205+
.main_pipeline
206+
.try_create_async_transform_pipeline_builder_with_len(
207+
|| {
208+
Ok(TransformBranchedAsyncFunction {
209+
ctx: self.ctx.clone(),
210+
branches: branches.clone(),
211+
})
212+
},
213+
transform_len,
214+
)?;
215+
if need_match {
216+
builder.add_items_prepend(vec![create_dummy_item()]);
217+
}
218+
self.main_pipeline.add_pipe(builder.finalize());
219+
}
220+
221+
if !cast_schema_branches.is_empty() {
222+
let branches = Arc::new(cast_schema_branches);
223+
let mut builder = self
224+
.main_pipeline
225+
.try_create_transform_pipeline_builder_with_len(
226+
|| {
227+
Ok(TransformBranchedCastSchema {
228+
ctx: self.ctx.clone(),
229+
branches: branches.clone(),
230+
})
231+
},
232+
transform_len,
233+
)?;
234+
if need_match {
235+
builder.add_items_prepend(vec![create_dummy_item()]);
236+
}
237+
self.main_pipeline.add_pipe(builder.finalize());
238+
}
239+
139240
let mut builder = self
140241
.main_pipeline
141242
.try_create_transform_pipeline_builder_with_len(
142243
|| {
143244
TransformResortAddOnWithoutSourceSchema::try_new(
144245
self.ctx.clone(),
145246
Arc::new(DataSchema::from(table_default_schema)),
146-
unmatched.clone(),
147247
tbl.clone(),
148248
Arc::new(DataSchema::from(table.schema_with_stream())),
249+
data_schemas.clone(),
250+
expression_transforms.clone(),
251+
trigger_non_null_errors.clone(),
149252
)
150253
},
151254
transform_len,

src/query/service/src/pipelines/processors/transforms/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@ mod transform_add_const_columns;
2323
mod transform_add_internal_columns;
2424
mod transform_add_stream_columns;
2525
mod transform_async_function;
26+
mod transform_branched_async_function;
27+
mod transform_branched_cast_schema;
2628
mod transform_cache_scan;
2729
mod transform_cast_schema;
2830
mod transform_create_sets;
@@ -50,9 +52,14 @@ pub use transform_add_const_columns::TransformAddConstColumns;
5052
pub use transform_add_internal_columns::TransformAddInternalColumns;
5153
pub use transform_add_stream_columns::TransformAddStreamColumns;
5254
pub use transform_async_function::TransformAsyncFunction;
55+
pub use transform_branched_async_function::AsyncFunctionBranch;
56+
pub use transform_branched_async_function::TransformBranchedAsyncFunction;
57+
pub use transform_branched_cast_schema::CastSchemaBranch;
58+
pub use transform_branched_cast_schema::TransformBranchedCastSchema;
5359
pub use transform_cache_scan::CacheSourceState;
5460
pub use transform_cache_scan::HashJoinCacheState;
5561
pub use transform_cache_scan::TransformCacheScan;
62+
pub use transform_cast_schema::build_cast_exprs;
5663
pub use transform_cast_schema::TransformCastSchema;
5764
pub use transform_create_sets::TransformCreateSets;
5865
pub use transform_expression_scan::TransformExpressionScan;
@@ -64,6 +71,7 @@ pub use transform_null_if::TransformNullIf;
6471
pub use transform_recursive_cte_scan::TransformRecursiveCteScan;
6572
pub use transform_recursive_cte_source::TransformRecursiveCteSource;
6673
pub use transform_resort_addon::TransformResortAddOn;
74+
pub use transform_resort_addon_without_source_schema::build_expression_transform;
6775
pub use transform_resort_addon_without_source_schema::TransformResortAddOnWithoutSourceSchema;
6876
pub use transform_srf::TransformSRF;
6977
pub use transform_udf_script::TransformUdfScript;

src/query/service/src/pipelines/processors/transforms/transform_async_function.rs

Lines changed: 30 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -59,27 +59,7 @@ impl TransformAsyncFunction {
5959
sequence_name: &String,
6060
data_type: &DataType,
6161
) -> Result<()> {
62-
let count = data_block.num_rows() as u64;
63-
let value = if count == 0 {
64-
UInt64Type::from_data(vec![])
65-
} else {
66-
let tenant = self.ctx.get_tenant();
67-
let catalog = self.ctx.get_default_catalog()?;
68-
let req = GetSequenceNextValueReq {
69-
ident: SequenceIdent::new(&tenant, sequence_name),
70-
count,
71-
};
72-
let resp = catalog.get_sequence_next_value(req).await?;
73-
let range = resp.start..resp.start + count;
74-
UInt64Type::from_data(range.collect::<Vec<u64>>())
75-
};
76-
let entry = BlockEntry {
77-
data_type: data_type.clone(),
78-
value: Value::Column(value),
79-
};
80-
data_block.add_column(entry);
81-
82-
Ok(())
62+
transform_sequence(&self.ctx, data_block, sequence_name, data_type).await
8363
}
8464
}
8565

@@ -114,3 +94,32 @@ impl AsyncTransform for TransformAsyncFunction {
11494
Ok(data_block)
11595
}
11696
}
97+
98+
pub async fn transform_sequence(
99+
ctx: &Arc<QueryContext>,
100+
data_block: &mut DataBlock,
101+
sequence_name: &String,
102+
data_type: &DataType,
103+
) -> Result<()> {
104+
let count = data_block.num_rows() as u64;
105+
let value = if count == 0 {
106+
UInt64Type::from_data(vec![])
107+
} else {
108+
let tenant = ctx.get_tenant();
109+
let catalog = ctx.get_default_catalog()?;
110+
let req = GetSequenceNextValueReq {
111+
ident: SequenceIdent::new(&tenant, sequence_name),
112+
count,
113+
};
114+
let resp = catalog.get_sequence_next_value(req).await?;
115+
let range = resp.start..resp.start + count;
116+
UInt64Type::from_data(range.collect::<Vec<u64>>())
117+
};
118+
let entry = BlockEntry {
119+
data_type: data_type.clone(),
120+
value: Value::Column(value),
121+
};
122+
data_block.add_column(entry);
123+
124+
Ok(())
125+
}
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
// Copyright 2021 Datafuse Labs
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
use std::collections::HashMap;
16+
use std::sync::Arc;
17+
18+
use databend_common_exception::Result;
19+
use databend_common_expression::BlockMetaInfoDowncast;
20+
use databend_common_expression::DataBlock;
21+
use databend_common_expression::SourceSchemaIndex;
22+
use databend_common_pipeline_transforms::processors::AsyncTransform;
23+
24+
use crate::pipelines::processors::transforms::transform_async_function::transform_sequence;
25+
use crate::sessions::QueryContext;
26+
use crate::sql::executor::physical_plans::AsyncFunctionDesc;
27+
use crate::sql::plans::AsyncFunctionArgument;
28+
29+
/// The key of branches is `SourceSchemaIndex`, see `TransformResortAddOnWithoutSourceSchema`.
30+
pub struct TransformBranchedAsyncFunction {
31+
pub ctx: Arc<QueryContext>,
32+
pub branches: Arc<HashMap<SourceSchemaIndex, AsyncFunctionBranch>>,
33+
}
34+
35+
pub struct AsyncFunctionBranch {
36+
pub async_func_descs: Vec<AsyncFunctionDesc>,
37+
}
38+
39+
#[async_trait::async_trait]
40+
impl AsyncTransform for TransformBranchedAsyncFunction {
41+
const NAME: &'static str = "BranchedAsyncFunction";
42+
43+
#[async_backtrace::framed]
44+
async fn transform(&mut self, mut block: DataBlock) -> Result<DataBlock> {
45+
// see the comment details of `TransformResortAddOnWithoutSourceSchema`.
46+
if block.get_meta().is_none() {
47+
return Ok(block);
48+
}
49+
let input_schema_idx =
50+
SourceSchemaIndex::downcast_from(block.clone().get_owned_meta().unwrap()).unwrap();
51+
let Some(branch) = self.branches.get(&input_schema_idx) else {
52+
// no async function to execute in this branch, just return the original block
53+
return Ok(block);
54+
};
55+
56+
let AsyncFunctionBranch { async_func_descs } = branch;
57+
58+
for async_func_desc in async_func_descs.iter() {
59+
match &async_func_desc.func_arg {
60+
AsyncFunctionArgument::SequenceFunction(sequence_name) => {
61+
transform_sequence(
62+
&self.ctx,
63+
&mut block,
64+
sequence_name,
65+
&async_func_desc.data_type,
66+
)
67+
.await?;
68+
}
69+
AsyncFunctionArgument::DictGetFunction(_) => unreachable!(),
70+
}
71+
}
72+
Ok(block)
73+
}
74+
}

0 commit comments

Comments
 (0)