Skip to content

Commit 22fb4e4

Browse files
committed
Load types using a single SQL query
When loading even a single type into pgx's type map, multiple SQL queries are performed in series. Over a slow link, this is not ideal. Worse, if multiple types are being registered, this is repeated multiple times. This commit add LoadTypes, which can retrieve type mapping information for multiple types in a single SQL call, including recursive fetching of dependent types. RegisterTypes performs the second stage of this operation.
1 parent 9907b87 commit 22fb4e4

File tree

8 files changed

+413
-8
lines changed

8 files changed

+413
-8
lines changed

conn.go

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,11 +41,16 @@ type ConnConfig struct {
4141
DefaultQueryExecMode QueryExecMode
4242

4343
createdByParseConfig bool // Used to enforce created by ParseConfig rule.
44+
45+
// automatically call LoadTypes with these values
46+
AutoLoadTypes []string
4447
}
4548

4649
// ParseConfigOptions contains options that control how a config is built such as getsslpassword.
4750
type ParseConfigOptions struct {
4851
pgconn.ParseConfigOptions
52+
53+
AutoLoadTypes []string
4954
}
5055

5156
// Copy returns a deep copy of the config that is safe to use and modify.
@@ -107,8 +112,10 @@ var (
107112
ErrTooManyRows = errors.New("too many rows in result set")
108113
)
109114

110-
var errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
111-
var errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
115+
var (
116+
errDisabledStatementCache = fmt.Errorf("cannot use QueryExecModeCacheStatement with disabled statement cache")
117+
errDisabledDescriptionCache = fmt.Errorf("cannot use QueryExecModeCacheDescribe with disabled description cache")
118+
)
112119

113120
// Connect establishes a connection with a PostgreSQL server with a connection string. See
114121
// pgconn.Connect for details.
@@ -194,6 +201,7 @@ func ParseConfigWithOptions(connString string, options ParseConfigOptions) (*Con
194201
DescriptionCacheCapacity: descriptionCacheCapacity,
195202
DefaultQueryExecMode: defaultQueryExecMode,
196203
connString: connString,
204+
AutoLoadTypes: options.AutoLoadTypes,
197205
}
198206

199207
return connConfig, nil
@@ -271,6 +279,14 @@ func connect(ctx context.Context, config *ConnConfig) (c *Conn, err error) {
271279
c.descriptionCache = stmtcache.NewLRUCache(c.config.DescriptionCacheCapacity)
272280
}
273281

282+
if c.config.AutoLoadTypes != nil {
283+
if types, err := LoadTypes(ctx, c, c.config.AutoLoadTypes); err == nil {
284+
c.TypeMap().RegisterTypes(types)
285+
} else {
286+
return nil, err
287+
}
288+
}
289+
274290
return c, nil
275291
}
276292

