Skip to content

Commit e2af36f

Browse files
authored
Add support for querying parameter names in prepared statements (#609)
Fixes #605
2 parents 1a1c690 + 7dfe827 commit e2af36f

File tree

3 files changed

+130
-0
lines changed

3 files changed

+130
-0
lines changed

crates/duckdb/src/error.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,16 @@ pub enum Error {
8080

8181
/// Error when the SQL contains multiple statements.
8282
MultipleStatement,
83+
8384
/// Error when the number of bound parameters does not match the number of
8485
/// parameters in the query. The first `usize` is how many parameters were
8586
/// given, the 2nd is how many were expected.
8687
InvalidParameterCount(usize, usize),
8788

89+
/// Error when a parameter is requested, but the index is out of range
90+
/// for the statement.
91+
InvalidParameterIndex(usize),
92+
8893
/// Append Error
8994
AppendError,
9095
}
@@ -108,6 +113,7 @@ impl PartialEq for Error {
108113
}
109114
(Self::StatementChangedRows(n1), Self::StatementChangedRows(n2)) => n1 == n2,
110115
(Self::InvalidParameterCount(i1, n1), Self::InvalidParameterCount(i2, n2)) => i1 == i2 && n1 == n2,
116+
(Self::InvalidParameterIndex(i1), Self::InvalidParameterIndex(i2)) => i1 == i2,
111117
(..) => false,
112118
}
113119
}
@@ -187,6 +193,7 @@ impl fmt::Display for Error {
187193
Self::InvalidParameterCount(i1, n1) => {
188194
write!(f, "Wrong number of parameters passed to query. Got {i1}, needed {n1}")
189195
}
196+
Self::InvalidParameterIndex(i) => write!(f, "Invalid parameter index: {i}"),
190197
Self::StatementChangedRows(i) => write!(f, "Query changed {i} rows"),
191198
Self::ToSqlConversionFailure(ref err) => err.fmt(f),
192199
Self::InvalidQuery => write!(f, "Query is not read-only"),
@@ -213,6 +220,7 @@ impl error::Error for Error {
213220
| Self::InvalidColumnType(..)
214221
| Self::InvalidPath(_)
215222
| Self::InvalidParameterCount(..)
223+
| Self::InvalidParameterIndex(_)
216224
| Self::StatementChangedRows(_)
217225
| Self::InvalidQuery
218226
| Self::AppendError

crates/duckdb/src/raw_statement.rs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -316,6 +316,30 @@ impl RawStatement {
316316
unsafe { ffi::duckdb_nparams(self.ptr) as usize }
317317
}
318318

319+
pub fn parameter_name(&self, idx: usize) -> Result<String> {
320+
let count = self.bind_parameter_count();
321+
if idx == 0 || idx > count {
322+
return Err(Error::InvalidParameterIndex(idx));
323+
}
324+
325+
unsafe {
326+
let name_ptr = ffi::duckdb_parameter_name(self.ptr, idx as u64);
327+
// Range check above ensures this shouldn't be null, but check defensively
328+
if name_ptr.is_null() {
329+
return Err(Error::DuckDBFailure(
330+
ffi::Error::new(ffi::DuckDBError),
331+
Some(format!("Could not retrieve parameter name for index {idx}")),
332+
));
333+
}
334+
335+
let name = CStr::from_ptr(name_ptr).to_string_lossy().to_string();
336+
337+
ffi::duckdb_free(name_ptr as *mut std::ffi::c_void);
338+
339+
Ok(name)
340+
}
341+
}
342+
319343
#[inline]
320344
pub fn sql(&self) -> Option<&CStr> {
321345
panic!("not supported")

crates/duckdb/src/statement.rs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,39 @@ impl Statement<'_> {
437437
self.stmt.bind_parameter_count()
438438
}
439439

440+
/// Returns the name of the parameter at the given index.
441+
///
442+
/// This can be used to query the names of named parameters (e.g., `$param_name`)
443+
/// in a prepared statement.
444+
///
445+
/// # Arguments
446+
///
447+
/// * `one_based_col_index` - One-based parameter index (1 to [`Statement::parameter_count`])
448+
///
449+
/// # Returns
450+
///
451+
/// * `Ok(String)` - The parameter name (without the `$` prefix for named params, or the numeric index for positional params)
452+
/// * `Err(InvalidParameterIndex)` - If the index is out of range
453+
///
454+
/// # Example
455+
///
456+
/// ```rust,no_run
457+
/// # use duckdb::{Connection, Result};
458+
/// fn query_parameter_names(conn: &Connection) -> Result<()> {
459+
/// let stmt = conn.prepare("SELECT $foo, $bar")?;
460+
///
461+
/// assert_eq!(stmt.parameter_count(), 2);
462+
/// assert_eq!(stmt.parameter_name(1)?, "foo");
463+
/// assert_eq!(stmt.parameter_name(2)?, "bar");
464+
///
465+
/// Ok(())
466+
/// }
467+
/// ```
468+
#[inline]
469+
pub fn parameter_name(&self, idx: usize) -> Result<String> {
470+
self.stmt.parameter_name(idx)
471+
}
472+
440473
/// Low level API to directly bind a parameter to a given index.
441474
///
442475
/// Note that the index is one-based, that is, the first parameter index is
@@ -1120,4 +1153,69 @@ mod test {
11201153
assert_eq!(expected, actual);
11211154
Ok(())
11221155
}
1156+
1157+
#[test]
1158+
fn test_parameter_name() -> Result<()> {
1159+
let db = Connection::open_in_memory()?;
1160+
1161+
{
1162+
let stmt = db.prepare("SELECT $foo, $bar")?;
1163+
1164+
assert_eq!(stmt.parameter_count(), 2);
1165+
assert_eq!(stmt.parameter_name(1)?, "foo");
1166+
assert_eq!(stmt.parameter_name(2)?, "bar");
1167+
1168+
assert!(matches!(stmt.parameter_name(0), Err(Error::InvalidParameterIndex(0))));
1169+
assert!(matches!(
1170+
stmt.parameter_name(100),
1171+
Err(Error::InvalidParameterIndex(100))
1172+
));
1173+
}
1174+
1175+
// Positional parameters return their index number as the name
1176+
{
1177+
let stmt = db.prepare("SELECT ?, ?")?;
1178+
assert_eq!(stmt.parameter_count(), 2);
1179+
assert_eq!(stmt.parameter_name(1)?, "1");
1180+
assert_eq!(stmt.parameter_name(2)?, "2");
1181+
}
1182+
1183+
// Numbered positional parameters also return their number
1184+
{
1185+
let stmt = db.prepare("SELECT ?1, ?2")?;
1186+
assert_eq!(stmt.parameter_count(), 2);
1187+
assert_eq!(stmt.parameter_name(1)?, "1");
1188+
assert_eq!(stmt.parameter_name(2)?, "2");
1189+
}
1190+
1191+
Ok(())
1192+
}
1193+
1194+
#[test]
1195+
fn test_bind_named_parameters_manually() -> Result<()> {
1196+
use std::collections::HashMap;
1197+
1198+
let db = Connection::open_in_memory()?;
1199+
let mut stmt = db.prepare("SELECT $foo > $bar")?;
1200+
1201+
let mut params: HashMap<String, i32> = HashMap::new();
1202+
params.insert("foo".to_string(), 42);
1203+
params.insert("bar".to_string(), 23);
1204+
1205+
for idx in 1..=stmt.parameter_count() {
1206+
let name = stmt.parameter_name(idx)?;
1207+
if let Some(value) = params.get(&name) {
1208+
stmt.raw_bind_parameter(idx, value)?;
1209+
}
1210+
}
1211+
1212+
stmt.raw_execute()?;
1213+
1214+
let mut rows = stmt.raw_query();
1215+
let row = rows.next()?.unwrap();
1216+
let result: bool = row.get(0)?;
1217+
assert!(result);
1218+
1219+
Ok(())
1220+
}
11231221
}

0 commit comments

Comments
 (0)