diff --git a/destination.go b/destination.go index 2732c6f..47d7652 100644 --- a/destination.go +++ b/destination.go @@ -18,6 +18,7 @@ import ( "context" "encoding/json" "fmt" + "math/big" "strings" sq "github.com/Masterminds/squirrel" @@ -26,6 +27,7 @@ import ( "github.com/conduitio/conduit-connector-postgres/internal" sdk "github.com/conduitio/conduit-connector-sdk" "github.com/jackc/pgx/v5" + "github.com/shopspring/decimal" ) type Destination struct { @@ -35,6 +37,7 @@ type Destination struct { getTableName destination.TableFn conn *pgx.Conn + dbInfo *internal.DbInfo stmtBuilder sq.StatementBuilderType } @@ -61,6 +64,7 @@ func (d *Destination) Open(ctx context.Context) error { return fmt.Errorf("invalid table name or table name function: %w", err) } + d.dbInfo = internal.NewDbInfo(conn) return nil } @@ -156,7 +160,7 @@ func (d *Destination) upsert(ctx context.Context, r opencdc.Record, b *pgx.Batch return fmt.Errorf("failed to get table name for write: %w", err) } - query, args, err := d.formatUpsertQuery(key, payload, keyColumnName, tableName) + query, args, err := d.formatUpsertQuery(ctx, key, payload, keyColumnName, tableName) if err != nil { return fmt.Errorf("error formatting query: %w", err) } @@ -215,7 +219,11 @@ func (d *Destination) insert(ctx context.Context, r opencdc.Record, b *pgx.Batch return err } - colArgs, valArgs := d.formatColumnsAndValues(key, payload) + colArgs, valArgs, err := d.formatColumnsAndValues(ctx, tableName, key, payload) + if err != nil { + return fmt.Errorf("error formatting columns and values: %w", err) + } + sdk.Logger(ctx).Trace(). Str("table_name", tableName). Msg("inserting record") @@ -272,12 +280,7 @@ func (d *Destination) structuredDataFormatter(data opencdc.Data) (opencdc.Struct // * In our case, we can only rely on the record.Key's parsed key value. // * If other schema constraints prevent a write, this won't upsert on // that conflict. -func (d *Destination) formatUpsertQuery( - key opencdc.StructuredData, - payload opencdc.StructuredData, - keyColumnName string, - tableName string, -) (string, []interface{}, error) { +func (d *Destination) formatUpsertQuery(ctx context.Context, key opencdc.StructuredData, payload opencdc.StructuredData, keyColumnName string, tableName string) (string, []interface{}, error) { upsertQuery := fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", internal.WrapSQLIdent(keyColumnName)) for column := range payload { // tuples form a comma separated list, so they need a comma at the end. @@ -294,10 +297,13 @@ func (d *Destination) formatUpsertQuery( // remove the last comma from the list of tuples upsertQuery = strings.TrimSuffix(upsertQuery, ",") - // we have to manually append a semi colon to the upsert sql; + // we have to manually append a semicolon to the upsert sql; upsertQuery += ";" - colArgs, valArgs := d.formatColumnsAndValues(key, payload) + colArgs, valArgs, err := d.formatColumnsAndValues(ctx, tableName, key, payload) + if err != nil { + return "", nil, fmt.Errorf("error formatting columns and values: %w", err) + } return d.stmtBuilder. Insert(internal.WrapSQLIdent(tableName)). @@ -309,7 +315,7 @@ func (d *Destination) formatUpsertQuery( // formatColumnsAndValues turns the key and payload into a slice of ordered // columns and values for upserting into Postgres. -func (d *Destination) formatColumnsAndValues(key, payload opencdc.StructuredData) ([]string, []interface{}) { +func (d *Destination) formatColumnsAndValues(ctx context.Context, table string, key, payload opencdc.StructuredData) ([]string, []interface{}, error) { var colArgs []string var valArgs []interface{} @@ -317,16 +323,24 @@ func (d *Destination) formatColumnsAndValues(key, payload opencdc.StructuredData // query for args and values in proper order for key, val := range key { colArgs = append(colArgs, internal.WrapSQLIdent(key)) - valArgs = append(valArgs, val) + formatted, err := d.formatValue(ctx, table, key, val) + if err != nil { + return nil, nil, fmt.Errorf("error formatting value: %w", err) + } + valArgs = append(valArgs, formatted) delete(payload, key) // NB: Delete Key from payload arguments } - for field, value := range payload { + for field, val := range payload { colArgs = append(colArgs, internal.WrapSQLIdent(field)) - valArgs = append(valArgs, value) + formatted, err := d.formatValue(ctx, table, field, val) + if err != nil { + return nil, nil, fmt.Errorf("error formatting value: %w", err) + } + valArgs = append(valArgs, formatted) } - return colArgs, valArgs + return colArgs, valArgs, nil } // getKeyColumnName will return the name of the first item in the key or the @@ -334,7 +348,7 @@ func (d *Destination) formatColumnsAndValues(key, payload opencdc.StructuredData func (d *Destination) getKeyColumnName(key opencdc.StructuredData, defaultKeyName string) string { if len(key) > 1 { // Go maps aren't order preserving, so anything over len 1 will have - // non deterministic results until we handle composite keys. + // non-deterministic results until we handle composite keys. panic("composite keys not yet supported") } for k := range key { @@ -346,3 +360,31 @@ func (d *Destination) getKeyColumnName(key opencdc.StructuredData, defaultKeyNam func (d *Destination) hasKey(e opencdc.Record) bool { return e.Key != nil && len(e.Key.Bytes()) > 0 } + +func (d *Destination) formatValue(ctx context.Context, table string, column string, val interface{}) (interface{}, error) { + switch v := val.(type) { + case *big.Rat: + return d.formatBigRat(ctx, table, column, v) + case big.Rat: + return d.formatBigRat(ctx, table, column, &v) + default: + return val, nil + } +} + +// formatBigRat formats a big.Rat into a string that can be written into a NUMERIC/DECIMAL column. +func (d *Destination) formatBigRat(ctx context.Context, table string, column string, v *big.Rat) (string, error) { + if v == nil { + return "", nil + } + + // we need to get the scale of the column so we that we can properly + // round the result of dividing the input big.Rat's numerator and denominator. + scale, err := d.dbInfo.GetNumericColumnScale(ctx, table, column) + if err != nil { + return "", fmt.Errorf("failed getting scale of numeric column: %w", err) + } + + //nolint:gosec // no risk of overflow, because the scale in Pg is always <= 16383 + return decimal.NewFromBigRat(v, int32(scale)).String(), nil +} diff --git a/destination_integration_test.go b/destination_integration_test.go index 5d42004..f97e7fe 100644 --- a/destination_integration_test.go +++ b/destination_integration_test.go @@ -17,12 +17,14 @@ package postgres import ( "context" "fmt" + "math/big" "strings" "testing" "github.com/conduitio/conduit-commons/opencdc" "github.com/conduitio/conduit-connector-postgres/test" sdk "github.com/conduitio/conduit-connector-sdk" + "github.com/google/go-cmp/cmp" "github.com/jackc/pgx/v5" "github.com/matryer/is" ) @@ -71,11 +73,13 @@ func TestDestination_Write(t *testing.T) { "column1": "foo", "column2": 123, "column3": true, + "column4": nil, "UppercaseColumn1": 222, }, }, }, - }, { + }, + { name: "create", record: opencdc.Record{ Position: opencdc.Position("foo"), @@ -87,11 +91,13 @@ func TestDestination_Write(t *testing.T) { "column1": "foo", "column2": 456, "column3": false, + "column4": nil, "UppercaseColumn1": 333, }, }, }, - }, { + }, + { name: "insert on update (upsert)", record: opencdc.Record{ Position: opencdc.Position("foo"), @@ -103,11 +109,13 @@ func TestDestination_Write(t *testing.T) { "column1": "bar", "column2": 567, "column3": true, + "column4": nil, "UppercaseColumn1": 444, }, }, }, - }, { + }, + { name: "update on conflict", record: opencdc.Record{ Position: opencdc.Position("foo"), @@ -119,11 +127,13 @@ func TestDestination_Write(t *testing.T) { "column1": "foobar", "column2": 567, "column3": true, + "column4": nil, "UppercaseColumn1": 555, }, }, }, - }, { + }, + { name: "delete", record: opencdc.Record{ Position: opencdc.Position("foo"), @@ -132,6 +142,24 @@ func TestDestination_Write(t *testing.T) { Key: opencdc.StructuredData{"id": 4}, }, }, + { + name: "write a big.Rat", + record: opencdc.Record{ + Position: opencdc.Position("foo"), + Operation: opencdc.OperationSnapshot, + Metadata: map[string]string{opencdc.MetadataCollection: tableName}, + Key: opencdc.StructuredData{"id": 123}, + Payload: opencdc.Change{ + After: opencdc.StructuredData{ + "column1": "abcdef", + "column2": 567, + "column3": true, + "column4": big.NewRat(123, 100), + "UppercaseColumn1": 555, + }, + }, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { @@ -146,7 +174,16 @@ func TestDestination_Write(t *testing.T) { switch tt.record.Operation { case opencdc.OperationCreate, opencdc.OperationSnapshot, opencdc.OperationUpdate: is.NoErr(err) - is.Equal(tt.record.Payload.After, got) + is.Equal( + "", + cmp.Diff( + tt.record.Payload.After, + got, + cmp.Comparer(func(x, y *big.Rat) bool { + return x.Cmp(y) == 0 + }), + ), + ) // -want, +got case opencdc.OperationDelete: is.Equal(err, pgx.ErrNoRows) } @@ -179,43 +216,50 @@ func TestDestination_Batch(t *testing.T) { is.NoErr(err) }() - records := []opencdc.Record{{ - Position: opencdc.Position("foo1"), - Operation: opencdc.OperationCreate, - Key: opencdc.StructuredData{"id": 5}, - Payload: opencdc.Change{ - After: opencdc.StructuredData{ - "column1": "foo1", - "column2": 1, - "column3": false, - "UppercaseColumn1": 111, + records := []opencdc.Record{ + { + Position: opencdc.Position("foo1"), + Operation: opencdc.OperationCreate, + Key: opencdc.StructuredData{"id": 5}, + Payload: opencdc.Change{ + After: opencdc.StructuredData{ + "column1": "foo1", + "column2": 1, + "column3": false, + "column4": nil, + "UppercaseColumn1": 111, + }, }, }, - }, { - Position: opencdc.Position("foo2"), - Operation: opencdc.OperationCreate, - Key: opencdc.StructuredData{"id": 6}, - Payload: opencdc.Change{ - After: opencdc.StructuredData{ - "column1": "foo2", - "column2": 2, - "column3": true, - "UppercaseColumn1": 222, + { + Position: opencdc.Position("foo2"), + Operation: opencdc.OperationCreate, + Key: opencdc.StructuredData{"id": 6}, + Payload: opencdc.Change{ + After: opencdc.StructuredData{ + "column1": "foo2", + "column2": 2, + "column3": true, + "column4": nil, + "UppercaseColumn1": 222, + }, }, }, - }, { - Position: opencdc.Position("foo3"), - Operation: opencdc.OperationCreate, - Key: opencdc.StructuredData{"id": 7}, - Payload: opencdc.Change{ - After: opencdc.StructuredData{ - "column1": "foo3", - "column2": 3, - "column3": false, - "UppercaseColumn1": 333, + { + Position: opencdc.Position("foo3"), + Operation: opencdc.OperationCreate, + Key: opencdc.StructuredData{"id": 7}, + Payload: opencdc.Change{ + After: opencdc.StructuredData{ + "column1": "foo3", + "column2": 3, + "column3": false, + "column4": nil, + "UppercaseColumn1": 333, + }, }, }, - }} + } i, err := d.Write(ctx, records) is.NoErr(err) @@ -231,7 +275,7 @@ func TestDestination_Batch(t *testing.T) { func queryTestTable(ctx context.Context, conn test.Querier, tableName string, id any) (opencdc.StructuredData, error) { row := conn.QueryRow( ctx, - fmt.Sprintf(`SELECT column1, column2, column3, "UppercaseColumn1" FROM %q WHERE id = $1`, tableName), + fmt.Sprintf(`SELECT column1, column2, column3, column4, "UppercaseColumn1" FROM %q WHERE id = $1`, tableName), id, ) @@ -239,17 +283,28 @@ func queryTestTable(ctx context.Context, conn test.Querier, tableName string, id col1 string col2 int col3 bool + col4Str *string uppercaseCol1 int ) - err := row.Scan(&col1, &col2, &col3, &uppercaseCol1) + + err := row.Scan(&col1, &col2, &col3, &col4Str, &uppercaseCol1) if err != nil { return nil, err } + // Handle the potential nil case for col4 + var col4 interface{} + if col4Str != nil { + r := new(big.Rat) + r.SetString(*col4Str) + col4 = r + } + return opencdc.StructuredData{ "column1": col1, "column2": col2, "column3": col3, + "column4": col4, "UppercaseColumn1": uppercaseCol1, }, nil } diff --git a/go.mod b/go.mod index d8734c1..70b8177 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/jackc/pgx/v5 v5.7.5 github.com/matryer/is v1.4.1 github.com/rs/zerolog v1.34.0 + github.com/shopspring/decimal v1.4.0 gopkg.in/tomb.v2 v2.0.0-20161208151619-d5d1b5820637 ) @@ -181,7 +182,6 @@ require ( github.com/sashamelentyev/interfacebloat v1.1.0 // indirect github.com/sashamelentyev/usestdlibvars v1.28.0 // indirect github.com/securego/gosec/v2 v2.22.2 // indirect - github.com/shopspring/decimal v1.4.0 // indirect github.com/sirupsen/logrus v1.9.3 // indirect github.com/sivchari/containedctx v1.0.3 // indirect github.com/sivchari/tenv v1.12.1 // indirect diff --git a/internal/db_info.go b/internal/db_info.go new file mode 100644 index 0000000..30394bd --- /dev/null +++ b/internal/db_info.go @@ -0,0 +1,105 @@ +// Copyright © 2025 Meroxa, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "errors" + "fmt" + + "github.com/jackc/pgx/v5" +) + +// DbInfo provides information about tables in a database. +type DbInfo struct { + conn *pgx.Conn + cache map[string]*tableCache +} + +// tableCache stores information about a table. +// The information is cached and refreshed every 'cacheExpiration'. +type tableCache struct { + columns map[string]int +} + +func NewDbInfo(conn *pgx.Conn) *DbInfo { + return &DbInfo{ + conn: conn, + cache: map[string]*tableCache{}, + } +} + +func (d *DbInfo) GetNumericColumnScale(ctx context.Context, table string, column string) (int, error) { + // Check if table exists in cache and is not expired + tableInfo, ok := d.cache[table] + if ok { + scale, ok := tableInfo.columns[column] + if ok { + return scale, nil + } + } else { + // Table info has expired, refresh the cache + d.cache[table] = &tableCache{ + columns: map[string]int{}, + } + } + + // Fetch scale from database + scale, err := d.numericScaleFromDb(ctx, table, column) + if err != nil { + return 0, err + } + + d.cache[table].columns[column] = scale + + return scale, nil +} + +func (d *DbInfo) numericScaleFromDb(ctx context.Context, table string, column string) (int, error) { + // Query to get the column type and numeric scale + query := ` + SELECT + data_type, + numeric_scale + FROM + information_schema.columns + WHERE + table_name = $1 + AND column_name = $2 + ` + + var dataType string + var numericScale *int + + err := d.conn.QueryRow(ctx, query, table, column).Scan(&dataType, &numericScale) + if err != nil { + if errors.Is(err, pgx.ErrNoRows) { + return 0, fmt.Errorf("column %s not found in table %s", column, table) + } + return 0, fmt.Errorf("error querying column info: %w", err) + } + + // Check if the column is of the numeric/decimal type + if dataType != "numeric" && dataType != "decimal" { + return 0, fmt.Errorf("column %s in table %s is not a numeric type (actual type: %s)", column, table, dataType) + } + + // Handle case where numeric_scale is NULL + if numericScale == nil { + return 0, nil // The default scale is 0 when not specified + } + + return *numericScale, nil +} diff --git a/source/logrepl/cdc_test.go b/source/logrepl/cdc_test.go index 8586b0c..487117d 100644 --- a/source/logrepl/cdc_test.go +++ b/source/logrepl/cdc_test.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "math/big" "strings" "testing" "time" @@ -163,8 +164,8 @@ func TestCDCIterator_Operation_NextN(t *testing.T) { "column1": "bizz", "column2": int32(456), "column3": false, - "column4": 12.3, - "column5": int64(14), + "column4": big.NewRat(123, 10), + "column5": big.NewRat(14, 1), "column6": []byte(`{"foo2": "bar2"}`), "column7": []byte(`{"foo2": "baz2"}`), "key": nil, @@ -198,8 +199,8 @@ func TestCDCIterator_Operation_NextN(t *testing.T) { "column1": "test cdc updates", "column2": int32(123), "column3": false, - "column4": 12.2, - "column5": int64(4), + "column4": big.NewRat(122, 10), + "column5": big.NewRat(4, 1), "column6": []byte(`{"foo": "bar"}`), "column7": []byte(`{"foo": "baz"}`), "key": []uint8("1"), @@ -235,8 +236,8 @@ func TestCDCIterator_Operation_NextN(t *testing.T) { "column1": "test cdc updates", "column2": int32(123), "column3": false, - "column4": 12.2, - "column5": int64(4), + "column4": big.NewRat(122, 10), + "column5": big.NewRat(4, 1), "column6": []byte(`{"foo": "bar"}`), "column7": []byte(`{"foo": "baz"}`), "key": []uint8("1"), @@ -247,8 +248,8 @@ func TestCDCIterator_Operation_NextN(t *testing.T) { "column1": "test cdc full updates", "column2": int32(123), "column3": false, - "column4": 12.2, - "column5": int64(4), + "column4": big.NewRat(122, 10), + "column5": big.NewRat(4, 1), "column6": []byte(`{"foo": "bar"}`), "column7": []byte(`{"foo": "baz"}`), "key": []uint8("1"), @@ -323,7 +324,7 @@ func TestCDCIterator_Operation_NextN(t *testing.T) { "column2": int32(789), "column3": false, "column4": nil, - "column5": int64(9), + "column5": big.NewRat(9, 1), "column6": []byte(`{"foo": "bar"}`), "column7": []byte(`{"foo": "baz"}`), "UppercaseColumn1": int32(3), @@ -355,7 +356,14 @@ func TestCDCIterator_Operation_NextN(t *testing.T) { tt.want.Metadata[opencdc.MetadataReadAt] = got.Metadata[opencdc.MetadataReadAt] tt.want.Position = got.Position - is.Equal("", cmp.Diff(tt.want, got, cmpopts.IgnoreUnexported(opencdc.Record{}))) + is.Equal("", cmp.Diff( + tt.want, + got, + cmpopts.IgnoreUnexported(opencdc.Record{}), + cmp.Comparer(func(x, y *big.Rat) bool { + return x.Cmp(y) == 0 + }), + )) is.NoErr(i.Ack(ctx, got.Position)) }) } @@ -521,8 +529,8 @@ func verifyOpenCDCRecords(is *is.I, got []opencdc.Record, tableName string, from "column1": fmt.Sprintf("test-%d", i), "column2": int32(i) * 100, //nolint:gosec // fine, we know the value is small enough "column3": false, - "column4": 12.3, - "column5": int64(14), + "column4": big.NewRat(123, 10), + "column5": big.NewRat(14, 1), "column6": nil, "column7": nil, "UppercaseColumn1": nil, @@ -539,6 +547,9 @@ func verifyOpenCDCRecords(is *is.I, got []opencdc.Record, tableName string, from cmpOpts := []cmp.Option{ cmpopts.IgnoreUnexported(opencdc.Record{}), cmpopts.IgnoreFields(opencdc.Record{}, "Position", "Metadata"), + cmp.Comparer(func(x, y *big.Rat) bool { + return x.Cmp(y) == 0 + }), } is.Equal("", cmp.Diff(want, got, cmpOpts...)) // mismatch (-want +got) } diff --git a/source/logrepl/combined_test.go b/source/logrepl/combined_test.go index 3d2a3bd..d48976d 100644 --- a/source/logrepl/combined_test.go +++ b/source/logrepl/combined_test.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "math/big" "testing" "time" @@ -319,6 +320,9 @@ func TestCombinedIterator_NextN(t *testing.T) { is.Equal("", cmp.Diff( expectedRecords[6], records[0].Payload.After.(opencdc.StructuredData), + cmp.Comparer(func(x, y *big.Rat) bool { + return x.Cmp(y) == 0 + }), )) is.NoErr(i.Ack(ctx, records[0].Position)) @@ -364,8 +368,8 @@ func testRecords() []opencdc.StructuredData { "column1": "foo", "column2": int32(123), "column3": false, - "column4": 12.2, - "column5": int64(4), + "column4": big.NewRat(122, 10), + "column5": big.NewRat(4, 1), "column6": []byte(`{"foo": "bar"}`), "column7": []byte(`{"foo": "baz"}`), "UppercaseColumn1": int32(1), @@ -376,8 +380,8 @@ func testRecords() []opencdc.StructuredData { "column1": "bar", "column2": int32(456), "column3": true, - "column4": 13.42, - "column5": int64(8), + "column4": big.NewRat(1342, 100), // 13.42 + "column5": big.NewRat(8, 1), "column6": []byte(`{"foo": "bar"}`), "column7": []byte(`{"foo": "baz"}`), "UppercaseColumn1": int32(2), @@ -389,7 +393,7 @@ func testRecords() []opencdc.StructuredData { "column2": int32(789), "column3": false, "column4": nil, - "column5": int64(9), + "column5": big.NewRat(9, 1), "column6": []byte(`{"foo": "bar"}`), "column7": []byte(`{"foo": "baz"}`), "UppercaseColumn1": int32(3), @@ -400,7 +404,7 @@ func testRecords() []opencdc.StructuredData { "column1": nil, "column2": nil, "column3": nil, - "column4": 91.1, + "column4": big.NewRat(911, 10), // 91.1 "column5": nil, "column6": nil, "column7": nil, @@ -412,8 +416,8 @@ func testRecords() []opencdc.StructuredData { "column1": "bizz", "column2": int32(1010), "column3": false, - "column4": 872.2, - "column5": int64(101), + "column4": big.NewRat(8722, 10), // 872.2 + "column5": big.NewRat(101, 1), "column6": []byte(`{"foo12": "bar12"}`), "column7": []byte(`{"foo13": "bar13"}`), "UppercaseColumn1": nil, @@ -424,8 +428,8 @@ func testRecords() []opencdc.StructuredData { "column1": "buzz", "column2": int32(10101), "column3": true, - "column4": 121.9, - "column5": int64(51), + "column4": big.NewRat(1219, 10), // 121.9 + "column5": big.NewRat(51, 1), "column6": []byte(`{"foo7": "bar7"}`), "column7": []byte(`{"foo8": "bar8"}`), "UppercaseColumn1": nil, diff --git a/source/logrepl/internal/relationset_test.go b/source/logrepl/internal/relationset_test.go index 614a6e1..c120bfe 100644 --- a/source/logrepl/internal/relationset_test.go +++ b/source/logrepl/internal/relationset_test.go @@ -107,6 +107,7 @@ func setupTableAllTypes(ctx context.Context, t *testing.T, conn test.Querier) st table := test.RandomIdentifier(t) query := ` CREATE TABLE %s ( + id bigserial PRIMARY KEY, col_bit bit(8), col_varbit varbit(8), col_boolean boolean, @@ -260,6 +261,7 @@ func insertRowAllTypes(ctx context.Context, t *testing.T, conn test.Querier, tab func isValuesAllTypes(is *is.I, got map[string]any) { want := map[string]any{ + "id": int64(1), "col_bit": pgtype.Bits{ Bytes: []byte{0b01}, Len: 8, @@ -312,7 +314,7 @@ func isValuesAllTypes(is *is.I, got map[string]any) { "col_macaddr": net.HardwareAddr{0x08, 0x00, 0x2b, 0x01, 0x02, 0x26}, "col_macaddr8": net.HardwareAddr{0x08, 0x00, 0x2b, 0x01, 0x02, 0x03, 0x04, 0x27}, "col_money": "$28.00", - "col_numeric": float64(292929.29), + "col_numeric": big.NewRat(29292929, 100), "col_path": pgtype.Path{ P: []pgtype.Vec2{{X: 30, Y: 31}, {X: 32, Y: 33}, {X: 34, Y: 35}}, Closed: false, @@ -351,11 +353,15 @@ func isValuesAllTypes(is *is.I, got map[string]any) { cmp.Comparer(func(x, y netip.Prefix) bool { return x.String() == y.String() }), + cmp.Comparer(func(x, y *big.Rat) bool { + return x.Cmp(y) == 0 + }), )) } func isValuesAllTypesStandalone(is *is.I, got map[string]any) { want := map[string]any{ + "id": int64(1), "col_bit": pgtype.Bits{ Bytes: []byte{0b01}, Len: 8, @@ -408,7 +414,7 @@ func isValuesAllTypesStandalone(is *is.I, got map[string]any) { "col_macaddr": net.HardwareAddr{0x08, 0x00, 0x2b, 0x01, 0x02, 0x26}, "col_macaddr8": net.HardwareAddr{0x08, 0x00, 0x2b, 0x01, 0x02, 0x03, 0x04, 0x27}, "col_money": "$28.00", - "col_numeric": float64(292929.29), + "col_numeric": big.NewRat(29292929, 100), "col_path": pgtype.Path{ P: []pgtype.Vec2{{X: 30, Y: 31}, {X: 32, Y: 33}, {X: 34, Y: 35}}, Closed: false, @@ -447,5 +453,8 @@ func isValuesAllTypesStandalone(is *is.I, got map[string]any) { cmp.Comparer(func(x, y netip.Prefix) bool { return x.String() == y.String() }), + cmp.Comparer(func(x, y *big.Rat) bool { + return x.Cmp(y) == 0 + }), )) } diff --git a/source/snapshot/fetch_worker_test.go b/source/snapshot/fetch_worker_test.go index 5b7bf9c..d558fa7 100644 --- a/source/snapshot/fetch_worker_test.go +++ b/source/snapshot/fetch_worker_test.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "math/big" "strings" "testing" "time" @@ -277,17 +278,23 @@ func Test_FetcherRun_Initial(t *testing.T) { ) expectedMatch := []opencdc.StructuredData{ - {"id": int64(1), "key": []uint8{49}, "column1": "foo", "column2": int32(123), "column3": false, "column4": 12.2, "column5": int64(4), "column6": value6, "column7": value7, "UppercaseColumn1": int32(1)}, - {"id": int64(2), "key": []uint8{50}, "column1": "bar", "column2": int32(456), "column3": true, "column4": 13.42, "column5": int64(8), "column6": value6, "column7": value7, "UppercaseColumn1": int32(2)}, - {"id": int64(3), "key": []uint8{51}, "column1": "baz", "column2": int32(789), "column3": false, "column4": nil, "column5": int64(9), "column6": value6, "column7": value7, "UppercaseColumn1": int32(3)}, - {"id": int64(4), "key": []uint8{52}, "column1": nil, "column2": nil, "column3": nil, "column4": 91.1, "column5": nil, "column6": nil, "column7": nil, "UppercaseColumn1": nil}, + {"id": int64(1), "key": []uint8{49}, "column1": "foo", "column2": int32(123), "column3": false, "column4": big.NewRat(122, 10), "column5": big.NewRat(4, 1), "column6": value6, "column7": value7, "UppercaseColumn1": int32(1)}, + {"id": int64(2), "key": []uint8{50}, "column1": "bar", "column2": int32(456), "column3": true, "column4": big.NewRat(1342, 100), "column5": big.NewRat(8, 1), "column6": value6, "column7": value7, "UppercaseColumn1": int32(2)}, + {"id": int64(3), "key": []uint8{51}, "column1": "baz", "column2": int32(789), "column3": false, "column4": nil, "column5": big.NewRat(9, 1), "column6": value6, "column7": value7, "UppercaseColumn1": int32(3)}, + {"id": int64(4), "key": []uint8{52}, "column1": nil, "column2": nil, "column3": nil, "column4": big.NewRat(911, 10), "column5": nil, "column6": nil, "column7": nil, "UppercaseColumn1": nil}, } for i, got := range gotFetchData { t.Run(fmt.Sprintf("payload_%d", i+1), func(t *testing.T) { is := is.New(t) is.Equal(got.Key, opencdc.StructuredData{"id": int64(i + 1)}) - is.Equal("", cmp.Diff(expectedMatch[i], got.Payload)) + is.Equal("", cmp.Diff( + expectedMatch[i], + got.Payload, + cmp.Comparer(func(x, y *big.Rat) bool { + return x.Cmp(y) == 0 + }), + )) is.Equal(got.Position, position.SnapshotPosition{ LastRead: int64(i + 1), @@ -342,18 +349,27 @@ func Test_FetcherRun_Resume(t *testing.T) { // validate generated record is.Equal(dd[0].Key, opencdc.StructuredData{"id": int64(3)}) - is.Equal("", cmp.Diff(dd[0].Payload, opencdc.StructuredData{ - "id": int64(3), - "key": []uint8{51}, - "column1": "baz", - "column2": int32(789), - "column3": false, - "column4": nil, - "column5": int64(9), - "column6": []byte(`{"foo": "bar"}`), - "column7": []byte(`{"foo": "baz"}`), - "UppercaseColumn1": int32(3), - })) + is.Equal( + "", + cmp.Diff( + dd[0].Payload, + opencdc.StructuredData{ + "id": int64(3), + "key": []uint8{51}, + "column1": "baz", + "column2": int32(789), + "column3": false, + "column4": nil, + "column5": big.NewRat(9, 1), + "column6": []byte(`{"foo": "bar"}`), + "column7": []byte(`{"foo": "baz"}`), + "UppercaseColumn1": int32(3), + }, + cmp.Comparer(func(x, y *big.Rat) bool { + return x.Cmp(y) == 0 + }), + ), + ) is.Equal(dd[0].Position, position.SnapshotPosition{ LastRead: 3, diff --git a/source/types/numeric.go b/source/types/numeric.go index c7dfbf0..0505090 100644 --- a/source/types/numeric.go +++ b/source/types/numeric.go @@ -15,42 +15,23 @@ package types import ( + "math/big" + "github.com/jackc/pgx/v5/pgtype" ) type NumericFormatter struct{} -// Format coerces `pgtype.Numeric` to int or double depending on the exponent. -// Returns error when value is invalid. -func (NumericFormatter) Format(num pgtype.Numeric) (any, error) { - // N.B. The numeric type in pgx is represented by two ints. - // When the type in Postgres is defined as `NUMERIC(10)' the scale is assumed to be 0. - // However, pgx may represent the number as two ints e.g. 1200 -> (int=12,exp=2) = 12*10^2. as well - // as a type with zero exponent, e.g. 121 -> (int=121,exp=0). - // Thus, a Numeric type with positive or zero exponent is assumed to be an integer. - if num.Exp >= 0 { - i8v, err := num.Int64Value() - if err != nil { - return nil, err - } - - v, err := i8v.Value() - if err != nil { - return nil, err - } - - return v, nil +// BigRatFromNumeric converts a pgtype.Numeric to a big.Rat. +func (NumericFormatter) BigRatFromNumeric(num pgtype.Numeric) (*big.Rat, error) { + if num.Int == nil { + return nil, nil } - - f8v, err := num.Float64Value() + v := new(big.Rat) + driverVal, err := num.Value() if err != nil { return nil, err } - - v, err := f8v.Value() - if err != nil { - return nil, err - } - + v.SetString(driverVal.(string)) return v, nil } diff --git a/source/types/types.go b/source/types/types.go index 5f826fa..87b1c67 100644 --- a/source/types/types.go +++ b/source/types/types.go @@ -30,9 +30,9 @@ func Format(oid uint32, v any) (any, error) { switch t := v.(type) { case pgtype.Numeric: - return Numeric.Format(t) + return Numeric.BigRatFromNumeric(t) case *pgtype.Numeric: - return Numeric.Format(*t) + return Numeric.BigRatFromNumeric(*t) case []uint8: if oid == pgtype.XMLOID { return string(t), nil diff --git a/source/types/types_test.go b/source/types/types_test.go index 3c42d2f..fb1dea7 100644 --- a/source/types/types_test.go +++ b/source/types/types_test.go @@ -15,6 +15,7 @@ package types import ( + "math/big" "testing" "time" @@ -47,13 +48,13 @@ func Test_Format(t *testing.T) { { name: "pgtype.Numeric", input: []any{ - pgxNumeric(t, "12.2121"), pgxNumeric(t, "101"), &pgtype.Numeric{}, nil, + pgxNumeric(t, "12.2121"), pgxNumeric(t, "101"), pgxNumeric(t, "0"), &pgtype.Numeric{}, nil, }, inputOID: []uint32{ - 0, 0, 0, 0, + 0, 0, 0, 0, 0, }, expect: []any{ - float64(12.2121), int64(101), nil, nil, + big.NewRat(122121, 10000), big.NewRat(101, 1), big.NewRat(0, 1), nil, nil, }, }, { diff --git a/source_integration_test.go b/source_integration_test.go index ac044cb..527078b 100644 --- a/source_integration_test.go +++ b/source_integration_test.go @@ -16,6 +16,7 @@ package postgres import ( "context" + "fmt" "strings" "testing" @@ -111,3 +112,127 @@ func TestSource_ParseConfig(t *testing.T) { }) } } + +func TestSource_Read(t *testing.T) { + ctx := test.Context(t) + is := is.New(t) + + conn := test.ConnectSimple(ctx, t, test.RegularConnString) + table := setupSourceTable(ctx, t, conn) + insertSourceRow(ctx, t, conn, table) + + s := NewSource() + err := sdk.Util.ParseConfig( + ctx, + map[string]string{ + "url": test.RepmgrConnString, + "tables": table, + "snapshotMode": "initial", + "cdcMode": "logrepl", + }, + s.Config(), + Connector.NewSpecification().SourceParams, + ) + is.NoErr(err) + + err = s.Open(ctx, nil) + is.NoErr(err) + + recs, err := s.ReadN(ctx, 1) + is.NoErr(err) + + fmt.Println(recs) +} + +// setupSourceTable creates a new table with all types and returns its name. +func setupSourceTable(ctx context.Context, t *testing.T, conn test.Querier) string { + is := is.New(t) + table := test.RandomIdentifier(t) + // todo still need to support: + // bit, varbit, box, char(n), cidr, circle, inet, interval, line, lseg, + // macaddr, macaddr8, money, path, pg_lsn, pg_snapshot, point, polygon, + // time, timetz, tsquery, tsvector, xml + query := ` + CREATE TABLE %s ( + id bigserial PRIMARY KEY, + col_boolean boolean, + col_bytea bytea, + col_varchar varchar(10), + col_date date, + col_float4 float4, + col_float8 float8, + col_int2 int2, + col_int4 int4, + col_int8 int8, + col_json json, + col_jsonb jsonb, + col_numeric numeric(8,2), + col_serial2 serial2, + col_serial4 serial4, + col_serial8 serial8, + col_text text, + col_timestamp timestamp, + col_timestamptz timestamptz, + col_uuid uuid + )` + query = fmt.Sprintf(query, table) + _, err := conn.Exec(ctx, query) + is.NoErr(err) + + t.Cleanup(func() { + query := `DROP TABLE %s` + query = fmt.Sprintf(query, table) + _, err := conn.Exec(context.Background(), query) + is.NoErr(err) + }) + return table +} + +func insertSourceRow(ctx context.Context, t *testing.T, conn test.Querier, table string) { + is := is.New(t) + query := ` + INSERT INTO %s ( + col_boolean, + col_bytea, + col_varchar, + col_date, + col_float4, + col_float8, + col_int2, + col_int4, + col_int8, + col_json, + col_jsonb, + col_numeric, + col_serial2, + col_serial4, + col_serial8, + col_text, + col_timestamp, + col_timestamptz, + col_uuid + ) VALUES ( + true, -- col_boolean + '\x07', -- col_bytea + '9', -- col_varchar + '2022-03-14', -- col_date + 15, -- col_float4 + 16.16, -- col_float8 + 32767, -- col_int2 + 2147483647, -- col_int4 + 9223372036854775807, -- col_int8 + '{"foo": "bar"}', -- col_json + '{"foo": "baz"}', -- col_jsonb + '292929.29', -- col_numeric + 32767, -- col_serial2 + 2147483647, -- col_serial4 + 9223372036854775807, -- col_serial8 + 'foo bar baz', -- col_text + '2022-03-14 15:16:17', -- col_timestamp + '2022-03-14 15:16:17-08', -- col_timestamptz + 'bd94ee0b-564f-4088-bf4e-8d5e626caf66' -- col_uuid + )` + query = fmt.Sprintf(query, table) + _, err := conn.Exec(ctx, query) + is.NoErr(err) +} diff --git a/test/docker-compose.yml b/test/docker-compose.yml index 8b6a1fc..16344bc 100644 --- a/test/docker-compose.yml +++ b/test/docker-compose.yml @@ -1,6 +1,6 @@ services: pg-0: - image: docker.io/bitnami/postgresql-repmgr:16 + image: docker.io/bitnami/postgresql-repmgr:17.5.0 ports: - "5433:5432" volumes: