Skip to content

Commit bca60ae

Browse files
authored
fix: panic when the string_agg parameter is not of type string (#18416)
* fix: panic when the `string_agg` parameter is not of type string * chore: codefmt * chore: replace CoreNumber with NumberType * chore: codefmt
1 parent 2c99447 commit bca60ae

File tree

8 files changed

+328
-66
lines changed

8 files changed

+328
-66
lines changed

โ€Žsrc/query/expression/src/types.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ pub use self::bitmap::BitmapType;
6262
pub use self::boolean::Bitmap;
6363
pub use self::boolean::BooleanType;
6464
pub use self::boolean::MutableBitmap;
65+
pub use self::compute_view::StringConvert;
6566
pub use self::date::DateType;
6667
pub use self::decimal::*;
6768
pub use self::empty_array::EmptyArrayType;

โ€Žsrc/query/expression/src/types/compute_view.rs

Lines changed: 157 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,17 @@ use std::fmt::Debug;
1717
use std::marker::PhantomData;
1818
use std::ops::Range;
1919

20-
use databend_common_column::buffer::Buffer;
2120
use num_traits::AsPrimitive;
2221

23-
use super::simple_type::SimpleType;
2422
use super::AccessType;
25-
use crate::types::CoreNumber;
23+
use super::AnyType;
24+
use super::ArgType;
25+
use super::NumberType;
26+
use super::StringColumn;
27+
use super::StringType;
28+
use crate::display::scalar_ref_to_string;
29+
use crate::types::string::StringDomain;
30+
use crate::types::string::StringIterator;
2631
use crate::types::Number;
2732
use crate::types::SimpleDomain;
2833
use crate::Column;
@@ -31,23 +36,23 @@ use crate::ScalarRef;
3136

3237
pub trait Compute<F, T>: Debug + Clone + PartialEq + 'static
3338
where
34-
F: SimpleType,
35-
T: SimpleType,
39+
F: AccessType,
40+
T: AccessType,
3641
{
37-
fn compute(value: &F::Scalar) -> T::Scalar;
42+
fn compute<'a>(value: F::ScalarRef<'a>) -> T::ScalarRef<'a>;
3843

3944
fn compute_domain(domain: &F::Domain) -> T::Domain;
4045
}
4146

4247
impl<T> Compute<T, T> for T
43-
where T: SimpleType
48+
where T: AccessType
4449
{
45-
fn compute(value: &T::Scalar) -> T::Scalar {
46-
*value
50+
fn compute<'a>(value: T::ScalarRef<'a>) -> T::ScalarRef<'a> {
51+
value
4752
}
4853

4954
fn compute_domain(domain: &T::Domain) -> T::Domain {
50-
*domain
55+
domain.to_owned()
5156
}
5257
}
5358

@@ -56,107 +61,109 @@ pub struct ComputeView<C, F, T>(PhantomData<(C, F, T)>);
5661

5762
impl<C, F, T> AccessType for ComputeView<C, F, T>
5863
where
59-
F: SimpleType,
60-
T: SimpleType,
64+
F: AccessType,
65+
T: AccessType,
6166
C: Compute<F, T>,
6267
{
6368
type Scalar = T::Scalar;
64-
type ScalarRef<'a> = T::Scalar;
65-
type Column = Buffer<F::Scalar>;
69+
type ScalarRef<'a> = T::ScalarRef<'a>;
70+
type Column = F::Column;
6671
type Domain = T::Domain;
6772
type ColumnIterator<'a> =
68-
std::iter::Map<std::slice::Iter<'a, F::Scalar>, fn(&'a F::Scalar) -> T::Scalar>;
73+
std::iter::Map<F::ColumnIterator<'a>, fn(F::ScalarRef<'a>) -> T::ScalarRef<'a>>;
6974

7075
fn to_owned_scalar(scalar: Self::ScalarRef<'_>) -> Self::Scalar {
71-
scalar
76+
T::to_owned_scalar(scalar)
7277
}
7378

7479
fn to_scalar_ref(scalar: &Self::Scalar) -> Self::ScalarRef<'_> {
75-
*scalar
80+
T::to_scalar_ref(scalar)
7681
}
7782