@@ -843,7 +859,6 @@ func (c *Conn) getStatementDescription(
843859
mode QueryExecMode,
844860
sql string,
845861
) (sd *pgconn.StatementDescription, err error) {
846-
847862
switch mode {
848863
case QueryExecModeCacheStatement:
849864
if c.statementCache == nil {

derived_types.go

Lines changed: 256 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,256 @@
1+
package pgx
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"regexp"
7+
"strconv"
8+
"strings"
9+
10+
"github.com/jackc/pgx/v5/pgtype"
11+
)
12+
13+
/*
14+
buildLoadDerivedTypesSQL generates the correct query for retrieving type information.
15+
16+
pgVersion: the major version of the PostgreSQL server
17+
typeNames: the names of the types to load. If nil, load all types.
18+
*/
19+
func buildLoadDerivedTypesSQL(pgVersion int64, typeNames []string) string {
20+
supportsMultirange := (pgVersion >= 14)
21+
var typeNamesClause string
22+
23+
if typeNames == nil {
24+
// collect all types. Not currently recommended.
25+
typeNamesClause = "IS NOT NULL"
26+
} else {
27+
typeNamesClause = "= ANY($1)"
28+
}
29+
parts := make([]string, 0, 10)
30+
31+
// Each of the type names provided might be found in pg_class or pg_type.
32+
// Additionally, it may or may not include a schema portion.
33+
parts = append(parts, `
34+
WITH RECURSIVE
35+
-- find the OIDs in pg_class which match one of the provided type names
36+
selected_classes(oid,reltype) AS (
37+
-- this query uses the namespace search path, so will match type names without a schema prefix
38+
SELECT pg_class.oid, pg_class.reltype
39+
FROM pg_catalog.pg_class
40+
LEFT JOIN pg_catalog.pg_namespace n ON n.oid = pg_class.relnamespace
41+
WHERE pg_catalog.pg_table_is_visible(pg_class.oid)
42+
AND relname `, typeNamesClause, `
43+
UNION ALL
44+
-- this query will only match type names which include the schema prefix
45+
SELECT pg_class.oid, pg_class.reltype
46+
FROM pg_class
47+
INNER JOIN pg_namespace ON (pg_class.relnamespace = pg_namespace.oid)
48+
WHERE nspname || '.' || relname `, typeNamesClause, `
49+
),
50+
selected_types(oid) AS (
51+
-- collect the OIDs from pg_types which correspond to the selected classes
52+
SELECT reltype AS oid
53+
FROM selected_classes
54+
UNION ALL
55+
-- as well as any other type names which match our criteria
56+
SELECT oid
57+
FROM pg_type
58+
WHERE typname `, typeNamesClause, `
59+
),
60+
-- this builds a parent/child mapping of objects, allowing us to know
61+
-- all the child (ie: dependent) types that a parent (type) requires
62+
-- As can be seen, there are 3 ways this can occur (the last of which
63+
-- is due to being a composite class, where the composite fields are children)
64+
pc(parent, child) AS (
65+
SELECT parent.oid, parent.typelem
66+
FROM pg_type parent
67+
WHERE parent.typtype = 'b' AND parent.typelem != 0
68+
UNION ALL
69+
SELECT parent.oid, parent.typbasetype
70+
FROM pg_type parent
71+
WHERE parent.typtypmod = -1 AND parent.typbasetype != 0
72+
UNION ALL
73+
SELECT pg_type.oid, atttypid
74+
FROM pg_attribute
75+
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
76+
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
77+
WHERE NOT attisdropped
78+
AND attnum > 0
79+
),
80+
-- Now construct a recursive query which includes a 'depth' element.
81+
-- This is used to ensure that the "youngest" children are registered before
82+
-- their parents.
83+
relationships(parent, child, depth) AS (
84+
SELECT DISTINCT 0::OID, selected_types.oid, 0
85+
FROM selected_types
86+
UNION ALL
87+
SELECT pg_type.oid AS parent, pg_attribute.atttypid AS child, 1
88+
FROM selected_classes c
89+
inner join pg_type ON (c.reltype = pg_type.oid)
90+
inner join pg_attribute on (c.oid = pg_attribute.attrelid)
91+
UNION ALL
92+
SELECT pc.parent, pc.child, relationships.depth + 1
93+
FROM pc
94+
INNER JOIN relationships ON (pc.parent = relationships.child)
95+
),
96+
-- composite fields need to be encapsulated as a couple of arrays to provide the required information for registration
97+
composite AS (
98+
SELECT pg_type.oid, ARRAY_AGG(attname ORDER BY attnum) AS attnames, ARRAY_AGG(atttypid ORDER BY ATTNUM) AS atttypids
99+
FROM pg_attribute
100+
INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid)
101+
INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype)
102+
WHERE NOT attisdropped
103+
AND attnum > 0
104+
GROUP BY pg_type.oid
105+
)
106+
-- Bring together this information, showing all the information which might possibly be required
107+
-- to complete the registration, applying filters to only show the items which relate to the selected
108+
-- types/classes.
109+
SELECT typname,
110+
typtype,
111+
typbasetype,
112+
typelem,
113+
pg_type.oid,`)
114+
if supportsMultirange {
115+
parts = append(parts, `
116+
COALESCE(multirange.rngtypid, 0) AS rngtypid,`)
117+
} else {
118+
parts = append(parts, `
119+
0 AS rngtypid,`)
120+
}
121+
parts = append(parts, `
122+
COALESCE(pg_range.rngsubtype, 0) AS rngsubtype,
123+
attnames, atttypids
124+
FROM relationships
125+
INNER JOIN pg_type ON (pg_type.oid = relationships.child)
126+
LEFT OUTER JOIN pg_range ON (pg_type.oid = pg_range.rngtypid)`)
127+
if supportsMultirange {
128+
parts = append(parts, `
129+
LEFT OUTER JOIN pg_range multirange ON (pg_type.oid = multirange.rngmultitypid)`)
130+
}
131+
132+
parts = append(parts, `
133+
LEFT OUTER JOIN composite USING (oid)
134+
WHERE NOT (typtype = 'b' AND typelem = 0)`)
135+
parts = append(parts, `
136+
GROUP BY typname, typtype, typbasetype, typelem, pg_type.oid, pg_range.rngsubtype,`)
137+
if supportsMultirange {
138+
parts = append(parts, `
139+
multirange.rngtypid,`)
140+
}
141+
parts = append(parts, `
142+
attnames, atttypids
143+
ORDER BY MAX(depth) desc, typname;`)
144+
return strings.Join(parts, "")
145+
}
146+
147+
type derivedTypeInfo struct {
148+
Oid, Typbasetype, Typelem, Rngsubtype, Rngtypid uint32
149+
TypeName, Typtype string
150+
Attnames []string
151+
Atttypids []uint32
152+
}
153+
154+
// LoadTypes performs a single (complex) query, returning all the required
155+
// information to register the named types, as well as any other types directly
156+
// or indirectly required to complete the registration.
157+
// The result of this call can be passed into RegisterTypes to complete the process.
158+
func LoadTypes(ctx context.Context, c *Conn, typeNames []string) ([]*pgtype.Type, error) {
159+
m := c.TypeMap().Copy()
160+
if typeNames == nil || len(typeNames) == 0 {
161+
return nil, fmt.Errorf("No type names were supplied.")
162+
}
163+
164+
serverVersion, err := serverVersion(c)
165+
if err != nil {
166+
return nil, fmt.Errorf("Unexpected server version error: %w", err)
167+
}
168+
sql := buildLoadDerivedTypesSQL(serverVersion, typeNames)
169+
var rows Rows
170+
if typeNames == nil {
171+
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol)
172+
} else {
173+
rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames)
174+
}
175+
if err != nil {
176+
return nil, fmt.Errorf("While generating load types query: %w", err)
177+
}
178+
defer rows.Close()
179+
result := make([]*pgtype.Type, 0, 100)
180+
for rows.Next() {
181+
ti := derivedTypeInfo{}
182+
err = rows.Scan(&ti.TypeName, &ti.Typtype, &ti.Typbasetype, &ti.Typelem, &ti.Oid, &ti.Rngtypid, &ti.Rngsubtype, &ti.Attnames, &ti.Atttypids)
183+
if err != nil {
184+
return nil, fmt.Errorf("While scanning type information: %w", err)
185+
}
186+
var type_ *pgtype.Type
187+
switch ti.Typtype {
188+
case "b": // array
189+
dt, ok := m.TypeForOID(ti.Typelem)
190+
if !ok {
191+
return nil, fmt.Errorf("Array element OID %v not registered while loading pgtype %q", ti.Typelem, ti.TypeName)
192+
}
193+
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.ArrayCodec{ElementType: dt}}
194+
case "c": // composite
195+
var fields []pgtype.CompositeCodecField
196+
for i, fieldName := range ti.Attnames {
197+
//if fieldOID64, err = strconv.ParseUint(composite_fields[i+1], 10, 32); err != nil {
198+
// return nil, fmt.Errorf("While extracting OID used in composite field: %w", err)
199+
//}
200+
dt, ok := m.TypeForOID(ti.Atttypids[i])
201+
if !ok {
202+
return nil, fmt.Errorf("Unknown field for composite type %q: field %q (OID %v) is not already registered.", ti.TypeName, fieldName, ti.Atttypids[i])
203+
}
204+
fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt})
205+
}
206+
207+
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.CompositeCodec{Fields: fields}}
208+
case "d": // domain
209+
dt, ok := m.TypeForOID(ti.Typbasetype)
210+
if !ok {
211+
return nil, fmt.Errorf("Domain base type OID %v was not already registered, needed for %q", ti.Typbasetype, ti.TypeName)
212+
}
213+
214+
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: dt.Codec}
215+
case "e": // enum
216+
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.EnumCodec{}}
217+
case "r": // range
218+
dt, ok := m.TypeForOID(ti.Rngsubtype)
219+
if !ok {
220+
return nil, fmt.Errorf("Range element OID %v was not already registered, needed for %q", ti.Rngsubtype, ti.TypeName)
221+
}
222+
223+
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.RangeCodec{ElementType: dt}}
224+
case "m": // multirange
225+
dt, ok := m.TypeForOID(ti.Rngtypid)
226+
if !ok {
227+
return nil, fmt.Errorf("Multirange element OID %v was not already registered, needed for %q", ti.Rngtypid, ti.TypeName)
228+
}
229+
230+
type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}}
231+
default:
232+
return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName)
233+
}
234+
if type_ != nil {
235+
m.RegisterType(type_)
236+
result = append(result, type_)
237+
}
238+
}
239+
return result, nil
240+
}
241+
242+
// serverVersion returns the postgresql server version.
243+
func serverVersion(c *Conn) (int64, error) {
244+
serverVersionStr := c.PgConn().ParameterStatus("server_version")
245+
serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr)
246+
// if not PostgreSQL do nothing
247+
if serverVersionStr == "" {
248+
return 0, fmt.Errorf("Cannot identify server version in %q", serverVersionStr)
249+
}
250+
251+
serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64)
252+
if err != nil {
253+
return 0, fmt.Errorf("postgres version parsing failed: %w", err)
254+
}
255+
return serverVersion, nil
256+
}

