Skip to content

Commit 94addf4

Browse files
authored
Merge pull request #314 from cipherstash/refactor-name
♻️ refactor: convert Name struct to enum and extract to own module
2 parents 1cf7ebf + 099bc1e commit 94addf4

File tree

10 files changed

+340
-76
lines changed

10 files changed

+340
-76
lines changed

packages/cipherstash-proxy/src/postgresql/context/mod.rs

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,7 @@ pub mod column;
22

33
use super::{
44
format_code::FormatCode,
5-
messages::{
6-
describe::{Describe, Target},
7-
Name,
8-
},
5+
messages::{describe::Describe, Name, Target},
96
Column,
107
};
118
use crate::{
@@ -264,7 +261,7 @@ impl Context {
264261
} => self.get_portal_statement(name),
265262
Describe {
266263
ref name,
267-
target: Target::PreparedStatement,
264+
target: Target::Statement,
268265
} => self.get_statement(name),
269266
}
270267
}
@@ -583,7 +580,7 @@ mod tests {
583580
use crate::{
584581
config::LogConfig,
585582
log,
586-
postgresql::messages::{describe::Target, Name},
583+
postgresql::messages::{Name, Target},
587584
};
588585
use cipherstash_client::IdentifiedBy;
589586
use eql_mapper::Schema;
@@ -621,15 +618,15 @@ mod tests {
621618

622619
let mut context = Context::new(1, schema);
623620

624-
let name = Name("name".to_string());
621+
let name = Name::from("name");
625622

626623
context.add_statement(name.clone(), statement());
627624

628625
let statement = context.get_statement(&name).unwrap();
629626

630627
let describe = Describe {
631628
name,
632-
target: Target::PreparedStatement,
629+
target: Target::Statement,
633630
};
634631
context.set_describe(describe);
635632

@@ -646,8 +643,8 @@ mod tests {
646643

647644
let mut context = Context::new(1, schema);
648645

649-
let statement_name = Name("statement".to_string());
650-
let portal_name = Name("portal".to_string());
646+
let statement_name = Name::from("statement");
647+
let portal_name = Name::from("portal");
651648

652649
// Add statement to context
653650
context.add_statement(statement_name.clone(), statement());
@@ -688,8 +685,8 @@ mod tests {
688685
let mut context = Context::new(1, schema);
689686

690687
// Create multiple statements
691-
let statement_name_1 = Name("statement_1".to_string());
692-
let statement_name_2 = Name("statement_2".to_string());
688+
let statement_name_1 = Name::from("statement_1");
689+
let statement_name_2 = Name::from("statement_2");
693690

694691
// Add statements to context
695692
context.add_statement(statement_name_1.clone(), statement());
@@ -698,7 +695,7 @@ mod tests {
698695
// Replicate pipelined execution
699696
// Add multiple portals with the same name
700697
// Pointing to different statements
701-
let portal_name = Name("portal".to_string());
698+
let portal_name = Name::from("portal");
702699

703700
let statement_1 = context.get_statement(&statement_name_1).unwrap();
704701
context.add_portal(portal_name.clone(), portal(&statement_1));
@@ -737,14 +734,14 @@ mod tests {
737734

738735
let mut context = Context::new(1, schema);
739736

740-
let statement_name_1 = Name("statement_1".to_string());
737+
let statement_name_1 = Name::from("statement_1");
741738
let portal_name_1 = Name::unnamed();
742739

743-
let statement_name_2 = Name("statement_2".to_string());
740+
let statement_name_2 = Name::from("statement_2");
744741
let portal_name_2 = Name::unnamed();
745742

746-
let statement_name_3 = Name("statement_3".to_string());
747-
let portal_name_3 = Name("portal_3".to_string());
743+
let statement_name_3 = Name::from("statement_3");
744+
let portal_name_3 = Name::from("portal_3");
748745

749746
// Add statement to context
750747
context.add_statement(statement_name_1.clone(), statement());

packages/cipherstash-proxy/src/postgresql/frontend.rs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,10 @@ use crate::log::{CONTEXT, MAPPER, PROTOCOL};
1515
use crate::postgresql::context::column::Column;
1616
use crate::postgresql::context::Portal;
1717
use crate::postgresql::data::literal_from_sql;
18+
use crate::postgresql::messages::close::Close;
1819
use crate::postgresql::messages::ready_for_query::ReadyForQuery;
1920
use crate::postgresql::messages::terminate::Terminate;
20-
use crate::postgresql::messages::Name;
21+
use crate::postgresql::messages::{Name, Target};
2122
use crate::prometheus::{
2223
CLIENTS_BYTES_RECEIVED_TOTAL, ENCRYPTED_VALUES_TOTAL, ENCRYPTION_DURATION_SECONDS,
2324
ENCRYPTION_ERROR_TOTAL, ENCRYPTION_REQUESTS_TOTAL, SERVER_BYTES_SENT_TOTAL,
@@ -287,6 +288,9 @@ where
287288
return Ok(());
288289
}
289290
}
291+
Code::Close => {
292+
self.close_handler(&bytes).await?;
293+
}
290294
code => {
291295
debug!(target: PROTOCOL,
292296
client_id = self.context.client_id,
@@ -322,6 +326,19 @@ where
322326
Ok(())
323327
}
324328

329+
async fn close_handler(&mut self, bytes: &BytesMut) -> Result<(), Error> {
330+
let close = Close::try_from(bytes)?;
331+
debug!(target: PROTOCOL, client_id = self.context.client_id, ?close);
332+
match close.target {
333+
Target::Portal => self.context.close_portal(&close.name),
334+
Target::Statement => {
335+
self.context.close_portal(&close.name);
336+
// self.context.close_statement(&close.name);
337+
}
338+
}
339+
Ok(())
340+
}
341+
325342
async fn execute_handler(&mut self, bytes: &BytesMut) -> Result<(), Error> {
326343
let execute = Execute::try_from(bytes)?;
327344
debug!(target: PROTOCOL, client_id = self.context.client_id, ?execute);

packages/cipherstash-proxy/src/postgresql/messages/bind.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -227,10 +227,10 @@ impl TryFrom<&BytesMut> for Bind {
227227
let _len = cursor.get_i32();
228228

229229
let portal = cursor.read_string()?;
230-
let portal = Name(portal);
230+
let portal = Name::from(portal);
231231

232232
let prepared_statement = cursor.read_string()?;
233-
let prepared_statement = Name(prepared_statement);
233+
let prepared_statement = Name::from(prepared_statement);
234234

235235
let num_param_format_codes = cursor.get_i16();
236236
let mut param_format_codes = Vec::new();
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
use crate::error::{Error, ProtocolError};
2+
use crate::postgresql::protocol::BytesMutReadString;
3+
use crate::{SIZE_I32, SIZE_U8};
4+
5+
use bytes::{Buf, BufMut, BytesMut};
6+
use std::convert::TryFrom;
7+
use std::ffi::CString;
8+
use std::io::Cursor;
9+
10+
use super::target::Target;
11+
use super::{FrontendCode, Name};
12+
13+
///
14+
/// Close b'C' (Frontend) message.
15+
///
16+
/// See: <https://www.postgresql.org/docs/current/protocol-message-formats.html>
17+
///
18+
/// Byte1('C')
19+
/// Identifies the message as a Close command.
20+
///
21+
/// Int32
22+
/// Length of message contents in bytes, including self.
23+
///
24+
/// Byte1
25+
/// 'S' to close a prepared statement; or 'P' to close a portal.
26+
///
27+
/// String
28+
/// The name of the prepared statement or portal to close (an empty string selects the unnamed prepared statement or portal).
29+
30+
#[derive(Debug, Clone)]
31+
pub(crate) struct Close {
32+
pub target: Target,
33+
pub name: Name,
34+
}
35+
36+
impl TryFrom<&BytesMut> for Close {
37+
type Error = Error;
38+
39+
fn try_from(bytes: &BytesMut) -> Result<Close, Self::Error> {
40+
let mut cursor = Cursor::new(bytes);
41+
let code = cursor.get_u8();
42+
43+
if FrontendCode::from(code) != FrontendCode::Close {
44+
return Err(ProtocolError::UnexpectedMessageCode {
45+
expected: FrontendCode::Close.into(),
46+
received: code as char,
47+
}
48+
.into());
49+
}
50+
51+
let _len = cursor.get_i32(); // read and progress cursor
52+
let target = cursor.get_u8();
53+
let target = Target::try_from(target)?;
54+
let name = cursor.read_string()?;
55+
let name = Name::from(name);
56+
57+
Ok(Close { target, name })
58+
}
59+
}
60+
61+
impl TryFrom<Close> for BytesMut {
62+
type Error = Error;
63+
64+
fn try_from(close: Close) -> Result<BytesMut, Error> {
65+
let mut bytes = BytesMut::new();
66+
67+
let name = CString::new(close.name.as_str())?;
68+
let name = name.as_bytes_with_nul();
69+
70+
let len = SIZE_I32 + SIZE_U8 + name.len();
71+
72+
bytes.put_u8(FrontendCode::Close.into());
73+
bytes.put_i32(len as i32);
74+
bytes.put_u8(close.target.into());
75+
bytes.put_slice(name);
76+
77+
Ok(bytes)
78+
}
79+
}
80+
81+
#[cfg(test)]
82+
mod tests {
83+
use super::*;
84+
use crate::{config::LogConfig, log, postgresql::messages::Name};
85+
use bytes::BytesMut;
86+
use std::convert::TryFrom;
87+
88+
fn to_message(s: &[u8]) -> BytesMut {
89+
BytesMut::from(s)
90+
}
91+
92+
#[test]
93+
pub fn test_close_statement() {
94+
log::init(LogConfig::default());
95+
96+
// Close unnamed prepared statement: C\0\0\0\x06S\0
97+
let bytes = to_message(b"C\0\0\0\x06S\0");
98+
let close = Close::try_from(&bytes).unwrap();
99+
100+
assert!(matches!(close.target, Target::Statement));
101+
assert!(close.name.is_unnamed());
102+
}
103+
104+
#[test]
105+
pub fn test_close_portal() {
106+
log::init(LogConfig::default());
107+
108+
// Close unnamed portal: C\0\0\0\x06P\0
109+
let bytes = to_message(b"C\0\0\0\x06P\0");
110+
let close = Close::try_from(&bytes).unwrap();
111+
112+
assert!(matches!(close.target, Target::Portal));
113+
assert!(close.name.is_unnamed());
114+
}
115+
116+
#[test]
117+
pub fn test_close_named_statement() {
118+
log::init(LogConfig::default());
119+
120+
// Close named prepared statement "stmt1": C\0\0\0\x0bSstmt1\0
121+
let bytes = to_message(b"C\0\0\0\x0bSstmt1\0");
122+
let close = Close::try_from(&bytes).unwrap();
123+
124+
assert!(matches!(close.target, Target::Statement));
125+
assert_eq!(close.name.as_str(), "stmt1");
126+
assert!(!close.name.is_unnamed());
127+
}
128+
129+
#[test]
130+
pub fn test_close_to_bytes() {
131+
log::init(LogConfig::default());
132+
133+
let close = Close {
134+
target: Target::Portal,
135+
name: Name::from("portal1"),
136+
};
137+
138+
let bytes = BytesMut::try_from(close).unwrap();
139+
let parsed = Close::try_from(&bytes).unwrap();
140+
141+
assert!(matches!(parsed.target, Target::Portal));
142+
assert_eq!(parsed.name.as_str(), "portal1");
143+
}
144+
}

packages/cipherstash-proxy/src/postgresql/messages/describe.rs

Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use std::convert::TryFrom;
77
use std::ffi::CString;
88
use std::io::Cursor;
99

10+
use super::target::Target;
1011
use super::{FrontendCode, Name};
1112

1213
///
@@ -32,29 +33,6 @@ pub(crate) struct Describe {
3233
pub name: Name,
3334
}
3435

35-
///
36-
/// The target of the describe message.
37-
///
38-
/// Valid values are PreparedStatment or Portal
39-
///
40-
/// A Portal is a parsed statement PLUS any bound parameters
41-
/// Describe with `Target::Portal` returns the RowDescription describing the result set.
42-
/// The assuumption is that the parameters are already bound to the portal, so the Describe message is not required to include any parameter information.
43-
///
44-
/// Calls to Execute are made on a Portal (not a prepared statement) as execute requires any bound parameters
45-
///
46-
/// A PreparedStatement is the parsed statement
47-
/// Describe with `Target::PreparedStatement` returns a ParameterDescription followed by the RowDescription.
48-
///
49-
///
50-
/// See https://www.postgresql.org/docs/current/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY
51-
///
52-
#[derive(Debug, Clone)]
53-
pub enum Target {
54-
Portal,
55-
PreparedStatement,
56-
}
57-
5836
impl TryFrom<&BytesMut> for Describe {
5937
type Error = Error;
6038

@@ -74,7 +52,7 @@ impl TryFrom<&BytesMut> for Describe {
7452
let target = cursor.get_u8();
7553
let target = Target::try_from(target)?;
7654
let name = cursor.read_string()?;
77-
let name = Name(name);
55+
let name = Name::from(name);
7856

7957
Ok(Describe { target, name })
8058
}
@@ -86,28 +64,16 @@ impl TryFrom<Describe> for BytesMut {
8664
fn try_from(describe: Describe) -> Result<BytesMut, Error> {
8765
let mut bytes = BytesMut::new();
8866

89-
let name = CString::new(describe.name.0.as_str())?;
67+
let name = CString::new(describe.name.as_str())?;
9068
let name = name.as_bytes_with_nul();
9169

9270
let len = SIZE_I32 + SIZE_U8 + name.len();
9371

9472
bytes.put_u8(FrontendCode::Describe.into());
9573
bytes.put_i32(len as i32);
96-
bytes.put_u8(describe.target as u8);
74+
bytes.put_u8(describe.target.into());
9775
bytes.put_slice(name);
9876

9977
Ok(bytes)
10078
}
10179
}
102-
103-
impl TryFrom<u8> for Target {
104-
type Error = Error;
105-
106-
fn try_from(t: u8) -> Result<Target, Error> {
107-
match t as char {
108-
'S' => Ok(Target::PreparedStatement),
109-
'P' => Ok(Target::Portal),
110-
t => Err(ProtocolError::UnexpectedDescribeTarget(t).into()),
111-
}
112-
}
113-
}

packages/cipherstash-proxy/src/postgresql/messages/execute.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ impl TryFrom<&BytesMut> for Execute {
2929
let _len = cursor.get_i32(); // read and progress cursor
3030

3131
let portal = cursor.read_string()?;
32-
let portal = Name(portal);
32+
let portal = Name::from(portal);
3333
let max_rows = cursor.get_i32();
3434

3535
Ok(Execute { portal, max_rows })

0 commit comments

Comments
 (0)