Skip to content

Commit 0b35094

Browse files
committed
feat: add timezone and list support for timestamps
1 parent 58ac220 commit 0b35094

File tree

1 file changed

+284
-61
lines changed

1 file changed

+284
-61
lines changed

datafusion-postgres/src/datatypes.rs

Lines changed: 284 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
use std::str::FromStr;
12
use std::sync::Arc;
23

3-
use chrono::{DateTime, Datelike, FixedOffset};
4+
use chrono::{DateTime, Datelike, FixedOffset, TimeZone, Utc};
45
use chrono::{NaiveDate, NaiveDateTime};
56
use datafusion::arrow::array::*;
67
use datafusion::arrow::datatypes::*;
@@ -13,6 +14,7 @@ use pgwire::api::portal::{Format, Portal};
1314
use pgwire::api::results::{DataRowEncoder, FieldInfo, QueryResponse};
1415
use pgwire::api::Type;
1516
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
17+
use timezone::Tz;
1618

1719
pub(crate) fn into_pg_type(df_type: &DataType) -> PgWireResult<Type> {
1820
Ok(match df_type {
@@ -228,45 +230,6 @@ fn get_time64_nanosecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<Naive
228230
.value_as_datetime(idx)
229231
}
230232

231-
fn get_timestamp_second_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
232-
arr.as_any()
233-
.downcast_ref::<TimestampSecondArray>()
234-
.unwrap()
235-
.value_as_datetime(idx)
236-
}
237-
238-
fn get_timestamp_millisecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
239-
arr.as_any()
240-
.downcast_ref::<TimestampMillisecondArray>()
241-
.unwrap()
242-
.value_as_datetime(idx)
243-
}
244-
245-
fn get_timestamp_microsecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
246-
arr.as_any()
247-
.downcast_ref::<TimestampMicrosecondArray>()
248-
.unwrap()
249-
.value_as_datetime(idx)
250-
}
251-
252-
fn get_timestamp_nanosecond_value(arr: &Arc<dyn Array>, idx: usize) -> Option<NaiveDateTime> {
253-
arr.as_any()
254-
.downcast_ref::<TimestampNanosecondArray>()
255-
.unwrap()
256-
.value_as_datetime(idx)
257-
}
258-
259-
fn get_utf8_list_value(arr: &Arc<dyn Array>, idx: usize) -> Vec<Option<String>> {
260-
let list_arr = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
261-
list_arr
262-
.as_any()
263-
.downcast_ref::<StringArray>()
264-
.unwrap()
265-
.iter()
266-
.map(|opt| opt.map(|val| val.to_owned()))
267-
.collect()
268-
}
269-
270233
fn encode_value(
271234
encoder: &mut DataRowEncoder,
272235
arr: &Arc<dyn Array>,
@@ -307,46 +270,72 @@ fn encode_value(
307270
},
308271
DataType::Timestamp(unit, timezone) => match unit {
309272
TimeUnit::Second => {
310-
let value = get_timestamp_second_value(arr, idx);
311-
if timezone.is_some() {
312-
let value_tz = value.map(|datetime| datetime.and_utc());
313-
314-
encoder.encode_field(&value_tz)?;
273+
let ts_array = arr.as_any().downcast_ref::<TimestampSecondArray>().unwrap();
274+
if let Some(tz) = timezone {
275+
let tz = Tz::from_str(tz.as_ref())
276+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
277+
let value = ts_array
278+
.value_as_datetime_with_tz(idx, tz)
279+
.map(|d| d.fixed_offset());
280+
encoder.encode_field(&value)?;
315281
} else {
282+
let value = ts_array.value_as_datetime(idx);
316283
encoder.encode_field(&value)?
317284
}
318285
}
319286
TimeUnit::Millisecond => {
320-
let value = get_timestamp_millisecond_value(arr, idx);
321-
if timezone.is_some() {
322-
let value_tz = value.map(|datetime| datetime.and_utc());
323-
324-
encoder.encode_field(&value_tz)?;
287+
let ts_array = arr
288+
.as_any()
289+
.downcast_ref::<TimestampMillisecondArray>()
290+
.unwrap();
291+
if let Some(tz) = timezone {
292+
let tz = Tz::from_str(tz.as_ref())
293+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
294+
let value = ts_array
295+
.value_as_datetime_with_tz(idx, tz)
296+
.map(|d| d.fixed_offset());
297+
encoder.encode_field(&value)?;
325298
} else {
299+
let value = ts_array.value_as_datetime(idx);
326300
encoder.encode_field(&value)?
327301
}
328302
}
329303
TimeUnit::Microsecond => {
330-
let value = get_timestamp_microsecond_value(arr, idx);
331-
if timezone.is_some() {
332-
let value_tz = value.map(|datetime| datetime.and_utc());
333-
334-
encoder.encode_field(&value_tz)?;
304+
let ts_array = arr
305+
.as_any()
306+
.downcast_ref::<TimestampMicrosecondArray>()
307+
.unwrap();
308+
if let Some(tz) = timezone {
309+
let tz = Tz::from_str(tz.as_ref())
310+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
311+
let value = ts_array
312+
.value_as_datetime_with_tz(idx, tz)
313+
.map(|d| d.fixed_offset());
314+
encoder.encode_field(&value)?;
335315
} else {
316+
let value = ts_array.value_as_datetime(idx);
336317
encoder.encode_field(&value)?
337318
}
338319
}
339320
TimeUnit::Nanosecond => {
340-
let value = get_timestamp_nanosecond_value(arr, idx);
341-
if timezone.is_some() {
342-
let value_tz = value.map(|datetime| datetime.and_utc());
343-
344-
encoder.encode_field(&value_tz)?;
321+
let ts_array = arr
322+
.as_any()
323+
.downcast_ref::<TimestampNanosecondArray>()
324+
.unwrap();
325+
if let Some(tz) = timezone {
326+
let tz = Tz::from_str(tz.as_ref())
327+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
328+
let value = ts_array
329+
.value_as_datetime_with_tz(idx, tz)
330+
.map(|d| d.fixed_offset());
331+
encoder.encode_field(&value)?;
345332
} else {
333+
let value = ts_array.value_as_datetime(idx);
346334
encoder.encode_field(&value)?
347335
}
348336
}
349337
},
338+
350339
DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => {
351340
match field.data_type() {
352341
DataType::Null => encoder.encode_field(&None::<i8>)?,
@@ -361,7 +350,241 @@ fn encode_value(
361350
DataType::UInt64 => encoder.encode_field(&get_u64_list_value(arr, idx))?,
362351
DataType::Float32 => encoder.encode_field(&get_f32_list_value(arr, idx))?,
363352
DataType::Float64 => encoder.encode_field(&get_f64_list_value(arr, idx))?,
364-
DataType::Utf8 => encoder.encode_field(&get_utf8_list_value(arr, idx))?,
353+
DataType::Utf8 => {
354+
let list_arr = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
355+
let value: Vec<_> = list_arr
356+
.as_any()
357+
.downcast_ref::<StringArray>()
358+
.unwrap()
359+
.iter()
360+
.collect();
361+
encoder.encode_field(&value)?
362+
}
363+
DataType::Binary => {
364+
let list_arr = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
365+
let value: Vec<_> = list_arr
366+
.as_any()
367+
.downcast_ref::<BinaryArray>()
368+
.unwrap()
369+
.iter()
370+
.collect();
371+
encoder.encode_field(&value)?
372+
}
373+
DataType::LargeBinary => {
374+
let list_arr = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
375+
let value: Vec<_> = list_arr
376+
.as_any()
377+
.downcast_ref::<LargeBinaryArray>()
378+
.unwrap()
379+
.iter()
380+
.collect();
381+
encoder.encode_field(&value)?
382+
}
383+
384+
DataType::Date32 => {
385+
let list_arr = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
386+
let value: Vec<_> = list_arr
387+
.as_any()
388+
.downcast_ref::<Date32Array>()
389+
.unwrap()
390+
.iter()
391+
.collect();
392+
encoder.encode_field(&value)?
393+
}
394+
DataType::Date64 => {
395+
let list_arr = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
396+
let value: Vec<_> = list_arr
397+
.as_any()
398+
.downcast_ref::<Date64Array>()
399+
.unwrap()
400+
.iter()
401+
.collect();
402+
encoder.encode_field(&value)?
403+
}
404+
DataType::Time32(unit) => match unit {
405+
TimeUnit::Second => {
406+
let list_arr = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
407+
let value: Vec<_> = list_arr
408+
.as_any()
409+
.downcast_ref::<Time32SecondArray>()
410+
.unwrap()
411+
.iter()
412+
.collect();
413+
encoder.encode_field(&value)?
414+
}
415+
TimeUnit::Millisecond => {
416+
let list_arr = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
417+
let value: Vec<_> = list_arr
418+
.as_any()
419+
.downcast_ref::<Time32MillisecondArray>()
420+
.unwrap()
421+
.iter()
422+
.collect();
423+
encoder.encode_field(&value)?
424+
}
425+
_ => {}
426+
},
427+
DataType::Time64(unit) => match unit {
428+
TimeUnit::Microsecond => {
429+
let list_arr = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
430+
let value: Vec<_> = list_arr
431+
.as_any()
432+
.downcast_ref::<Time64MicrosecondArray>()
433+
.unwrap()
434+
.iter()
435+
.collect();
436+
encoder.encode_field(&value)?
437+
}
438+
TimeUnit::Nanosecond => {
439+
let list_arr = arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
440+
let value: Vec<_> = list_arr
441+
.as_any()
442+
.downcast_ref::<Time64NanosecondArray>()
443+
.unwrap()
444+
.iter()
445+
.collect();
446+
encoder.encode_field(&value)?
447+
}
448+
_ => {}
449+
},
450+
DataType::Timestamp(unit, timezone) => match unit {
451+
TimeUnit::Second => {
452+
let list_array =
453+
arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
454+
let array_iter = list_array
455+
.as_any()
456+
.downcast_ref::<TimestampSecondArray>()
457+
.unwrap()
458+
.iter();
459+
460+
if let Some(tz) = timezone {
461+
let tz = Tz::from_str(tz.as_ref())
462+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
463+
let value: Vec<_> = array_iter
464+
.map(|i| {
465+
i.and_then(|i| {
466+
DateTime::from_timestamp(i, 0).map(|dt| {
467+
Utc.from_utc_datetime(&dt.naive_utc())
468+
.with_timezone(&tz)
469+
.fixed_offset()
470+
})
471+
})
472+
})
473+
.collect();
474+
encoder.encode_field(&value)?;
475+
} else {
476+
let value: Vec<_> = array_iter
477+
.map(|i| {
478+
i.and_then(|i| {
479+
DateTime::from_timestamp(i, 0).map(|dt| dt.naive_utc())
480+
})
481+
})
482+
.collect();
483+
encoder.encode_field(&value)?
484+
}
485+
}
486+
TimeUnit::Millisecond => {
487+
let list_array =
488+
arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
489+
let array_iter = list_array
490+
.as_any()
491+
.downcast_ref::<TimestampMillisecondArray>()
492+
.unwrap()
493+
.iter();
494+
495+
if let Some(tz) = timezone {
496+
let tz = Tz::from_str(tz.as_ref())
497+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
498+
let value: Vec<_> = array_iter
499+
.map(|i| {
500+
i.and_then(|i| {
501+
DateTime::from_timestamp_millis(i).map(|dt| {
502+
Utc.from_utc_datetime(&dt.naive_utc())
503+
.with_timezone(&tz)
504+
.fixed_offset()
505+
})
506+
})
507+
})
508+
.collect();
509+
encoder.encode_field(&value)?;
510+
} else {
511+
let value: Vec<_> = array_iter
512+
.map(|i| {
513+
i.and_then(|i| {
514+
DateTime::from_timestamp_millis(i).map(|dt| dt.naive_utc())
515+
})
516+
})
517+
.collect();
518+
encoder.encode_field(&value)?
519+
}
520+
}
521+
TimeUnit::Microsecond => {
522+
let list_array =
523+
arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
524+
let array_iter = list_array
525+
.as_any()
526+
.downcast_ref::<TimestampMicrosecondArray>()
527+
.unwrap()
528+
.iter();
529+
530+
if let Some(tz) = timezone {
531+
let tz = Tz::from_str(tz.as_ref())
532+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
533+
let value: Vec<_> = array_iter
534+
.map(|i| {
535+
i.and_then(|i| {
536+
DateTime::from_timestamp_micros(i).map(|dt| {
537+
Utc.from_utc_datetime(&dt.naive_utc())
538+
.with_timezone(&tz)
539+
.fixed_offset()
540+
})
541+
})
542+
})
543+
.collect();
544+
encoder.encode_field(&value)?;
545+
} else {
546+
let value: Vec<_> = array_iter
547+
.map(|i| {
548+
i.and_then(|i| {
549+
DateTime::from_timestamp_micros(i).map(|dt| dt.naive_utc())
550+
})
551+
})
552+
.collect();
553+
encoder.encode_field(&value)?
554+
}
555+
}
556+
TimeUnit::Nanosecond => {
557+
let list_array =
558+
arr.as_any().downcast_ref::<ListArray>().unwrap().value(idx);
559+
let array_iter = list_array
560+
.as_any()
561+
.downcast_ref::<TimestampNanosecondArray>()
562+
.unwrap()
563+
.iter();
564+
565+
if let Some(tz) = timezone {
566+
let tz = Tz::from_str(tz.as_ref())
567+
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
568+
let value: Vec<_> = array_iter
569+
.map(|i| {
570+
i.map(|i| {
571+
Utc.from_utc_datetime(
572+
&DateTime::from_timestamp_nanos(i).naive_utc(),
573+
)
574+
.with_timezone(&tz)
575+
.fixed_offset()
576+
})
577+
})
578+
.collect();
579+
encoder.encode_field(&value)?;
580+
} else {
581+
let value: Vec<_> = array_iter
582+
.map(|i| i.map(|i| DateTime::from_timestamp_nanos(i).naive_utc()))
583+
.collect();
584+
encoder.encode_field(&value)?
585+
}
586+
}
587+
},
365588

366589
// TODO: more types
367590
list_type => {

0 commit comments

Comments
 (0)