Skip to content

Commit e929ea2

Browse files
authored
feat: impl function: regexp_extract & regexp_extract_all (#17658)
* feat: impl function: `regexp_extract` & `regexp_extract_all` Signed-off-by: Kould <[email protected]> * feat: impl function: `regexp_extract(string, pattern, name_list)` Signed-off-by: Kould <[email protected]> * test: add unit test for `regexp_extract` Signed-off-by: Kould <[email protected]> * chore: remove nullable Signed-off-by: Kould <[email protected]> * chore: add null source case on regexp.txt Signed-off-by: Kould <[email protected]> --------- Signed-off-by: Kould <[email protected]>
1 parent 7187df1 commit e929ea2

File tree

5 files changed

+832
-0
lines changed

5 files changed

+832
-0
lines changed

src/query/functions/src/scalars/string_multi_args.rs

Lines changed: 258 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
use std::sync::Arc;
1616

1717
use databend_common_expression::passthrough_nullable;
18+
use databend_common_expression::types::array::ArrayColumnBuilder;
1819
use databend_common_expression::types::nullable::NullableColumn;
1920
use databend_common_expression::types::number::Int64Type;
2021
use databend_common_expression::types::number::NumberScalar;
@@ -32,6 +33,7 @@ use databend_common_expression::FunctionRegistry;
3233
use databend_common_expression::FunctionSignature;
3334
use databend_common_expression::Scalar;
3435
use databend_common_expression::Value;
36+
use regex::Match;
3537
use string::StringColumnBuilder;
3638

3739
pub fn register(registry: &mut FunctionRegistry) {
@@ -323,6 +325,110 @@ pub fn register(registry: &mut FunctionRegistry) {
323325
}
324326
});
325327

328+
registry.register_passthrough_nullable_2_arg::<StringType, StringType, StringType, _, _>(
329+
"regexp_extract",
330+
|_, _, _| FunctionDomain::MayThrow,
331+
|source_arg, pat_arg, ctx| {
332+
inner_regexp_extract(&source_arg, &pat_arg, &Value::Scalar(0), ctx)
333+
},
334+
);
335+
336+
registry.register_passthrough_nullable_3_arg::<StringType, StringType, UInt32Type, StringType, _, _>(
337+
"regexp_extract",
338+
|_, _, _, _| FunctionDomain::MayThrow,
339+
|source_arg, pat_arg, group_arg, ctx| {
340+
inner_regexp_extract(&source_arg, &pat_arg, &group_arg, ctx)
341+
}
342+
);
343+
344+
registry.register_passthrough_nullable_3_arg::<StringType, StringType, ArrayType<StringType>, MapType<StringType, StringType>, _, _>(
345+
"regexp_extract",
346+
|_, _, _, _| FunctionDomain::MayThrow,
347+
|source_arg, pat_arg, name_list_arg, ctx| {
348+
let len = [&source_arg, &pat_arg].iter().find_map(|arg| match arg {
349+
Value::Column(col) => Some(col.len()),
350+
_ => None,
351+
}).or_else(|| match &name_list_arg {
352+
Value::Column(col) => Some(col.len()),
353+
_ => None,
354+
});
355+
356+
let cached_reg = match &pat_arg {
357+
Value::Scalar(pat) => {
358+
match regexp::build_regexp_from_pattern("regexp_extract", pat, None) {
359+
Ok(re) => Some(re),
360+
_ => None,
361+
}
362+
}
363+
_ => None,
364+
};
365+
366+
let size = len.unwrap_or(1);
367+
let mut builder =
368+
MapType::<StringType, StringType>::create_builder(size, ctx.generics);
369+
370+
for idx in 0..size {
371+
let source = unsafe { source_arg.index_unchecked(idx) };
372+
let pat = unsafe { pat_arg.index_unchecked(idx) };
373+
let name_list = unsafe { name_list_arg.index_unchecked(idx) };
374+
let mut local_re = None;
375+
if cached_reg.is_none() {
376+
match regexp::build_regexp_from_pattern("regexp_extract", pat, None) {
377+
Ok(re) => {
378+
local_re = Some(re);
379+
}
380+
Err(err) => {
381+
ctx.set_error(builder.len(), err);
382+
builder.push_default();
383+
continue;
384+
}
385+
}
386+
};
387+
let re = cached_reg
388+
.as_ref()
389+
.unwrap_or_else(|| local_re.as_ref().unwrap());
390+
let captures = re.captures_iter(source).last();
391+
if let Some(captures) = &captures {
392+
if name_list.len() + 1 > captures.len() {
393+
ctx.set_error(builder.len(), "Not enough group names in regexp_extract");
394+
builder.push_default();
395+
continue;
396+
}
397+
}
398+
for (i, name) in name_list.iter().enumerate() {
399+
let value = captures
400+
.as_ref()
401+
.and_then(|caps| caps.get(i + 1).as_ref().map(Match::as_str))
402+
.unwrap_or("");
403+
builder.put_item((name, value))
404+
}
405+
builder.commit_row();
406+
}
407+
if len.is_some() {
408+
Value::Column(builder.build())
409+
} else {
410+
Value::Scalar(builder.build_scalar())
411+
}
412+
}
413+
);
414+
415+
registry
416+
.register_passthrough_nullable_2_arg::<StringType, StringType, ArrayType<StringType>, _, _>(
417+
"regexp_extract_all",
418+
|_, _, _| FunctionDomain::MayThrow,
419+
|source_arg, pat_arg, ctx| {
420+
regexp_extract_all(&source_arg, &pat_arg, &Value::Scalar(0), ctx)
421+
},
422+
);
423+
424+
registry.register_passthrough_nullable_3_arg::<StringType, StringType, UInt32Type, ArrayType<StringType>, _, _>(
425+
"regexp_extract_all",
426+
|_, _, _, _| FunctionDomain::MayThrow,
427+
|source_arg, pat_arg, group_arg, ctx| {
428+
regexp_extract_all(&source_arg, &pat_arg, &group_arg, ctx)
429+
}
430+
);
431+
326432
// Notes: https://dev.mysql.com/doc/refman/8.0/en/regexp.html#function_regexp-replace
327433
registry.register_function_factory("regexp_replace", |_, args_type| {
328434
let has_null = args_type.iter().any(|t| t.is_nullable_or_null());
@@ -418,6 +524,158 @@ pub fn register(registry: &mut FunctionRegistry) {
418524
});
419525
}
420526

