Skip to content
Closed
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 29 additions & 2 deletions arrow-pg/src/datatypes/df.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,12 @@ where
let value = portal.parameter::<i64>(i, &pg_type)?;
deserialized_params.push(ScalarValue::Int64(value));
}
Type::TEXT | Type::VARCHAR => {
Type::TEXT => {
let value = portal.parameter::<String>(i, &pg_type)?;
// Use Utf8View for TEXT type to match DataFusion's internal schema expectations
deserialized_params.push(ScalarValue::Utf8View(value));
}
Type::VARCHAR => {
let value = portal.parameter::<String>(i, &pg_type)?;
deserialized_params.push(ScalarValue::Utf8(value));
}
Expand Down Expand Up @@ -236,7 +241,17 @@ where
&DataType::Float64,
)));
}
Type::TEXT_ARRAY | Type::VARCHAR_ARRAY => {
Type::TEXT_ARRAY => {
let value = portal.parameter::<Vec<Option<String>>>(i, &pg_type)?;
let scalar_values: Vec<ScalarValue> = value.map_or(Vec::new(), |v| {
v.into_iter().map(ScalarValue::Utf8View).collect()
});
deserialized_params.push(ScalarValue::List(ScalarValue::new_list_nullable(
&scalar_values,
&DataType::Utf8View,
)));
}
Type::VARCHAR_ARRAY => {
let value = portal.parameter::<Vec<Option<String>>>(i, &pg_type)?;
let scalar_values: Vec<ScalarValue> = value.map_or(Vec::new(), |v| {
v.into_iter().map(ScalarValue::Utf8).collect()
Expand All @@ -262,6 +277,18 @@ where
// Store MAC addresses as strings for now
deserialized_params.push(ScalarValue::Utf8(value));
}
Type::UNKNOWN => {
// For unknown types, try to deserialize as integer first, then fallback to text
// This handles cases like NULL arithmetic where DataFusion can't infer types
match portal.parameter::<i32>(i, &Type::INT4) {
Ok(value) => deserialized_params.push(ScalarValue::Int32(value)),
Err(_) => {
// Fallback to text if integer parsing fails
let value = portal.parameter::<String>(i, &Type::TEXT)?;
deserialized_params.push(ScalarValue::Utf8View(value));
}
}
}
// TODO: add more advanced types (composite types, ranges, etc.)
_ => {
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
Expand Down
1 change: 1 addition & 0 deletions datafusion-postgres/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ futures.workspace = true
getset = "0.1"
log = "0.4"
pgwire = { workspace = true, features = ["server-api-ring", "scram"] }
regex = "1"
postgres-types.workspace = true
rust_decimal.workspace = true
tokio = { version = "1.47", features = ["sync", "net"] }
Expand Down
150 changes: 137 additions & 13 deletions datafusion-postgres/src/handlers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,24 @@ impl SimpleQueryHandler for DfSessionService {
}
}

let df_result = self.session_context.sql(query).await;
// Add query timeout for simple queries
let query_timeout = std::time::Duration::from_secs(60); // 60 seconds
let df_result =
match tokio::time::timeout(query_timeout, self.session_context.sql(query)).await {
Ok(result) => result,
Err(_) => {
return Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"57014".to_string(), // PostgreSQL query_canceled error code
format!(
"Query execution timeout after {} seconds",
query_timeout.as_secs()
),
),
)));
}
};

