Skip to content

Commit 6e1e9eb

Browse files
authored
Merge pull request jackc#2510 from abrightwell/abrightwell-tsvector
Add `tsvector` type support
2 parents 2c62512 + ea6b093 commit 6e1e9eb

File tree

6 files changed

+1075
-0
lines changed

6 files changed

+1075
-0
lines changed

copy_from_test.go

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010

1111
"github.com/jackc/pgx/v5"
1212
"github.com/jackc/pgx/v5/pgconn"
13+
"github.com/jackc/pgx/v5/pgtype"
1314
"github.com/jackc/pgx/v5/pgxtest"
1415
"github.com/stretchr/testify/require"
1516
)
@@ -452,6 +453,106 @@ func TestConnCopyFromJSON(t *testing.T) {
452453
ensureConnValid(t, conn)
453454
}
454455

456+
func TestConnCopyFromTSVector(t *testing.T) {
457+
t.Parallel()
458+
459+
ctx, cancel := context.WithTimeout(context.Background(), 120*time.Second)
460+
defer cancel()
461+
462+
conn := mustConnectString(t, os.Getenv("PGX_TEST_DATABASE"))
463+
defer closeConn(t, conn)
464+
465+
pgxtest.SkipCockroachDB(t, conn, "CockroachDB handles tsvector escaping differently")
466+
467+
tx, err := conn.Begin(ctx)
468+
require.NoError(t, err)
469+
defer tx.Rollback(ctx)
470+
471+
_, err = tx.Exec(ctx, `create temporary table tmp_tsv (id int, t tsvector)`)
472+
require.NoError(t, err)
473+
474+
inputRows := [][]any{
475+
// Text format: core functionality.
476+
{1, `'a':1A 'cat':5 'fat':2B,4C`}, // Multiple lexemes with positions and weights.
477+
{2, `'bare'`}, // Single lexeme with no positions.
478+
{3, `'multi':1,2,3,4,5`}, // Multiple positions (default weight D).
479+
{4, `'test':1A,2B,3C,4D`}, // All four weights on one lexeme.
480+
{5, `'word':1D`}, // Explicit weight D (normalizes to no suffix).
481+
{6, `'high':16383A`}, // High position number (near 14-bit max).
482+
483+
// Text format: escaping.
484+
{7, `'don''t'`}, // Quote escaping (doubled single quote).
485+
{8, `'don\'t'`}, // Quote escaping (backslash).
486+
{9, `'ab\\c'`}, // Backslash in lexeme.
487+
{10, `'\ foo'`}, // Escaped space.
488+
489+
// Text format: special characters.
490+
{11, `'café' 'naïve'`}, // Unicode lexemes.
491+
{12, `'a:b' 'c,d'`}, // Delimiter-like characters (colon, comma).
492+
493+
// Struct format: tests binary encoding path.
494+
{13, pgtype.TSVector{
495+
Lexemes: []pgtype.TSVectorLexeme{
496+
{Word: "alpha", Positions: []pgtype.TSVectorPosition{{Position: 1, Weight: pgtype.TSVectorWeightA}}},
497+
{Word: "beta", Positions: []pgtype.TSVectorPosition{{Position: 2, Weight: pgtype.TSVectorWeightB}}},
498+
{Word: "gamma", Positions: nil},
499+
},
500+
Valid: true,
501+
}},
502+
{14, pgtype.TSVector{Valid: true}}, // Empty valid tsvector (no lexemes).
503+
504+
// NULL handling.
505+
{15, pgtype.TSVector{Valid: false}}, // Invalid (NULL) TSVector struct.
506+
{16, nil}, // Nil value.
507+
}
508+
509+
copyCount, err := conn.CopyFrom(ctx, pgx.Identifier{"tmp_tsv"}, []string{"id", "t"}, pgx.CopyFromRows(inputRows))
510+
require.NoError(t, err)
511+
require.EqualValues(t, len(inputRows), copyCount)
512+
513+
rows, err := conn.Query(ctx, "select id, t::text from tmp_tsv order by id nulls last")
514+
require.NoError(t, err)
515+
516+
var outputRows [][]any
517+
for rows.Next() {
518+
row, err := rows.Values()
519+
require.NoError(t, err)
520+
outputRows = append(outputRows, row)
521+
}
522+
require.NoError(t, rows.Err())
523+
524+
expectedOutputRows := [][]any{
525+
// Text format: core functionality.
526+
{int32(1), `'a':1A 'cat':5 'fat':2B,4C`},
527+
{int32(2), `'bare'`},
528+
{int32(3), `'multi':1,2,3,4,5`},
529+
{int32(4), `'test':1A,2B,3C,4`},
530+
{int32(5), `'word':1`},
531+
{int32(6), `'high':16383A`},
532+
533+
// Text format: escaping.
534+
{int32(7), `'don''t'`},
535+
{int32(8), `'don''t'`},
536+
{int32(9), `'ab\\c'`},
537+
{int32(10), `' foo'`},
538+
539+
// Text format: special characters.
540+
{int32(11), `'café' 'naïve'`},
541+
{int32(12), `'a:b' 'c,d'`},
542+
543+
// Struct format.
544+
{int32(13), `'alpha':1A 'beta':2B 'gamma'`},
545+
{int32(14), ``},
546+
547+
// NULL handling.
548+
{int32(15), nil},
549+
{int32(16), nil},
550+
}
551+
require.Equal(t, expectedOutputRows, outputRows)
552+
553+
ensureConnValid(t, conn)
554+
}
555+
455556
type clientFailSource struct {
456557
count int
457558
err error

pgconn/pgconn_test.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,7 @@ func TestConnectOAuthError(t *testing.T) {
170170
_, err = pgconn.ConnectConfig(context.Background(), config)
171171
require.Error(t, err, "connect should return error for invalid token")
172172
}
173+
173174
func TestConnectTLSPasswordProtectedClientCertWithSSLPassword(t *testing.T) {
174175
t.Parallel()
175176

pgtype/pgtype.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,8 @@ const (
9696
RecordArrayOID = 2287
9797
UUIDOID = 2950
9898
UUIDArrayOID = 2951
99+
TSVectorOID = 3614
100+
TSVectorArrayOID = 3643
99101
JSONBOID = 3802
100102
JSONBArrayOID = 3807
101103
DaterangeOID = 3912

pgtype/pgtype_default.go

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ func initDefaultMap() {
8181
defaultMap.RegisterType(&Type{Name: "record", OID: RecordOID, Codec: RecordCodec{}})
8282
defaultMap.RegisterType(&Type{Name: "text", OID: TextOID, Codec: TextCodec{}})
8383
defaultMap.RegisterType(&Type{Name: "tid", OID: TIDOID, Codec: TIDCodec{}})
84+
defaultMap.RegisterType(&Type{Name: "tsvector", OID: TSVectorOID, Codec: TSVectorCodec{}})
8485
defaultMap.RegisterType(&Type{Name: "time", OID: TimeOID, Codec: TimeCodec{}})
8586
defaultMap.RegisterType(&Type{Name: "timestamp", OID: TimestampOID, Codec: &TimestampCodec{}})
8687
defaultMap.RegisterType(&Type{Name: "timestamptz", OID: TimestamptzOID, Codec: &TimestamptzCodec{}})
@@ -164,6 +165,7 @@ func initDefaultMap() {
164165
defaultMap.RegisterType(&Type{Name: "_record", OID: RecordArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[RecordOID]}})
165166
defaultMap.RegisterType(&Type{Name: "_text", OID: TextArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TextOID]}})
166167
defaultMap.RegisterType(&Type{Name: "_tid", OID: TIDArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TIDOID]}})
168+
defaultMap.RegisterType(&Type{Name: "_tsvector", OID: TSVectorArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TSVectorOID]}})
167169
defaultMap.RegisterType(&Type{Name: "_time", OID: TimeArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimeOID]}})
168170
defaultMap.RegisterType(&Type{Name: "_timestamp", OID: TimestampArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestampOID]}})
169171
defaultMap.RegisterType(&Type{Name: "_timestamptz", OID: TimestamptzArrayOID, Codec: &ArrayCodec{ElementType: defaultMap.oidToType[TimestamptzOID]}})
@@ -242,6 +244,7 @@ func initDefaultMap() {
242244
registerDefaultPgTypeVariants[Multirange[Range[Timestamp]]](defaultMap, "tsmultirange")
243245
registerDefaultPgTypeVariants[Range[Timestamptz]](defaultMap, "tstzrange")
244246
registerDefaultPgTypeVariants[Multirange[Range[Timestamptz]]](defaultMap, "tstzmultirange")
247+
registerDefaultPgTypeVariants[TSVector](defaultMap, "tsvector")
245248
registerDefaultPgTypeVariants[UUID](defaultMap, "uuid")
246249

247250
defaultMap.buildReflectTypeToType()

0 commit comments

Comments
 (0)