Skip to content

Commit 810f7b3

Browse files
committed
add: use pg_basetype() for domain type resolution on PG 17+
1 parent 8f34afd commit 810f7b3

File tree

9 files changed

+128
-85
lines changed

9 files changed

+128
-85
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ All notable changes to this project will be documented in this file. From versio
66

77
### Added
88

9+
- Use `pg_basetype()` for domain type resolution on PostgreSQL 17+ by @joelonsql in #XXXX
910
- Log error when `db-schemas` config contains schema `pg_catalog` or `information_schema` by @taimoorzaeem in #4359
1011

1112
### Fixed

src/PostgREST/App.hs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,8 @@ postgrestResponse appState conf@AppConfig{..} maybeSchemaCache authResult@AuthRe
149149
(parseTime, apiReq@ApiRequest{..}) <- withTiming $ liftEither . mapLeft Error.ApiRequestError $ ApiRequest.userApiRequest conf prefs req body
150150
(planTime, plan) <- withTiming $ liftEither $ Plan.actionPlan iAction conf apiReq sCache
151151

152-
let mainQ = Query.mainQuery plan conf apiReq authResult configDbPreRequest
152+
pgVer <- lift $ AppState.getPgVersion appState
153+
let mainQ = Query.mainQuery pgVer plan conf apiReq authResult configDbPreRequest
153154
tx = MainTx.mainTx mainQ conf authResult apiReq plan sCache
154155
obsQuery s = when configLogQuery $ observer $ QueryObs mainQ s
155156

