Skip to content

Commit 95e15ef

Browse files
authored
Merge pull request #261 from cipherstash/fix/handle-domain-type
Handle PostgreSQL OIDs for custom and domain types
2 parents 2fa8701 + 08c8c1d commit 95e15ef

File tree

14 files changed

+173
-25
lines changed

14 files changed

+173
-25
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/cipherstash-proxy-integration/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,5 +29,6 @@ cipherstash-config = "0.2.3"
2929
clap = "4.5.32"
3030
fake = { version = "4", features = ["chrono", "derive"] }
3131
hex = "0.4.3"
32+
postgres-types = { version = "0.2.9", features = ["derive"] }
3233
tap = "1.0.1"
3334
uuid = { version = "1.11.0", features = ["serde", "v4"] }

packages/cipherstash-proxy-integration/src/common.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,11 @@ pub async fn clear() {
4949
}
5050

5151
pub async fn reset_schema() {
52-
let client = connect_with_tls(PROXY).await;
52+
let port = std::env::var("CS_DATABASE__PORT")
53+
.map(|s| s.parse().unwrap())
54+
.unwrap_or(PG_LATEST);
55+
56+
let client = connect_with_tls(port).await;
5357
client.simple_query(TEST_SCHEMA_SQL).await.unwrap();
5458
}
5559

packages/cipherstash-proxy-integration/src/decrypt/insert_returning.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ mod tests {
9090
let encrypted_text_result: String = row.get("encrypted_text");
9191
assert_eq!(encrypted_text, encrypted_text_result);
9292

93-
let encrypted_text_result: String = row.get(3);
93+
let encrypted_text_result: String = row.get(4);
9494
assert_eq!(encrypted_text, encrypted_text_result);
9595
}
9696
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#[cfg(test)]
2+
mod tests {
3+
use crate::common::{connect_with_tls, insert, query_by, random_id, trace, PROXY};
4+
use tokio_postgres::types::{FromSql, ToSql};
5+
6+
#[derive(Debug, ToSql, FromSql, PartialEq)]
7+
#[postgres(name = "domain_type_with_check")]
8+
pub struct Domain(String);
9+
10+
///
11+
/// Tests insertion of custom domain type
12+
///
13+
#[tokio::test]
14+
async fn insert_domain_type() {
15+
trace();
16+
17+
let id = random_id();
18+
let encrypted_domain = Domain("ZZ".to_string());
19+
20+
let sql = "INSERT INTO encrypted (id, plaintext_domain) VALUES ($1, $2)";
21+
insert(sql, &[&id, &encrypted_domain]).await;
22+
23+
let sql = "SELECT plaintext_domain FROM encrypted WHERE id = $1";
24+
let result = query_by::<Domain>(sql, &id).await;
25+
26+
let expected = vec![encrypted_domain];
27+
assert_eq!(expected, result);
28+
}
29+
30+
///
31+
/// Tests insertion of custom domain type with returned values
32+
///
33+
#[tokio::test]
34+
async fn insert_domain_type_with_encrypted_and_returning() {
35+
trace();
36+
37+
let id = random_id();
38+
let encrypted_domain = Domain("ZZ".to_string());
39+
let encrypted_text = "blah-vtha".to_string();
40+
41+
let sql = "INSERT INTO encrypted (id, plaintext_domain, encrypted_text) VALUES ($1, $2, $3) RETURNING id, plaintext_domain, encrypted_text";
42+
43+
let client = connect_with_tls(PROXY).await;
44+
let result = client
45+
.query(sql, &[&id, &encrypted_domain, &encrypted_text])
46+
.await
47+
.unwrap();
48+
49+
assert_eq!(result.len(), 1);
50+
51+
for row in result {
52+
let result_id: i64 = row.get("id");
53+
assert_eq!(id, result_id);
54+
55+
let result_encrypted_text: String = row.get("encrypted_text");
56+
assert_eq!(encrypted_text, result_encrypted_text);
57+
58+
let result_domain: Domain = row.get("plaintext_domain");
59+
assert_eq!(encrypted_domain, result_domain);
60+
}
61+
}
62+
}

packages/cipherstash-proxy-integration/src/insert/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
mod insert_domain_type;
12
mod insert_with_literal;
23
mod insert_with_null_literal;
34
mod insert_with_null_param;

