Skip to content

Commit 8b6bf61

Browse files
committed
Load types using a single SQL query
When loading even a single type into pgx's type map, multiple SQL queries are performed in series. Over a slow link, this is not ideal. Worse, if multiple types are being registered, this is repeated multiple times. This commit changes the internal implementation of LoadType to use a single SQL query. It also added a LoadTypes, which can retrieve type mapping information for multiple types in a single SQL call. Additionally, LoadTypes will recursively load any related types, avoiding the need to explicitly list everything.
1 parent 9907b87 commit 8b6bf61

File tree

3 files changed

+276
-12
lines changed

3 files changed

+276
-12
lines changed

conn.go

Lines changed: 230 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/hex"
77
"errors"
88
"fmt"
9+
"regexp"
910
"strconv"
1011
"strings"
1112
"time"
@@ -107,8 +108,10 @@ var (
107108
ErrTooManyRows = errors.New("too many rows in result set")
108109
)
109110

110-
var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
111-
var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
111+
var (
112+
errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
113+
errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
114+
)
112115

113116
// Connect establishes a connection with a PostgreSQL server with a connection string. See
114117
// pgconn.Connect for details.
@@ -843,7 +846,6 @@ func (c *Conn) getStatementDescription(
843846
mode QueryExecMode,
844847
sql string,
845848
) (sd *pgconn.StatementDescription, err error) {
846-
847849
switch mode {
848850
case QueryExecModeCacheStatement:
849851
if c.statementCache == nil {
@@ -1393,3 +1395,228 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error
13931395

13941396
return nil
13951397
}
1398+
1399+
/*
1400+
buildLoadTypesSQL generates the correct query for retrieving type information.
1401+
1402+
pgVersion: the major version of the PostgreSQL server
1403+
typeNames: the names of the types to load. If nil, load all types.
1404+
*/
1405+
func buildLoadTypesSQL(pgVersion int64, typeNames []string) string {
1406+
supportsMultirange := (pgVersion >= 14)
1407+
var typeNamesClause string
1408+
if typeNames == nil {
1409+
typeNamesClause = "IS NOT NULL"
1410+
} else {
1411+
typeNamesClause = "= ANY($1)"
1412+
}
1413+
parts := make([]string, 0, 10)
1414+
1415+
parts = append(parts, `
1416+
WITH RECURSIVE
1417+
selected_classes(oid,reltype) AS (
1418+
SELECT pg_class.oid, pg_class.reltype
1419+
FROM pg_catalog.pg_class
1420+
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = pg_class.relnamespace
1421+
WHERE pg_catalog.pg_table_is_visible(pg_class.oid)
1422+
AND relname `, typeNamesClause, `
1423+
UNION ALL
1424+
SELECT pg_class.oid, pg_class.reltype
1425+
FROM pg_class
1426+
INNER JOIN pg_namespace ON (pg_class.relnamespace = pg_namespace.oid)
1427+
WHERE nspname || '.' || relname `, typeNamesClause, `
1428+
),
1429+
selected_types(oid) AS (
1430+
SELECT reltype AS oid
1431+
FROM selected_classes
1432+
UNION ALL
1433+
SELECT oid
1434+
FROM pg_type
1435+
WHERE typname `, typeNamesClause, `
1436+
),
1437+
pc(parent, child) AS (
1438+
SELECT parent.oid, parent.typelem
1439+
FROM pg_type parent
1440+
WHERE parent.typtype = 'b' AND parent.typelem != 0
1441+
UNION ALL
1442+
SELECT parent.oid, parent.typbasetype
1443+
FROM pg_type parent
1444+
WHERE parent.typtypmod = -1 AND parent.typbasetype != 0
1445+
UNION ALL
1446+
SELECT pg_type.oid, atttypid
1447+
FROM pg_attribute
1448+
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
1449+
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
1450+
WHERE NOT attisdropped
1451+
AND attnum > 0
1452+
),
1453+
relationships(parent, child, depth) AS (
1454+
SELECT DISTINCT 0::OID, selected_types.oid, 0
1455+
FROM selected_types
1456+
UNION ALL
1457+
SELECT pg_type.oid AS parent, pg_attribute.atttypid AS child, 1
1458+
FROM selected_classes c
1459+
inner join pg_type ON (c.reltype = pg_type.oid)
1460+
inner join pg_attribute on (c.oid = pg_attribute.attrelid)
1461+
UNION ALL
1462+
SELECT pc.parent, pc.child, relationships.depth + 1
1463+
FROM pc
1464+
INNER JOIN relationships ON (pc.parent = relationships.child)
1465+
),
1466+
composite AS (
1467+
SELECT pg_type.oid, ARRAY_AGG(attname ORDER BY attnum) AS attnames, ARRAY_AGG(atttypid ORDER BY ATTNUM) AS atttypids
1468+
FROM pg_attribute
1469+
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
1470+
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
1471+
WHERE NOT attisdropped
1472+
AND attnum > 0
1473+
GROUP BY pg_type.oid
1474+
)
1475+
SELECT typname,
1476+
typtype,
1477+
typbasetype,
1478+
typelem,
1479+
pg_type.oid,`)
1480+
if supportsMultirange {
1481+
parts = append(parts, `
1482+
COALESCE(multirange.rngtypid, 0) AS rngtypid,`)
1483+
} else {
1484+
parts = append(parts, `
1485+
0 AS rngtypid,`)
1486+
}
1487+
parts = append(parts, `
1488+
COALESCE(pg_range.rngsubtype, 0) AS rngsubtype,
1489+
attnames, atttypids
1490+
FROM relationships
1491+
INNER JOIN pg_type ON (pg_type.oid IN ( relationships.child,relationships.parent) )
1492+
LEFT OUTER JOIN pg_range ON (pg_type.oid = pg_range.rngtypid)`)
1493+
if supportsMultirange {
1494+
parts = append(parts, `
1495+
LEFT OUTER JOIN pg_range multirange ON (pg_type.oid = multirange.rngmultitypid)`)
1496+
}
1497+
1498+
parts = append(parts, `
1499+
LEFT OUTER JOIN composite USING (oid)
1500+
WHERE NOT (typtype = 'b' AND typelem = 0)`)
1501+
parts = append(parts, `
1502+
GROUP BY typname, typtype, typbasetype, typelem, pg_type.oid, pg_range.rngsubtype,`)
1503+
if supportsMultirange {
1504+
parts = append(parts, `
1505+
multirange.rngtypid,`)
1506+
}
1507+
parts = append(parts, `
1508+
attnames, atttypids
1509+
ORDER BY MAX(depth) desc, typname;`)
1510+
return strings.Join(parts, "")
1511+
}
1512+
1513+
// LoadAndRegisterTypes inspects the database for []typeNames and automatically registers all discovered
1514+
// types. Any types referenced by these will also be included in the registration.
1515+
func (c *Conn) LoadAndRegisterTypes(ctx context.Context, typeNames []string) error {
1516+
if typeNames == nil || len(typeNames) == 0 {
1517+
return fmt.Errorf("No type names were supplied.")
1518+
}
1519+
return c.loadAndRegisterTypes(ctx, typeNames, c.TypeMap())
1520+
}
1521+
1522+
func (c *Conn) loadAndRegisterTypes(ctx context.Context, typeNames []string, registerWith *pgtype.Map) error {
1523+
if registerWith == nil {
1524+
return fmt.Errorf("Type map must be supplied")
1525+
}
1526+
serverVersion, err := c.ServerVersion()
1527+
if err != nil {
1528+
return fmt.Errorf("Unexpected server version error: %w", err)
1529+
}
1530+
sql := buildLoadTypesSQL(serverVersion, typeNames)
1531+
var rows Rows
1532+
if typeNames == nil {
1533+
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol)
1534+
} else {
1535+
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
1536+
}
1537+
if err != nil {
1538+
return fmt.Errorf("While generating load types query: %w", err)
1539+
}
1540+
defer rows.Close()
1541+
for rows.Next() {
1542+
var oid uint32
1543+
var typeName, typtype string
1544+
var typbasetype, typelem uint32
1545+
var rngsubtype, rngtypid uint32
1546+
attnames := make([]string, 0, 0)
1547+
atttypids := make([]uint32, 0, 0)
1548+
err = rows.Scan(&typeName, &typtype, &typbasetype, &typelem, &oid, &rngtypid, &rngsubtype, &attnames, &atttypids)
1549+
fmt.Println(oid, typeName, typtype, typelem, attnames, atttypids)
1550+
if err != nil {
1551+
return fmt.Errorf("While scanning type information: %w", err)
1552+
}
1553+
1554+
switch typtype {
1555+
case "b": // array
1556+
dt, ok := c.TypeMap().TypeForOID(typelem)
1557+
if !ok {
1558+
return fmt.Errorf("array element OID %v not registered while loading for %v", typelem, typeName)
1559+
}
1560+
registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementType: dt}})
1561+
case "c": // composite
1562+
var fields []pgtype.CompositeCodecField
1563+
for i, fieldName := range attnames {
1564+
//if fieldOID64, err = strconv.ParseUint(composite_fields[i+1], 10, 32); err != nil {
1565+
// return nil, fmt.Errorf("While extracting OID used in composite field: %w", err)
1566+
//}
1567+
dt, ok := c.TypeMap().TypeForOID(atttypids[i])
1568+
if !ok {
1569+
return fmt.Errorf("unknown composite type field OID %v (%v)", atttypids[i], fieldName)
1570+
}
1571+
fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt})
1572+
}
1573+
if err != nil {
1574+
return fmt.Errorf("While parsing %v: %w", typeName, err)
1575+
}
1576+
1577+
registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}})
1578+
case "d": // domain
1579+
dt, ok := c.TypeMap().TypeForOID(typbasetype)
1580+
if !ok {
1581+
return fmt.Errorf("domain base type OID %v was not already registered, needed for %v", typbasetype, typeName)
1582+
}
1583+
1584+
registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: dt.Codec})
1585+
case "e": // enum
1586+
registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}})
1587+
case "r": // range
1588+
dt, ok := c.TypeMap().TypeForOID(rngsubtype)
1589+
if !ok {
1590+
return fmt.Errorf("range element OID %v not registered for %v", rngsubtype, typeName)
1591+
}
1592+
1593+
registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.RangeCodec{ElementType: dt}})
1594+
case "m": // multirange
1595+
dt, ok := c.TypeMap().TypeForOID(rngtypid)
1596+
if !ok {
1597+
return fmt.Errorf("multirange element OID %v not registered while loading %v", rngtypid, typeName)
1598+
}
1599+
1600+
registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}})
1601+
default:
1602+
return fmt.Errorf("unknown typtype %v for %v", typtype, typeName)
1603+
}
1604+
}
1605+
return nil
1606+
}
1607+
1608+
// ServerVersion returns the postgresql server version.
1609+
func (conn *Conn) ServerVersion() (int64, error) {
1610+
serverVersionStr := conn.PgConn().ParameterStatus("server_version")
1611+
serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
1612+
// if not PostgreSQL do nothing
1613+
if serverVersionStr == "" {
1614+
return 0, fmt.Errorf("Cannot identify server version in %q", serverVersionStr)
1615+
}
1616+
1617+
serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64)
1618+
if err != nil {
1619+
return 0, fmt.Errorf("postgres version parsing failed: %w", err)
1620+
}
1621+
return serverVersion, nil
1622+
}

