Skip to content

Commit e7e89b3

Browse files
committed
Load types using a single SQL query
When loading even a single type into pgx's type map, multiple SQL queries are performed in series. Over a slow link, this is not ideal. Worse, if multiple types are being registered, this is repeated multiple times. This commit changes the internal implementation of LoadType to use a single SQL query. It also added a LoadTypes, which can retrieve type mapping information for multiple types in a single SQL call. Additionally, LoadTypes will recursively load any related types, avoiding the need to explicitly list everything.
1 parent 9907b87 commit e7e89b3

File tree

2 files changed

+260
-8
lines changed

2 files changed

+260
-8
lines changed

conn.go

Lines changed: 214 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/hex"
77
"errors"
88
"fmt"
9+
"regexp"
910
"strconv"
1011
"strings"
1112
"time"
@@ -107,8 +108,10 @@ var (
107108
ErrTooManyRows = errors.New("too many rows in result set")
108109
)
109110

110-
var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
111-
var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
111+
var (
112+
errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
113+
errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
114+
)
112115

113116
// Connect establishes a connection with a PostgreSQL server with a connection string. See
114117
// pgconn.Connect for details.
@@ -843,7 +846,6 @@ func (c *Conn) getStatementDescription(
843846
mode QueryExecMode,
844847
sql string,
845848
) (sd *pgconn.StatementDescription, err error) {
846-
847849
switch mode {
848850
case QueryExecModeCacheStatement:
849851
if c.statementCache == nil {
@@ -1393,3 +1395,212 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error
13931395

13941396
return nil
13951397
}
1398+
1399+
/*
1400+
buildLoadTypesSQL generates the correct query for retrieving type information.
1401+
1402+
pgVersion: the major version of the PostgreSQL server
1403+
typeNames: the names of the types to load. If nil, load all types.
1404+
*/
1405+
func buildLoadTypesSQL(pgVersion int64, typeNames []string) string {
1406+
supportsMultirange := (pgVersion >= 14)
1407+
var typeNamesClause string
1408+
if typeNames == nil {
1409+
typeNamesClause = "IS NOT NULL"
1410+
} else {
1411+
typeNamesClause = "= ANY($1)"
1412+
}
1413+
parts := make([]string, 0, 10)
1414+
1415+
parts = append(parts, `
1416+
WITH RECURSIVE
1417+
selected_classes(oid,reltype) AS (
1418+
SELECT pg_class.oid, pg_class.reltype
1419+
FROM pg_catalog.pg_class
1420+
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = pg_class.relnamespace
1421+
WHERE pg_catalog.pg_table_is_visible(pg_class.oid)
1422+
AND relname `, typeNamesClause, `
1423+
UNION ALL
1424+
SELECT pg_class.oid, pg_class.reltype
1425+
FROM pg_class
1426+
INNER JOIN pg_namespace ON (pg_class.relnamespace = pg_namespace.oid)
1427+
WHERE nspname || '.' || relname `, typeNamesClause, `
1428+
),
1429+
selected_types(oid) AS (
1430+
SELECT reltype AS oid
1431+
FROM selected_classes
1432+
UNION ALL
1433+
SELECT oid
1434+
FROM pg_type
1435+
WHERE typname `, typeNamesClause, `
1436+
),
1437+
relationships(parent, child, depth) AS (
1438+
SELECT DISTINCT 0::OID, selected_types.oid, 0
1439+
FROM selected_types
1440+
UNION ALL
1441+
SELECT pg_type.oid AS parent, pg_attribute.atttypid AS child, 1
1442+
FROM selected_classes c
1443+
inner join pg_type ON (c.reltype = pg_type.oid)
1444+
inner join pg_attribute on (c.oid = pg_attribute.attrelid)
1445+
UNION ALL
1446+
SELECT parent.oid, parent.typbasetype, relationships.depth + 1
1447+
FROM pg_type parent
1448+
INNER JOIN relationships ON (parent.oid = relationships.child)
1449+
WHERE parent.typtypmod = -1 AND parent.typbasetype != 0
1450+
),
1451+
composite AS (
1452+
SELECT pg_type.oid, ARRAY_AGG(attname ORDER BY attnum) AS attnames, ARRAY_AGG(atttypid ORDER BY ATTNUM) AS atttypids
1453+
FROM pg_attribute
1454+
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
1455+
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
1456+
WHERE NOT attisdropped
1457+
AND attnum > 0
1458+
GROUP BY pg_type.oid
1459+
)
1460+
SELECT typname,
1461+
typtype,
1462+
typbasetype,
1463+
typelem,
1464+
pg_type.oid,`)
1465+
if supportsMultirange {
1466+
parts = append(parts, `
1467+
COALESCE(multirange.rngtypid, 0) AS rngtypid,`)
1468+
} else {
1469+
parts = append(parts, `
1470+
0 AS rngtypid,`)
1471+
}
1472+
parts = append(parts, `
1473+
COALESCE(pg_range.rngsubtype, 0) AS rngsubtype,
1474+
attnames, atttypids
1475+
FROM relationships
1476+
INNER JOIN pg_type ON (pg_type.oid IN ( relationships.child,relationships.parent) )
1477+
LEFT OUTER JOIN pg_range ON (pg_type.oid = pg_range.rngtypid)`)
1478+
if supportsMultirange {
1479+
parts = append(parts, `
1480+
LEFT OUTER JOIN pg_range multirange ON (pg_type.oid = multirange.rngmultitypid)`)
1481+
}
1482+
1483+
parts = append(parts, `
1484+
LEFT OUTER JOIN composite USING (oid)
1485+
WHERE typtype != 'b'`)
1486+
parts = append(parts, `
1487+
GROUP BY typname, typtype, typbasetype, typelem, pg_type.oid, pg_range.rngsubtype,`)
1488+
if supportsMultirange {
1489+
parts = append(parts, `
1490+
multirange.rngtypid,`)
1491+
}
1492+
parts = append(parts, `
1493+
attnames, atttypids
1494+
ORDER BY MAX(depth) desc, typname;`)
1495+
return strings.Join(parts, "")
1496+
}
1497+
1498+
// LoadAndRegisterTypes inspects the database for []typeNames and automatically registers all discovered
1499+
// types. Any types referenced by these will also be included in the registration.
1500+
func (c *Conn) LoadAndRegisterTypes(ctx context.Context, typeNames []string) error {
1501+
if typeNames == nil || len(typeNames) == 0 {
1502+
return fmt.Errorf("No type names were supplied.")
1503+
}
1504+
return c.loadAndRegisterTypes(ctx, typeNames, c.TypeMap())
1505+
}
1506+
1507+
func (c *Conn) loadAndRegisterTypes(ctx context.Context, typeNames []string, registerWith *pgtype.Map) error {
1508+
if registerWith == nil {
1509+
return fmt.Errorf("Type map must be supplied")
1510+
}
1511+
serverVersion, err := c.ServerVersion()
1512+
if err != nil {
1513+
return fmt.Errorf("Unexpected server version error: %w", err)
1514+
}
1515+
sql := buildLoadTypesSQL(serverVersion, typeNames)
1516+
var rows Rows
1517+
if typeNames == nil {
1518+
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol)
1519+
} else {
1520+
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
1521+
}
1522+
if err != nil {
1523+
return fmt.Errorf("While generating load types query: %w", err)
1524+
}
1525+
defer rows.Close()
1526+
for rows.Next() {
1527+
var oid uint32
1528+
var typeName, typtype string
1529+
var typbasetype, typelem uint32
1530+
var rngsubtype, rngtypid uint32
1531+
attnames := make([]string, 0, 0)
1532+
atttypids := make([]uint32, 0, 0)
1533+
err = rows.Scan(&typeName, &typtype, &typbasetype, &typelem, &oid, &rngtypid, &rngsubtype, &attnames, &atttypids)
1534+
if err != nil {
1535+
return fmt.Errorf("While scanning type information: %w", err)
1536+
}
1537+
1538+
switch typtype {
1539+
case "b": // array
1540+
dt, ok := c.TypeMap().TypeForOID(typelem)
1541+
if !ok {
1542+
return fmt.Errorf("array element OID not registered while loading %v", typeName)
1543+
}
1544+
registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementType: dt}})
1545+
case "c": // composite
1546+
var fields []pgtype.CompositeCodecField
1547+
for i, fieldName := range attnames {
1548+
//if fieldOID64, err = strconv.ParseUint(composite_fields[i+1], 10, 32); err != nil {
1549+
// return nil, fmt.Errorf("While extracting OID used in composite field: %w", err)
1550+
//}
1551+
dt, ok := c.TypeMap().TypeForOID(atttypids[i])
1552+
if !ok {
1553+
return fmt.Errorf("unknown composite type field OID %v (%v)", atttypids[i], fieldName)
1554+
}
1555+
fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt})
1556+
}
1557+
if err != nil {
1558+
return fmt.Errorf("While parsing %v: %w", typeName, err)
1559+
}
1560+
1561+
registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}})
1562+
case "d": // domain
1563+
dt, ok := c.TypeMap().TypeForOID(typbasetype)
1564+
if !ok {
1565+
return fmt.Errorf("domain base type OID not registered for %v", typeName)
1566+
}
1567+
1568+
registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: dt.Codec})
1569+
case "e": // enum
1570+
registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}})
1571+
case "r": // range
1572+
dt, ok := c.TypeMap().TypeForOID(rngsubtype)
1573+
if !ok {
1574+
return fmt.Errorf("range element OID %v not registered for %v", rngsubtype, typeName)
1575+
}
1576+
1577+
registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.RangeCodec{ElementType: dt}})
1578+
case "m": // multirange
1579+
dt, ok := c.TypeMap().TypeForOID(rngtypid)
1580+
if !ok {
1581+
return fmt.Errorf("multirange element OID %v not registered while loading %v", rngtypid, typeName)
1582+
}
1583+
1584+
registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}})
1585+
default:
1586+
return fmt.Errorf("unknown typtype %v for %v", typtype, typeName)
1587+
}
1588+
}
1589+
return nil
1590+
}
1591+
1592+
// ServerVersion returns the postgresql server version.
1593+
func (conn *Conn) ServerVersion() (int64, error) {
1594+
serverVersionStr := conn.PgConn().ParameterStatus("server_version")
1595+
serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
1596+
// if not PostgreSQL do nothing
1597+
if serverVersionStr == "" {
1598+
return 0, fmt.Errorf("Cannot identify server version in %q", serverVersionStr)
1599+
}
1600+
1601+
serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64)
1602+
if err != nil {
1603+
return 0, fmt.Errorf("postgres version parsing failed: %w", err)
1604+
}
1605+
return serverVersion, nil
1606+
}

