Skip to content

Commit 2ae30af

Browse files
authored
Support serializing generate_series in datafusion-proto (#17200)
* Allow `generate_series` to be serialized via protobuf * Add breaking change to the upgrade guide
1 parent b84ddfd commit 2ae30af

File tree

9 files changed

+1480
-97
lines changed

9 files changed

+1480
-97
lines changed

datafusion/core/tests/execution/coop.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ use datafusion_physical_plan::union::InterleaveExec;
5454
use futures::StreamExt;
5555
use parking_lot::RwLock;
5656
use rstest::rstest;
57+
use std::any::Any;
5758
use std::error::Error;
5859
use std::fmt::Formatter;
5960
use std::ops::Range;
@@ -80,6 +81,10 @@ impl std::fmt::Display for RangeBatchGenerator {
8081
}
8182

8283
impl LazyBatchGenerator for RangeBatchGenerator {
84+
fn as_any(&self) -> &dyn Any {
85+
self
86+
}
87+
8388
fn boundedness(&self) -> Boundedness {
8489
self.boundedness
8590
}

datafusion/functions-table/src/generate_series.rs

Lines changed: 166 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -31,17 +31,28 @@ use datafusion_expr::{Expr, TableType};
3131
use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec};
3232
use datafusion_physical_plan::ExecutionPlan;
3333
use parking_lot::RwLock;
34+
use std::any::Any;
3435
use std::fmt;
3536
use std::str::FromStr;
3637
use std::sync::Arc;
3738

3839
/// Empty generator that produces no rows - used when series arguments contain null values
3940
#[derive(Debug, Clone)]
40-
struct Empty {
41+
pub struct Empty {
4142
name: &'static str,
4243
}
4344

45+
impl Empty {
46+
pub fn name(&self) -> &'static str {
47+
self.name
48+
}
49+
}
50+
4451
impl LazyBatchGenerator for Empty {
52+
fn as_any(&self) -> &dyn Any {
53+
self
54+
}
55+
4556
fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
4657
Ok(None)
4758
}
@@ -54,7 +65,7 @@ impl fmt::Display for Empty {
5465
}
5566

5667
/// Trait for values that can be generated in a series
57-
trait SeriesValue: fmt::Debug + Clone + Send + Sync + 'static {
68+
pub trait SeriesValue: fmt::Debug + Clone + Send + Sync + 'static {
5869
type StepType: fmt::Debug + Clone + Send + Sync;
5970
type ValueType: fmt::Debug + Clone + Send + Sync;
6071

@@ -101,12 +112,22 @@ impl SeriesValue for i64 {
101112
}
102113

103114
#[derive(Debug, Clone)]
104-
struct TimestampValue {
115+
pub struct TimestampValue {
105116
value: i64,
106117
parsed_tz: Option<Tz>,
107118
tz_str: Option<Arc<str>>,
108119
}
109120

121+
impl TimestampValue {
122+
pub fn value(&self) -> i64 {
123+
self.value
124+
}
125+
126+
pub fn tz_str(&self) -> Option<&Arc<str>> {
127+
self.tz_str.as_ref()
128+
}
129+
}
130+
110131
impl SeriesValue for TimestampValue {
111132
type StepType = IntervalMonthDayNano;
112133
type ValueType = i64;
@@ -167,7 +188,7 @@ impl SeriesValue for TimestampValue {
167188

168189
/// Indicates the arguments used for generating a series.
169190
#[derive(Debug, Clone)]
170-
enum GenSeriesArgs {
191+
pub enum GenSeriesArgs {
171192
/// ContainsNull signifies that at least one argument(start, end, step) was null, thus no series will be generated.
172193
ContainsNull { name: &'static str },
173194
/// Int64Args holds the start, end, and step values for generating integer series when all arguments are not null.
@@ -203,13 +224,115 @@ enum GenSeriesArgs {
203224

204225
/// Table that generates a series of integers/timestamps from `start`(inclusive) to `end`, incrementing by step
205226
#[derive(Debug, Clone)]
206-
struct GenerateSeriesTable {
227+
pub struct GenerateSeriesTable {
207228
schema: SchemaRef,
208229
args: GenSeriesArgs,
209230
}
210231

232+
impl GenerateSeriesTable {
233+
pub fn new(schema: SchemaRef, args: GenSeriesArgs) -> Self {
234+
Self { schema, args }
235+
}
236+
237+
pub fn as_generator(
238+
&self,
239+
batch_size: usize,
240+
) -> Result<Arc<RwLock<dyn LazyBatchGenerator>>> {
241+
let generator: Arc<RwLock<dyn LazyBatchGenerator>> = match &self.args {
242+
GenSeriesArgs::ContainsNull { name } => Arc::new(RwLock::new(Empty { name })),
243+
GenSeriesArgs::Int64Args {
244+
start,
245+
end,
246+
step,
247+
include_end,
248+
name,
249+
} => Arc::new(RwLock::new(GenericSeriesState {
250+
schema: self.schema(),
251+
start: *start,
252+
end: *end,
253+
step: *step,
254+
current: *start,
255+
batch_size,
256+
include_end: *include_end,
257+
name,
258+
})),
259+
GenSeriesArgs::TimestampArgs {
260+
start,
261+
end,
262+
step,
263+
tz,
264+
include_end,
265+
name,
266+
} => {
267+
let parsed_tz = tz
268+
.as_ref()
269+
.map(|s| Tz::from_str(s.as_ref()))
270+
.transpose()
271+
.map_err(|e| {
272+
datafusion_common::DataFusionError::Internal(format!(
273+
"Failed to parse timezone: {e}"
274+
))
275+
})?
276+
.unwrap_or_else(|| Tz::from_str("+00:00").unwrap());
277+
Arc::new(RwLock::new(GenericSeriesState {
278+
schema: self.schema(),
279+
start: TimestampValue {
280+
value: *start,
281+
parsed_tz: Some(parsed_tz),
282+
tz_str: tz.clone(),
283+
},
284+
end: TimestampValue {
285+
value: *end,
286+
parsed_tz: Some(parsed_tz),
287+
tz_str: tz.clone(),
288+
},
289+
step: *step,
290+
current: TimestampValue {
291+
value: *start,
292+
parsed_tz: Some(parsed_tz),
293+
tz_str: tz.clone(),
294+
},
295+
batch_size,
296+
include_end: *include_end,
297+
name,
298+
}))
299+
}
300+
GenSeriesArgs::DateArgs {
301+
start,
302+
end,
303+
step,
304+
include_end,
305+
name,
306+
} => Arc::new(RwLock::new(GenericSeriesState {
307+
schema: self.schema(),
308+
start: TimestampValue {
309+
value: *start,
310+
parsed_tz: None,
311+
tz_str: None,
312+
},
313+
end: TimestampValue {
314+
value: *end,
315+
parsed_tz: None,
316+
tz_str: None,
317+
},
318+
step: *step,
319+
current: TimestampValue {
320+
value: *start,
321+
parsed_tz: None,
322+
tz_str: None,
323+
},
324+
batch_size,
325+
include_end: *include_end,
326+
name,
327+
})),
328+
};
329+
330+
Ok(generator)
331+
}
332+
}
333+
211334
#[derive(Debug, Clone)]
212-
struct GenericSeriesState<T: SeriesValue> {
335+
pub struct GenericSeriesState<T: SeriesValue> {
213336
schema: SchemaRef,
214337
start: T,
215338
end: T,
@@ -220,7 +343,41 @@ struct GenericSeriesState<T: SeriesValue> {
220343
name: &'static str,
221344
}
222345

346+
impl<T: SeriesValue> GenericSeriesState<T> {
347+
pub fn name(&self) -> &'static str {
348+
self.name
349+
}
350+
351+
pub fn batch_size(&self) -> usize {
352+
self.batch_size
353+
}
354+
355+
pub fn include_end(&self) -> bool {
356+
self.include_end
357+
}
358+
359+
pub fn start(&self) -> &T {
360+
&self.start
361+
}
362+
363+
pub fn end(&self) -> &T {
364+
&self.end
365+
}
366+
367+
pub fn step(&self) -> &T::StepType {
368+
&self.step
369+
}
370+
371+
pub fn current(&self) -> &T {
372+
&self.current
373+
}
374+
}
375+
223376
impl<T: SeriesValue> LazyBatchGenerator for GenericSeriesState<T> {
377+
fn as_any(&self) -> &dyn Any {
378+
self
379+
}
380+
224381
fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
225382
let mut buf = Vec::with_capacity(self.batch_size);
226383

@@ -295,7 +452,7 @@ fn validate_interval_step(
295452

296453
#[async_trait]
297454
impl TableProvider for GenerateSeriesTable {
298-
fn as_any(&self) -> &dyn std::any::Any {
455+
fn as_any(&self) -> &dyn Any {
299456
self
300457
}
301458

@@ -319,94 +476,8 @@ impl TableProvider for GenerateSeriesTable {
319476
Some(projection) => Arc::new(self.schema.project(projection)?),
320477
None => self.schema(),
321478
};
322-
let generator: Arc<RwLock<dyn LazyBatchGenerator>> = match &self.args {
323-
GenSeriesArgs::ContainsNull { name } => Arc::new(RwLock::new(Empty { name })),
324-
GenSeriesArgs::Int64Args {
325-
start,
326-
end,
327-
step,
328-
include_end,
329-
name,
330-
} => Arc::new(RwLock::new(GenericSeriesState {
331-
schema: self.schema(),
332-
start: *start,
333-
end: *end,
334-
step: *step,
335-
current: *start,
336-
batch_size,
337-
include_end: *include_end,
338-
name,
339-
})),
340-
GenSeriesArgs::TimestampArgs {
341-
start,
342-
end,
343-
step,
344-
tz,
345-
include_end,
346-
name,
347-
} => {
348-
let parsed_tz = tz
349-
.as_ref()
350-
.map(|s| Tz::from_str(s.as_ref()))
351-
.transpose()
352-
.map_err(|e| {
353-
datafusion_common::DataFusionError::Internal(format!(
354-
"Failed to parse timezone: {e}"
355-
))
356-
})?
357-
.unwrap_or_else(|| Tz::from_str("+00:00").unwrap());
358-
Arc::new(RwLock::new(GenericSeriesState {
359-
schema: self.schema(),
360-
start: TimestampValue {
361-
value: *start,
362-
parsed_tz: Some(parsed_tz),
363-
tz_str: tz.clone(),
364-
},
365-
end: TimestampValue {
366-
value: *end,
367-
parsed_tz: Some(parsed_tz),
368-
tz_str: tz.clone(),
369-
},
370-
step: *step,
371-
current: TimestampValue {
372-
value: *start,
373-
parsed_tz: Some(parsed_tz),
374-
tz_str: tz.clone(),
375-
},
376-
batch_size,
377-
include_end: *include_end,
378-
name,
379-
}))
380-
}
381-
GenSeriesArgs::DateArgs {
382-
start,
383-
end,
384-
step,
385-
include_end,
386-
name,
387-
} => Arc::new(RwLock::new(GenericSeriesState {
388-
schema: self.schema(),
389-
start: TimestampValue {
390-
value: *start,
391-
parsed_tz: None,
392-
tz_str: None,
393-
},
394-
end: TimestampValue {
395-
value: *end,
396-
parsed_tz: None,
397-
tz_str: None,
398-
},
399-
step: *step,
400-
current: TimestampValue {
401-
value: *start,
402-
parsed_tz: None,
403-
tz_str: None,
404-
},
405-
batch_size,
406-
include_end: *include_end,
407-
name,
408-
})),
409-
};
479+
480+
let generator = self.as_generator(batch_size)?;
410481

411482
Ok(Arc::new(LazyMemoryExec::try_new(schema, vec![generator])?))
412483
}

datafusion/physical-plan/src/memory.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ impl RecordBatchStream for MemoryStream {
134134
}
135135

136136
pub trait LazyBatchGenerator: Send + Sync + fmt::Debug + fmt::Display {
137+
/// Returns the generator as [`Any`] so that it can be
138+
/// downcast to a specific implementation.
139+
fn as_any(&self) -> &dyn Any;
140+
137141
fn boundedness(&self) -> Boundedness {
138142
Boundedness::Bounded
139143
}
@@ -219,6 +223,11 @@ impl LazyMemoryExec {
219223
.eq_properties
220224
.add_orderings(std::iter::once(ordering));
221225
}
226+
227+
/// Get the batch generators
228+
pub fn generators(&self) -> &Vec<Arc<RwLock<dyn LazyBatchGenerator>>> {
229+
&self.batch_generators
230+
}
222231
}
223232

224233
impl fmt::Debug for LazyMemoryExec {
@@ -394,6 +403,10 @@ mod lazy_memory_tests {
394403
}
395404

396405
impl LazyBatchGenerator for TestGenerator {
406+
fn as_any(&self) -> &dyn Any {
407+
self
408+
}
409+
397410
fn generate_next_batch(&mut self) -> Result<Option<RecordBatch>> {
398411
if self.counter >= self.max_batches {
399412
return Ok(None);

0 commit comments

Comments
 (0)