// Handle query execution errors and transaction state
let df = match df_result {
Expand Down Expand Up @@ -509,19 +526,105 @@ impl ExtendedQueryHandler for DfSessionService {

let (_, plan) = &portal.statement.statement;

let param_types = plan
.get_parameter_types()
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
// Enhanced parameter type inference with fallback for NULL + NULL scenarios
let param_types = match plan.get_parameter_types() {
Ok(types) => types,
Err(e) => {
let error_msg = e.to_string();
if error_msg.contains("Cannot get result type for arithmetic operation Null + Null")
|| error_msg.contains("Invalid arithmetic operation: Null + Null")
{
// Fallback: assume all parameters are integers for arithmetic operations
log::warn!("DataFusion type inference failed for arithmetic operation, using integer fallback");
let param_count = portal.statement.parameter_types.len();
std::collections::HashMap::from_iter((0..param_count).map(|i| {
(
format!("${}", i + 1),
Some(datafusion::arrow::datatypes::DataType::Int32),
)
}))
} else {
return Err(PgWireError::ApiError(Box::new(e)));
}
}
};
let param_values = df::deserialize_parameters(portal, &ordered_param_types(&param_types))?; // Fixed: Use &param_types
let plan = plan
.clone()
.replace_params_with_values(&param_values)
.map_err(|e| PgWireError::ApiError(Box::new(e)))?; // Fixed: Use &param_values
let dataframe = self
.session_context
.execute_logical_plan(plan)
.await
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;

// Replace parameters with values, with automatic retry for type inference failures
let plan = match plan.clone().replace_params_with_values(&param_values) {
Ok(plan) => plan,
Err(e) => {
let error_msg = e.to_string();
if error_msg.contains("Cannot get result type for arithmetic operation Null + Null")
|| error_msg.contains("Invalid arithmetic operation: Null + Null")
{
log::info!(
"Retrying query with enhanced type casting for arithmetic operations"
);

// Attempt to reparse the query with explicit type casting
let original_query = &portal.statement.statement.0;
let enhanced_query = enhance_query_with_type_casting(original_query);

// Try to create a new plan with the enhanced query
match self.session_context.sql(&enhanced_query).await {
Ok(new_plan_df) => {
// Get the logical plan from the new dataframe
let new_plan = new_plan_df.logical_plan().clone();

// Try parameter substitution again with the new plan
match new_plan.replace_params_with_values(&param_values) {
Ok(final_plan) => final_plan,
Err(_) => {
// If it still fails, return helpful error message
return Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"42804".to_string(),
"Cannot infer parameter types for arithmetic operation. Please use explicit type casting like $1::integer + $2::integer".to_string(),
),
)));
}
}
}
Err(_) => {
// If enhanced query fails, return helpful error message
return Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"42804".to_string(),
"Cannot infer parameter types for arithmetic operation. Please use explicit type casting like $1::integer + $2::integer".to_string(),
),
)));
}
}
} else {
return Err(PgWireError::ApiError(Box::new(e)));
}
}
};
// Add query timeout to prevent long-running queries from hanging connections
let query_timeout = std::time::Duration::from_secs(60); // 60 seconds
let dataframe = match tokio::time::timeout(
query_timeout,
self.session_context.execute_logical_plan(plan),
)
.await
{
Ok(result) => result.map_err(|e| PgWireError::ApiError(Box::new(e)))?,
Err(_) => {
return Err(PgWireError::UserError(Box::new(
pgwire::error::ErrorInfo::new(
"ERROR".to_string(),
"57014".to_string(), // PostgreSQL query_canceled error code
format!(
"Query execution timeout after {} seconds",
query_timeout.as_secs()
),
),
)));
}
};
let resp = df::encode_dataframe(dataframe, &portal.result_column_format).await?;
Ok(Response::Query(resp))
}
Expand Down Expand Up @@ -555,6 +658,27 @@ impl QueryParser for Parser {
}
}

/// Enhance a SQL query by adding type casting to parameters in arithmetic operations
/// This helps DataFusion's type inference when it encounters ambiguous parameter types
fn enhance_query_with_type_casting(query: &str) -> String {
use regex::Regex;

// Pattern to match arithmetic operations with parameters: $1 + $2, $1 - $2, etc.
let arithmetic_pattern = Regex::new(r"\$(\d+)\s*([+\-*/])\s*\$(\d+)").unwrap();

// Replace untyped parameters in arithmetic operations with integer-cast parameters
let enhanced = arithmetic_pattern.replace_all(query, "$$$1::integer $2 $$$3::integer");

// Pattern to match single parameters followed by arithmetic operators (avoiding lookaround)
let single_param_pattern = Regex::new(r"\$(\d+)(\s*[+\-*/=<>])").unwrap();

// Add integer casting to remaining untyped parameters in arithmetic contexts
let enhanced = single_param_pattern.replace_all(&enhanced, "$$$1::integer$2");

log::debug!("Enhanced query: {} -> {}", query, enhanced);
enhanced.to_string()
}