derived_types_test.go

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
package pgx_test
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/jackc/pgx/v5"
8+
"github.com/stretchr/testify/require"
9+
)
10+
11+
func TestCompositeCodecTranscodeWithLoadTypes(t *testing.T) {
12+
skipCockroachDB(t, "Server does not support composite types (see https://github.com/cockroachdb/cockroach/issues/27792)")
13+
14+
defaultConnTestRunner.RunTest(context.Background(), t, func(ctx context.Context, t testing.TB, conn *pgx.Conn) {
15+
_, err := conn.Exec(ctx, `
16+
drop type if exists dtype_test;
17+
drop domain if exists anotheruint64;
18+
19+
create domain anotheruint64 as numeric(20,0);
20+
create type dtype_test as (
21+
a text,
22+
b int4,
23+
c anotheruint64,
24+
d anotheruint64[]
25+
);`)
26+
require.NoError(t, err)
27+
defer conn.Exec(ctx, "drop type dtype_test")
28+
defer conn.Exec(ctx, "drop domain anotheruint64")
29+
30+
types, err := pgx.LoadTypes(ctx, conn, []string{"dtype_test"})
31+
require.NoError(t, err)
32+
require.Len(t, types, 3)
33+
require.Equal(t, types[0].Name, "anotheruint64")
34+
require.Equal(t, types[1].Name, "_anotheruint64")
35+
require.Equal(t, types[2].Name, "dtype_test")
36+
})
37+
}

go.sum

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,6 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c
44
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
55
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
66
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
7-
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
8-
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
9-
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9 h1:L0QtFUgDarD7Fpv9jeVMgy/+Ec0mtnmYuImjTz6dtDA=
10-
github.com/jackc/pgservicefile v0.0.0-20231201235250-de7065d80cb9/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
117
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
128
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
139
github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=

0 commit comments

Comments
 (0)