Skip to content

Commit 8e31f94

Browse files
authored
pg qrep: don't assume all columns nullable (#3515)
following along from #3504, that PR likely unnecessary since I think hamba infers as much
1 parent adf31a3 commit 8e31f94

File tree

4 files changed

+60
-30
lines changed

4 files changed

+60
-30
lines changed

flow/connectors/postgres/qrep_query_executor.go

Lines changed: 47 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import (
44
"context"
55
"fmt"
66
"log/slog"
7+
"maps"
8+
"slices"
79

810
"github.com/jackc/pgx/v5"
911
"github.com/jackc/pgx/v5/pgconn"
@@ -58,43 +60,68 @@ func (qe *QRepQueryExecutor) ExecuteQuery(ctx context.Context, query string, arg
5860
return rows, nil
5961
}
6062

61-
func (qe *QRepQueryExecutor) executeQueryInTx(ctx context.Context, tx pgx.Tx, cursorName string, fetchSize int) (pgx.Rows, error) {
62-
qe.logger.Info("Executing query in transaction")
63-
q := fmt.Sprintf("FETCH %d FROM %s", fetchSize, cursorName)
63+
// FieldDescriptionsToSchema converts a slice of pgconn.FieldDescription to a QRecordSchema.
64+
func (qe *QRepQueryExecutor) cursorToSchema(
65+
ctx context.Context,
66+
tx pgx.Tx,
67+
cursorName string,
68+
) (types.QRecordSchema, error) {
69+
type attId struct {
70+
relid uint32
71+
num uint16
72+
}
6473

65-
rows, err := tx.Query(ctx, q)
74+
rows, err := tx.Query(ctx, "FETCH 0 FROM "+cursorName)
6675
if err != nil {
67-
qe.logger.Error("[pg_query_executor] failed to execute query in tx", slog.Any("error", err))
68-
return nil, err
76+
return types.QRecordSchema{}, fmt.Errorf("failed to fetch 0 for field descriptions: %w", err)
6977
}
70-
71-
return rows, nil
72-
}
73-
74-
// FieldDescriptionsToSchema converts a slice of pgconn.FieldDescription to a QRecordSchema.
75-
func (qe *QRepQueryExecutor) fieldDescriptionsToSchema(fds []pgconn.FieldDescription) types.QRecordSchema {
78+
fds := rows.FieldDescriptions()
79+
tableOIDset := make(map[uint32]struct{})
80+
nullPointers := make(map[attId]*bool, len(fds))
7681
qfields := make([]types.QField, len(fds))
7782
for i, fd := range fds {
83+
tableOIDset[fd.TableOID] = struct{}{}
7884
ctype := qe.postgresOIDToQValueKind(fd.DataTypeOID, qe.customTypeMapping, qe.version)
79-
// there isn't a way to know if a column is nullable or not
8085
if ctype == types.QValueKindNumeric || ctype == types.QValueKindArrayNumeric {
8186
precision, scale := datatypes.ParseNumericTypmod(fd.TypeModifier)
8287
qfields[i] = types.QField{
8388
Name: fd.Name,
8489
Type: ctype,
85-
Nullable: true,
90+
Nullable: false,
8691
Precision: precision,
8792
Scale: scale,
8893
}
8994
} else {
9095
qfields[i] = types.QField{
9196
Name: fd.Name,
9297
Type: ctype,
93-
Nullable: true,
98+
Nullable: false,
9499
}
95100
}
101+
nullPointers[attId{
102+
relid: fd.TableOID,
103+
num: fd.TableAttributeNumber,
104+
}] = &qfields[i].Nullable
96105
}
97-
return types.NewQRecordSchema(qfields)
106+
rows.Close()
107+
tableOIDs := slices.Collect(maps.Keys(tableOIDset))
108+
109+
rows, err = tx.Query(ctx, "SELECT a.attrelid,a.attnum FROM pg_attribute a WHERE a.attrelid = ANY($1) AND NOT a.attnotnull", tableOIDs)
110+
if err != nil {
111+
return types.QRecordSchema{}, fmt.Errorf("failed to query schema for field descriptions: %w", err)
112+
}
113+
114+
var att attId
115+
if _, err := pgx.ForEachRow(rows, []any{&att.relid, &att.num}, func() error {
116+
if nullPointer, ok := nullPointers[att]; ok {
117+
*nullPointer = true
118+
}
119+
return nil
120+
}); err != nil {
121+
return types.QRecordSchema{}, fmt.Errorf("failed to process schema for field descriptions: %w", err)
122+
}
123+
124+
return types.NewQRecordSchema(qfields), nil
98125
}
99126

100127
func (qe *QRepQueryExecutor) processRowsStream(
@@ -152,20 +179,17 @@ func (qe *QRepQueryExecutor) processFetchedRows(
152179
fetchSize int,
153180
stream *model.QRecordStream,
154181
) (int64, int64, error) {
155-
rows, err := qe.executeQueryInTx(ctx, tx, cursorName, fetchSize)
182+
qe.logger.Info("[pg_query_executor] fetching from cursor", slog.String("cursor", cursorName))
183+
184+
rows, err := tx.Query(ctx, fmt.Sprintf("FETCH %d FROM %s", fetchSize, cursorName))
156185
if err != nil {
157-
qe.logger.Error("[pg_query_executor] failed to execute query in tx",
186+
qe.logger.Error("[pg_query_executor] failed to fetch cursor in tx",
158187
slog.Any("error", err), slog.String("query", query))
159188
return 0, 0, fmt.Errorf("[pg_query_executor] failed to execute query in tx: %w", err)
160189
}
161190
defer rows.Close()
162191

163192
fieldDescriptions := rows.FieldDescriptions()
164-
if !stream.IsSchemaSet() {
165-
schema := qe.fieldDescriptionsToSchema(fieldDescriptions)
166-
stream.SetSchema(schema)
167-
}
168-
169193
numRows, numBytes, err := qe.processRowsStream(ctx, cursorName, stream, rows, fieldDescriptions)
170194
if err != nil {
171195
qe.logger.Error("[pg_query_executor] failed to process rows", slog.Any("error", err))

flow/connectors/postgres/sink_q.go

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ func (stream RecordStreamSink) ExecuteQueryWithTx(
4545
randomUint := rand.Uint64()
4646

4747
cursorName := fmt.Sprintf("peerdb_cursor_%d", randomUint)
48-
fetchSize := shared.FetchAndChannelSize
4948
cursorQuery := fmt.Sprintf("DECLARE %s CURSOR FOR %s", cursorName, query)
5049

5150
if _, err := tx.Exec(ctx, cursorQuery, args...); err != nil {
@@ -60,10 +59,18 @@ func (stream RecordStreamSink) ExecuteQueryWithTx(
6059
slog.String("query", query),
6160
slog.Int("channelLen", len(stream.Records)))
6261

62+
if !stream.IsSchemaSet() {
63+
schema, err := qe.cursorToSchema(ctx, tx, cursorName)
64+
if err != nil {
65+
return 0, 0, err
66+
}
67+
stream.SetSchema(schema)
68+
}
69+
6370
var totalNumRows int64
6471
var totalNumBytes int64
6572
for {
66-
numRows, numBytes, err := qe.processFetchedRows(ctx, query, tx, cursorName, fetchSize, stream.QRecordStream)
73+
numRows, numBytes, err := qe.processFetchedRows(ctx, query, tx, cursorName, shared.FetchAndChannelSize, stream.QRecordStream)
6774
if err != nil {
6875
qe.logger.Error("[pg_query_executor] failed to process fetched rows", slog.Any("error", err))
6976
return totalNumRows, totalNumBytes, err

flow/connectors/snowflake/snowflake.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ const (
3535
_PEERDB_RECORD_TYPE INTEGER NOT NULL, _PEERDB_MATCH_DATA STRING,_PEERDB_BATCH_ID INT,
3636
_PEERDB_UNCHANGED_TOAST_COLUMNS STRING)`
3737
createDummyTableSQL = "CREATE TABLE IF NOT EXISTS %s.%s(_PEERDB_DUMMY_COL STRING)"
38-
rawTableMultiValueInsertSQL = "INSERT INTO %s.%s VALUES%s"
3938
createNormalizedTableSQL = "CREATE TABLE IF NOT EXISTS %s(%s)"
4039
createOrReplaceNormalizedTableSQL = "CREATE OR REPLACE TABLE %s(%s)"
4140
toVariantColumnName = "VAR_COLS"

flow/e2e/snowflake_test.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ func (s PeerFlowE2ETestSuiteSF) Test_Invalid_Numeric() {
9090
_, err := s.Conn().Exec(s.t.Context(), fmt.Sprintf(`
9191
CREATE TABLE IF NOT EXISTS %s (
9292
id INT PRIMARY KEY,
93-
num1 NUMERIC(100, 50) NOT NULL,
94-
num2 NUMERIC(100, 50) NOT NULL
93+
num1 NUMERIC(100, 50),
94+
num2 NUMERIC(100, 50)
9595
);
9696
`, srcTableName))
9797
require.NoError(s.t, err)
@@ -115,7 +115,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Invalid_Numeric() {
115115
env := ExecutePeerflow(s.t, tc, flowConnConfig)
116116
SetupCDCFlowStatusQuery(s.t, env, flowConnConfig)
117117

118-
EnvWaitFor(s.t, env, 3*time.Minute, "normalize shapes", func() bool {
118+
EnvWaitFor(s.t, env, 3*time.Minute, "init", func() bool {
119119
records, err := s.sfHelper.ExecuteAndProcessQuery(s.t.Context(), "select num1, num2 from "+dstTableName+" where id = 1")
120120
if err != nil || len(records.Records) == 0 {
121121
return false
@@ -130,7 +130,7 @@ func (s PeerFlowE2ETestSuiteSF) Test_Invalid_Numeric() {
130130
"9999999999999999")
131131
EnvNoError(s.t, env, err)
132132

133-
EnvWaitFor(s.t, env, 3*time.Minute, "normalize shapes", func() bool {
133+
EnvWaitFor(s.t, env, 3*time.Minute, "cdc", func() bool {
134134
records, err := s.sfHelper.ExecuteAndProcessQuery(s.t.Context(), "select num1, num2 from "+dstTableName+" where id = 2")
135135
if err != nil || len(records.Records) == 0 {
136136
return false

0 commit comments

Comments
 (0)