Skip to content

Commit db87d62

Browse files
committed
in work
1 parent 2ef63e9 commit db87d62

File tree

6 files changed

+226
-163
lines changed

6 files changed

+226
-163
lines changed

rust/cubestore/cubestore/src/queryplanner/inline_aggregate/inline_aggregate_stream.rs

Lines changed: 142 additions & 157 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use crate::metastore::multi_index::MultiPartition;
77
use crate::metastore::table::Table;
88
use crate::metastore::{Column, ColumnType, IdRow, Index, Partition};
99
use crate::queryplanner::filter_by_key_range::FilterByKeyRangeExec;
10+
use crate::queryplanner::inline_aggregate::sorted_group_values::SortedGroupValues;
1011
use crate::queryplanner::merge_sort::LastRowByUniqueKeyExec;
1112
use crate::queryplanner::metadata_cache::{MetadataCacheFactory, NoopParquetMetadataCache};
1213
use crate::queryplanner::optimizations::{CubeQueryPlanner, PreOptimizeRule};
@@ -50,8 +51,8 @@ use datafusion::dfschema::internal_err;
5051
use datafusion::dfschema::not_impl_err;
5152
use datafusion::error::DataFusionError;
5253
use datafusion::error::Result as DFResult;
53-
use datafusion::execution::TaskContext;
54-
use datafusion::logical_expr::{Expr, GroupsAccumulator, LogicalPlan};
54+
use datafusion::execution::{RecordBatchStream, TaskContext};
55+
use datafusion::logical_expr::{EmitTo, Expr, GroupsAccumulator, LogicalPlan};
5556
use datafusion::physical_expr::expressions::Column as DFColumn;
5657
use datafusion::physical_expr::LexOrdering;
5758
use datafusion::physical_expr::{self, GroupsAccumulatorAdapter};
@@ -135,6 +136,7 @@ pub(crate) struct InlineAggregateStream {
135136
input_done: bool,
136137

137138
accumulators: Vec<Box<dyn GroupsAccumulator>>,
139+
group_values: SortedGroupValues,
138140
current_group_indices: Vec<usize>,
139141
}
140142

