Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 58 additions & 16 deletions destination.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"encoding/json"
"fmt"
"math/big"
"strings"

sq "github.com/Masterminds/squirrel"
Expand All @@ -26,6 +27,7 @@ import (
"github.com/conduitio/conduit-connector-postgres/internal"
sdk "github.com/conduitio/conduit-connector-sdk"
"github.com/jackc/pgx/v5"
"github.com/shopspring/decimal"
)

type Destination struct {
Expand All @@ -35,6 +37,7 @@ type Destination struct {
getTableName destination.TableFn

conn *pgx.Conn
dbInfo *internal.DbInfo
stmtBuilder sq.StatementBuilderType
}

Expand All @@ -61,6 +64,7 @@ func (d *Destination) Open(ctx context.Context) error {
return fmt.Errorf("invalid table name or table name function: %w", err)
}

d.dbInfo = internal.NewDbInfo(conn)
return nil
}

Expand Down Expand Up @@ -156,7 +160,7 @@ func (d *Destination) upsert(ctx context.Context, r opencdc.Record, b *pgx.Batch
return fmt.Errorf("failed to get table name for write: %w", err)
}

query, args, err := d.formatUpsertQuery(key, payload, keyColumnName, tableName)
query, args, err := d.formatUpsertQuery(ctx, key, payload, keyColumnName, tableName)
if err != nil {
return fmt.Errorf("error formatting query: %w", err)
}
Expand Down Expand Up @@ -215,7 +219,11 @@ func (d *Destination) insert(ctx context.Context, r opencdc.Record, b *pgx.Batch
return err
}

colArgs, valArgs := d.formatColumnsAndValues(key, payload)
colArgs, valArgs, err := d.formatColumnsAndValues(ctx, tableName, key, payload)
if err != nil {
return fmt.Errorf("error formatting columns and values: %w", err)
}

sdk.Logger(ctx).Trace().
Str("table_name", tableName).
Msg("inserting record")
Expand Down Expand Up @@ -272,12 +280,7 @@ func (d *Destination) structuredDataFormatter(data opencdc.Data) (opencdc.Struct
// * In our case, we can only rely on the record.Key's parsed key value.
// * If other schema constraints prevent a write, this won't upsert on
// that conflict.
func (d *Destination) formatUpsertQuery(
key opencdc.StructuredData,
payload opencdc.StructuredData,
keyColumnName string,
tableName string,
) (string, []interface{}, error) {
func (d *Destination) formatUpsertQuery(ctx context.Context, key opencdc.StructuredData, payload opencdc.StructuredData, keyColumnName string, tableName string) (string, []interface{}, error) {
upsertQuery := fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", internal.WrapSQLIdent(keyColumnName))
for column := range payload {
// tuples form a comma separated list, so they need a comma at the end.
Expand All @@ -294,10 +297,13 @@ func (d *Destination) formatUpsertQuery(
// remove the last comma from the list of tuples
upsertQuery = strings.TrimSuffix(upsertQuery, ",")

// we have to manually append a semi colon to the upsert sql;
// we have to manually append a semicolon to the upsert sql;
upsertQuery += ";"

colArgs, valArgs := d.formatColumnsAndValues(key, payload)
colArgs, valArgs, err := d.formatColumnsAndValues(ctx, tableName, key, payload)
if err != nil {
return "", nil, fmt.Errorf("error formatting columns and values: %w", err)
}

return d.stmtBuilder.
Insert(internal.WrapSQLIdent(tableName)).
Expand All @@ -309,32 +315,40 @@ func (d *Destination) formatUpsertQuery(

// formatColumnsAndValues turns the key and payload into a slice of ordered
// columns and values for upserting into Postgres.
func (d *Destination) formatColumnsAndValues(key, payload opencdc.StructuredData) ([]string, []interface{}) {
func (d *Destination) formatColumnsAndValues(ctx context.Context, table string, key, payload opencdc.StructuredData) ([]string, []interface{}, error) {
var colArgs []string
var valArgs []interface{}

// range over both the key and payload values in order to format the
// query for args and values in proper order
for key, val := range key {
colArgs = append(colArgs, internal.WrapSQLIdent(key))
valArgs = append(valArgs, val)
formatted, err := d.formatValue(ctx, table, key, val)
if err != nil {
return nil, nil, fmt.Errorf("error formatting value: %w", err)
}
valArgs = append(valArgs, formatted)
delete(payload, key) // NB: Delete Key from payload arguments
}

for field, value := range payload {
for field, val := range payload {
colArgs = append(colArgs, internal.WrapSQLIdent(field))
valArgs = append(valArgs, value)
formatted, err := d.formatValue(ctx, table, field, val)
if err != nil {
return nil, nil, fmt.Errorf("error formatting value: %w", err)
}
valArgs = append(valArgs, formatted)
}

return colArgs, valArgs
return colArgs, valArgs, nil
}

// getKeyColumnName will return the name of the first item in the key or the
// connector-configured default name of the key column name.
func (d *Destination) getKeyColumnName(key opencdc.StructuredData, defaultKeyName string) string {
if len(key) > 1 {
// Go maps aren't order preserving, so anything over len 1 will have
// non deterministic results until we handle composite keys.
// non-deterministic results until we handle composite keys.
panic("composite keys not yet supported")
}
for k := range key {
Expand All @@ -346,3 +360,31 @@ func (d *Destination) getKeyColumnName(key opencdc.StructuredData, defaultKeyNam
func (d *Destination) hasKey(e opencdc.Record) bool {
return e.Key != nil && len(e.Key.Bytes()) > 0
}

func (d *Destination) formatValue(ctx context.Context, table string, column string, val interface{}) (interface{}, error) {
switch v := val.(type) {
case *big.Rat:
return d.formatBigRat(ctx, table, column, v)
case big.Rat:
return d.formatBigRat(ctx, table, column, &v)
default:
return val, nil
}
}

// formatBigRat formats a big.Rat into a string that can be written into a NUMERIC/DECIMAL column.
func (d *Destination) formatBigRat(ctx context.Context, table string, column string, v *big.Rat) (string, error) {
if v == nil {
return "", nil
}

// we need to get the scale of the column so we that we can properly
// round the result of dividing the input big.Rat's numerator and denominator.
scale, err := d.dbInfo.GetNumericColumnScale(ctx, table, column)
if err != nil {
return "", fmt.Errorf("failed getting scale of numeric column: %w", err)
}

//nolint:gosec // no risk of overflow, because the scale in Pg is always <= 16383
return decimal.NewFromBigRat(v, int32(scale)).String(), nil
}
Loading