pgtype/composite_test.go

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,56 @@ import (
1010
"github.com/stretchr/testify/require"
1111
)
1212

13-
func TestCompositeCodecTranscode(t *testing.T) {
13+
func TestCompositeCodecTranscodeWithLoadTypes(t *testing.T) {
1414
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
1515

1616
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
17+
_, err := conn.Exec(ctx, `drop domain if exists anotheruint64;
18+
drop type if exists ct_test;
19+
create domain anotheruint64 as numeric(20,0);
20+
21+
create type ct_test as (
22+
a text,
23+
b int4,
24+
c anotheruint64
25+
);`)
26+
require.NoError(t, err)
27+
defer conn.Exec(ctx, "drop type ct_test")
28+
defer conn.Exec(ctx, "drop domain anotheruint64")
29+
30+
err = conn.LoadAndRegisterTypes(ctx, []string{"ct_test"})
31+
require.NoError(t, err)
32+
33+
formats := []struct {
34+
name string
35+
code int16
36+
}{
37+
{name: "TextFormat", code: pgx.TextFormatCode},
38+
{name: "BinaryFormat", code: pgx.BinaryFormatCode},
39+
}
40+
41+
for _, format := range formats {
42+
var a string
43+
var b int32
44+
var c uint64
1745

46+
err := conn.QueryRow(ctx, "select $1::ct_test", pgx.QueryResultFormats{format.code},
47+
pgtype.CompositeFields{"hi", int32(42), uint64(123)},
48+
).Scan(
49+
pgtype.CompositeFields{&a, &b, &c},
50+
)
51+
require.NoErrorf(t, err, "%v", format.name)
52+
require.EqualValuesf(t, "hi", a, "%v", format.name)
53+
require.EqualValuesf(t, 42, b, "%v", format.name)
54+
require.EqualValuesf(t, 123, c, "%v", format.name)
55+
}
56+
})
57+
}
58+
59+
func TestCompositeCodecTranscode(t *testing.T) {
60+
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
61+
62+
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
1863
_, err := conn.Exec(ctx, `drop type if exists ct_test;
1964
2065
create type ct_test as (
@@ -94,7 +139,6 @@ func TestCompositeCodecTranscodeStruct(t *testing.T) {
94139
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
95140

96141
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
97-
98142
_, err := conn.Exec(ctx, `drop type if exists point3d;
99143
100144
create type point3d as (
@@ -131,7 +175,6 @@ func TestCompositeCodecTranscodeStructWrapper(t *testing.T) {
131175
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
132176

133177
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
134-
135178
_, err := conn.Exec(ctx, `drop type if exists point3d;
136179
137180
create type point3d as (
@@ -172,7 +215,6 @@ func TestCompositeCodecDecodeValue(t *testing.T) {
172215
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
173216

174217
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
175-
176218
_, err := conn.Exec(ctx, `drop type if exists point3d;
177219
178220
create type point3d as (
@@ -217,7 +259,6 @@ func TestCompositeCodecTranscodeStructWrapperForTable(t *testing.T) {
217259
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
218260

219261
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
220-
221262
_, err := conn.Exec(ctx, `drop table if exists point3d;
222263
223264
create table point3d (

0 commit comments

Comments
 (0)