|
| 1 | +package pgx |
| 2 | + |
| 3 | +import ( |
| 4 | + "context" |
| 5 | + "fmt" |
| 6 | + "regexp" |
| 7 | + "strconv" |
| 8 | + "strings" |
| 9 | + |
| 10 | + "github.com/jackc/pgx/v5/pgtype" |
| 11 | +) |
| 12 | + |
| 13 | +/* |
| 14 | +buildLoadDerivedTypesSQL generates the correct query for retrieving type information. |
| 15 | +
|
| 16 | + pgVersion: the major version of the PostgreSQL server |
| 17 | + typeNames: the names of the types to load. If nil, load all types. |
| 18 | +*/ |
| 19 | +func buildLoadDerivedTypesSQL(pgVersion int64, typeNames []string) string { |
| 20 | + supportsMultirange := (pgVersion >= 14) |
| 21 | + var typeNamesClause string |
| 22 | + |
| 23 | + if typeNames == nil { |
| 24 | + // collect all types. Not currently recommended. |
| 25 | + typeNamesClause = "IS NOT NULL" |
| 26 | + } else { |
| 27 | + typeNamesClause = "= ANY($1)" |
| 28 | + } |
| 29 | + parts := make([]string, 0, 10) |
| 30 | + |
| 31 | + // Each of the type names provided might be found in pg_class or pg_type. |
| 32 | + // Additionally, it may or may not include a schema portion. |
| 33 | + parts = append(parts, ` |
| 34 | +WITH RECURSIVE |
| 35 | +-- find the OIDs in pg_class which match one of the provided type names |
| 36 | +selected_classes(oid,reltype) AS ( |
| 37 | + -- this query uses the namespace search path, so will match type names without a schema prefix |
| 38 | + SELECT pg_class.oid, pg_class.reltype |
| 39 | + FROM pg_catalog.pg_class |
| 40 | + LEFT JOIN pg_catalog.pg_namespace n ON n.oid = pg_class.relnamespace |
| 41 | + WHERE pg_catalog.pg_table_is_visible(pg_class.oid) |
| 42 | + AND relname `, typeNamesClause, ` |
| 43 | +UNION ALL |
| 44 | + -- this query will only match type names which include the schema prefix |
| 45 | + SELECT pg_class.oid, pg_class.reltype |
| 46 | + FROM pg_class |
| 47 | + INNER JOIN pg_namespace ON (pg_class.relnamespace = pg_namespace.oid) |
| 48 | + WHERE nspname || '.' || relname `, typeNamesClause, ` |
| 49 | +), |
| 50 | +selected_types(oid) AS ( |
| 51 | + -- collect the OIDs from pg_types which correspond to the selected classes |
| 52 | + SELECT reltype AS oid |
| 53 | + FROM selected_classes |
| 54 | +UNION ALL |
| 55 | + -- as well as any other type names which match our criteria |
| 56 | + SELECT oid |
| 57 | + FROM pg_type |
| 58 | + WHERE typname `, typeNamesClause, ` |
| 59 | +), |
| 60 | +-- this builds a parent/child mapping of objects, allowing us to know |
| 61 | +-- all the child (ie: dependent) types that a parent (type) requires |
| 62 | +-- As can be seen, there are 3 ways this can occur (the last of which |
| 63 | +-- is due to being a composite class, where the composite fields are children) |
| 64 | +pc(parent, child) AS ( |
| 65 | + SELECT parent.oid, parent.typelem |
| 66 | + FROM pg_type parent |
| 67 | + WHERE parent.typtype = 'b' AND parent.typelem != 0 |
| 68 | +UNION ALL |
| 69 | + SELECT parent.oid, parent.typbasetype |
| 70 | + FROM pg_type parent |
| 71 | + WHERE parent.typtypmod = -1 AND parent.typbasetype != 0 |
| 72 | +UNION ALL |
| 73 | + SELECT pg_type.oid, atttypid |
| 74 | + FROM pg_attribute |
| 75 | + INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid) |
| 76 | + INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype) |
| 77 | + WHERE NOT attisdropped |
| 78 | + AND attnum > 0 |
| 79 | +), |
| 80 | +-- Now construct a recursive query which includes a 'depth' element. |
| 81 | +-- This is used to ensure that the "youngest" children are registered before |
| 82 | +-- their parents. |
| 83 | +relationships(parent, child, depth) AS ( |
| 84 | + SELECT DISTINCT 0::OID, selected_types.oid, 0 |
| 85 | + FROM selected_types |
| 86 | +UNION ALL |
| 87 | + SELECT pg_type.oid AS parent, pg_attribute.atttypid AS child, 1 |
| 88 | + FROM selected_classes c |
| 89 | + inner join pg_type ON (c.reltype = pg_type.oid) |
| 90 | + inner join pg_attribute on (c.oid = pg_attribute.attrelid) |
| 91 | +UNION ALL |
| 92 | + SELECT pc.parent, pc.child, relationships.depth + 1 |
| 93 | + FROM pc |
| 94 | + INNER JOIN relationships ON (pc.parent = relationships.child) |
| 95 | +), |
| 96 | +-- composite fields need to be encapsulated as a couple of arrays to provide the required information for registration |
| 97 | +composite AS ( |
| 98 | + SELECT pg_type.oid, ARRAY_AGG(attname ORDER BY attnum) AS attnames, ARRAY_AGG(atttypid ORDER BY ATTNUM) AS atttypids |
| 99 | + FROM pg_attribute |
| 100 | + INNER JOIN pg_class ON (pg_class.oid = pg_attribute.attrelid) |
| 101 | + INNER JOIN pg_type ON (pg_type.oid = pg_class.reltype) |
| 102 | + WHERE NOT attisdropped |
| 103 | + AND attnum > 0 |
| 104 | + GROUP BY pg_type.oid |
| 105 | +) |
| 106 | +-- Bring together this information, showing all the information which might possibly be required |
| 107 | +-- to complete the registration, applying filters to only show the items which relate to the selected |
| 108 | +-- types/classes. |
| 109 | +SELECT typname, |
| 110 | + typtype, |
| 111 | + typbasetype, |
| 112 | + typelem, |
| 113 | + pg_type.oid,`) |
| 114 | + if supportsMultirange { |
| 115 | + parts = append(parts, ` |
| 116 | + COALESCE(multirange.rngtypid, 0) AS rngtypid,`) |
| 117 | + } else { |
| 118 | + parts = append(parts, ` |
| 119 | + 0 AS rngtypid,`) |
| 120 | + } |
| 121 | + parts = append(parts, ` |
| 122 | + COALESCE(pg_range.rngsubtype, 0) AS rngsubtype, |
| 123 | + attnames, atttypids |
| 124 | + FROM relationships |
| 125 | + INNER JOIN pg_type ON (pg_type.oid = relationships.child) |
| 126 | + LEFT OUTER JOIN pg_range ON (pg_type.oid = pg_range.rngtypid)`) |
| 127 | + if supportsMultirange { |
| 128 | + parts = append(parts, ` |
| 129 | + LEFT OUTER JOIN pg_range multirange ON (pg_type.oid = multirange.rngmultitypid)`) |
| 130 | + } |
| 131 | + |
| 132 | + parts = append(parts, ` |
| 133 | + LEFT OUTER JOIN composite USING (oid) |
| 134 | + WHERE NOT (typtype = 'b' AND typelem = 0)`) |
| 135 | + parts = append(parts, ` |
| 136 | + GROUP BY typname, typtype, typbasetype, typelem, pg_type.oid, pg_range.rngsubtype,`) |
| 137 | + if supportsMultirange { |
| 138 | + parts = append(parts, ` |
| 139 | + multirange.rngtypid,`) |
| 140 | + } |
| 141 | + parts = append(parts, ` |
| 142 | + attnames, atttypids |
| 143 | + ORDER BY MAX(depth) desc, typname;`) |
| 144 | + return strings.Join(parts, "") |
| 145 | +} |
| 146 | + |
| 147 | +type derivedTypeInfo struct { |
| 148 | + Oid, Typbasetype, Typelem, Rngsubtype, Rngtypid uint32 |
| 149 | + TypeName, Typtype string |
| 150 | + Attnames []string |
| 151 | + Atttypids []uint32 |
| 152 | +} |
| 153 | + |
| 154 | +// LoadTypes performs a single (complex) query, returning all the required |
| 155 | +// information to register the named types, as well as any other types directly |
| 156 | +// or indirectly required to complete the registration. |
| 157 | +// The result of this call can be passed into RegisterTypes to complete the process. |
| 158 | +func LoadTypes(ctx context.Context, c *Conn, typeNames []string) ([]*pgtype.Type, error) { |
| 159 | + m := c.TypeMap().Copy() |
| 160 | + if typeNames == nil || len(typeNames) == 0 { |
| 161 | + return nil, fmt.Errorf("No type names were supplied.") |
| 162 | + } |
| 163 | + |
| 164 | + serverVersion, err := serverVersion(c) |
| 165 | + if err != nil { |
| 166 | + return nil, fmt.Errorf("Unexpected server version error: %w", err) |
| 167 | + } |
| 168 | + sql := buildLoadDerivedTypesSQL(serverVersion, typeNames) |
| 169 | + var rows Rows |
| 170 | + if typeNames == nil { |
| 171 | + rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol) |
| 172 | + } else { |
| 173 | + rows, err = c.Query(ctx, sql, QueryExecModeSimpleProtocol, typeNames) |
| 174 | + } |
| 175 | + if err != nil { |
| 176 | + return nil, fmt.Errorf("While generating load types query: %w", err) |
| 177 | + } |
| 178 | + defer rows.Close() |
| 179 | + result := make([]*pgtype.Type, 0, 100) |
| 180 | + for rows.Next() { |
| 181 | + ti := derivedTypeInfo{} |
| 182 | + err = rows.Scan(&ti.TypeName, &ti.Typtype, &ti.Typbasetype, &ti.Typelem, &ti.Oid, &ti.Rngtypid, &ti.Rngsubtype, &ti.Attnames, &ti.Atttypids) |
| 183 | + if err != nil { |
| 184 | + return nil, fmt.Errorf("While scanning type information: %w", err) |
| 185 | + } |
| 186 | + var type_ *pgtype.Type |
| 187 | + switch ti.Typtype { |
| 188 | + case "b": // array |
| 189 | + dt, ok := m.TypeForOID(ti.Typelem) |
| 190 | + if !ok { |
| 191 | + return nil, fmt.Errorf("Array element OID %v not registered while loading pgtype %q", ti.Typelem, ti.TypeName) |
| 192 | + } |
| 193 | + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.ArrayCodec{ElementType: dt}} |
| 194 | + case "c": // composite |
| 195 | + var fields []pgtype.CompositeCodecField |
| 196 | + for i, fieldName := range ti.Attnames { |
| 197 | + //if fieldOID64, err = strconv.ParseUint(composite_fields[i+1], 10, 32); err != nil { |
| 198 | + // return nil, fmt.Errorf("While extracting OID used in composite field: %w", err) |
| 199 | + //} |
| 200 | + dt, ok := m.TypeForOID(ti.Atttypids[i]) |
| 201 | + if !ok { |
| 202 | + return nil, fmt.Errorf("Unknown field for composite type %q: field %q (OID %v) is not already registered.", ti.TypeName, fieldName, ti.Atttypids[i]) |
| 203 | + } |
| 204 | + fields = append(fields, pgtype.CompositeCodecField{Name: fieldName, Type: dt}) |
| 205 | + } |
| 206 | + |
| 207 | + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.CompositeCodec{Fields: fields}} |
| 208 | + case "d": // domain |
| 209 | + dt, ok := m.TypeForOID(ti.Typbasetype) |
| 210 | + if !ok { |
| 211 | + return nil, fmt.Errorf("Domain base type OID %v was not already registered, needed for %q", ti.Typbasetype, ti.TypeName) |
| 212 | + } |
| 213 | + |
| 214 | + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: dt.Codec} |
| 215 | + case "e": // enum |
| 216 | + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.EnumCodec{}} |
| 217 | + case "r": // range |
| 218 | + dt, ok := m.TypeForOID(ti.Rngsubtype) |
| 219 | + if !ok { |
| 220 | + return nil, fmt.Errorf("Range element OID %v was not already registered, needed for %q", ti.Rngsubtype, ti.TypeName) |
| 221 | + } |
| 222 | + |
| 223 | + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.RangeCodec{ElementType: dt}} |
| 224 | + case "m": // multirange |
| 225 | + dt, ok := m.TypeForOID(ti.Rngtypid) |
| 226 | + if !ok { |
| 227 | + return nil, fmt.Errorf("Multirange element OID %v was not already registered, needed for %q", ti.Rngtypid, ti.TypeName) |
| 228 | + } |
| 229 | + |
| 230 | + type_ = &pgtype.Type{Name: ti.TypeName, OID: ti.Oid, Codec: &pgtype.MultirangeCodec{ElementType: dt}} |
| 231 | + default: |
| 232 | + return nil, fmt.Errorf("Unknown typtype %q was found while registering %q", ti.Typtype, ti.TypeName) |
| 233 | + } |
| 234 | + if type_ != nil { |
| 235 | + m.RegisterType(type_) |
| 236 | + result = append(result, type_) |
| 237 | + } |
| 238 | + } |
| 239 | + return result, nil |
| 240 | +} |
| 241 | + |
| 242 | +// serverVersion returns the postgresql server version. |
| 243 | +func serverVersion(c *Conn) (int64, error) { |
| 244 | + serverVersionStr := c.PgConn().ParameterStatus("server_version") |
| 245 | + serverVersionStr = regexp.MustCompile(`^[0-9]+`).FindString(serverVersionStr) |
| 246 | + // if not PostgreSQL do nothing |
| 247 | + if serverVersionStr == "" { |
| 248 | + return 0, fmt.Errorf("Cannot identify server version in %q", serverVersionStr) |
| 249 | + } |
| 250 | + |
| 251 | + serverVersion, err := strconv.ParseInt(serverVersionStr, 10, 64) |
| 252 | + if err != nil { |
| 253 | + return 0, fmt.Errorf("postgres version parsing failed: %w", err) |
| 254 | + } |
| 255 | + return serverVersion, nil |
| 256 | +} |
0 commit comments