diff --git a/README.md b/README.md index fbe2513..808fd4a 100644 --- a/README.md +++ b/README.md @@ -215,11 +215,6 @@ pipelines: # Type: string # Required: yes url: "" - # Key represents the column name for the key used to identify and - # update existing rows. - # Type: string - # Required: no - key: "" # Table is used as the target table into which records are inserted. # Type: string # Required: no diff --git a/connector.yaml b/connector.yaml index 8e62de3..1f78668 100644 --- a/connector.yaml +++ b/connector.yaml @@ -230,11 +230,6 @@ specification: validations: - type: required value: "" - - name: key - description: Key represents the column name for the key used to identify and update existing rows. - type: string - default: "" - validations: [] - name: table description: Table is used as the target table into which records are inserted. type: string diff --git a/destination.go b/destination.go index 47d7652..0e6e4a3 100644 --- a/destination.go +++ b/destination.go @@ -18,7 +18,9 @@ import ( "context" "encoding/json" "fmt" + "maps" "math/big" + "slices" "strings" sq "github.com/Masterminds/squirrel" @@ -74,6 +76,10 @@ func (d *Destination) Write(ctx context.Context, recs []opencdc.Record) (int, er b := &pgx.Batch{} for _, rec := range recs { var err error + rec, err = d.ensureStructuredData(rec) + if err != nil { + return 0, fmt.Errorf("failed to clean record: %w", err) + } switch rec.Operation { case opencdc.OperationCreate: err = d.handleInsert(ctx, rec, b) @@ -117,9 +123,6 @@ func (d *Destination) Teardown(ctx context.Context) error { // exists and no key column name is configured, it will plainly insert the data. // Otherwise it upserts the record. func (d *Destination) handleInsert(ctx context.Context, r opencdc.Record, b *pgx.Batch) error { - if !d.hasKey(r) || d.config.Key == "" { - return d.insert(ctx, r, b) - } return d.upsert(ctx, r, b) } @@ -143,30 +146,21 @@ func (d *Destination) handleDelete(ctx context.Context, r opencdc.Record, b *pgx } func (d *Destination) upsert(ctx context.Context, r opencdc.Record, b *pgx.Batch) error { - payload, err := d.getPayload(r) - if err != nil { - return fmt.Errorf("failed to get payload: %w", err) - } - - key, err := d.getKey(r) - if err != nil { - return fmt.Errorf("failed to get key: %w", err) - } - - keyColumnName := d.getKeyColumnName(key, d.config.Key) - + payload := r.Payload.After.(opencdc.StructuredData) + key := r.Key.(opencdc.StructuredData) tableName, err := d.getTableName(r) if err != nil { - return fmt.Errorf("failed to get table name for write: %w", err) + return fmt.Errorf("failed to get table name for upsert: %w", err) } - query, args, err := d.formatUpsertQuery(ctx, key, payload, keyColumnName, tableName) + query, args, err := d.formatUpsertQuery(ctx, key, payload, tableName) if err != nil { return fmt.Errorf("error formatting query: %w", err) } sdk.Logger(ctx).Trace(). - Str("table_name", tableName). - Any("key", map[string]interface{}{keyColumnName: key[keyColumnName]}). + Str("table", tableName). + Str("query", query). + Any("key", key). Msg("upserting record") b.Queue(query, args...) @@ -174,148 +168,92 @@ func (d *Destination) upsert(ctx context.Context, r opencdc.Record, b *pgx.Batch } func (d *Destination) remove(ctx context.Context, r opencdc.Record, b *pgx.Batch) error { - key, err := d.getKey(r) - if err != nil { - return err - } - keyColumnName := d.getKeyColumnName(key, d.config.Key) + key := r.Key.(opencdc.StructuredData) tableName, err := d.getTableName(r) if err != nil { - return fmt.Errorf("failed to get table name for write: %w", err) + return fmt.Errorf("failed to get table name for delete: %w", err) + } + + where := make(sq.Eq) + for col, val := range key { + where[internal.WrapSQLIdent(col)] = val } - sdk.Logger(ctx).Trace(). - Str("table_name", tableName). - Any("key", map[string]interface{}{keyColumnName: key[keyColumnName]}). - Msg("deleting record") query, args, err := d.stmtBuilder. Delete(internal.WrapSQLIdent(tableName)). - Where(sq.Eq{internal.WrapSQLIdent(keyColumnName): key[keyColumnName]}). + Where(where). ToSql() if err != nil { return fmt.Errorf("error formatting delete query: %w", err) } - b.Queue(query, args...) - return nil -} - -// insert is an append-only operation that doesn't care about keys, but -// can error on constraints violations so should only be used when no table -// key or unique constraints are otherwise present. -func (d *Destination) insert(ctx context.Context, r opencdc.Record, b *pgx.Batch) error { - tableName, err := d.getTableName(r) - if err != nil { - return err - } - - key, err := d.getKey(r) - if err != nil { - return err - } - - payload, err := d.getPayload(r) - if err != nil { - return err - } - - colArgs, valArgs, err := d.formatColumnsAndValues(ctx, tableName, key, payload) - if err != nil { - return fmt.Errorf("error formatting columns and values: %w", err) - } - sdk.Logger(ctx).Trace(). - Str("table_name", tableName). - Msg("inserting record") - query, args, err := d.stmtBuilder. - Insert(internal.WrapSQLIdent(tableName)). - Columns(colArgs...). - Values(valArgs...). - ToSql() - if err != nil { - return fmt.Errorf("error formatting insert query: %w", err) - } + Str("table", tableName). + Str("query", query). + Any("key", key). + Msg("deleting record") b.Queue(query, args...) return nil } -func (d *Destination) getPayload(r opencdc.Record) (opencdc.StructuredData, error) { - if r.Payload.After == nil { - return opencdc.StructuredData{}, nil - } - return d.structuredDataFormatter(r.Payload.After) -} - -func (d *Destination) getKey(r opencdc.Record) (opencdc.StructuredData, error) { - if r.Key == nil { - return opencdc.StructuredData{}, nil - } - return d.structuredDataFormatter(r.Key) -} - -func (d *Destination) structuredDataFormatter(data opencdc.Data) (opencdc.StructuredData, error) { - if data == nil { - return opencdc.StructuredData{}, nil - } - if sdata, ok := data.(opencdc.StructuredData); ok { - return sdata, nil - } - raw := data.Bytes() - if len(raw) == 0 { - return opencdc.StructuredData{}, nil - } - - m := make(map[string]interface{}) - err := json.Unmarshal(raw, &m) - if err != nil { - return nil, err - } - return m, nil -} - // formatUpsertQuery manually formats the UPSERT and ON CONFLICT query statements. // The `ON CONFLICT` portion of this query needs to specify the constraint // name. // * In our case, we can only rely on the record.Key's parsed key value. // * If other schema constraints prevent a write, this won't upsert on // that conflict. -func (d *Destination) formatUpsertQuery(ctx context.Context, key opencdc.StructuredData, payload opencdc.StructuredData, keyColumnName string, tableName string) (string, []interface{}, error) { - upsertQuery := fmt.Sprintf("ON CONFLICT (%s) DO UPDATE SET", internal.WrapSQLIdent(keyColumnName)) - for column := range payload { - // tuples form a comma separated list, so they need a comma at the end. - // `EXCLUDED` references the new record's values. This will overwrite - // every column's value except for the key column. - wrappedCol := internal.WrapSQLIdent(column) - tuple := fmt.Sprintf("%s=EXCLUDED.%s,", wrappedCol, wrappedCol) - // TODO: Consider removing this space. - upsertQuery += " " - // add the tuple to the query string - upsertQuery += tuple - } - - // remove the last comma from the list of tuples - upsertQuery = strings.TrimSuffix(upsertQuery, ",") - - // we have to manually append a semicolon to the upsert sql; - upsertQuery += ";" - - colArgs, valArgs, err := d.formatColumnsAndValues(ctx, tableName, key, payload) +func (d *Destination) formatUpsertQuery( + ctx context.Context, + key, payload opencdc.StructuredData, + tableName string, +) (string, []interface{}, error) { + colArgs, valArgs, err := d.formatColumnsAndValues(ctx, key, payload, tableName) if err != nil { return "", nil, fmt.Errorf("error formatting columns and values: %w", err) } - return d.stmtBuilder. + stmt := d.stmtBuilder. Insert(internal.WrapSQLIdent(tableName)). Columns(colArgs...). - Values(valArgs...). - SuffixExpr(sq.Expr(upsertQuery)). - ToSql() + Values(valArgs...) + + if len(key) > 0 { + keyColumns := slices.Collect(maps.Keys(key)) + for i := range keyColumns { + keyColumns[i] = internal.WrapSQLIdent(keyColumns[i]) + } + + var setOnConflict []string + for column := range payload { + // tuples form a comma separated list, so they need a comma at the end. + // `EXCLUDED` references the new record's values. This will overwrite + // every column's value except for the key columns. + wrappedCol := internal.WrapSQLIdent(column) + tuple := fmt.Sprintf("%s=EXCLUDED.%s", wrappedCol, wrappedCol) + // add the tuple to the query string + setOnConflict = append(setOnConflict, tuple) + } + + upsertQuery := fmt.Sprintf( + "ON CONFLICT (%s) DO UPDATE SET %s", + strings.Join(keyColumns, ","), + strings.Join(setOnConflict, ","), + ) + + stmt = stmt.Suffix(upsertQuery) + } + + return stmt.ToSql() } // formatColumnsAndValues turns the key and payload into a slice of ordered // columns and values for upserting into Postgres. -func (d *Destination) formatColumnsAndValues(ctx context.Context, table string, key, payload opencdc.StructuredData) ([]string, []interface{}, error) { +func (d *Destination) formatColumnsAndValues( + ctx context.Context, + key, payload opencdc.StructuredData, + table string, +) ([]string, []interface{}, error) { var colArgs []string var valArgs []interface{} @@ -343,22 +281,51 @@ func (d *Destination) formatColumnsAndValues(ctx context.Context, table string, return colArgs, valArgs, nil } -// getKeyColumnName will return the name of the first item in the key or the -// connector-configured default name of the key column name. -func (d *Destination) getKeyColumnName(key opencdc.StructuredData, defaultKeyName string) string { - if len(key) > 1 { - // Go maps aren't order preserving, so anything over len 1 will have - // non-deterministic results until we handle composite keys. - panic("composite keys not yet supported") +func (d *Destination) hasKey(e opencdc.Record) bool { + structuredKey, ok := e.Key.(opencdc.StructuredData) + if !ok { + return false } - for k := range key { - return k + return len(structuredKey) > 0 +} + +// ensureStructuredData makes sure the record key and payload are structured data. +func (d *Destination) ensureStructuredData(r opencdc.Record) (opencdc.Record, error) { + payloadAfter, err := d.structuredDataFormatter(r.Payload.After) + if err != nil { + return opencdc.Record{}, fmt.Errorf("failed to get structured data for .Payload.After: %w", err) } - return defaultKeyName + key, err := d.structuredDataFormatter(r.Key) + if err != nil { + return opencdc.Record{}, fmt.Errorf("failed to get structured data for .Key: %w", err) + } + + r.Key = key + r.Payload.After = payloadAfter + return r, nil } -func (d *Destination) hasKey(e opencdc.Record) bool { - return e.Key != nil && len(e.Key.Bytes()) > 0 +func (d *Destination) structuredDataFormatter(data opencdc.Data) (opencdc.StructuredData, error) { + switch data := data.(type) { + case opencdc.StructuredData: + // already structured data, no need to convert + return data, nil + case opencdc.RawData: + raw := data.Bytes() + if len(raw) == 0 { + return opencdc.StructuredData{}, nil + } + m := make(map[string]interface{}) + err := json.Unmarshal(raw, &m) + if err != nil { + return nil, fmt.Errorf("failed to JSON unmarshal raw data: %w", err) + } + return m, nil + case nil: + return opencdc.StructuredData{}, nil + default: + return nil, fmt.Errorf("unexpected data type %T, expected StructuredData or RawData", data) + } } func (d *Destination) formatValue(ctx context.Context, table string, column string, val interface{}) (interface{}, error) { diff --git a/destination/config.go b/destination/config.go index 0569ab2..e33072a 100644 --- a/destination/config.go +++ b/destination/config.go @@ -36,8 +36,6 @@ type Config struct { URL string `json:"url" validate:"required"` // Table is used as the target table into which records are inserted. Table string `json:"table" default:"{{ index .Metadata \"opencdc.collection\" }}"` - // Key represents the column name for the key used to identify and update existing rows. - Key string `json:"key"` } func (c *Config) Validate(ctx context.Context) error { diff --git a/source/logrepl/handler.go b/source/logrepl/handler.go index 77ebfa2..b4039db 100644 --- a/source/logrepl/handler.go +++ b/source/logrepl/handler.go @@ -16,7 +16,6 @@ package logrepl import ( "context" - "errors" "fmt" "sync" "time" @@ -83,8 +82,13 @@ func (h *CDCHandler) scheduleFlushing(ctx context.Context) { ticker := time.NewTicker(h.flushInterval) defer ticker.Stop() - for range time.Tick(h.flushInterval) { - h.flush(ctx) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + h.flush(ctx) + } } } @@ -96,20 +100,18 @@ func (h *CDCHandler) flush(ctx context.Context) { return } - if errors.Is(ctx.Err(), context.Canceled) { - close(h.out) + select { + case <-ctx.Done(): sdk.Logger(ctx).Warn(). Err(ctx.Err()). Int("records", len(h.recordBatch)). Msg("CDCHandler flushing records cancelled") - return + case h.out <- h.recordBatch: + sdk.Logger(ctx).Debug(). + Int("records", len(h.recordBatch)). + Msg("CDCHandler sending batch of records") + h.recordBatch = make([]opencdc.Record, 0, h.batchSize) } - - h.out <- h.recordBatch - sdk.Logger(ctx).Debug(). - Int("records", len(h.recordBatch)). - Msg("CDCHandler sending batch of records") - h.recordBatch = make([]opencdc.Record, 0, h.batchSize) } // Handle is the handler function that receives all logical replication messages. diff --git a/source/logrepl/handler_test.go b/source/logrepl/handler_test.go index 5fb0f31..5956377 100644 --- a/source/logrepl/handler_test.go +++ b/source/logrepl/handler_test.go @@ -79,8 +79,10 @@ func TestHandler_Batching_ContextCancelled(t *testing.T) { <-ctx.Done() underTest.addToBatch(ctx, newTestRecord(0)) - _, recordReceived := <-ch - is.True(!recordReceived) + recs, gotRecs, err := cchan.ChanOut[[]opencdc.Record](ch).RecvTimeout(context.Background(), time.Second) + is.Equal(recs, nil) + is.True(!gotRecs) + is.Equal(err, context.DeadlineExceeded) } func newTestRecord(id int) opencdc.Record {