527+
fn regexp_extract_all(
528+
source_arg: &Value<StringType>,
529+
pat_arg: &Value<StringType>,
530+
group_arg: &Value<UInt32Type>,
531+
ctx: &mut EvalContext,
532+
) -> Value<ArrayType<StringType>> {
533+
let len = [&source_arg, &pat_arg]
534+
.iter()
535+
.find_map(|arg| match arg {
536+
Value::Column(col) => Some(col.len()),
537+
_ => None,
538+
})
539+
.or_else(|| match &group_arg {
540+
Value::Column(col) => Some(col.len()),
541+
_ => None,
542+
});
543+
let cached_reg = match &pat_arg {
544+
Value::Scalar(pat) => {
545+
match regexp::build_regexp_from_pattern("regexp_extract", pat, None) {
546+
Ok(re) => Some(re),
547+
_ => None,
548+
}
549+
}
550+
_ => None,
551+
};
552+
553+
let size = len.unwrap_or(1);
554+
let mut builder = ArrayColumnBuilder::<StringType>::with_capacity(size, 0, ctx.generics);
555+
for idx in 0..size {
556+
let source = unsafe { source_arg.index_unchecked(idx) };
557+
let pat = unsafe { pat_arg.index_unchecked(idx) };
558+
let group = unsafe { group_arg.index_unchecked(idx) as usize };
559+
560+
let mut local_re = None;
561+
if cached_reg.is_none() {
562+
match regexp::build_regexp_from_pattern("regexp_extract", pat, None) {
563+
Ok(re) => {
564+
local_re = Some(re);
565+
}
566+
Err(err) => {
567+
ctx.set_error(builder.len(), err);
568+
builder.push_default();
569+
continue;
570+
}
571+
}
572+
};
573+
574+
let re = cached_reg
575+
.as_ref()
576+
.unwrap_or_else(|| local_re.as_ref().unwrap());
577+
let mut row = StringColumnBuilder::with_capacity(0);
578+
if group > 9 {
579+
ctx.set_error(builder.len(), "Group index must be between 0 and 9!");
580+
}
581+
for caps in re.captures_iter(source) {
582+
if group >= caps.len() {
583+
ctx.set_error(
584+
builder.len(),
585+
format!(
586+
"Pattern has {} groups. Cannot access group {}",
587+
caps.len(),
588+
group
589+
),
590+
);
591+
row.put_str("");
592+
row.commit_row();
593+
continue;
594+
}
595+
if let Some(v) = caps.get(group).map(|ma| ma.as_str()) {
596+
row.put_str(v);
597+
} else {
598+
row.put_str("");
599+
}
600+
row.commit_row();
601+
}
602+
builder.push(row.build());
603+
}
604+
if len.is_some() {
605+
Value::Column(builder.build())
606+
} else {
607+
Value::Scalar(builder.build_scalar())
608+
}
609+
}
610+
611+
fn inner_regexp_extract(
612+
source_arg: &Value<StringType>,
613+
pat_arg: &Value<StringType>,
614+
group_arg: &Value<UInt32Type>,
615+
ctx: &mut EvalContext,
616+
) -> Value<StringType> {
617+
let len = [&source_arg, &pat_arg]
618+
.iter()
619+
.find_map(|arg| match arg {
620+
Value::Column(col) => Some(col.len()),
621+
_ => None,
622+
})
623+
.or_else(|| match &group_arg {
624+
Value::Column(col) => Some(col.len()),
625+
_ => None,
626+
});
627+
628+
let cached_reg = match &pat_arg {
629+
Value::Scalar(pat) => {
630+
match regexp::build_regexp_from_pattern("regexp_extract", pat, None) {
631+
Ok(re) => Some(re),
632+
_ => None,
633+
}
634+
}
635+
_ => None,
636+
};
637+
638+
let size = len.unwrap_or(1);
639+
let mut builder = StringColumnBuilder::with_capacity(size);
640+
for idx in 0..size {
641+
let source = unsafe { source_arg.index_unchecked(idx) };
642+
let pat = unsafe { pat_arg.index_unchecked(idx) };
643+
let group = unsafe { group_arg.index_unchecked(idx) as usize };
644+
645+
let mut local_re = None;
646+
if cached_reg.is_none() {
647+
match regexp::build_regexp_from_pattern("regexp_extract", pat, None) {
648+
Ok(re) => {
649+
local_re = Some(re);
650+
}
651+
Err(err) => {
652+
ctx.set_error(builder.len(), err);
653+
builder.put_str("");
654+
continue;
655+
}
656+
}
657+
};
658+
let re = cached_reg
659+
.as_ref()
660+
.unwrap_or_else(|| local_re.as_ref().unwrap());
661+
if let Some(caps) = re.captures(source) {
662+
if group > 9 {
663+
ctx.set_error(builder.len(), "Group index must be between 0 and 9!");
664+
builder.put_str("");
665+
} else if let Some(ma) = caps.get(group) {
666+
builder.put_str(ma.as_str());
667+
}
668+
}
669+
builder.put_str("");
670+
builder.commit_row();
671+
}
672+
if len.is_some() {
673+
Value::Column(builder.build())
674+
} else {
675+
Value::Scalar(builder.build_scalar())
676+
}
677+
}
678+
421679
fn concat_fn(args: &[Value<AnyType>], _: &mut EvalContext) -> Value<AnyType> {
422680
let len = args.iter().find_map(|arg| match arg {
423681
Value::Column(col) => Some(col.len()),

src/query/functions/tests/it/scalars/regexp.rs

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ fn test_string() {
3131
test_regexp_replace(regexp_file);
3232
test_regexp_substr(regexp_file);
3333
test_glob(regexp_file);
34+
test_regexp_extract(regexp_file);
3435
}
3536

3637
fn test_regexp_instr(file: &mut impl Write) {
@@ -690,3 +691,92 @@ fn test_regexp_substr(file: &mut impl Write) {
690691
match_type_error_five_columns,
691692
);
692693
}
694+
695+
fn test_regexp_extract(file: &mut impl Write) {
696+
run_ast(file, "regexp_extract('abc def ghi', '[a-z]+')", &[]);
697+
run_ast(file, "regexp_extract('abc def ghi', '[a-z]+', 2)", &[]);
698+
run_ast(file, "regexp_extract('abc def ghi', NULL)", &[]);
699+
run_ast(file, "regexp_extract('abc def ghi', '')", &[]);
700+
run_ast(file, "regexp_extract('', '[a-z]+')", &[]);
701+
run_ast(file, "regexp_extract('123 456', '[a-z]+')", &[]);
702+
run_ast(
703+
file,
704+
"regexp_extract('John Doe', '([A-Za-z]+) ([A-Za-z]+)', 1)",
705+
&[],
706+
);
707+
708+
run_ast(file, "regexp_extract(s, '([A-Za-z]+) ([A-Za-z]+)', 1)", &[
709+
(
710+
"s",
711+
StringType::from_data(vec!["John Doe", "James Davis", "Lisa Taylor"]),
712+
),
713+
]);
714+
715+
run_ast(file, "regexp_extract('name: John, age: 30', 'name: ([A-Za-z]+), age: ([0-9]+)', ['name', 'age'])", &[]);
716+
run_ast(file, "regexp_extract('name: John, age: 30', 'name: ([A-Za-z]+), age: ([0-9]+)', ['name', 'age'])", &[]);
717+
run_ast(
718+
file,
719+
"regexp_extract('name: John, age: 30', NULL, ['name', 'age'])",
720+
&[],
721+
);
722+
run_ast(
723+
file,
724+
"regexp_extract('name: John, age: 30', 'name: ([A-Za-z]+), age: ([0-9]+)', [])",
725+
&[],
726+
);
727+
run_ast(file, "regexp_extract('name: John, age: 30', 'name: ([A-Za-z]+), age: ([0-9]+)', ['name', 'age'])", &[]);
728+
729+
run_ast(
730+
file,
731+
"regexp_extract(s, 'name: ([A-Za-z]+), age: ([0-9]+)', ['name', 'age'])",
732+
&[(
733+
"s",
734+
StringType::from_data(vec![
735+
"name: John, age: 30",
736+
"name: James, age: 25",
737+
"name: Lisa, age: 19",
738+
]),
739+
)],
740+
);
741+
742+
run_ast(file, "regexp_extract_all('abc def ghi', '[a-z]+')", &[]);
743+
run_ast(
744+
file,
745+
"regexp_extract_all('John Doe, Jane Smith', '([A-Za-z]+) ([A-Za-z]+)', 1)",
746+
&[],
747+
);
748+
run_ast(file, "regexp_extract_all('abc def ghi', NULL)", &[]);
749+
run_ast(file, "regexp_extract_all('', '[a-z]+')", &[]);
750+
run_ast(file, "regexp_extract_all('123 456', '[a-z]+')", &[]);
751+
run_ast(file, "regexp_extract_all('name: John, age: 30; name: Jane, age: 25', 'name: ([A-Za-z]+), age: ([0-9]+)')", &[]);
752+
753+
run_ast(
754+
file,
755+
"regexp_extract_all(s, '([A-Za-z]+) ([A-Za-z]+)', 1)",
756+
&[(
757+
"s",
758+
StringType::from_data(vec![
759+
"John Doe, Jane Smith",
760+
"James Davis, Robert Wilson",
761+
"Lisa Taylor, Sarah Williams",
762+
]),
763+
)],
764+
);
765+
766+
// null source
767+
run_ast(
768+
file,
769+
"regexp_extract(null, '(\\d+)-(\\d+)-(\\d+)', ['y', 'm'])",
770+
&[],
771+
);
772+
run_ast(
773+
file,
774+
"regexp_extract_all(null, 'Order-(\\d+)-(\\d+)', 2)",
775+
&[],
776+
);
777+
run_ast(
778+
file,
779+
"regexp_extract(null, '([A-Za-z]+) ([A-Za-z]+), Age: (\\d+)', 3)",
780+
&[],
781+
);
782+
}

0 commit comments

Comments
 (0)