fn ordered_param_types(types: &HashMap<String, Option<DataType>>) -> Vec<Option<&DataType>> {
// Datafusion stores the parameters as a map. In our case, the keys will be
// `$1`, `$2` etc. The values will be the parameter types.
Expand Down
21 changes: 19 additions & 2 deletions datafusion-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -133,8 +133,25 @@ pub async fn serve_with_handlers(
let tls_acceptor_ref = tls_acceptor.clone();

tokio::spawn(async move {
if let Err(e) = process_socket(socket, tls_acceptor_ref, factory_ref).await {
warn!("Error processing socket: {e}");
// Add connection timeout to prevent hanging connections
let timeout_duration = std::time::Duration::from_secs(300); // 5 minutes
match tokio::time::timeout(
timeout_duration,
process_socket(socket, tls_acceptor_ref, factory_ref),
)
.await
{
Ok(result) => {
if let Err(e) = result {
warn!("Error processing socket: {e}");
}
}
Err(_) => {
warn!(
"Connection timed out after {} seconds",
timeout_duration.as_secs()
);
}
}
});
}
Expand Down
92 changes: 92 additions & 0 deletions datafusion-postgres/src/pg_catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -936,6 +936,40 @@ impl PgCatalogSchemaProvider {
None,
);

// current_setting(setting_name) function
data.add_function(
2077,
"current_setting",
11,
10,
12,
1.0,
0.0,
0,
0,
"f",
false,
true,
true,
false,
"s",
"s",
1,
0,
25, // returns TEXT
"25", // takes TEXT parameter
None,
None,
None,
None,
None,
"current_setting",
None,
None,
None,
None,
);

data
}
}
Expand Down Expand Up @@ -2020,6 +2054,63 @@ pub fn create_format_type_udf() -> ScalarUDF {
)
}

pub fn create_current_setting_udf() -> ScalarUDF {
// Define the function implementation for current_setting(setting_name)
let func = move |args: &[ColumnarValue]| {
let args = ColumnarValue::values_to_arrays(args)?;
let setting_names = &args[0];

// Handle different setting name requests with reasonable defaults
let mut builder = StringBuilder::new();

for i in 0..setting_names.len() {
let setting_name = setting_names
.as_any()
.downcast_ref::<StringArray>()
.ok_or_else(|| DataFusionError::Internal("Expected string array".to_string()))?
.value(i);

// Provide reasonable defaults for common PostgreSQL settings
let value = match setting_name.to_lowercase().as_str() {
"server_version" => "16.0", // Match modern PostgreSQL version
"server_version_num" => "160000",
"client_encoding" => "UTF8",
"timezone" => "UTC",
"datestyle" => "ISO, MDY",
"default_transaction_isolation" => "read committed",
"application_name" => "datafusion-postgres",
"session_authorization" => "postgres",
"is_superuser" => "on",
"integer_datetimes" => "on",
"search_path" => "\"$user\", public",
"standard_conforming_strings" => "on",
"synchronous_commit" => "on",
"wal_level" => "replica",
"max_connections" => "100",
"shared_preload_libraries" => "",
"log_statement" => "none",
"log_min_messages" => "warning",
"default_text_search_config" => "pg_catalog.english",
_ => "", // Return empty string for unknown settings
};

builder.append_value(value);
}

let array: ArrayRef = Arc::new(builder.finish());
Ok(ColumnarValue::Array(array))
};

// Wrap the implementation in a scalar function
create_udf(
"current_setting",
vec![DataType::Utf8],
DataType::Utf8,
Volatility::Stable,
Arc::new(func),
)
}

/// Install pg_catalog and postgres UDFs to current `SessionContext`
pub fn setup_pg_catalog(
session_context: &SessionContext,
Expand All @@ -2042,6 +2133,7 @@ pub fn setup_pg_catalog(
session_context.register_udf(create_has_table_privilege_2param_udf());
session_context.register_udf(create_pg_table_is_visible());
session_context.register_udf(create_format_type_udf());
session_context.register_udf(create_current_setting_udf());

Ok(())
}
Loading