go.sum

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
44
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
55
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
66
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
7-
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
8-
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
9-
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
10-
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
117
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
128
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
139
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=

pgtype/composite_test.go

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,56 @@ import (
1010
"github.com/stretchr/testify/require"
1111
)
1212

13-
func TestCompositeCodecTranscode(t *testing.T) {
13+
func TestCompositeCodecTranscodeWithLoadTypes(t *testing.T) {
1414
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
1515

1616
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
17+
_, err := conn.Exec(ctx, `drop domain if exists anotheruint64;
18+
drop type if exists ct_test;
19+
create domain anotheruint64 as numeric(20,0);
20+
21+
create type ct_test as (
22+
a text,
23+
b int4,
24+
c anotheruint64
25+
);`)
26+
require.NoError(t, err)
27+
defer conn.Exec(ctx, "drop type ct_test")
28+
defer conn.Exec(ctx, "drop domain anotheruint64")
29+
30+
err = conn.LoadAndRegisterTypes(ctx, []string{"ct_test"})
31+
require.NoError(t, err)
32+
33+
formats := []struct {
34+
name string
35+
code int16
36+
}{
37+
{name: "TextFormat", code: pgx.TextFormatCode},
38+
{name: "BinaryFormat", code: pgx.BinaryFormatCode},
39+
}
40+
41+
for _, format := range formats {
42+
var a string
43+
var b int32
44+
var c uint64
1745

46+
err := conn.QueryRow(ctx, "select $1::ct_test", pgx.QueryResultFormats{format.code},
47+
pgtype.CompositeFields{"hi", int32(42), uint64(123)},
48+
).Scan(
49+
pgtype.CompositeFields{&a, &b, &c},
50+
)
51+
require.NoErrorf(t, err, "%v", format.name)
52+
require.EqualValuesf(t, "hi", a, "%v", format.name)
53+
require.EqualValuesf(t, 42, b, "%v", format.name)
54+
require.EqualValuesf(t, 123, c, "%v", format.name)
55+
}
56+
})
57+
}
58+
59+
func TestCompositeCodecTranscode(t *testing.T) {
60+
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
61+
62+
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
1863
_, err := conn.Exec(ctx, `drop type if exists ct_test;
1964
2065
create type ct_test as (
@@ -94,7 +139,6 @@ func TestCompositeCodecTranscodeStruct(t *testing.T) {
94139
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
95140

96141
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
97-
98142
_, err := conn.Exec(ctx, `drop type if exists point3d;
99143
100144
create type point3d as (
@@ -131,7 +175,6 @@ func TestCompositeCodecTranscodeStructWrapper(t *testing.T) {
131175
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
132176

133177
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
134-
135178
_, err := conn.Exec(ctx, `drop type if exists point3d;
136179
137180
create type point3d as (
@@ -172,7 +215,6 @@ func TestCompositeCodecDecodeValue(t *testing.T) {
172215
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
173216

174217
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
175-
176218
_, err := conn.Exec(ctx, `drop type if exists point3d;
177219
178220
create type point3d as (
@@ -217,7 +259,6 @@ func TestCompositeCodecTranscodeStructWrapperForTable(t *testing.T) {
217259
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
218260

219261
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
220-
221262
_, err := conn.Exec(ctx, `drop table if exists point3d;
222263
223264
create table point3d (

0 commit comments

Comments
 (0)