Skip to content

Commit 2ef63e9

Browse files
committed
in work
1 parent 6ae25f6 commit 2ef63e9

File tree

1 file changed

+86
-40
lines changed

1 file changed

+86
-40
lines changed

rust/cubestore/cubestore/src/queryplanner/inline_aggregate/sorted_group_values.rs

Lines changed: 86 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use datafusion::logical_expr::EmitTo;
12
use datafusion::physical_plan::aggregates::group_values::multi_group_by::GroupColumn;
23

34
use std::mem::{self, size_of};
@@ -6,15 +7,15 @@ use datafusion::arrow::array::{Array, ArrayRef, RecordBatch};
67
use datafusion::arrow::compute::cast;
78
use datafusion::arrow::datatypes::{
89
BinaryType, BinaryViewType, DataType, Date32Type, Date64Type, Decimal128Type, Float32Type,
9-
Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, LargeBinaryType, LargeUtf8Type,
10-
Schema, SchemaRef, StringViewType, Time32MillisecondType, Time32SecondType,
11-
Time64MicrosecondType, Time64NanosecondType, TimeUnit, TimestampMicrosecondType,
12-
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType, UInt16Type,
13-
UInt32Type, UInt64Type, UInt8Type, Utf8Type,
10+
Float64Type, Int16Type, Int32Type, Int64Type, Int8Type, LargeBinaryType, LargeUtf8Type, Schema,
11+
SchemaRef, StringViewType, Time32MillisecondType, Time32SecondType, Time64MicrosecondType,
12+
Time64NanosecondType, TimeUnit, TimestampMicrosecondType, TimestampMillisecondType,
13+
TimestampNanosecondType, TimestampSecondType, UInt16Type, UInt32Type, UInt64Type, UInt8Type,
14+
Utf8Type,
1415
};
1516
use datafusion::dfschema::internal_err;
1617
use datafusion::dfschema::not_impl_err;
17-
use datafusion::error::Result as DFResult;
18+
use datafusion::error::{DataFusionError, Result as DFResult};
1819
use datafusion::physical_expr::binary_map::OutputType;
1920
use datafusion::physical_plan::aggregates::group_values::multi_group_by::{
2021
ByteGroupValueBuilder, ByteViewGroupValueBuilder, PrimitiveGroupValueBuilder,
@@ -129,41 +130,83 @@ impl SortedGroupValues {
129130
&DataType::Time32(t) => match t {
130131
TimeUnit::Second => {
131132
instantiate_primitive!(v, nullable, Time32SecondType, data_type);
132-
instantiate_primitive_comparator!(comparators, nullable, Time32SecondType);
133+
instantiate_primitive_comparator!(
134+
comparators,
135+
nullable,
136+
Time32SecondType
137+
);
133138
}
134139
TimeUnit::Millisecond => {
135140
instantiate_primitive!(v, nullable, Time32MillisecondType, data_type);
136-
instantiate_primitive_comparator!(comparators, nullable, Time32MillisecondType);
141+
instantiate_primitive_comparator!(
142+
comparators,
143+
nullable,
144+
Time32MillisecondType
145+
);
137146
}
138147
_ => {}
139148
},
140149
&DataType::Time64(t) => match t {
141150
TimeUnit::Microsecond => {
142151
instantiate_primitive!(v, nullable, Time64MicrosecondType, data_type);
143-
instantiate_primitive_comparator!(comparators, nullable, Time64MicrosecondType);
152+
instantiate_primitive_comparator!(
153+
comparators,
154+
nullable,
155+
Time64MicrosecondType
156+
);
144157
}
145158
TimeUnit::Nanosecond => {
146159
instantiate_primitive!(v, nullable, Time64NanosecondType, data_type);
147-
instantiate_primitive_comparator!(comparators, nullable, Time64NanosecondType);
160+
instantiate_primitive_comparator!(
161+
comparators,
162+
nullable,
163+
Time64NanosecondType
164+
);
148165
}
149166
_ => {}
150167
},
151168
&DataType::Timestamp(t, _) => match t {
152169
TimeUnit::Second => {
153170
instantiate_primitive!(v, nullable, TimestampSecondType, data_type);
154-
instantiate_primitive_comparator!(comparators, nullable, TimestampSecondType);
171+
instantiate_primitive_comparator!(
172+
comparators,
173+
nullable,
174+
TimestampSecondType
175+
);
155176
}
156177
TimeUnit::Millisecond => {
157-
instantiate_primitive!(v, nullable, TimestampMillisecondType, data_type);
158-
instantiate_primitive_comparator!(comparators, nullable, TimestampMillisecondType);
178+
instantiate_primitive!(
179+
v,
180+
nullable,
181+
TimestampMillisecondType,
182+
data_type
183+
);
184+
instantiate_primitive_comparator!(
185+
comparators,
186+
nullable,
187+
TimestampMillisecondType
188+
);
159189
}
160190
TimeUnit::Microsecond => {
161-
instantiate_primitive!(v, nullable, TimestampMicrosecondType, data_type);
162-
instantiate_primitive_comparator!(comparators, nullable, TimestampMicrosecondType);
191+
instantiate_primitive!(
192+
v,
193+
nullable,
194+
TimestampMicrosecondType,
195+
data_type
196+
);
197+
instantiate_primitive_comparator!(
198+
comparators,
199+
nullable,
200+
TimestampMicrosecondType
201+
);
163202
}
164203
TimeUnit::Nanosecond => {
165204
instantiate_primitive!(v, nullable, TimestampNanosecondType, data_type);
166-
instantiate_primitive_comparator!(comparators, nullable, TimestampNanosecondType);
205+
instantiate_primitive_comparator!(
206+
comparators,
207+
nullable,
208+
TimestampNanosecondType
209+
);
167210
}
168211
},
169212
&DataType::Decimal128(_, _) => {
@@ -231,8 +274,8 @@ impl SortedGroupValues {
231274
self.group_values[0].len()
232275
}
233276

234-
pub fn emit(&mut self) -> DFResult<Vec<ArrayRef>> {
235-
/* let mut output = match emit_to {
277+
fn emit(&mut self, emit_to: EmitTo) -> DFResult<Vec<ArrayRef>> {
278+
let mut output = match emit_to {
236279
EmitTo::All => {
237280
let group_values = mem::take(&mut self.group_values);
238281
debug_assert!(self.group_values.is_empty());
@@ -253,7 +296,6 @@ impl SortedGroupValues {
253296
}
254297
};
255298

256-
// TODO: Materialize dictionaries in group keys (#7647)
257299
for (field, array) in self.schema.fields.iter().zip(&mut output) {
258300
let expected = field.data_type();
259301
if let DataType::Dictionary(_, v) = expected {
@@ -267,78 +309,82 @@ impl SortedGroupValues {
267309
}
268310
}
269311

270-
Ok(output) */
271-
todo!()
312+
Ok(output)
272313
}
273314

274315
fn clear_shrink(&mut self, batch: &RecordBatch) {
275316
self.group_values.clear();
317+
self.comparators.clear();
276318
self.rows_inds.clear();
277319
self.equal_to_results.clear();
278320
}
279321

280322
fn intern_impl(&mut self, cols: &[ArrayRef], groups: &mut Vec<usize>) -> DFResult<()> {
281-
/* let n_rows = cols[0].len();
323+
let n_rows = cols[0].len();
282324
groups.clear();
283325

284326
if n_rows == 0 {
285327
return Ok(());
286328
}
287329

330+
// Handle first row - compare with last group or create new group
288331
let first_group_idx = self.make_new_group_if_needed(cols, 0);
289332
groups.push(first_group_idx);
290333

291334
if n_rows == 1 {
292335
return Ok(());
293336
}
294337

295-
if self.rows_inds.len() < n_rows {
296-
let old_len = self.rows_inds.len();
297-
self.rows_inds.extend(old_len..n_rows);
298-
}
299-
300-
self.equal_to_results.fill(true);
338+
// Prepare buffer for vectorized comparison
301339
self.equal_to_results.resize(n_rows - 1, true);
340+
self.equal_to_results[..n_rows - 1].fill(true);
302341

303-
let lhs_rows = &self.rows_inds[0..n_rows - 1];
304-
let rhs_rows = &self.rows_inds[1..n_rows];
305-
for (col_idx, group_col) in self.group_values.iter().enumerate() {
306-
cols[col_idx].vectorized_equal_to(
307-
lhs_rows,
308-
&cols[col_idx],
309-
rhs_rows,
310-
&mut self.equal_to_results,
311-
);
342+
// Vectorized comparison: compare row[i] with row[i+1] for all columns
343+
for (col, comparator) in cols.iter().zip(&self.comparators) {
344+
comparator.compare_adjacent(col, &mut self.equal_to_results[..n_rows - 1]);
312345
}
313-
println!("!!!!! AAAAAAAAAA");
346+
347+
// Build groups based on comparison results
314348
let mut current_group_idx = first_group_idx;
315349
for i in 0..n_rows - 1 {
316350
if !self.equal_to_results[i] {
351+
// Group boundary detected - add new group
317352
for (col_idx, group_value) in self.group_values.iter_mut().enumerate() {
318353
group_value.append_val(&cols[col_idx], i + 1);
319354
}
320355
current_group_idx = self.group_values[0].len() - 1;
321356
}
322357
groups.push(current_group_idx);
323358
}
324-
println!("!!!!! BBBBBBB");
325-
Ok(()) */
359+
326360
Ok(())
327361
}
328362

363+
/// Compare the specified row with the last group and create a new group if different.
364+
///
365+
/// This is used to handle the first row of a batch, which needs to be compared
366+
/// with the last group from the previous batch to detect group boundaries across batches.
367+
///
368+
/// Returns the group index for this row.
329369
fn make_new_group_if_needed(&mut self, cols: &[ArrayRef], row: usize) -> usize {
330370
let new_group_needed = if self.group_values[0].len() == 0 {
371+
// No groups yet - always create first group
331372
true
332373
} else {
374+
// Compare with last group - if any column differs, need new group
333375
self.group_values.iter().enumerate().any(|(i, group_val)| {
334376
!group_val.equal_to(self.group_values[0].len() - 1, &cols[i], row)
335377
})
336378
};
379+
337380
if new_group_needed {
381+
// Add new group with values from this row
338382
for (i, group_value) in self.group_values.iter_mut().enumerate() {
339383
group_value.append_val(&cols[i], row);
340384
}
341385
}
386+
387+
// Return index of the group (either newly created or existing last group)
342388
self.group_values[0].len() - 1
343389
}
344390
}

0 commit comments

Comments
 (0)