Skip to content

Commit f238d1d

Browse files
committed
Load types using a single SQL query
When loading even a single type into pgx's type map, multiple SQL queries are performed in series. Over a slow link, this is not ideal. Worse, if multiple types are being registered, this is repeated multiple times. This commit changes the internal implementation of LoadType to use a single SQL query. It also added a LoadTypes, which can retrieve type mapping information for multiple types in a single SQL call. Additionally, LoadTypes will recursively load any related types, avoiding the need to explicitly list everything.
1 parent 9907b87 commit f238d1d

File tree

2 files changed

+264
-8
lines changed

2 files changed

+264
-8
lines changed

conn.go

Lines changed: 218 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"encoding/hex"
77
"errors"
88
"fmt"
9+
"regexp"
910
"strconv"
1011
"strings"
1112
"time"
@@ -107,8 +108,10 @@ var (
107108
ErrTooManyRows = errors.New("too many rows in result set")
108109
)
109110

110-
var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
111-
var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
111+
var (
112+
errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
113+
errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
114+
)
112115

113116
// Connect establishes a connection with a PostgreSQL server with a connection string. See
114117
// pgconn.Connect for details.
@@ -843,7 +846,6 @@ func (c *Conn) getStatementDescription(
843846
mode QueryExecMode,
844847
sql string,
845848
) (sd *pgconn.StatementDescription, err error) {
846-
847849
switch mode {
848850
case QueryExecModeCacheStatement:
849851
if c.statementCache == nil {
@@ -1393,3 +1395,216 @@ func (c *Conn) deallocateInvalidatedCachedStatements(ctx context.Context) error
13931395

13941396
return nil
13951397
}
1398+
1399+
/*
1400+
buildLoadTypesSQL generates the correct query for retrieving type information.
1401+
1402+
pgVersion: the major version of the PostgreSQL server
1403+
isCockroachDB: is this cockroachDB?
1404+
typeNames: the names of the types to load. If nil, load all types.
1405+
*/
1406+
func buildLoadTypesSQL(pgVersion int64, isCockroachDB bool, typeNames []string) string {
1407+
supportsMultirange := (pgVersion >= 14)
1408+
var typeNamesClause string
1409+
if typeNames == nil {
1410+
typeNamesClause = "IS NOT NULL"
1411+
} else {
1412+
typeNamesClause = "= ANY($1)"
1413+
}
1414+
parts := make([]string, 0, 10)
1415+
1416+
parts = append(parts, `
1417+
WITH RECURSIVE
1418+
selected_classes(oid,reltype) AS (
1419+
SELECT pg_class.oid, pg_class.reltype
1420+
FROM pg_catalog.pg_class
1421+
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = pg_class.relnamespace
1422+
WHERE pg_catalog.pg_table_is_visible(pg_class.oid)
1423+
AND relname `, typeNamesClause, `
1424+
UNION ALL
1425+
SELECT pg_class.oid, pg_class.reltype
1426+
FROM pg_class
1427+
INNER JOIN pg_namespace ON (pg_class.relnamespace = pg_namespace.oid)
1428+
WHERE nspname || '.' || relname `, typeNamesClause, `
1429+
),
1430+
selected_types(oid) AS (
1431+
SELECT reltype AS oid
1432+
FROM selected_classes
1433+
UNION ALL
1434+
SELECT oid
1435+
FROM pg_type
1436+
WHERE typname `, typeNamesClause, `
1437+
),
1438+
relationships(parent, child, depth) AS (
1439+
SELECT DISTINCT 0::OID, selected_types.oid, 0
1440+
FROM selected_types
1441+
UNION ALL
1442+
SELECT pg_type.oid AS parent, pg_attribute.atttypid AS child, 1
1443+
FROM selected_classes c
1444+
inner join pg_type ON (c.reltype = pg_type.oid)
1445+
inner join pg_attribute on (c.oid = pg_attribute.attrelid)
1446+
UNION ALL
1447+
SELECT parent.oid, parent.typbasetype, relationships.depth + 1
1448+
FROM pg_type parent
1449+
INNER JOIN relationships ON (parent.oid = relationships.child)
1450+
WHERE parent.typtypmod = -1 AND parent.typbasetype != 0
1451+
),
1452+
composite AS (
1453+
SELECT pg_type.oid, ARRAY_AGG(attname ORDER BY attnum) AS attnames, ARRAY_AGG(atttypid ORDER BY ATTNUM) AS atttypids
1454+
FROM pg_attribute
1455+
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
1456+
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
1457+
WHERE NOT attisdropped
1458+
AND attnum > 0
1459+
GROUP BY pg_type.oid
1460+
)
1461+
SELECT typname,
1462+
typtype,
1463+
typbasetype,
1464+
typelem,
1465+
pg_type.oid,`)
1466+
if supportsMultirange {
1467+
parts = append(parts, `
1468+
COALESCE(multirange.rngtypid, 0) AS rngtypid,`)
1469+
} else {
1470+
parts = append(parts, `
1471+
0 AS rngtypid,`)
1472+
}
1473+
parts = append(parts, `
1474+
COALESCE(pg_range.rngsubtype, 0) AS rngsubtype,
1475+
attnames, atttypids
1476+
FROM relationships
1477+
INNER JOIN pg_type ON (pg_type.oid IN ( relationships.child,relationships.parent) )
1478+
LEFT OUTER JOIN pg_range ON (pg_type.oid = pg_range.rngtypid)`)
1479+
if supportsMultirange {
1480+
parts = append(parts, `
1481+
LEFT OUTER JOIN pg_range multirange ON (pg_type.oid = multirange.rngmultitypid)`)
1482+
}
1483+
1484+
parts = append(parts, `
1485+
LEFT OUTER JOIN composite USING (oid)
1486+
WHERE typtype != 'b'`)
1487+
parts = append(parts, `
1488+
GROUP BY typname, typtype, typbasetype, typelem, pg_type.oid, pg_range.rngsubtype,`)
1489+
if supportsMultirange {
1490+
parts = append(parts, `
1491+
multirange.rngtypid,`)
1492+
}
1493+
parts = append(parts, `
1494+
attnames, atttypids
1495+
ORDER BY MAX(depth) desc, typname;`)
1496+
return strings.Join(parts, "")
1497+
}
1498+
1499+
// LoadTypes inspects the database for []typeNames and automatically registers all discovered
1500+
// types. Any types referenced by these will also be included in the registration.
1501+
func (c *Conn) LoadAndRegisterTypes(ctx context.Context, typeNames []string) error {
1502+
if typeNames == nil || len(typeNames) == 0 {
1503+
return fmt.Errorf("No type names were supplied.")
1504+
}
1505+
return c.loadAndRegisterTypes(ctx, typeNames, c.TypeMap())
1506+
}
1507+
1508+
func (c *Conn) loadAndRegisterTypes(ctx context.Context, typeNames []string, registerWith *pgtype.Map) error {
1509+
if registerWith == nil {
1510+
return fmt.Errorf("Type map must be supplied")
1511+
}
1512+
serverVersion, err := c.ServerVersion()
1513+
if err != nil {
1514+
return fmt.Errorf("Unexpected server version error: %w", err)
1515+
}
1516+
sql := buildLoadTypesSQL(serverVersion, c.IsCockroachDB(), typeNames)
1517+
var rows Rows
1518+
if typeNames == nil {
1519+
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol)
1520+
} else {
1521+
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
1522+
}
1523+
if err != nil {
1524+
return fmt.Errorf("While generating load types query: %w", err)
1525+
}
1526+
defer rows.Close()
1527+
for rows.Next() {
1528+
var oid uint32
1529+
var typeName, typtype string
1530+
var typbasetype, typelem uint32
1531+
var rngsubtype, rngtypid uint32
1532+
attnames := make([]string, 0, 0)
1533+
atttypids := make([]uint32, 0, 0)
1534+
err = rows.Scan(&typeName, &typtype, &typbasetype, &typelem, &oid, &rngtypid, &rngsubtype, &attnames, &atttypids)
1535+
if err != nil {
1536+
return fmt.Errorf("While scanning type information: %w", err)
1537+
}
1538+
1539+
switch typtype {
1540+
case "b": // array
1541+
dt, ok := c.TypeMap().TypeForOID(typelem)
1542+
if !ok {
1543+
return fmt.Errorf("array element OID not registered while loading %v", typeName)
1544+
}
1545+
registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.ArrayCodec{ElementType: dt}})
1546+
case "c": // composite
1547+
var fields []pgtype.CompositeCodecField
1548+
for i, fieldName := range attnames {
1549+
//if fieldOID64, err = strconv.ParseUint(composite_fields[i+1], 10, 32); err != nil {
1550+
// return nil, fmt.Errorf("While extracting OID used in composite field: %w", err)
1551+
//}
1552+
dt, ok := c.TypeMap().TypeForOID(atttypids[i])
1553+
if !ok {
1554+
return fmt.Errorf("unknown composite type field OID %v (%v)", atttypids[i], fieldName)
1555+
}
1556+
fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt})
1557+
}
1558+
if err != nil {
1559+
return fmt.Errorf("While parsing %v: %w", typeName, err)
1560+
}
1561+
1562+
registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.CompositeCodec{Fields: fields}})
1563+
case "d": // domain
1564+
dt, ok := c.TypeMap().TypeForOID(typbasetype)
1565+
if !ok {
1566+
return fmt.Errorf("domain base type OID not registered for %v", typeName)
1567+
}
1568+
1569+
registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: dt.Codec})
1570+
case "e": // enum
1571+
registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.EnumCodec{}})
1572+
case "r": // range
1573+
dt, ok := c.TypeMap().TypeForOID(rngsubtype)
1574+
if !ok {
1575+
return fmt.Errorf("range element OID %v not registered for %v", rngsubtype, typeName)
1576+
}
1577+
1578+
registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.RangeCodec{ElementType: dt}})
1579+
case "m": // multirange
1580+
dt, ok := c.TypeMap().TypeForOID(rngtypid)
1581+
if !ok {
1582+
return fmt.Errorf("multirange element OID %v not registered while loading %v", rngtypid, typeName)
1583+
}
1584+
1585+
registerWith.RegisterType(&pgtype.Type{Name: typeName, OID: oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}})
1586+
default:
1587+
return fmt.Errorf("unknown typtype %v for %v", typtype, typeName)
1588+
}
1589+
}
1590+
return nil
1591+
}
1592+
1593+
func (conn *Conn) IsCockroachDB() bool {
1594+
return conn.PgConn().ParameterStatus("crdb_version") != ""
1595+
}
1596+
1597+
func (conn *Conn) ServerVersion() (int64, error) {
1598+
serverVersionStr := conn.PgConn().ParameterStatus("server_version")
1599+
serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
1600+
// if not PostgreSQL do nothing
1601+
if serverVersionStr == "" {
1602+
return 0, fmt.Errorf("Cannot identify server version in %q", serverVersionStr)
1603+
}
1604+
1605+
serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64)
1606+
if err != nil {
1607+
return 0, fmt.Errorf("postgres version parsing failed: %w", err)
1608+
}
1609+
return serverVersion, nil
1610+
}

