Skip to content

Commit 99d0190

Browse files
authored
Merge pull request #323 from cipherstash/go-json-struct-tests
fix: simple protocol type error for JSONB encrypted types
2 parents 86b0ca4 + 1faea84 commit 99d0190

File tree

3 files changed

+89
-45
lines changed

3 files changed

+89
-45
lines changed

packages/cipherstash-proxy/src/postgresql/backend.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -578,7 +578,7 @@ where
578578

579579
debug!(target: PROTOCOL, client_id = self.context.client_id, RowDescription = ?description);
580580

581-
if let Some(statement) = self.context.get_statement_from_describe() {
581+
if let Some(statement) = self.context.get_statement_for_row_decription() {
582582
let projection_types = statement
583583
.projection_columns
584584
.iter()

packages/cipherstash-proxy/src/postgresql/context/mod.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -248,6 +248,19 @@ impl Context {
248248
}
249249
}
250250

251+
pub fn get_statement_for_row_decription(&self) -> Option<Arc<Statement>> {
252+
if let Some(statement) = self.get_statement_from_describe() {
253+
return Some(statement.clone());
254+
}
255+
256+
if let Some(Portal::Encrypted { statement, .. }) = self.get_portal_from_execute().as_deref()
257+
{
258+
return Some(statement.clone());
259+
};
260+
261+
None
262+
}
263+
251264
pub fn get_statement_from_describe(&self) -> Option<Arc<Statement>> {
252265
let queue = self.describe.read().ok()?;
253266
let describe = queue.next()?;

tests/integration/golang/pgx_test.go

Lines changed: 75 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -220,57 +220,88 @@ func TestPgxInsertEncryptedWithDomainTypeAndReturning(t *testing.T) {
220220
}
221221
}
222222

223-
/*
224-
func TestPgxEncryptedMapDate(t *testing.T) {
223+
// EncryptedRow represents a row with id and encrypted_text fields
224+
type EncryptedRow struct {
225+
ID int `db:"id"`
226+
EncryptedText string `db:"encrypted_text"`
227+
}
228+
229+
// Scan implements the sql.Scanner interface for EncryptedRow
230+
func (er *EncryptedRow) Scan(src interface{}) error {
231+
return fmt.Errorf("Scan method not implemented for single value")
232+
}
233+
234+
func TestPgxInsertEncryptedWithStructScan(t *testing.T) {
235+
t.Parallel()
225236
conn := setupPgxConnection(t)
226237
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
227238
defer cancel()
228239

229-
column := "encrypted_date"
230-
dates := []time.Time{
231-
time.Date(1, 1, 1, 0, 0, 0, 0, time.UTC),
232-
time.Date(1000, 1, 1, 0, 0, 0, 0, time.UTC),
233-
time.Date(1600, 1, 1, 0, 0, 0, 0, time.UTC),
234-
time.Date(1700, 1, 1, 0, 0, 0, 0, time.UTC),
235-
time.Date(1800, 1, 1, 0, 0, 0, 0, time.UTC),
236-
time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC),
237-
time.Date(1990, 1, 1, 0, 0, 0, 0, time.UTC),
238-
time.Date(1999, 12, 31, 0, 0, 0, 0, time.UTC),
239-
time.Date(2000, 1, 1, 0, 0, 0, 0, time.UTC),
240-
time.Date(2001, 1, 2, 0, 0, 0, 0, time.UTC),
241-
time.Date(2004, 2, 29, 0, 0, 0, 0, time.UTC),
242-
time.Date(2013, 7, 4, 0, 0, 0, 0, time.UTC),
243-
time.Date(2013, 12, 25, 0, 0, 0, 0, time.UTC),
244-
time.Date(2029, 1, 1, 0, 0, 0, 0, time.UTC),
245-
time.Date(2081, 1, 1, 0, 0, 0, 0, time.UTC),
246-
time.Date(2096, 2, 29, 0, 0, 0, 0, time.UTC),
247-
time.Date(2550, 1, 1, 0, 0, 0, 0, time.UTC),
248-
time.Date(9999, 12, 31, 0, 0, 0, 0, time.UTC),
240+
tx, err := conn.Begin(ctx)
241+
require.NoError(t, err)
242+
defer tx.Rollback(ctx)
243+
244+
encryptedTextValue := "test encrypted content"
245+
insertStmt := `INSERT INTO encrypted (id, encrypted_text) VALUES ($1, $2) RETURNING id, encrypted_text`
246+
247+
for _, mode := range modes {
248+
id := rand.Int()
249+
250+
t.Run(mode.String(), func(t *testing.T) {
251+
t.Run("insert_with_struct_scan", func(t *testing.T) {
252+
var row EncryptedRow
253+
err := conn.QueryRow(context.Background(), insertStmt, mode, id, encryptedTextValue).Scan(&row.ID, &row.EncryptedText)
254+
require.NoError(t, err)
255+
require.Equal(t, id, row.ID)
256+
require.Equal(t, encryptedTextValue, row.EncryptedText)
257+
})
258+
})
249259
}
250-
insertStmt := fmt.Sprintf(`INSERT INTO encrypted (id, %s) VALUES ($1, $2)`, column)
251-
selectStmt := fmt.Sprintf(`SELECT id, %s FROM encrypted WHERE id=$1`, column)
260+
}
252261

253-
for _, value := range dates {
254-
t.Run(value.String(), func(t *testing.T) {
255-
for _, mode := range modes {
256-
id := rand.Int()
257-
t.Run(mode.String(), func(t *testing.T) {
258-
t.Run("insert", func(t *testing.T) {
259-
_, err := conn.Exec(ctx, insertStmt, mode, id, value)
260-
require.NoError(t, err)
261-
})
262+
// EncryptedRowWithJsonb represents a row with id, encrypted_text and encrypted_jsonb fields
263+
type EncryptedRowWithJsonb struct {
264+
ID int `db:"id"`
265+
EncryptedText string `db:"encrypted_text"`
266+
EncryptedJsonb map[string]interface{} `db:"encrypted_jsonb"`
267+
}
262268

263-
t.Run("select", func(t *testing.T) {
264-
var rid int
265-
var rv time.Time
266-
err := conn.QueryRow(context.Background(), selectStmt, mode, id).Scan(&rid, &rv)
267-
require.NoError(t, err)
268-
require.Equal(t, id, rid)
269-
require.Equal(t, value, rv)
270-
})
271-
})
272-
}
269+
// Scan implements the sql.Scanner interface for EncryptedRowWithJsonb
270+
func (er *EncryptedRowWithJsonb) Scan(src interface{}) error {
271+
return fmt.Errorf("Scan method not implemented for single value")
272+
}
273+
274+
func TestPgxInsertEncryptedWithJsonbStructScan(t *testing.T) {
275+
t.Parallel()
276+
conn := setupPgxConnection(t)
277+
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
278+
defer cancel()
279+
280+
tx, err := conn.Begin(ctx)
281+
require.NoError(t, err)
282+
defer tx.Rollback(ctx)
283+
284+
encryptedTextValue := "test encrypted content"
285+
encryptedJsonbValue := `{"key":"value","number":42}`
286+
expectedJsonbValue := map[string]interface{}{
287+
"key": "value",
288+
"number": float64(42),
289+
}
290+
insertStmt := `INSERT INTO encrypted (id, encrypted_text, encrypted_jsonb) VALUES ($1, $2, $3) RETURNING id, encrypted_text, encrypted_jsonb`
291+
292+
for _, mode := range modes {
293+
id := rand.Int()
294+
295+
t.Run(mode.String(), func(t *testing.T) {
296+
t.Run("insert_with_jsonb_struct_scan", func(t *testing.T) {
297+
var row EncryptedRowWithJsonb
298+
err := conn.QueryRow(context.Background(), insertStmt, mode, id, encryptedTextValue, encryptedJsonbValue).Scan(&row.ID, &row.EncryptedText, &row.EncryptedJsonb)
299+
require.NoError(t, err)
300+
require.Equal(t, id, row.ID)
301+
require.Equal(t, encryptedTextValue, row.EncryptedText)
302+
require.Equal(t, expectedJsonbValue, row.EncryptedJsonb)
303+
})
273304
})
274305
}
275306
}
276-
*/
307+

0 commit comments

Comments
 (0)