packages/cipherstash-proxy-integration/src/map_unique_index.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#[cfg(test)]
22
mod tests {
3-
use crate::common::{clear, connect_with_tls, random_id, reset_schema, trace, PROXY};
3+
use crate::common::{clear, connect_with_tls, random_id, trace, PROXY};
44
use chrono::NaiveDate;
55

66
#[tokio::test]
@@ -225,7 +225,7 @@ mod tests {
225225
async fn map_unique_index_all_with_wildcard() {
226226
trace();
227227

228-
reset_schema().await;
228+
clear().await;
229229

230230
let client = connect_with_tls(PROXY).await;
231231

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#[cfg(test)]
2+
mod tests {
3+
use crate::common::{insert, query_by, random_id, trace};
4+
use tokio_postgres::types::{FromSql, ToSql};
5+
6+
#[derive(Debug, ToSql, FromSql, PartialEq)]
7+
#[postgres(name = "domain_type_with_check")]
8+
pub struct Domain(String);
9+
10+
///
11+
/// Tests insertion of custom domain type
12+
///
13+
#[tokio::test]
14+
async fn select_domain_type() {
15+
trace();
16+
17+
let id = random_id();
18+
let encrypted_val = Domain("ZZ".to_string());
19+
20+
let insert_sql = "INSERT INTO encrypted (id, plaintext_domain) VALUES ($1, $2)";
21+
insert(insert_sql, &[&id, &encrypted_val]).await;
22+
23+
let select_sql = "SELECT plaintext_domain FROM encrypted WHERE id = $1";
24+
let result = query_by::<Domain>(select_sql, &id).await;
25+
26+
let expected = vec![encrypted_val];
27+
assert_eq!(expected, result);
28+
}
29+
}

packages/cipherstash-proxy/src/postgresql/backend.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,8 @@ where
355355
) -> Result<Option<BytesMut>, Error> {
356356
let mut description = ParamDescription::try_from(bytes)?;
357357

358+
debug!(target: PROTOCOL, client_id = self.context.client_id, ParamDescription = ?description);
359+
358360
if let Some(statement) = self.context.get_statement_from_describe() {
359361
let param_types = statement
360362
.param_columns
@@ -389,6 +391,8 @@ where
389391
) -> Result<Option<BytesMut>, Error> {
390392
let mut description = RowDescription::try_from(bytes)?;
391393

394+
debug!(target: PROTOCOL, client_id = self.context.client_id, RowDescription = ?description);
395+
392396
if let Some(statement) = self.context.get_statement_from_describe() {
393397
let projection_types = statement
394398
.projection_columns

packages/cipherstash-proxy/src/postgresql/messages/param_description.rs

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,15 @@ use std::io::Cursor;
2828
2929
#[derive(Debug)]
3030
pub struct ParamDescription {
31-
pub types: Vec<postgres_types::Type>,
31+
pub types: Vec<i32>,
3232
dirty: bool,
3333
}
3434

3535
impl ParamDescription {
3636
pub fn map_types(&mut self, mapped_types: &[Option<Type>]) {
3737
for (idx, t) in mapped_types.iter().enumerate() {
3838
if let Some(t) = t {
39-
self.types[idx] = t.clone();
39+
self.types[idx] = t.oid() as i32;
4040
self.dirty = true;
4141
}
4242
}
@@ -69,8 +69,6 @@ impl TryFrom<&BytesMut> for ParamDescription {
6969
let mut types = vec![];
7070
for _idx in 0..count {
7171
let type_oid = cursor.get_i32();
72-
let type_oid = postgres_types::Type::from_oid(type_oid as u32)
73-
.unwrap_or(postgres_types::Type::UNKNOWN);
7472
types.push(type_oid)
7573
}
7674

@@ -97,7 +95,7 @@ impl TryFrom<ParamDescription> for BytesMut {
9795
bytes.put_i16(count as i16);
9896

9997
for type_oid in parameter_description.types.into_iter() {
100-
bytes.put_i32(type_oid.oid() as i32);
98+
bytes.put_i32(type_oid);
10199
}
102100

103101
Ok(bytes)
@@ -124,9 +122,9 @@ mod tests {
124122

125123
let mut pd = ParamDescription {
126124
types: vec![
127-
postgres_types::Type::TEXT,
128-
postgres_types::Type::INT4,
129-
postgres_types::Type::INT8,
125+
postgres_types::Type::TEXT.oid() as i32,
126+
postgres_types::Type::INT4.oid() as i32,
127+
postgres_types::Type::INT8.oid() as i32,
130128
],
131129
dirty: false,
132130
};
@@ -145,9 +143,9 @@ mod tests {
145143
assert!(pd.requires_rewrite());
146144

147145
let expected = vec![
148-
postgres_types::Type::TEXT,
149-
postgres_types::Type::INT4,
150-
postgres_types::Type::TEXT,
146+
postgres_types::Type::TEXT.oid() as i32,
147+
postgres_types::Type::INT4.oid() as i32,
148+
postgres_types::Type::TEXT.oid() as i32,
151149
];
152150

153151
assert_eq!(pd.types, expected);
@@ -165,8 +163,14 @@ mod tests {
165163
info!("{:?}", description);
166164

167165
assert_eq!(description.types.len(), 2);
168-
assert_eq!(description.types[0], postgres_types::Type::INT8);
169-
assert_eq!(description.types[1], postgres_types::Type::JSONB);
166+
assert_eq!(
167+
description.types[0],
168+
postgres_types::Type::INT8.oid() as i32
169+
);
170+
assert_eq!(
171+
description.types[1],
172+
postgres_types::Type::JSONB.oid() as i32
173+
);
170174

171175
let bytes = BytesMut::try_from(description).unwrap();
172176
assert_eq!(bytes, expected);

0 commit comments

Comments
 (0)