Skip to content

Commit cf90ce9

Browse files
committed
refactor
1 parent d4de082 commit cf90ce9

File tree

8 files changed

+271
-316
lines changed

8 files changed

+271
-316
lines changed

src/database.rs

Lines changed: 142 additions & 164 deletions
Large diffs are not rendered by default.

src/dml.rs

Lines changed: 28 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ use tracing::{error, info};
1616

1717
use crate::database::Database;
1818

19+
/// Type alias for DML information extracted from logical plan
20+
type DmlInfo = (String, String, Option<Expr>, Option<Vec<(String, Expr)>>);
21+
1922
/// Custom query planner that intercepts DML operations
2023
pub struct DmlQueryPlanner {
2124
planner: DefaultPhysicalPlanner,
@@ -82,7 +85,7 @@ fn extract_dml_info(
8285
input: &LogicalPlan,
8386
table_name: &str,
8487
extract_assignments: bool,
85-
) -> Result<(String, String, Option<Expr>, Option<Vec<(String, Expr)>>)> {
88+
) -> Result<DmlInfo> {
8689
let mut current_plan = input;
8790
let mut predicate = None;
8891
let mut assignments = None;
@@ -100,21 +103,21 @@ fn extract_dml_info(
100103
current_plan = filter.input.as_ref();
101104
}
102105
LogicalPlan::TableScan(scan) => {
103-
if !scan.filters.is_empty() {
104-
project_id = scan.filters.iter()
105-
.find_map(extract_project_id)
106-
.unwrap_or(project_id);
107-
108-
let combined = scan.filters.iter()
109-
.cloned()
110-
.reduce(|acc, filter| Expr::BinaryExpr(BinaryExpr {
111-
left: Box::new(acc),
112-
op: Operator::And,
113-
right: Box::new(filter),
114-
}));
115-
116-
predicate = predicate.or(combined);
117-
}
106+
project_id = scan.filters.iter()
107+
.find_map(extract_project_id)
108+
.unwrap_or(project_id);
109+
110+
predicate = predicate.or_else(|| {
111+
(!scan.filters.is_empty()).then(|| {
112+
scan.filters.iter()
113+
.cloned()
114+
.reduce(|acc, filter| Expr::BinaryExpr(BinaryExpr {
115+
left: Box::new(acc),
116+
op: Operator::And,
117+
right: Box::new(filter),
118+
}))
119+
}).flatten()
120+
});
118121
break;
119122
}
120123
_ => match current_plan.inputs().first() {
@@ -141,13 +144,12 @@ fn extract_assignments_from_projection(proj: &datafusion::logical_expr::Projecti
141144
.filter_map(|(expr, field)| {
142145
let field_name = field.name();
143146
match expr {
144-
Expr::Column(col) if &col.name == field_name => None,
145-
Expr::Alias(alias) if &alias.name == field_name => match &*alias.expr {
146-
Expr::Column(col) if &col.name == field_name => None,
147-
_ => Some((field_name.clone(), (*alias.expr).clone())),
148-
},
149-
_ if !matches!(expr, Expr::Column(_)) => Some((field_name.clone(), expr.clone())),
150-
_ => None,
147+
Expr::Column(col) if col.name == *field_name => None,
148+
Expr::Alias(alias) if alias.name == *field_name =>
149+
(!matches!(&*alias.expr, Expr::Column(col) if col.name == *field_name))
150+
.then(|| (field_name.clone(), (*alias.expr).clone())),
151+
Expr::Column(_) => None,
152+
_ => Some((field_name.clone(), expr.clone())),
151153
}
152154
})
153155
.collect())
@@ -156,13 +158,12 @@ fn extract_assignments_from_projection(proj: &datafusion::logical_expr::Projecti
156158
/// Extract project_id from filter expression
157159
fn extract_project_id(expr: &Expr) -> Option<String> {
158160
match expr {
159-
Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right }) => {
161+
Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right }) =>
160162
match (left.as_ref(), right.as_ref()) {
161163
(Expr::Column(col), Expr::Literal(val, _)) | (Expr::Literal(val, _), Expr::Column(col))
162164
if col.name == "project_id" => Some(val.to_string()),
163165
_ => None,
164-
}
165-
}
166+
},
166167
Expr::BinaryExpr(BinaryExpr { left, op: Operator::And, right }) =>
167168
extract_project_id(left).or_else(|| extract_project_id(right)),
168169
_ => None,
@@ -421,7 +422,7 @@ fn convert_expr_to_delta(expr: &Expr) -> Result<Expr> {
421422
Expr::Column(col) => Ok(Expr::Column(Column::from_name(&col.name))),
422423
Expr::BinaryExpr(binary) => Ok(Expr::BinaryExpr(BinaryExpr {
423424
left: Box::new(convert_expr_to_delta(&binary.left)?),
424-
op: binary.op.clone(),
425+
op: binary.op,
425426
right: Box::new(convert_expr_to_delta(&binary.right)?),
426427
})),
427428
_ => Ok(expr.clone()),

src/functions.rs

Lines changed: 55 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -91,42 +91,38 @@ fn create_to_char_udf() -> ScalarUDF {
9191

9292
/// Format timestamps according to PostgreSQL format patterns
9393
fn format_timestamps(timestamp_array: &ArrayRef, format_str: &str) -> datafusion::error::Result<ArrayRef> {
94-
// Try to handle both microsecond and nanosecond timestamps
94+
let chrono_format = postgres_to_chrono_format(format_str);
9595
let mut builder = StringBuilder::new();
9696

97-
if let Some(timestamps) = timestamp_array.as_any().downcast_ref::<TimestampMicrosecondArray>() {
98-
for i in 0..timestamps.len() {
99-
if timestamps.is_null(i) {
100-
builder.append_null();
101-
} else {
102-
let timestamp_us = timestamps.value(i);
103-
let datetime =
104-
DateTime::<Utc>::from_timestamp_micros(timestamp_us).ok_or_else(|| DataFusionError::Execution("Invalid timestamp".to_string()))?;
105-
106-
// Convert PostgreSQL format to chrono format
107-
let chrono_format = postgres_to_chrono_format(format_str);
108-
let formatted = datetime.format(&chrono_format).to_string();
97+
let format_fn = |timestamp_us: i64| -> datafusion::error::Result<String> {
98+
DateTime::<Utc>::from_timestamp_micros(timestamp_us)
99+
.ok_or_else(|| DataFusionError::Execution("Invalid timestamp".to_string()))
100+
.map(|dt| dt.format(&chrono_format).to_string())
101+
};
109102

110-
builder.append_value(&formatted);
103+
match timestamp_array.as_any().downcast_ref::<TimestampMicrosecondArray>() {
104+
Some(timestamps) => {
105+
for i in 0..timestamps.len() {
106+
if timestamps.is_null(i) {
107+
builder.append_null();
108+
} else {
109+
builder.append_value(&format_fn(timestamps.value(i))?);
110+
}
111111
}
112112
}
113-
} else if let Some(timestamps) = timestamp_array.as_any().downcast_ref::<TimestampNanosecondArray>() {
114-
for i in 0..timestamps.len() {
115-
if timestamps.is_null(i) {
116-
builder.append_null();
117-
} else {
118-
let timestamp_ns = timestamps.value(i);
119-
let datetime = DateTime::<Utc>::from_timestamp_nanos(timestamp_ns);
120-
121-
// Convert PostgreSQL format to chrono format
122-
let chrono_format = postgres_to_chrono_format(format_str);
123-
let formatted = datetime.format(&chrono_format).to_string();
124-
125-
builder.append_value(&formatted);
113+
None => match timestamp_array.as_any().downcast_ref::<TimestampNanosecondArray>() {
114+
Some(timestamps) => {
115+
for i in 0..timestamps.len() {
116+
if timestamps.is_null(i) {
117+
builder.append_null();
118+
} else {
119+
let timestamp_us = timestamps.value(i) / 1000; // Convert nanos to micros
120+
builder.append_value(&format_fn(timestamp_us)?);
121+
}
122+
}
126123
}
124+
None => return Err(DataFusionError::Execution("First argument must be a timestamp".to_string())),
127125
}
128-
} else {
129-
return Err(DataFusionError::Execution("First argument must be a timestamp".to_string()));
130126
}
131127

132128
Ok(Arc::new(builder.finish()))
@@ -610,50 +606,41 @@ fn create_time_bucket_udf() -> ScalarUDF {
610606
/// Parse interval string to microseconds
611607
fn parse_interval_to_micros(interval_str: &str) -> datafusion::error::Result<i64> {
612608
let trimmed = interval_str.trim();
613-
614-
// Try to parse with whitespace first (e.g., "30 minutes")
615609
let parts: Vec<&str> = trimmed.split_whitespace().collect();
616610

617-
let (value, unit) = if parts.len() == 2 {
618-
// Format: "30 minutes"
619-
let value = parts[0].parse::<i64>()
620-
.map_err(|_| DataFusionError::Execution("Invalid interval value".to_string()))?;
621-
(value, parts[1].to_lowercase())
622-
} else if parts.len() == 1 {
623-
// Try to parse format without space (e.g., "30m")
624-
let part = parts[0];
625-
626-
// Find where the number ends and the unit begins
627-
let split_pos = part.chars()
628-
.position(|c| c.is_alphabetic())
629-
.ok_or_else(|| DataFusionError::Execution(
630-
"Invalid interval format. Expected format: 'N unit' (e.g., '5 minutes' or '5m')".to_string()
631-
))?;
632-
633-
let (num_str, unit_str) = part.split_at(split_pos);
634-
635-
let value = num_str.parse::<i64>()
636-
.map_err(|_| DataFusionError::Execution("Invalid interval value".to_string()))?;
611+
let (value, unit) = match parts.as_slice() {
612+
[value_str, unit_str] => {
613+
let value = value_str.parse::<i64>()
614+
.map_err(|_| DataFusionError::Execution("Invalid interval value".to_string()))?;
615+
(value, unit_str.to_lowercase())
616+
}
617+
[combined] => {
618+
let split_pos = combined.chars()
619+
.position(|c| c.is_alphabetic())
620+
.ok_or_else(|| DataFusionError::Execution(
621+
"Invalid interval format. Expected format: 'N unit' (e.g., '5 minutes' or '5m')".to_string()
622+
))?;
637623

638-
(value, unit_str.to_lowercase())
639-
} else {
640-
return Err(DataFusionError::Execution(
624+
let (num_str, unit_str) = combined.split_at(split_pos);
625+
let value = num_str.parse::<i64>()
626+
.map_err(|_| DataFusionError::Execution("Invalid interval value".to_string()))?;
627+
(value, unit_str.to_lowercase())
628+
}
629+
_ => return Err(DataFusionError::Execution(
641630
"Invalid interval format. Expected format: 'N unit' (e.g., '5 minutes' or '5m')".to_string(),
642-
));
631+
)),
643632
};
644633

645634
let micros_per_unit = match unit.as_str() {
646635
"second" | "seconds" | "sec" | "secs" | "s" => 1_000_000,
647-
"minute" | "minutes" | "min" | "mins" | "m" => 60 * 1_000_000,
648-
"hour" | "hours" | "hr" | "hrs" | "h" => 3600 * 1_000_000,
649-
"day" | "days" | "d" => 86400 * 1_000_000,
650-
"week" | "weeks" | "w" => 7 * 86400 * 1_000_000,
651-
_ => {
652-
return Err(DataFusionError::Execution(format!(
653-
"Unsupported time unit: {}. Supported units: second(s), minute(s), hour(s), day(s), week(s)",
654-
unit
655-
)));
656-
}
636+
"minute" | "minutes" | "min" | "mins" | "m" => 60_000_000,
637+
"hour" | "hours" | "hr" | "hrs" | "h" => 3_600_000_000,
638+
"day" | "days" | "d" => 86_400_000_000,
639+
"week" | "weeks" | "w" => 604_800_000_000,
640+
_ => return Err(DataFusionError::Execution(format!(
641+
"Unsupported time unit: {}. Supported units: second(s), minute(s), hour(s), day(s), week(s)",
642+
unit
643+
))),
657644
};
658645

659646
Ok(value * micros_per_unit)
@@ -819,7 +806,7 @@ impl Accumulator for PercentileAccumulator {
819806
if !binary_array.is_null(i) {
820807
let bytes = binary_array.value(i);
821808
let other_digest = TDigestWrapper::from_bytes(bytes)
822-
.map_err(|e| DataFusionError::Execution(e))?;
809+
.map_err(DataFusionError::Execution)?;
823810

824811
self.digest.merge(&other_digest);
825812
}
@@ -913,15 +900,15 @@ impl ScalarUDFImpl for ApproxPercentileUDF {
913900
let percentile = percentile_values.value(i);
914901

915902
// Validate percentile is between 0 and 1
916-
if percentile < 0.0 || percentile > 1.0 {
903+
if !(0.0..=1.0).contains(&percentile) {
917904
return Err(DataFusionError::Execution(
918905
format!("Percentile must be between 0 and 1, got {}", percentile),
919906
));
920907
}
921908

922909
let digest_bytes = digest_values.value(i);
923910
let wrapper = TDigestWrapper::from_bytes(digest_bytes)
924-
.map_err(|e| DataFusionError::Execution(e))?;
911+
.map_err(DataFusionError::Execution)?;
925912

926913
match wrapper.to_digest() {
927914
Some(digest) => {

src/optimizers.rs

Lines changed: 24 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -11,29 +11,28 @@ pub mod time_range_partition_pruner {
1111
match expr {
1212
Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
1313
// Check if this is a timestamp comparison
14-
if let (Expr::Column(col), Expr::Literal(ScalarValue::TimestampNanosecond(Some(ts), _tz), _)) = (left.as_ref(), right.as_ref()) {
15-
if col.name == "timestamp" {
16-
// Convert timestamp to date for partition filter
17-
let datetime = chrono::DateTime::from_timestamp_nanos(*ts);
18-
let date = datetime.date_naive();
14+
if let (Expr::Column(col), Expr::Literal(ScalarValue::TimestampNanosecond(Some(ts), _tz), _)) = (left.as_ref(), right.as_ref())
15+
&& col.name == "timestamp" {
16+
// Convert timestamp to date for partition filter
17+
let datetime = chrono::DateTime::from_timestamp_nanos(*ts);
18+
let date = datetime.date_naive();
1919

20-
let date_scalar = ScalarValue::Date32(Some(date.and_hms_opt(0, 0, 0).unwrap().and_utc().timestamp() as i32 / 86400));
20+
let date_scalar = ScalarValue::Date32(Some(date.and_hms_opt(0, 0, 0).unwrap().and_utc().timestamp() as i32 / 86400));
2121

22-
// Create corresponding date filter
23-
let date_col = Expr::Column(datafusion::common::Column::new_unqualified("date"));
24-
let date_filter = match op {
25-
Operator::Gt | Operator::GtEq => {
26-
Expr::BinaryExpr(BinaryExpr::new(Box::new(date_col), *op, Box::new(Expr::Literal(date_scalar, None))))
27-
}
28-
Operator::Lt | Operator::LtEq => {
29-
Expr::BinaryExpr(BinaryExpr::new(Box::new(date_col), *op, Box::new(Expr::Literal(date_scalar, None))))
30-
}
31-
Operator::Eq => Expr::BinaryExpr(BinaryExpr::new(Box::new(date_col), Operator::Eq, Box::new(Expr::Literal(date_scalar, None)))),
32-
_ => return None,
33-
};
22+
// Create corresponding date filter
23+
let date_col = Expr::Column(datafusion::common::Column::new_unqualified("date"));
24+
let date_filter = match op {
25+
Operator::Gt | Operator::GtEq => {
26+
Expr::BinaryExpr(BinaryExpr::new(Box::new(date_col), *op, Box::new(Expr::Literal(date_scalar, None))))
27+
}
28+
Operator::Lt | Operator::LtEq => {
29+
Expr::BinaryExpr(BinaryExpr::new(Box::new(date_col), *op, Box::new(Expr::Literal(date_scalar, None))))
30+
}
31+
Operator::Eq => Expr::BinaryExpr(BinaryExpr::new(Box::new(date_col), Operator::Eq, Box::new(Expr::Literal(date_scalar, None)))),
32+
_ => return None,
33+
};
3434

35-
return Some(date_filter);
36-
}
35+
return Some(date_filter);
3736
}
3837
None
3938
}
@@ -43,7 +42,7 @@ pub mod time_range_partition_pruner {
4342
}
4443

4544
/// Utilities for checking project_id filters
46-
pub struct ProjectIdPushdown {}
45+
pub struct ProjectIdPushdown;
4746

4847
impl ProjectIdPushdown {
4948
pub fn has_project_id_filter(filters: &[Expr]) -> bool {
@@ -52,18 +51,14 @@ impl ProjectIdPushdown {
5251

5352
pub fn contains_project_id(expr: &Expr) -> bool {
5453
match expr {
55-
Expr::BinaryExpr(BinaryExpr { left, op, right }) if *op == Operator::Eq => {
54+
Expr::BinaryExpr(BinaryExpr { left, op: Operator::Eq, right }) =>
5655
matches!(
5756
(left.as_ref(), right.as_ref()),
5857
(Expr::Column(col), Expr::Literal(_, _)) | (Expr::Literal(_, _), Expr::Column(col))
5958
if col.name == "project_id"
60-
)
61-
}
62-
Expr::BinaryExpr(BinaryExpr {
63-
left,
64-
op: Operator::And,
65-
right,
66-
}) => Self::contains_project_id(left) || Self::contains_project_id(right),
59+
),
60+
Expr::BinaryExpr(BinaryExpr { left, op: Operator::And, right }) =>
61+
Self::contains_project_id(left) || Self::contains_project_id(right),
6762
_ => false,
6863
}
6964
}

src/pgwire_handlers.rs

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,12 @@ impl SimpleQueryHandler for LoggingSimpleQueryHandler {
100100
{
101101
// Log UPDATE and DELETE queries
102102
let query_lower = query.trim().to_lowercase();
103-
if query_lower.starts_with("update") || query_lower.contains(" update ") {
104-
info!("UPDATE query executed: {}", query);
105-
} else if query_lower.starts_with("delete") || query_lower.contains(" delete ") {
106-
info!("DELETE query executed: {}", query);
103+
let is_dml = ["update", "delete"].iter()
104+
.any(|&cmd| query_lower.starts_with(cmd) || query_lower.contains(&format!(" {} ", cmd)));
105+
106+
if is_dml {
107+
let cmd_type = if query_lower.contains("update") { "UPDATE" } else { "DELETE" };
108+
info!("{} query executed: {}", cmd_type, query);
107109
}
108110

109111
// Delegate to inner handler
@@ -174,14 +176,14 @@ impl ExtendedQueryHandler for LoggingExtendedQueryHandler {
174176
PgWireError: From<<C as Sink<PgWireBackendMessage>>::Error>,
175177
{
176178
// Log UPDATE and DELETE queries being executed
177-
// portal.statement is an Arc<StoredStatement>, not Option
178-
let statement = &portal.statement;
179-
let query = &statement.statement.0;
179+
let query = &portal.statement.statement.0;
180180
let query_lower = query.trim().to_lowercase();
181-
if query_lower.starts_with("update") || query_lower.contains(" update ") {
182-
info!("UPDATE query executed (extended): {}", query);
183-
} else if query_lower.starts_with("delete") || query_lower.contains(" delete ") {
184-
info!("DELETE query executed (extended): {}", query);
181+
let is_dml = ["update", "delete"].iter()
182+
.any(|&cmd| query_lower.starts_with(cmd) || query_lower.contains(&format!(" {} ", cmd)));
183+
184+
if is_dml {
185+
let cmd_type = if query_lower.contains("update") { "UPDATE" } else { "DELETE" };
186+
info!("{} query executed (extended): {}", cmd_type, query);
185187
}
186188

187189
<DfSessionService as ExtendedQueryHandler>::do_query(&self.inner, client, portal, max_rows).await

0 commit comments

Comments
 (0)