@@ -189,6 +191,7 @@ impl InlineAggregateStream {
189191

190192
let exec_state = ExecutionState::ReadingInput;
191193
let current_group_indices = Vec::with_capacity(batch_size);
194+
let group_values = SortedGroupValues::try_new(group_schema)?;
192195

193196
Ok(InlineAggregateStream {
194197
schema: agg_schema,
@@ -201,6 +204,7 @@ impl InlineAggregateStream {
201204
exec_state,
202205
batch_size,
203206
current_group_indices,
207+
group_values,
204208
input_done: false,
205209
})
206210
}
@@ -303,144 +307,68 @@ impl Stream for InlineAggregateStream {
303307
) -> Poll<Option<Self::Item>> {
304308
loop {
305309
match &self.exec_state {
306-
ExecutionState::ReadingInput => 'reading_input: {
310+
ExecutionState::ReadingInput => {
307311
match ready!(self.input.poll_next_unpin(cx)) {
308-
// New batch to aggregate in partial aggregation operator
309-
Some(Ok(batch)) if self.mode == InlineAggregateMode::Partial => {
310-
/* let timer = elapsed_compute.timer();
311-
let input_rows = batch.num_rows();
312-
313-
// Do the grouping
314-
self.group_aggregate_batch(batch)?;
315-
316-
self.update_skip_aggregation_probe(input_rows);
317-
318-
// If we can begin emitting rows, do so,
319-
// otherwise keep consuming input
320-
assert!(!self.input_done);
321-
322-
// If the number of group values equals or exceeds the soft limit,
323-
// emit all groups and switch to producing output
324-
if self.hit_soft_group_limit() {
325-
timer.done();
326-
self.set_input_done_and_produce_output()?;
327-
// make sure the exec_state just set is not overwritten below
328-
break 'reading_input;
329-
}
330-
331-
if let Some(to_emit) = self.group_ordering.emit_to() {
332-
timer.done();
333-
if let Some(batch) = self.emit(to_emit, false)? {
334-
335-
ExecutionState::ProducingOutput(batch);
336-
};
337-
// make sure the exec_state just set is not overwritten below
338-
break 'reading_input;
339-
}
340-
341-
self.emit_early_if_necessary()?;
342-
343-
self.switch_to_skip_aggregation()?;
344-
345-
timer.done(); */
346-
todo!()
347-
}
348-
349-
// New batch to aggregate in terminal aggregation operator
350-
// (Final/FinalPartitioned/Single/SinglePartitioned)
312+
// New input batch to aggregate
351313
Some(Ok(batch)) => {
352-
/* let timer = elapsed_compute.timer();
353-
354-
// Make sure we have enough capacity for `batch`, otherwise spill
355-
self.spill_previous_if_necessary(&batch)?;
356-
357-
// Do the grouping
358-
359-
360-
// If we can begin emitting rows, do so,
361-
// otherwise keep consuming input
362-
assert!(!self.input_done);
363-
364-
// If the number of group values equals or exceeds the soft limit,
365-
// emit all groups and switch to producing output
366-
if self.hit_soft_group_limit() {
367-
timer.done();
368-
self.set_input_done_and_produce_output()?;
369-
// make sure the exec_state just set is not overwritten below
370-
break 'reading_input;
314+
// Aggregate the batch
315+
if let Err(e) = self.group_aggregate_batch(batch) {
316+
return Poll::Ready(Some(Err(e)));
371317
}
372318

373-
if let Some(to_emit) = self.group_ordering.emit_to() {
374-
timer.done();
375-
if let Some(batch) = self.emit(to_emit, false)? {
376-
self.exec_state =
377-
ExecutionState::ProducingOutput(batch);
378-
};
379-
// make sure the exec_state just set is not overwritten below
380-
break 'reading_input;
319+
// Try to emit a batch if we have enough groups
320+
match self.emit_early_if_ready() {
321+
Ok(Some(batch)) => {
322+
self.exec_state = ExecutionState::ProducingOutput(batch);
323+
}
324+
Ok(None) => {
325+
// Not enough groups yet, continue reading
326+
}
327+
Err(e) => {
328+
return Poll::Ready(Some(Err(e)));
329+
}
381330
}
382-
383-
timer.done(); */
384-
todo!()
385331
}
386332

387-
// Found error from input stream
333+
// Error from input stream
388334
Some(Err(e)) => {
389-
// inner had error, return to caller
390335
return Poll::Ready(Some(Err(e)));
391336
}
392337

393-
// Found end from input stream
338+
// Input stream exhausted - emit all remaining groups
394339
None => {
395-
// inner is done, emit all rows and switch to producing output
396-
//self.set_input_done_and_produce_output()?;
397-
todo!()
340+
self.input_done = true;
341+
342+
match self.emit(EmitTo::All) {
343+
Ok(Some(batch)) => {
344+
self.exec_state = ExecutionState::ProducingOutput(batch);
345+
}
346+
Ok(None) => {
347+
// No groups to emit, we're done
348+
self.exec_state = ExecutionState::Done;
349+
}
350+
Err(e) => {
351+
return Poll::Ready(Some(Err(e)));
352+
}
353+
}
398354
}
399355
}
400356
}
401357

402358
ExecutionState::ProducingOutput(batch) => {
403-
// slice off a part of the batch, if needed
404-
/* let output_batch;
405-
let size = self.batch_size;
406-
(self.exec_state, output_batch) = if batch.num_rows() <= size {
407-
(
408-
if self.input_done {
409-
ExecutionState::Done
410-
}
411-
// In Partial aggregation, we also need to check
412-
// if we should trigger partial skipping
413-
else if self.mode == AggregateMode::Partial
414-
&& self.should_skip_aggregation()
415-
{
416-
ExecutionState::SkippingAggregation
417-
} else {
418-
ExecutionState::ReadingInput
419-
},
420-
batch.clone(),
421-
)
359+
let batch = batch.clone();
360+
361+
// Determine next state
362+
self.exec_state = if self.input_done {
363+
ExecutionState::Done
422364
} else {
423-
// output first batch_size rows
424-
let size = self.batch_size;
425-
let num_remaining = batch.num_rows() - size;
426-
let remaining = batch.slice(size, num_remaining);
427-
let output = batch.slice(0, size);
428-
(ExecutionState::ProducingOutput(remaining), output)
365+
ExecutionState::ReadingInput
429366
};
430-
// Empty record batches should not be emitted.
431-
// They need to be treated as [`Option<RecordBatch>`]es and handled separately
432-
debug_assert!(output_batch.num_rows() > 0);
433-
return Poll::Ready(Some(Ok(
434-
output_batch.record_output(&self.baseline_metrics)
435-
))); */
436-
todo!()
367+
368+
return Poll::Ready(Some(Ok(batch)));
437369
}
438370

439371
ExecutionState::Done => {
440-
// release the memory reservation since sending back output batch itself needs
441-
// some memory reservation, so make some room for it.
442-
/* self.clear_all();
443-
let _ = self.update_memory_reservation(); */
444372
return Poll::Ready(None);
445373
}
446374
}
@@ -449,58 +377,109 @@ impl Stream for InlineAggregateStream {
449377
}
450378

451379
impl InlineAggregateStream {
380+
/// Emit groups based on EmitTo strategy.
381+
///
382+
/// Returns None if there are no groups to emit.
383+
/// Emit groups based on EmitTo strategy.
384+
///
385+
/// Returns None if there are no groups to emit.
386+
fn emit(&mut self, emit_to: EmitTo) -> DFResult<Option<RecordBatch>> {
387+
if self.group_values.is_empty() {
388+
return Ok(None);
389+
}
390+
391+
// Get group values arrays
392+
let group_arrays = self.group_values.emit(emit_to)?;
393+
394+
// Get aggregate arrays based on mode
395+
let mut aggr_arrays = vec![];
396+
for acc in &mut self.accumulators {
397+
match self.mode {
398+
InlineAggregateMode::Partial => {
399+
// Emit intermediate state
400+
let state = acc.state(emit_to)?;
401+
aggr_arrays.extend(state);
402+
}
403+
InlineAggregateMode::Final => {
404+
// Emit final aggregated values
405+
aggr_arrays.push(acc.evaluate(emit_to)?);
406+
}
407+
}
408+
}
409+
410+
// Combine group columns and aggregate columns
411+
let mut columns = group_arrays;
412+
columns.extend(aggr_arrays);
413+
414+
let batch = RecordBatch::try_new(Arc::clone(&self.schema), columns)?;
415+
416+
Ok(Some(batch))
417+
}
418+
419+
/// Check if we have enough groups to emit a batch, keeping the last (potentially incomplete) group.
420+
///
421+
/// For sorted aggregation, we emit batches of size batch_size when we have accumulated
422+
/// more than batch_size groups. We always keep the last group as it may continue in the next input batch.
423+
fn should_emit_early(&self) -> bool {
424+
// Need at least (batch_size + 1) groups to emit batch_size and keep 1
425+
self.group_values.len() > self.batch_size
426+
}
427+
428+
/// Emit a batch of groups if we have enough accumulated, keeping the last group.
429+
///
430+
/// Returns Some(batch) if emitted, None otherwise.
431+
fn emit_early_if_ready(&mut self) -> DFResult<Option<RecordBatch>> {
432+
if !self.should_emit_early() {
433+
return Ok(None);
434+
}
435+
436+
// Emit exactly batch_size groups, keeping the rest (including last incomplete group)
437+
self.emit(EmitTo::First(self.batch_size))
438+
}
439+
452440
fn group_aggregate_batch(&mut self, batch: RecordBatch) -> DFResult<()> {
453441
// Evaluate the grouping expressions
454-
/* let group_by_values = evaluate_group_by(&self.group_by, &batch)?;
442+
let group_by_values = evaluate_group_by(&self.group_by, &batch)?;
455443

456444
// Evaluate the aggregation expressions.
457445
let input_values = evaluate_many(&self.aggregate_arguments, &batch)?;
458446

459447
// Evaluate the filter expressions, if any, against the inputs
460448
let filter_values = evaluate_optional(&self.filter_expressions, &batch)?;
461449

462-
for group_values in &group_by_values {
463-
// calculate the group indices for each input row
464-
let starting_num_groups = self.group_values.len();
465-
self.group_values
466-
.intern(group_values, &mut self.current_group_indices)?;
467-
let group_indices = &self.current_group_indices;
468-
469-
// Update ordering information if necessary
470-
/* let total_num_groups = self.group_values.len();
471-
if total_num_groups > starting_num_groups {
472-
self.group_ordering
473-
.new_groups(group_values, group_indices, total_num_groups)?;
474-
} */
475-
476-
// Gather the inputs to call the actual accumulator
477-
let t = self
478-
.accumulators
479-
.iter_mut()
480-
.zip(input_values.iter())
481-
.zip(filter_values.iter());
482-
483-
for ((acc, values), opt_filter) in t {
484-
let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean());
485-
486-
// Call the appropriate method on each aggregator with
487-
// the entire input row and the relevant group indexes
488-
match self.mode {
489-
InlineAggregateMode::Partial => {
490-
acc.update_batch(values, group_indices, opt_filter, total_num_groups)?;
450+
assert_eq!(group_by_values.len(), 1, "Exactly 1 group value required");
451+
self.group_values
452+
.intern(&group_by_values[0], &mut self.current_group_indices)?;
453+
let group_indices = &self.current_group_indices;
454+
455+
let total_num_groups = self.group_values.len();
456+
// Gather the inputs to call the actual accumulator
457+
let t = self
458+
.accumulators
459+
.iter_mut()
460+
.zip(input_values.iter())
461+
.zip(filter_values.iter());
462+
463+
for ((acc, values), opt_filter) in t {
464+
let opt_filter = opt_filter.as_ref().map(|filter| filter.as_boolean());
465+
466+
// Call the appropriate method on each aggregator with
467+
// the entire input row and the relevant group indexes
468+
match self.mode {
469+
InlineAggregateMode::Partial => {
470+
acc.update_batch(values, group_indices, opt_filter, total_num_groups)?;
471+
}
472+
_ => {
473+
if opt_filter.is_some() {
474+
return internal_err!("aggregate filter should be applied in partial stage, there should be no filter in final stage");
491475
}
492-
_ => {
493-
if opt_filter.is_some() {
494-
return internal_err!("aggregate filter should be applied in partial stage, there should be no filter in final stage");
495-
}
496476

497-
// if aggregation is over intermediate states,
498-
// use merge
499-
acc.merge_batch(values, group_indices, None, total_num_groups)?;
500-
}
477+
// if aggregation is over intermediate states,
478+
// use merge
479+
acc.merge_batch(values, group_indices, None, total_num_groups)?;
501480
}
502481
}
503-
} */
482+
}
504483
Ok(())
505484
}
506485
}
@@ -609,3 +588,9 @@ fn evaluate_group_by(
609588
})
610589
.collect()
611590
}
591+
592+
impl RecordBatchStream for InlineAggregateStream {
593+
fn schema(&self) -> SchemaRef {
594+
Arc::clone(&self.schema)
595+
}
596+
}

0 commit comments

Comments
 (0)