src/PostgREST/AppState.hs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,9 +401,10 @@ retryingSchemaCacheLoad appState@AppState{stateObserver=observer, stateMainThrea
401401
qSchemaCache :: IO (Maybe SchemaCache)
402402
qSchemaCache = do
403403
conf@AppConfig{..} <- getConfig appState
404+
pgVer <- getPgVersion appState
404405
(resultTime, result) <-
405406
let transaction = if configDbPreparedStatements then SQL.transaction else SQL.unpreparedTransaction in
406-
timeItT $ usePool appState (transaction SQL.ReadCommitted SQL.Read $ querySchemaCache conf)
407+
timeItT $ usePool appState (transaction SQL.ReadCommitted SQL.Read $ querySchemaCache pgVer conf)
407408
case result of
408409
Left e -> do
409410
putSCacheStatus appState SCPending

src/PostgREST/CLI.hs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,11 @@ runAppCommand conf@AppConfig{..} runCmd = do
6161
dumpSchema :: AppState -> IO LBS.ByteString
6262
dumpSchema appState = do
6363
conf@AppConfig{..} <- AppState.getConfig appState
64+
pgVer <- AppState.getPgVersion appState
6465
result <-
6566
let transaction = if configDbPreparedStatements then SQL.transaction else SQL.unpreparedTransaction in
6667
AppState.usePool appState
67-
(transaction SQL.ReadCommitted SQL.Read $ querySchemaCache conf)
68+
(transaction SQL.ReadCommitted SQL.Read $ querySchemaCache pgVer conf)
6869
case result of
6970
Left e -> do
7071
let observer = AppState.getObserver appState

src/PostgREST/Query.hs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import PostgREST.ApiRequest.Preferences (Preferences (..),
2323
shouldExplainCount)
2424
import PostgREST.Auth.Types (AuthResult (..))
2525
import PostgREST.Config (AppConfig (..))
26+
import PostgREST.Config.PgVersion (PgVersion)
2627
import PostgREST.Plan (ActionPlan (..),
2728
CrudPlan (..),
2829
DbActionPlan (..),
@@ -41,9 +42,9 @@ data MainQuery = MainQuery
4142
, mqExplain :: Maybe SQL.Snippet -- ^ the explain query that gets generated for the "Prefer: count=estimated" case
4243
}
4344

44-
mainQuery :: ActionPlan -> AppConfig -> ApiRequest -> AuthResult -> Maybe QualifiedIdentifier -> MainQuery
45-
mainQuery (NoDb _) _ _ _ _ = MainQuery mempty Nothing mempty (mempty, mempty, mempty) mempty
46-
mainQuery (Db plan) conf@AppConfig{..} apiReq@ApiRequest{iPreferences=Preferences{..}} authRes preReq =
45+
mainQuery :: PgVersion -> ActionPlan -> AppConfig -> ApiRequest -> AuthResult -> Maybe QualifiedIdentifier -> MainQuery
46+
mainQuery _ (NoDb _) _ _ _ _ = MainQuery mempty Nothing mempty (mempty, mempty, mempty) mempty
47+
mainQuery pgVer (Db plan) conf@AppConfig{..} apiReq@ApiRequest{iPreferences=Preferences{..}} authRes preReq =
4748
let genQ = MainQuery (PreQuery.txVarQuery plan conf authRes apiReq) (PreQuery.preReqQuery <$> preReq) in
4849
case plan of
4950
DbCrud _ WrappedReadPlan{..} ->
@@ -55,4 +56,4 @@ mainQuery (Db plan) conf@AppConfig{..} apiReq@ApiRequest{iPreferences=Preference
5556
DbCrud _ CallReadPlan{..} ->
5657
genQ (Statements.mainCall crProc crCallPlan crReadPlan preferCount pMedia crHandler) (mempty, mempty, mempty) mempty
5758
MayUseDb InspectPlan{ipSchema=tSchema} ->
58-
genQ mempty (SqlFragment.accessibleTables tSchema, SqlFragment.accessibleFuncs tSchema, SqlFragment.schemaDescription tSchema) mempty
59+
genQ mempty (SqlFragment.accessibleTables tSchema, SqlFragment.accessibleFuncs pgVer tSchema, SqlFragment.schemaDescription tSchema) mempty

src/PostgREST/Query/SqlFragment.hs

Lines changed: 51 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ import PostgREST.ApiRequest.Types (AggregateFunction (..),
7474
OrderNulls (..),
7575
QuantOperator (..),
7676
SimpleOperator (..))
77+
import PostgREST.Config.PgVersion (PgVersion, pgVersion170)
7778
import PostgREST.MediaType (MTVndPlanFormat (..),
7879
MTVndPlanOption (..))
7980
import PostgREST.Plan.ReadPlan (JoinCondition (..))
@@ -614,39 +615,60 @@ accessibleTables schema = SQL.sql (encodeUtf8 [trimming|
614615
where
615616
encodedSchema = SQL.encoderAndParam (HE.nonNullable HE.text) schema
616617

617-
accessibleFuncs :: Text -> SQL.Snippet
618-
accessibleFuncs schema = baseFuncSqlQuery <> "AND p.pronamespace = " <> encodedSchema <> "::regnamespace"
618+
accessibleFuncs :: PgVersion -> Text -> SQL.Snippet
619+
accessibleFuncs pgVer schema = baseFuncSqlQuery pgVer <> "AND p.pronamespace = " <> encodedSchema <> "::regnamespace"
619620
where
620621
encodedSchema = SQL.encoderAndParam (HE.nonNullable HE.text) schema
621622

622-
baseFuncSqlQuery :: SQL.Snippet
623-
baseFuncSqlQuery = SQL.sql $ encodeUtf8 [trimming|
623+
baseTypesCte :: PgVersion -> Text
624+
baseTypesCte pgVer
625+
| pgVer >= pgVersion170 = [trimming|
626+
-- Get base types using pg_basetype() (PG 17+)
627+
base_types AS (
628+
SELECT
629+
t.oid,
630+
bt.typnamespace AS base_namespace,
631+
bt.oid AS base_type
632+
FROM pg_type t
633+
JOIN pg_type bt ON bt.oid = pg_basetype(t.oid)
634+
)
635+
|]
636+
| otherwise = [trimming|
637+
-- Recursively get the base types of domains (PG < 17)
638+
base_types AS (
639+
WITH RECURSIVE
640+
recurse AS (
641+
SELECT
642+
oid,
643+
typbasetype,
644+
typnamespace AS base_namespace,
645+
COALESCE(NULLIF(typbasetype, 0), oid) AS base_type
646+
FROM pg_type
647+
UNION
648+
SELECT
649+
t.oid,
650+
b.typbasetype,
651+
b.typnamespace AS base_namespace,
652+
COALESCE(NULLIF(b.typbasetype, 0), b.oid) AS base_type
653+
FROM recurse t
654+
JOIN pg_type b ON t.typbasetype = b.oid
655+
)
656+
SELECT
657+
oid,
658+
base_namespace,
659+
base_type
660+
FROM recurse
661+
WHERE typbasetype = 0
662+
)
663+
|]
664+
665+
-- | SQL query to get accessible functions for OpenAPI.
666+
baseFuncSqlQuery :: PgVersion -> SQL.Snippet
667+
baseFuncSqlQuery pgVer =
668+
let baseCte = baseTypesCte pgVer
669+
in SQL.sql $ encodeUtf8 [trimming|
624670
WITH
625-
base_types AS (
626-
WITH RECURSIVE
627-
recurse AS (
628-
SELECT
629-
oid,
630-
typbasetype,
631-
typnamespace AS base_namespace,
632-
COALESCE(NULLIF(typbasetype, 0), oid) AS base_type
633-
FROM pg_type
634-
UNION
635-
SELECT
636-
t.oid,
637-
b.typbasetype,
638-
b.typnamespace AS base_namespace,
639-
COALESCE(NULLIF(b.typbasetype, 0), b.oid) AS base_type
640-
FROM recurse t
641-
JOIN pg_type b ON t.typbasetype = b.oid
642-
)
643-
SELECT
644-
oid,
645-
base_namespace,
646-
base_type
647-
FROM recurse
648-
WHERE typbasetype = 0
649-
),
671+
$baseCte,
650672
arguments AS (
651673
SELECT
652674
oid,

src/PostgREST/SchemaCache.hs

Lines changed: 60 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ import NeatInterpolation (trimming)
4242
import PostgREST.Config (AppConfig (..))
4343
import PostgREST.Config.Database (TimezoneNames,
4444
toIsolationLevel)
45+
import PostgREST.Config.PgVersion (PgVersion, pgVersion170)
4546
import PostgREST.SchemaCache.Identifiers (FieldName,
4647
QualifiedIdentifier (..),
4748
RelIdentifier (..),
@@ -139,13 +140,13 @@ data KeyDep
139140
type SqlQuery = ByteString
140141

141142

142-
querySchemaCache :: AppConfig -> SQL.Transaction SchemaCache
143-
querySchemaCache conf@AppConfig{..} = do
143+
querySchemaCache :: PgVersion -> AppConfig -> SQL.Transaction SchemaCache
144+
querySchemaCache pgVer conf@AppConfig{..} = do
144145
SQL.sql "set local schema ''" -- This voids the search path. The following queries need this for getting the fully qualified name(schema.name) of every db object
145-
tabs <- SQL.statement conf $ allTables prepared
146+
tabs <- SQL.statement conf $ allTables pgVer prepared
146147
keyDeps <- SQL.statement conf $ allViewsKeyDependencies prepared
147148
m2oRels <- SQL.statement mempty $ allM2OandO2ORels prepared
148-
funcs <- SQL.statement conf $ allFunctions prepared
149+
funcs <- SQL.statement conf $ allFunctions pgVer prepared
149150
cRels <- SQL.statement mempty $ allComputedRels prepared
150151
reps <- SQL.statement conf $ dataRepresentations prepared
151152
mHdlers <- SQL.statement conf $ mediaHandlers prepared
@@ -353,47 +354,61 @@ dataRepresentations = SQL.Statement sql mempty decodeRepresentations
353354
OR (dst_t.typtype = 'd' AND c.castsource IN ('json'::regtype::oid , 'text'::regtype::oid)))
354355
|]
355356

356-
allFunctions :: Bool -> SQL.Statement AppConfig RoutineMap
357-
allFunctions = SQL.Statement funcsSqlQuery params decodeFuncs
357+
allFunctions :: PgVersion -> Bool -> SQL.Statement AppConfig RoutineMap
358+
allFunctions pgVer = SQL.Statement (funcsSqlQuery pgVer) params decodeFuncs
358359
where
359360
params =
360361
(map escapeIdent . toList . configDbSchemas >$< arrayParam HE.text) <>
361362
(configDbHoistedTxSettings >$< arrayParam HE.text)
362363

363-
baseTypesCte :: Text
364-
baseTypesCte = [trimming|
365-
-- Recursively get the base types of domains
366-
base_types AS (
367-
WITH RECURSIVE
368-
recurse AS (
369-
SELECT
370-
oid,
371-
typbasetype,
372-
typnamespace AS base_namespace,
373-
COALESCE(NULLIF(typbasetype, 0), oid) AS base_type
374-
FROM pg_type
375-
UNION
376-
SELECT
377-
t.oid,
378-
b.typbasetype,
379-
b.typnamespace AS base_namespace,
380-
COALESCE(NULLIF(b.typbasetype, 0), b.oid) AS base_type
381-
FROM recurse t
382-
JOIN pg_type b ON t.typbasetype = b.oid
383-
)
384-
SELECT
385-
oid,
386-
base_namespace,
387-
base_type
388-
FROM recurse
389-
WHERE typbasetype = 0
390-
)
391-
|]
364+
baseTypesCte :: PgVersion -> Text
365+
baseTypesCte pgVer
366+
| pgVer >= pgVersion170 = [trimming|
367+
-- Get base types using pg_basetype() (PG 17+)
368+
base_types AS (
369+
SELECT
370+
t.oid,
371+
bt.typnamespace AS base_namespace,
372+
bt.oid AS base_type
373+
FROM pg_type t
374+
JOIN pg_type bt ON bt.oid = pg_basetype(t.oid)
375+
)
376+
|]
377+
| otherwise = [trimming|
378+
-- Recursively get the base types of domains (PG < 17)
379+
base_types AS (
380+
WITH RECURSIVE
381+
recurse AS (
382+
SELECT
383+
oid,
384+
typbasetype,
385+
typnamespace AS base_namespace,
386+
COALESCE(NULLIF(typbasetype, 0), oid) AS base_type
387+
FROM pg_type
388+
UNION
389+
SELECT
390+
t.oid,
391+
b.typbasetype,
392+
b.typnamespace AS base_namespace,
393+
COALESCE(NULLIF(b.typbasetype, 0), b.oid) AS base_type
394+
FROM recurse t
395+
JOIN pg_type b ON t.typbasetype = b.oid
396+
)
397+
SELECT
398+
oid,
399+
base_namespace,
400+
base_type
401+
FROM recurse
402+
WHERE typbasetype = 0
403+
)
404+
|]
392405

393-
funcsSqlQuery :: SqlQuery
394-
funcsSqlQuery = encodeUtf8 [trimming|
406+
funcsSqlQuery :: PgVersion -> SqlQuery
407+
funcsSqlQuery pgVer =
408+
let baseCte = baseTypesCte pgVer
409+
in encodeUtf8 [trimming|
395410
WITH
396-
$baseTypesCte,
411+
$baseCte,
397412
arguments AS (
398413
SELECT
399414
oid,
@@ -566,22 +581,23 @@ addViewPrimaryKeys tabs keyDeps =
566581
takeFirstPK = mapMaybe (head . snd)
567582
indexedDeps = HM.fromListWith (++) $ fmap ((keyDepType &&& keyDepView) &&& pure) keyDeps
568583

569-
allTables :: Bool -> SQL.Statement AppConfig TablesMap
570-
allTables = SQL.Statement tablesSqlQuery params decodeTables
584+
allTables :: PgVersion -> Bool -> SQL.Statement AppConfig TablesMap
585+
allTables pgVer = SQL.Statement (tablesSqlQuery pgVer) params decodeTables
571586
where
572587
params = map escapeIdent . toList . configDbSchemas >$< arrayParam HE.text
573588

574589
-- | Gets tables with their PK cols
575-
tablesSqlQuery :: SqlQuery
576-
tablesSqlQuery =
590+
tablesSqlQuery :: PgVersion -> SqlQuery
591+
tablesSqlQuery pgVer =
577592
-- the tbl_constraints/key_col_usage CTEs are based on the standard "information_schema.table_constraints"/"information_schema.key_column_usage" views,
578593
-- we cannot use those directly as they include the following privilege filter:
579594
-- (pg_has_role(ss.relowner, 'USAGE'::text) OR has_column_privilege(ss.roid, a.attnum, 'SELECT, INSERT, UPDATE, REFERENCES'::text));
580595
-- on the "columns" CTE, left joining on pg_depend and pg_class is used to obtain the sequence name as a column default in case there are GENERATED .. AS IDENTITY,
581596
-- generated columns are only available from pg >= 10 but the query is agnostic to versions. dep.deptype = 'i' is done because there are other 'a' dependencies on PKs
582-
encodeUtf8 [trimming|
597+
let baseCte = baseTypesCte pgVer
598+
in encodeUtf8 [trimming|
583599
WITH
584-
$baseTypesCte,
600+
$baseCte,
585601
columns AS (
586602
SELECT
587603
c.oid AS relid,

test/io/test_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -759,7 +759,7 @@ def drain_stdout(proc):
759759
)
760760
infinite_recursion_5xx_regx = r'.+: WITH pgrst_source AS.+SELECT "public"\."infinite_recursion"\.\* FROM "public"\."infinite_recursion".+_postgrest_t'
761761
root_tables_regx = r".+: SELECT n.nspname AS table_schema, .+ FROM pg_class c .+ ORDER BY table_schema, table_name"
762-
root_procs_regx = r".+: WITH base_types AS \(.+\) SELECT pn.nspname AS proc_schema, .+ FROM pg_proc p.+AND p.pronamespace = \$1::regnamespace"
762+
root_procs_regx = r".+: WITH.+base_types AS.+pn\.nspname AS proc_schema.+FROM pg_proc p.+p\.pronamespace = \$1::regnamespace"
763763
root_descr_regx = r".+: SELECT pg_catalog\.obj_description\(\$1::regnamespace, 'pg_namespace'\)"
764764
set_config_regx = (
765765
r".+: select set_config\('search_path', \$1, true\), set_config\("

test/spec/Main.hs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ main = do
8383
actualPgVersion <- either (panic . show) id <$> P.use pool (queryPgVersion False)
8484

8585
-- cached schema cache so most tests run fast
86-
baseSchemaCache <- loadSCache pool testCfg
86+
baseSchemaCache <- loadSCache pool actualPgVersion testCfg
8787
sockets <- AppState.initSockets testCfg
8888
loggerState <- Logger.init
8989
metricsState <- Metrics.init (configDbPoolSize testCfg)
@@ -100,7 +100,7 @@ main = do
100100

101101
-- For tests that run with a different SchemaCache (depends on configSchemas)
102102
appDbs config = do
103-
customSchemaCache <- loadSCache pool config
103+
customSchemaCache <- loadSCache pool actualPgVersion config
104104
initApp customSchemaCache () config
105105

106106
let withApp = app testCfg
@@ -279,5 +279,5 @@ main = do
279279
describe "Feature.Auth.JwtCacheSpec" Feature.Auth.JwtCacheSpec.spec
280280

281281
where
282-
loadSCache pool conf =
283-
either (panic.show) id <$> P.use pool (HT.transaction HT.ReadCommitted HT.Read $ querySchemaCache conf)
282+
loadSCache pool pgVer conf =
283+
either (panic.show) id <$> P.use pool (HT.transaction HT.ReadCommitted HT.Read $ querySchemaCache pgVer conf)

0 commit comments

Comments
 (0)