7883
fn try_downcast_scalar<'a>(scalar: &ScalarRef<'a>) -> Option<Self::ScalarRef<'a>> {
79-
F::downcast_scalar(scalar).map(|v| C::compute(&v))
84+
F::try_downcast_scalar(scalar).map(|v| C::compute(v))
8085
}
8186

8287
fn try_downcast_column(col: &Column) -> Option<Self::Column> {
83-
F::downcast_column(col)
88+
F::try_downcast_column(col)
8489
}
8590

8691
fn try_downcast_domain(domain: &Domain) -> Option<Self::Domain> {
87-
F::downcast_domain(domain).map(|domain| C::compute_domain(&domain))
92+
F::try_downcast_domain(domain).map(|domain| C::compute_domain(&domain))
8893
}
8994

9095
fn column_len(col: &Self::Column) -> usize {
91-
col.len()
96+
F::column_len(col)
9297
}
9398

9499
fn index_column(col: &Self::Column, index: usize) -> Option<Self::ScalarRef<'_>> {
95-
col.get(index).map(C::compute)
100+
F::index_column(col, index).map(C::compute)
96101
}
97102

98103
unsafe fn index_column_unchecked(col: &Self::Column, index: usize) -> Self::ScalarRef<'_> {
99-
debug_assert!(index < col.len());
100-
C::compute(col.get_unchecked(index))
104+
let scalar = F::index_column_unchecked(col, index);
105+
C::compute(scalar)
101106
}
102107

103108
fn slice_column(col: &Self::Column, range: Range<usize>) -> Self::Column {
104-
col.clone().sliced(range.start, range.end - range.start)
109+
F::slice_column(col, range)
105110
}
106111

107112
fn iter_column(col: &Self::Column) -> Self::ColumnIterator<'_> {
108-
col.iter().map(C::compute as fn(&F::Scalar) -> T::Scalar)
113+
F::iter_column(col).map(C::compute)
109114
}
110115

111116
fn scalar_memory_size(_: &Self::ScalarRef<'_>) -> usize {
112117
std::mem::size_of::<F>()
113118
}
114119

115120
fn column_memory_size(col: &Self::Column) -> usize {
116-
col.len() * std::mem::size_of::<F>()
121+
F::column_len(col) * std::mem::size_of::<F>()
117122
}
118123

119124
fn compare(lhs: Self::ScalarRef<'_>, rhs: Self::ScalarRef<'_>) -> Ordering {
120-
T::compare(&lhs, &rhs)
125+
T::compare(lhs, rhs)
121126
}
122127

