@@ -104,10 +104,13 @@ type Pool struct {
104104 closeOnce sync.Once
105105 closeChan chan struct {}
106106
107- autoLoadTypeNames []string
108- reuseTypeMap bool
109- autoLoadMutex * sync.Mutex
110- autoLoadTypes []* pgtype.Type
107+ autoLoadTypeNames []string
108+ reuseTypeMap bool
109+ autoLoadMutex * sync.Mutex
110+ autoLoadTypes []* pgtype.Type
111+ customRegistrationMap map [string ]CustomRegistrationFunction
112+ customRegistrationMutex * sync.Mutex
113+ customRegistrationOidMap map [string ]uint32
111114}
112115
113116// Config is the configuration struct for creating a pool. It must be created by [ParseConfig] and then it can be
@@ -198,6 +201,10 @@ func New(ctx context.Context, connString string) (*Pool, error) {
198201 return NewWithConfig (ctx , config )
199202}
200203
204+ // CustomRegistrationFunction is capable of registering whatever is necessary for
205+ // a custom type. It is provided with the backend's OID for this type.
206+ type CustomRegistrationFunction func (ctx context.Context , m * pgtype.Map , oid uint32 ) error
207+
201208// NewWithConfig creates a new Pool. config must have been created by [ParseConfig].
202209func NewWithConfig (ctx context.Context , config * Config ) (* Pool , error ) {
203210 // Default values are set in ParseConfig. Enforce initial creation by ParseConfig rather than setting defaults from
@@ -207,23 +214,25 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
207214 }
208215
209216 p := & Pool {
210- config : config ,
211- beforeConnect : config .BeforeConnect ,
212- afterConnect : config .AfterConnect ,
213- autoLoadTypeNames : config .AutoLoadTypes ,
214- reuseTypeMap : config .ReuseTypeMaps ,
215- beforeAcquire : config .BeforeAcquire ,
216- afterRelease : config .AfterRelease ,
217- beforeClose : config .BeforeClose ,
218- minConns : config .MinConns ,
219- maxConns : config .MaxConns ,
220- maxConnLifetime : config .MaxConnLifetime ,
221- maxConnLifetimeJitter : config .MaxConnLifetimeJitter ,
222- maxConnIdleTime : config .MaxConnIdleTime ,
223- healthCheckPeriod : config .HealthCheckPeriod ,
224- healthCheckChan : make (chan struct {}, 1 ),
225- closeChan : make (chan struct {}),
226- autoLoadMutex : new (sync.Mutex ),
217+ config : config ,
218+ beforeConnect : config .BeforeConnect ,
219+ afterConnect : config .AfterConnect ,
220+ autoLoadTypeNames : config .AutoLoadTypes ,
221+ reuseTypeMap : config .ReuseTypeMaps ,
222+ beforeAcquire : config .BeforeAcquire ,
223+ afterRelease : config .AfterRelease ,
224+ beforeClose : config .BeforeClose ,
225+ minConns : config .MinConns ,
226+ maxConns : config .MaxConns ,
227+ maxConnLifetime : config .MaxConnLifetime ,
228+ maxConnLifetimeJitter : config .MaxConnLifetimeJitter ,
229+ maxConnIdleTime : config .MaxConnIdleTime ,
230+ healthCheckPeriod : config .HealthCheckPeriod ,
231+ healthCheckChan : make (chan struct {}, 1 ),
232+ closeChan : make (chan struct {}),
233+ autoLoadMutex : new (sync.Mutex ),
234+ customRegistrationMap : make (map [string ]CustomRegistrationFunction ),
235+ customRegistrationMutex : new (sync.Mutex ),
227236 }
228237
229238 if t , ok := config .ConnConfig .Tracer .(AcquireTracer ); ok {
@@ -265,6 +274,24 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
265274 }
266275 }
267276
277+ if len (p .customRegistrationMap ) > 0 {
278+ oidMap , err := p .getOidMapForCustomRegistration (ctx , conn )
279+ if err != nil {
280+ conn .Close (ctx )
281+ return nil , fmt .Errorf ("While retrieving OIDs for custom type registration: %w" , err )
282+ }
283+ for typeName , f := range p .customRegistrationMap {
284+ if oid , exists := oidMap [typeName ]; exists {
285+ if err := f (ctx , conn .TypeMap (), oid ); err != nil {
286+ return nil , err
287+ }
288+ } else {
289+ return nil , fmt .Errorf ("Type %q does not have an associated OID." , typeName )
290+ }
291+ }
292+
293+ }
294+
268295 if p .autoLoadTypeNames != nil && len (p .autoLoadTypeNames ) > 0 {
269296 types , err := p .loadTypes (ctx , conn , p .autoLoadTypeNames )
270297 if err != nil {
@@ -315,6 +342,51 @@ func NewWithConfig(ctx context.Context, config *Config) (*Pool, error) {
315342 return p , nil
316343}
317344
345+ func (p * Pool ) getOidMapForCustomRegistration (ctx context.Context , conn * pgx.Conn ) (map [string ]uint32 , error ) {
346+ if p .reuseTypeMap {
347+ p .customRegistrationMutex .Lock ()
348+ defer p .customRegistrationMutex .Unlock ()
349+ if p .customRegistrationOidMap != nil {
350+ return p .customRegistrationOidMap , nil
351+ }
352+ oidMap , err := p .fetchOidMapForCustomRegistration (ctx , conn )
353+ if err != nil {
354+ return nil , err
355+ }
356+ p .customRegistrationOidMap = oidMap
357+ return oidMap , nil
358+ }
359+ // Avoid needing to acquire the mutex and allow connections to initialise in parallel
360+ // if we have chosen to not reuse the type mapping
361+ return p .fetchOidMapForCustomRegistration (ctx , conn )
362+ }
363+
364+ func (p * Pool ) fetchOidMapForCustomRegistration (ctx context.Context , conn * pgx.Conn ) (map [string ]uint32 , error ) {
365+ sql := `
366+ SELECT oid, typname
367+ FROM pg_type
368+ WHERE typname = ANY($1)`
369+ result := make (map [string ]uint32 )
370+ typeNames := make ([]string , 0 , len (p .customRegistrationMap ))
371+ for typeName := range p .customRegistrationMap {
372+ typeNames = append (typeNames , typeName )
373+ }
374+ rows , err := conn .Query (ctx , sql , typeNames )
375+ if err != nil {
376+ return nil , fmt .Errorf ("While collecting OIDs for custom registrations: %w" , err )
377+ }
378+ defer rows .Close ()
379+ var typeName string
380+ var oid uint32
381+ for rows .Next () {
382+ if err := rows .Scan (& typeName , & oid ); err != nil {
383+ return nil , fmt .Errorf ("While scanning a row for custom registrations: %w" , err )
384+ }
385+ result [typeName ] = oid
386+ }
387+ return result , nil
388+ }
389+
318390// ParseConfig builds a Config from connString. It parses connString with the same behavior as [pgx.ParseConfig] with the
319391// addition of the following variables:
320392//
@@ -425,6 +497,12 @@ func (p *Pool) Close() {
425497 })
426498}
427499
500+ // RegisterCustomType is used to provide a function capable of performing
501+ // type registration for situations where the autoloader is unable to do so on its own
502+ func (p * Pool ) RegisterCustomType (typeName string , f CustomRegistrationFunction ) {
503+ p .customRegistrationMap [typeName ] = f
504+ }
505+
428506// loadTypes is used internally to autoload the custom types for a connection,
429507// potentially reusing previously-loaded typemap information.
430508func (p * Pool ) loadTypes (ctx context.Context , conn * pgx.Conn , typeNames []string ) ([]* pgtype.Type , error ) {
0 commit comments