|
4 | 4 | "context" |
5 | 5 | "fmt" |
6 | 6 | "log/slog" |
| 7 | + "maps" |
| 8 | + "slices" |
7 | 9 |
|
8 | 10 | "github.com/jackc/pgx/v5" |
9 | 11 | "github.com/jackc/pgx/v5/pgconn" |
@@ -58,43 +60,68 @@ func (qe *QRepQueryExecutor) ExecuteQuery(ctx context.Context, query string, arg |
58 | 60 | return rows, nil |
59 | 61 | } |
60 | 62 |
|
61 | | -func (qe *QRepQueryExecutor) executeQueryInTx(ctx context.Context, tx pgx.Tx, cursorName string, fetchSize int) (pgx.Rows, error) { |
62 | | - qe.logger.Info("Executing query in transaction") |
63 | | - q := fmt.Sprintf("FETCH %d FROM %s", fetchSize, cursorName) |
| 63 | +// FieldDescriptionsToSchema converts a slice of pgconn.FieldDescription to a QRecordSchema. |
| 64 | +func (qe *QRepQueryExecutor) cursorToSchema( |
| 65 | + ctx context.Context, |
| 66 | + tx pgx.Tx, |
| 67 | + cursorName string, |
| 68 | +) (types.QRecordSchema, error) { |
| 69 | + type attId struct { |
| 70 | + relid uint32 |
| 71 | + num uint16 |
| 72 | + } |
64 | 73 |
|
65 | | - rows, err := tx.Query(ctx, q) |
| 74 | + rows, err := tx.Query(ctx, "FETCH 0 FROM "+cursorName) |
66 | 75 | if err != nil { |
67 | | - qe.logger.Error("[pg_query_executor] failed to execute query in tx", slog.Any("error", err)) |
68 | | - return nil, err |
| 76 | + return types.QRecordSchema{}, fmt.Errorf("failed to fetch 0 for field descriptions: %w", err) |
69 | 77 | } |
70 | | - |
71 | | - return rows, nil |
72 | | -} |
73 | | - |
74 | | -// FieldDescriptionsToSchema converts a slice of pgconn.FieldDescription to a QRecordSchema. |
75 | | -func (qe *QRepQueryExecutor) fieldDescriptionsToSchema(fds []pgconn.FieldDescription) types.QRecordSchema { |
| 78 | + fds := rows.FieldDescriptions() |
| 79 | + tableOIDset := make(map[uint32]struct{}) |
| 80 | + nullPointers := make(map[attId]*bool, len(fds)) |
76 | 81 | qfields := make([]types.QField, len(fds)) |
77 | 82 | for i, fd := range fds { |
| 83 | + tableOIDset[fd.TableOID] = struct{}{} |
78 | 84 | ctype := qe.postgresOIDToQValueKind(fd.DataTypeOID, qe.customTypeMapping, qe.version) |
79 | | - // there isn't a way to know if a column is nullable or not |
80 | 85 | if ctype == types.QValueKindNumeric || ctype == types.QValueKindArrayNumeric { |
81 | 86 | precision, scale := datatypes.ParseNumericTypmod(fd.TypeModifier) |
82 | 87 | qfields[i] = types.QField{ |
83 | 88 | Name: fd.Name, |
84 | 89 | Type: ctype, |
85 | | - Nullable: true, |
| 90 | + Nullable: false, |
86 | 91 | Precision: precision, |
87 | 92 | Scale: scale, |
88 | 93 | } |
89 | 94 | } else { |
90 | 95 | qfields[i] = types.QField{ |
91 | 96 | Name: fd.Name, |
92 | 97 | Type: ctype, |
93 | | - Nullable: true, |
| 98 | + Nullable: false, |
94 | 99 | } |
95 | 100 | } |
| 101 | + nullPointers[attId{ |
| 102 | + relid: fd.TableOID, |
| 103 | + num: fd.TableAttributeNumber, |
| 104 | + }] = &qfields[i].Nullable |
96 | 105 | } |
97 | | - return types.NewQRecordSchema(qfields) |
| 106 | + rows.Close() |
| 107 | + tableOIDs := slices.Collect(maps.Keys(tableOIDset)) |
| 108 | + |
| 109 | + rows, err = tx.Query(ctx, "SELECT a.attrelid,a.attnum FROM pg_attribute a WHERE a.attrelid = ANY($1) AND NOT a.attnotnull", tableOIDs) |
| 110 | + if err != nil { |
| 111 | + return types.QRecordSchema{}, fmt.Errorf("failed to query schema for field descriptions: %w", err) |
| 112 | + } |
| 113 | + |
| 114 | + var att attId |
| 115 | + if _, err := pgx.ForEachRow(rows, []any{&att.relid, &att.num}, func() error { |
| 116 | + if nullPointer, ok := nullPointers[att]; ok { |
| 117 | + *nullPointer = true |
| 118 | + } |
| 119 | + return nil |
| 120 | + }); err != nil { |
| 121 | + return types.QRecordSchema{}, fmt.Errorf("failed to process schema for field descriptions: %w", err) |
| 122 | + } |
| 123 | + |
| 124 | + return types.NewQRecordSchema(qfields), nil |
98 | 125 | } |
99 | 126 |
|
100 | 127 | func (qe *QRepQueryExecutor) processRowsStream( |
@@ -152,20 +179,17 @@ func (qe *QRepQueryExecutor) processFetchedRows( |
152 | 179 | fetchSize int, |
153 | 180 | stream *model.QRecordStream, |
154 | 181 | ) (int64, int64, error) { |
155 | | - rows, err := qe.executeQueryInTx(ctx, tx, cursorName, fetchSize) |
| 182 | + qe.logger.Info("[pg_query_executor] fetching from cursor", slog.String("cursor", cursorName)) |
| 183 | + |
| 184 | + rows, err := tx.Query(ctx, fmt.Sprintf("FETCH %d FROM %s", fetchSize, cursorName)) |
156 | 185 | if err != nil { |
157 | | - qe.logger.Error("[pg_query_executor] failed to execute query in tx", |
| 186 | + qe.logger.Error("[pg_query_executor] failed to fetch cursor in tx", |
158 | 187 | slog.Any("error", err), slog.String("query", query)) |
159 | 188 | return 0, 0, fmt.Errorf("[pg_query_executor] failed to execute query in tx: %w", err) |
160 | 189 | } |
161 | 190 | defer rows.Close() |
162 | 191 |
|
163 | 192 | fieldDescriptions := rows.FieldDescriptions() |
164 | | - if !stream.IsSchemaSet() { |
165 | | - schema := qe.fieldDescriptionsToSchema(fieldDescriptions) |
166 | | - stream.SetSchema(schema) |
167 | | - } |
168 | | - |
169 | 193 | numRows, numBytes, err := qe.processRowsStream(ctx, cursorName, stream, rows, fieldDescriptions) |
170 | 194 | if err != nil { |
171 | 195 | qe.logger.Error("[pg_query_executor] failed to process rows", slog.Any("error", err)) |
|
0 commit comments