123128
fn equal(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool {
124-
left == right
129+
T::equal(left, right)
125130
}
126131

127132
fn not_equal(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool {
128-
left != right
133+
T::not_equal(left, right)
129134
}
130135

131136
fn greater_than(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool {
132-
T::greater_than(&left, &right)
137+
T::greater_than(left, right)
133138
}
134139

135140
fn less_than(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool {
136-
T::less_than(&left, &right)
141+
T::less_than(left, right)
137142
}
138143

139144
fn greater_than_equal(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool {
140-
T::greater_than_equal(&left, &right)
145+
T::greater_than_equal(left, right)
141146
}
142147

143148
fn less_than_equal(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool {
144-
T::less_than_equal(&left, &right)
149+
T::less_than_equal(left, right)
145150
}
146151
}
147152

148153
/// For number convert
149-
pub type NumberConvertView<F, T> = ComputeView<NumberConvert<F, T>, CoreNumber<F>, CoreNumber<T>>;
154+
pub type NumberConvertView<F, T> = ComputeView<NumberConvert<F, T>, NumberType<F>, NumberType<T>>;
150155

151156
#[derive(Debug, Clone, PartialEq, Eq)]
152157
pub struct NumberConvert<F, T>(std::marker::PhantomData<(F, T)>);
153158

154-
impl<F, T> Compute<CoreNumber<F>, CoreNumber<T>> for NumberConvert<F, T>
159+
impl<F, T> Compute<NumberType<F>, NumberType<T>> for NumberConvert<F, T>
155160
where
156161
F: Number + AsPrimitive<T>,
157162
T: Number,
158163
{
159-
fn compute(value: &F) -> T {
164+
fn compute<'a>(
165+
value: <NumberType<F> as AccessType>::ScalarRef<'a>,
166+
) -> <NumberType<T> as AccessType>::ScalarRef<'a> {
160167
value.as_()
161168
}
162169

@@ -166,3 +173,116 @@ where
166173
SimpleDomain { min, max }
167174
}
168175
}
176+
177+
/// For number convert
178+
pub type StringConvertView = ComputeView<StringConvert, AnyType, OwnedStringType>;
179+
180+
#[derive(Debug, Clone, PartialEq, Eq)]
181+
pub struct OwnedStringType;
182+
183+
impl AccessType for OwnedStringType {
184+
type Scalar = String;
185+
type ScalarRef<'a> = String;
186+
type Column = StringColumn;
187+
type Domain = StringDomain;
188+
type ColumnIterator<'a> = std::iter::Map<StringIterator<'a>, fn(&str) -> String>;
189+
190+
fn to_owned_scalar(scalar: Self::ScalarRef<'_>) -> Self::Scalar {
191+
scalar.to_string()
192+
}
193+
194+
fn to_scalar_ref(scalar: &Self::Scalar) -> Self::ScalarRef<'_> {
195+
scalar.clone()
196+
}
197+
198+
fn try_downcast_scalar<'a>(scalar: &ScalarRef<'a>) -> Option<Self::ScalarRef<'a>> {
199+
scalar.as_string().map(|s| s.to_string())
200+
}
201+
202+
fn try_downcast_column(col: &Column) -> Option<Self::Column> {
203+
col.as_string().cloned()
204+
}
205+
206+
fn try_downcast_domain(domain: &Domain) -> Option<Self::Domain> {
207+
domain.as_string().cloned()
208+
}
209+
210+
fn column_len(col: &Self::Column) -> usize {
211+
col.len()
212+
}
213+
214+
fn index_column(col: &Self::Column, index: usize) -> Option<Self::ScalarRef<'_>> {
215+
col.index(index).map(str::to_string)
216+
}
217+
218+
#[inline]
219+
unsafe fn index_column_unchecked(col: &Self::Column, index: usize) -> Self::ScalarRef<'_> {
220+
col.value_unchecked(index).to_string()
221+
}
222+
223+
fn slice_column(col: &Self::Column, range: Range<usize>) -> Self::Column {
224+
col.clone().sliced(range.start, range.end - range.start)
225+
}
226+
227+
fn iter_column(col: &Self::Column) -> Self::ColumnIterator<'_> {
228+
col.iter().map(str::to_string)
229+
}
230+
231+
fn scalar_memory_size(scalar: &Self::ScalarRef<'_>) -> usize {
232+
scalar.len()
233+
}
234+
235+
fn column_memory_size(col: &Self::Column) -> usize {
236+
col.memory_size()
237+
}
238+
239+
#[inline(always)]
240+
fn compare(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> Ordering {
241+
left.cmp(&right)
242+
}
243+
244+
#[inline(always)]
245+
fn equal(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool {
246+
left == right
247+
}
248+
249+
#[inline(always)]
250+
fn not_equal(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool {
251+
left != right
252+
}
253+
254+
#[inline(always)]
255+
fn greater_than(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool {
256+
left > right
257+
}
258+
259+
#[inline(always)]
260+
fn greater_than_equal(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool {
261+
left >= right
262+
}
263+
264+
#[inline(always)]
265+
fn less_than(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool {
266+
left < right
267+
}
268+
269+
#[inline(always)]
270+
fn less_than_equal(left: Self::ScalarRef<'_>, right: Self::ScalarRef<'_>) -> bool {
271+
left <= right
272+
}
273+
}
274+
275+
#[derive(Debug, Clone, PartialEq, Eq)]
276+
pub struct StringConvert;
277+
278+
impl Compute<AnyType, OwnedStringType> for StringConvert {
279+
fn compute<'a>(
280+
value: <AnyType as AccessType>::ScalarRef<'a>,
281+
) -> <OwnedStringType as AccessType>::ScalarRef<'a> {
282+
scalar_ref_to_string(&value)
283+
}
284+
285+
fn compute_domain(_: &<AnyType as AccessType>::Domain) -> StringDomain {
286+
StringType::full_domain()
287+
}
288+
}

0 commit comments

Comments
ย (0)