Skip to content

Commit 31939cf

Browse files
committed
Read Postgres' NUMERIC as big.Rat
1 parent 9063503 commit 31939cf

File tree

12 files changed

+470
-118
lines changed

12 files changed

+470
-118
lines changed

destination.go

Lines changed: 52 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import (
1818
"context"
1919
"encoding/json"
2020
"fmt"
21+
"math/big"
2122
"strings"
2223

2324
sq "github.com/Masterminds/squirrel"
@@ -26,6 +27,7 @@ import (
2627
"github.com/conduitio/conduit-connector-postgres/internal"
2728
sdk "github.com/conduitio/conduit-connector-sdk"
2829
"github.com/jackc/pgx/v5"
30+
"github.com/shopspring/decimal"
2931
)
3032

3133
type Destination struct {
@@ -35,6 +37,7 @@ type Destination struct {
3537
getTableName destination.TableFn
3638

3739
conn *pgx.Conn
40+
dbInfo *internal.DbInfo
3841
stmtBuilder sq.StatementBuilderType
3942
}
4043

@@ -61,6 +64,7 @@ func (d *Destination) Open(ctx context.Context) error {
6164
return fmt.Errorf("invalid table name or table name function: %w", err)
6265
}
6366

67+
d.dbInfo = internal.NewDbInfo(conn)
6468
return nil
6569
}
6670

@@ -215,7 +219,11 @@ func (d *Destination) insert(ctx context.Context, r opencdc.Record, b *pgx.Batch
215219
return err
216220
}
217221

218-
colArgs, valArgs := d.formatColumnsAndValues(key, payload)
222+
colArgs, valArgs, err := d.formatColumnsAndValues(tableName, key, payload)
223+
if err != nil {
224+
return fmt.Errorf("error formatting columns and values: %w", err)
225+
}
226+
219227
sdk.Logger(ctx).Trace().
220228
Str("table_name", tableName).
221229
Msg("inserting record")
@@ -294,10 +302,13 @@ func (d *Destination) formatUpsertQuery(
294302
// remove the last comma from the list of tuples
295303
upsertQuery = strings.TrimSuffix(upsertQuery, ",")
296304

297-
// we have to manually append a semi colon to the upsert sql;
305+
// we have to manually append a semicolon to the upsert sql;
298306
upsertQuery += ";"
299307

300-
colArgs, valArgs := d.formatColumnsAndValues(key, payload)
308+
colArgs, valArgs, err := d.formatColumnsAndValues(tableName, key, payload)
309+
if err != nil {
310+
return "", nil, fmt.Errorf("error formatting columns and values: %w", err)
311+
}
301312

302313
return d.stmtBuilder.
303314
Insert(internal.WrapSQLIdent(tableName)).
@@ -309,24 +320,32 @@ func (d *Destination) formatUpsertQuery(
309320

310321
// formatColumnsAndValues turns the key and payload into a slice of ordered
311322
// columns and values for upserting into Postgres.
312-
func (d *Destination) formatColumnsAndValues(key, payload opencdc.StructuredData) ([]string, []interface{}) {
323+
func (d *Destination) formatColumnsAndValues(table string, key, payload opencdc.StructuredData) ([]string, []interface{}, error) {
313324
var colArgs []string
314325
var valArgs []interface{}
315326

316327
// range over both the key and payload values in order to format the
317328
// query for args and values in proper order
318329
for key, val := range key {
319330
colArgs = append(colArgs, internal.WrapSQLIdent(key))
320-
valArgs = append(valArgs, val)
331+
formatted, err := d.formatValue(table, key, val)
332+
if err != nil {
333+
return nil, nil, fmt.Errorf("error formatting value: %w", err)
334+
}
335+
valArgs = append(valArgs, formatted)
321336
delete(payload, key) // NB: Delete Key from payload arguments
322337
}
323338

324-
for field, value := range payload {
339+
for field, val := range payload {
325340
colArgs = append(colArgs, internal.WrapSQLIdent(field))
326-
valArgs = append(valArgs, value)
341+
formatted, err := d.formatValue(table, field, val)
342+
if err != nil {
343+
return nil, nil, fmt.Errorf("error formatting value: %w", err)
344+
}
345+
valArgs = append(valArgs, formatted)
327346
}
328347

329-
return colArgs, valArgs
348+
return colArgs, valArgs, nil
330349
}
331350

332351
// getKeyColumnName will return the name of the first item in the key or the
@@ -346,3 +365,28 @@ func (d *Destination) getKeyColumnName(key opencdc.StructuredData, defaultKeyNam
346365
func (d *Destination) hasKey(e opencdc.Record) bool {
347366
return e.Key != nil && len(e.Key.Bytes()) > 0
348367
}
368+
369+
func (d *Destination) formatValue(table string, column string, val interface{}) (interface{}, error) {
370+
switch v := val.(type) {
371+
case *big.Rat:
372+
return d.formatBigRat(table, column, v)
373+
case big.Rat:
374+
return d.formatBigRat(table, column, &v)
375+
default:
376+
return val, nil
377+
}
378+
}
379+
380+
func (d *Destination) formatBigRat(table string, column string, v *big.Rat) (string, error) {
381+
scale, err := d.dbInfo.GetNumericColumnScale(table, column)
382+
if err != nil {
383+
return "", fmt.Errorf("failed getting scale of numeric column: %w", err)
384+
}
385+
386+
if v == nil {
387+
return "", nil
388+
}
389+
390+
//nolint:gosec // no risk of overflow, because the scale in Pg is always <= 16383
391+
return decimal.NewFromBigRat(v, int32(scale)).String(), nil
392+
}

destination_integration_test.go

Lines changed: 93 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ package postgres
1717
import (
1818
"context"
1919
"fmt"
20+
"math/big"
2021
"strings"
2122
"testing"
2223

2324
"github.com/conduitio/conduit-commons/opencdc"
2425
"github.com/conduitio/conduit-connector-postgres/test"
2526
sdk "github.com/conduitio/conduit-connector-sdk"
27+
"github.com/google/go-cmp/cmp"
2628
"github.com/jackc/pgx/v5"
2729
"github.com/matryer/is"
2830
)
@@ -71,11 +73,13 @@ func TestDestination_Write(t *testing.T) {
7173
"column1": "foo",
7274
"column2": 123,
7375
"column3": true,
76+
"column4": nil,
7477
"UppercaseColumn1": 222,
7578
},
7679
},
7780
},
78-
}, {
81+
},
82+
{
7983
name: "create",
8084
record: opencdc.Record{
8185
Position: opencdc.Position("foo"),
@@ -87,11 +91,13 @@ func TestDestination_Write(t *testing.T) {
8791
"column1": "foo",
8892
"column2": 456,
8993
"column3": false,
94+
"column4": nil,
9095
"UppercaseColumn1": 333,
9196
},
9297
},
9398
},
94-
}, {
99+
},
100+
{
95101
name: "insert on update (upsert)",
96102
record: opencdc.Record{
97103
Position: opencdc.Position("foo"),
@@ -103,11 +109,13 @@ func TestDestination_Write(t *testing.T) {
103109
"column1": "bar",
104110
"column2": 567,
105111
"column3": true,
112+
"column4": nil,
106113
"UppercaseColumn1": 444,
107114
},
108115
},
109116
},
110-
}, {
117+
},
118+
{
111119
name: "update on conflict",
112120
record: opencdc.Record{
113121
Position: opencdc.Position("foo"),
@@ -119,11 +127,13 @@ func TestDestination_Write(t *testing.T) {
119127
"column1": "foobar",
120128
"column2": 567,
121129
"column3": true,
130+
"column4": nil,
122131
"UppercaseColumn1": 555,
123132
},
124133
},
125134
},
126-
}, {
135+
},
136+
{
127137
name: "delete",
128138
record: opencdc.Record{
129139
Position: opencdc.Position("foo"),
@@ -132,6 +142,24 @@ func TestDestination_Write(t *testing.T) {
132142
Key: opencdc.StructuredData{"id": 4},
133143
},
134144
},
145+
{
146+
name: "write a big.Rat",
147+
record: opencdc.Record{
148+
Position: opencdc.Position("foo"),
149+
Operation: opencdc.OperationSnapshot,
150+
Metadata: map[string]string{opencdc.MetadataCollection: tableName},
151+
Key: opencdc.StructuredData{"id": 123},
152+
Payload: opencdc.Change{
153+
After: opencdc.StructuredData{
154+
"column1": "abcdef",
155+
"column2": 567,
156+
"column3": true,
157+
"column4": big.NewRat(123, 100),
158+
"UppercaseColumn1": 555,
159+
},
160+
},
161+
},
162+
},
135163
}
136164
for _, tt := range tests {
137165
t.Run(tt.name, func(t *testing.T) {
@@ -146,7 +174,16 @@ func TestDestination_Write(t *testing.T) {
146174
switch tt.record.Operation {
147175
case opencdc.OperationCreate, opencdc.OperationSnapshot, opencdc.OperationUpdate:
148176
is.NoErr(err)
149-
is.Equal(tt.record.Payload.After, got)
177+
is.Equal(
178+
"",
179+
cmp.Diff(
180+
tt.record.Payload.After,
181+
got,
182+
cmp.Comparer(func(x, y *big.Rat) bool {
183+
return x.Cmp(y) == 0
184+
}),
185+
),
186+
) // -want, +got
150187
case opencdc.OperationDelete:
151188
is.Equal(err, pgx.ErrNoRows)
152189
}
@@ -179,43 +216,50 @@ func TestDestination_Batch(t *testing.T) {
179216
is.NoErr(err)
180217
}()
181218

182-
records := []opencdc.Record{{
183-
Position: opencdc.Position("foo1"),
184-
Operation: opencdc.OperationCreate,
185-
Key: opencdc.StructuredData{"id": 5},
186-
Payload: opencdc.Change{
187-
After: opencdc.StructuredData{
188-
"column1": "foo1",
189-
"column2": 1,
190-
"column3": false,
191-
"UppercaseColumn1": 111,
219+
records := []opencdc.Record{
220+
{
221+
Position: opencdc.Position("foo1"),
222+
Operation: opencdc.OperationCreate,
223+
Key: opencdc.StructuredData{"id": 5},
224+
Payload: opencdc.Change{
225+
After: opencdc.StructuredData{
226+
"column1": "foo1",
227+
"column2": 1,
228+
"column3": false,
229+
"column4": nil,
230+
"UppercaseColumn1": 111,
231+
},
192232
},
193233
},
194-
}, {
195-
Position: opencdc.Position("foo2"),
196-
Operation: opencdc.OperationCreate,
197-
Key: opencdc.StructuredData{"id": 6},
198-
Payload: opencdc.Change{
199-
After: opencdc.StructuredData{
200-
"column1": "foo2",
201-
"column2": 2,
202-
"column3": true,
203-
"UppercaseColumn1": 222,
234+
{
235+
Position: opencdc.Position("foo2"),
236+
Operation: opencdc.OperationCreate,
237+
Key: opencdc.StructuredData{"id": 6},
238+
Payload: opencdc.Change{
239+
After: opencdc.StructuredData{
240+
"column1": "foo2",
241+
"column2": 2,
242+
"column3": true,
243+
"column4": nil,
244+
"UppercaseColumn1": 222,
245+
},
204246
},
205247
},
206-
}, {
207-
Position: opencdc.Position("foo3"),
208-
Operation: opencdc.OperationCreate,
209-
Key: opencdc.StructuredData{"id": 7},
210-
Payload: opencdc.Change{
211-
After: opencdc.StructuredData{
212-
"column1": "foo3",
213-
"column2": 3,
214-
"column3": false,
215-
"UppercaseColumn1": 333,
248+
{
249+
Position: opencdc.Position("foo3"),
250+
Operation: opencdc.OperationCreate,
251+
Key: opencdc.StructuredData{"id": 7},
252+
Payload: opencdc.Change{
253+
After: opencdc.StructuredData{
254+
"column1": "foo3",
255+
"column2": 3,
256+
"column3": false,
257+
"column4": nil,
258+
"UppercaseColumn1": 333,
259+
},
216260
},
217261
},
218-
}}
262+
}
219263

220264
i, err := d.Write(ctx, records)
221265
is.NoErr(err)
@@ -231,25 +275,36 @@ func TestDestination_Batch(t *testing.T) {
231275
func queryTestTable(ctx context.Context, conn test.Querier, tableName string, id any) (opencdc.StructuredData, error) {
232276
row := conn.QueryRow(
233277
ctx,
234-
fmt.Sprintf(`SELECT column1, column2, column3, "UppercaseColumn1" FROM %q WHERE id = $1`, tableName),
278+
fmt.Sprintf(`SELECT column1, column2, column3, column4, "UppercaseColumn1" FROM %q WHERE id = $1`, tableName),
235279
id,
236280
)
237281

238282
var (
239283
col1 string
240284
col2 int
241285
col3 bool
286+
col4Str *string
242287
uppercaseCol1 int
243288
)
244-
err := row.Scan(&col1, &col2, &col3, &uppercaseCol1)
289+
290+
err := row.Scan(&col1, &col2, &col3, &col4Str, &uppercaseCol1)
245291
if err != nil {
246292
return nil, err
247293
}
248294

295+
// Handle the potential nil case for col4
296+
var col4 interface{}
297+
if col4Str != nil {
298+
r := new(big.Rat)
299+
r.SetString(*col4Str)
300+
col4 = r
301+
}
302+
249303
return opencdc.StructuredData{
250304
"column1": col1,
251305
"column2": col2,
252306
"column3": col3,
307+
"column4": col4,
253308
"UppercaseColumn1": uppercaseCol1,
254309
}, nil
255310
}

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ require (
1515
github.com/jackc/pgx/v5 v5.7.5
1616
github.com/matryer/is v1.4.1
1717
github.com/rs/zerolog v1.34.0
18+
github.com/shopspring/decimal v1.4.0
1819
gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637
1920
)
2021

@@ -181,7 +182,6 @@ require (
181182
github.com/sashamelentyev/interfacebloat v1.1.0 // indirect
182183
github.com/sashamelentyev/usestdlibvars v1.28.0 // indirect
183184
github.com/securego/gosec/v2 v2.22.2 // indirect
184-
github.com/shopspring/decimal v1.4.0 // indirect
185185
github.com/sirupsen/logrus v1.9.3 // indirect
186186
github.com/sivchari/containedctx v1.0.3 // indirect
187187
github.com/sivchari/tenv v1.12.1 // indirect

0 commit comments

Comments
 (0)