pgtype/composite_test.go

Lines changed: 46 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,11 +10,56 @@ import (
1010
"github.com/stretchr/testify/require"
1111
)
1212

13-
func TestCompositeCodecTranscode(t *testing.T) {
13+
func TestCompositeCodecTranscodeWithLoadTypes(t *testing.T) {
1414
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
1515

1616
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
17+
_, err := conn.Exec(ctx, `drop domain if exists anotheruint64;
18+
drop type if exists ct_test;
19+
create domain anotheruint64 as numeric(20,0);
20+
21+
create type ct_test as (
22+
a text,
23+
b int4,
24+
c anotheruint64
25+
);`)
26+
require.NoError(t, err)
27+
defer conn.Exec(ctx, "drop type ct_test")
28+
defer conn.Exec(ctx, "drop domain anotheruint64")
29+
30+
err = conn.LoadAndRegisterTypes(ctx, []string{"ct_test"})
31+
require.NoError(t, err)
32+
33+
formats := []struct {
34+
name string
35+
code int16
36+
}{
37+
{name: "TextFormat", code: pgx.TextFormatCode},
38+
{name: "BinaryFormat", code: pgx.BinaryFormatCode},
39+
}
40+
41+
for _, format := range formats {
42+
var a string
43+
var b int32
44+
var c uint64
1745

46+
err := conn.QueryRow(ctx, "select $1::ct_test", pgx.QueryResultFormats{format.code},
47+
pgtype.CompositeFields{"hi", int32(42), uint64(123)},
48+
).Scan(
49+
pgtype.CompositeFields{&a, &b, &c},
50+
)
51+
require.NoErrorf(t, err, "%v", format.name)
52+
require.EqualValuesf(t, "hi", a, "%v", format.name)
53+
require.EqualValuesf(t, 42, b, "%v", format.name)
54+
require.EqualValuesf(t, 123, c, "%v", format.name)
55+
}
56+
})
57+
}
58+
59+
func TestCompositeCodecTranscode(t *testing.T) {
60+
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
61+
62+
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
1863
_, err := conn.Exec(ctx, `drop type if exists ct_test;
1964
2065
create type ct_test as (
@@ -94,7 +139,6 @@ func TestCompositeCodecTranscodeStruct(t *testing.T) {
94139
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
95140

96141
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
97-
98142
_, err := conn.Exec(ctx, `drop type if exists point3d;
99143
100144
create type point3d as (
@@ -131,7 +175,6 @@ func TestCompositeCodecTranscodeStructWrapper(t *testing.T) {
131175
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
132176

133177
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
134-
135178
_, err := conn.Exec(ctx, `drop type if exists point3d;
136179
137180
create type point3d as (
@@ -172,7 +215,6 @@ func TestCompositeCodecDecodeValue(t *testing.T) {
172215
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
173216

174217
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
175-
176218
_, err := conn.Exec(ctx, `drop type if exists point3d;
177219
178220
create type point3d as (
@@ -217,7 +259,6 @@ func TestCompositeCodecTranscodeStructWrapperForTable(t *testing.T) {
217259
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
218260

219261
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
220-
221262
_, err := conn.Exec(ctx, `drop table if exists point3d;
222263
223264
create table point3d (

0 commit comments

Comments
 (0)