@@ -12,13 +12,14 @@ module IHP.TypedSql
1212 , sqlExecTyped -- execute a typed statement and return affected rows
1313 ) where
1414
15- import Control.Exception (bracket_ )
15+ import Control.Exception (bracket , bracket_ )
1616import Control.Monad (guard , when )
1717import qualified Data.Aeson as Aeson
1818import qualified Data.ByteString as BS
1919import qualified Data.ByteString.Char8 as BS8
2020import qualified Data.Char as Char
2121import Data.Coerce (coerce )
22+ import Data.Int (Int32 )
2223import qualified Data.List as List
2324import qualified Data.Map.Strict as Map
2425import Data.Maybe (catMaybes , mapMaybe )
@@ -36,6 +37,12 @@ import qualified Database.PostgreSQL.Simple.FromRow as PGFR
3637import qualified Database.PostgreSQL.Simple.ToField as PGTF
3738import qualified Database.PostgreSQL.Simple.ToRow as PGTR
3839import qualified Database.PostgreSQL.Simple.Types as PG
40+ import qualified Hasql.Connection as HasqlConnection
41+ import qualified Hasql.Connection.Settings as HasqlSettings
42+ import qualified Hasql.Decoders as HasqlDecoders
43+ import qualified Hasql.Encoders as HasqlEncoders
44+ import qualified Hasql.Session as HasqlSession
45+ import qualified Hasql.Statement as HasqlStatement
3946import IHP.Prelude
4047import qualified Language.Haskell.Meta.Parse as HaskellMeta
4148import qualified Language.Haskell.TH as TH
@@ -165,13 +172,13 @@ data PgTypeInfo = PgTypeInfo
165172 , ptiNamespace :: ! (Maybe Text ) -- low-level: namespace name
166173 }
167174
168- -- Convert libpq Oid to postgres-simple Oid .
169- toPGOid :: PQ. Oid -> PG. Oid
170- toPGOid (PQ. Oid w ) = PG. Oid w -- low-level: wrap the same numeric value
175+ -- Convert libpq Oid to Int32 for Hasql parameter encoding .
176+ toOidInt32 :: PQ. Oid -> Int32
177+ toOidInt32 (PQ. Oid oid ) = fromIntegral oid
171178
172- -- Convert postgres-simple Oid to libpq Oid.
173- toPQOid :: PG. Oid -> PQ. Oid
174- toPQOid ( PG. Oid w) = PQ. Oid w -- low-level: wrap the same numeric value
179+ -- Convert Hasql-decoded Oid value back to libpq Oid.
180+ fromOidInt32 :: Int32 -> PQ. Oid
181+ fromOidInt32 oid = PQ. Oid ( fromIntegral oid)
175182
176183-- Build the TH expression for a typed SQL quasiquote.
177184typedSqlExp :: String -> TH. ExpQ
@@ -748,8 +755,7 @@ describeStatementWith dbUrl sql = do
748755 let tableOids = Set. fromList (map dcTable columns) |> Set. delete (PQ. Oid 0 ) -- collect referenced table OIDs
749756 typeOids = Set. fromList paramTypes <> Set. fromList (map dcType columns) -- collect referenced type OIDs
750757
751- pgConn <- PG. connectPostgreSQL dbUrl -- open postgres-simple connection for catalog queries
752- tables <- loadTableMeta pgConn (Set. toList tableOids) -- load table metadata
758+ tables <- loadTableMeta dbUrl (Set. toList tableOids) -- load table metadata
753759 let referencedOids =
754760 tables
755761 |> Map. elems
@@ -759,10 +765,9 @@ describeStatementWith dbUrl sql = do
759765 )
760766 mempty
761767 let missingRefs = referencedOids `Set.difference` Map. keysSet tables
762- extraTables <- loadTableMeta pgConn (Set. toList missingRefs)
768+ extraTables <- loadTableMeta dbUrl (Set. toList missingRefs)
763769 let tables' = tables <> extraTables
764- types <- loadTypeInfo pgConn (Set. toList typeOids) -- load type metadata
765- PG. close pgConn -- close postgres-simple connection
770+ types <- loadTypeInfo dbUrl (Set. toList typeOids) -- load type metadata
766771
767772 _ <- PQ. exec conn (" DEALLOCATE " <> statementName) -- release prepared statement
768773 PQ. finish conn -- close libpq connection
@@ -782,48 +787,130 @@ ensureOk actionName = \case
782787 msg <- PQ. resultErrorMessage res -- read error message
783788 fail (" typedSql: " <> actionName <> " failed: " <> CS. cs (fromMaybe " " msg)) -- abort
784789
785- -- | Load table metadata for all referenced tables.
786- -- High-level: read pg_catalog to map table/column info.
787- loadTableMeta :: PG. Connection -> [PQ. Oid ] -> IO (Map. Map PQ. Oid TableMeta )
788- loadTableMeta _ [] = pure mempty -- no tables requested
789- loadTableMeta conn tableOids = do
790- rows <- PG. query conn -- fetch column info for each requested table
790+ runHasqlMetadataSession :: BS. ByteString -> HasqlSession. Session a -> IO a
791+ runHasqlMetadataSession dbUrl session = do
792+ let settings = HasqlSettings. connectionString (CS. cs dbUrl)
793+ hasqlConnection <-
794+ HasqlConnection. acquire settings >>= \ case
795+ Left connectionError ->
796+ fail (CS. cs (" typedSql: could not connect to database: " <> tshow connectionError))
797+ Right connection ->
798+ pure connection
799+ bracket (pure hasqlConnection) HasqlConnection. release \ connection ->
800+ HasqlConnection. use connection session >>= \ case
801+ Left sessionError ->
802+ fail (CS. cs (" typedSql: metadata query failed: " <> tshow sessionError))
803+ Right result ->
804+ pure result
805+
806+ oidArrayParamsEncoder :: HasqlEncoders. Params [Int32 ]
807+ oidArrayParamsEncoder =
808+ HasqlEncoders. param
809+ (HasqlEncoders. nonNullable
810+ (HasqlEncoders. foldableArray
811+ (HasqlEncoders. nonNullable HasqlEncoders. oid)
812+ )
813+ )
814+
815+ tableColumnsStatement :: HasqlStatement. Statement [Int32 ] [(Int32 , Text , Int32 , Text , Int32 , Bool )]
816+ tableColumnsStatement =
817+ HasqlStatement. preparable
791818 (mconcat
792- [ " SELECT c.oid, c.relname, a.attnum, a.attname, a.atttypid, a.attnotnull "
819+ [ " SELECT c.oid::int4 , c.relname::text , a.attnum::int4 , a.attname::text , a.atttypid::int4 , a.attnotnull "
793820 , " FROM pg_class c "
794821 , " JOIN pg_namespace ns ON ns.oid = c.relnamespace "
795822 , " JOIN pg_attribute a ON a.attrelid = c.oid "
796- , " WHERE c.oid = ANY(? ) AND a.attnum > 0 AND NOT a.attisdropped "
823+ , " WHERE c.oid = ANY($1 ) AND a.attnum > 0 AND NOT a.attisdropped "
797824 , " ORDER BY c.oid, a.attnum"
798825 ])
799- (PG. Only (PG. PGArray (map toPGOid tableOids)) :: PG. Only (PG. PGArray PG. Oid )) -- parameterize the OID list
800-
801- primaryKeys <- PG. query conn -- fetch primary key columns for each table
802- " SELECT conrelid, unnest(conkey) as attnum FROM pg_constraint WHERE contype = 'p' AND conrelid = ANY(?)"
803- (PG. Only (PG. PGArray (map toPGOid tableOids)) :: PG. Only (PG. PGArray PG. Oid )) -- parameterize the OID list
804-
805- foreignKeys <- PG. query conn -- fetch simple (single-column) foreign keys
826+ oidArrayParamsEncoder
827+ (HasqlDecoders. rowList
828+ ((,,,,,)
829+ <$> HasqlDecoders. column (HasqlDecoders. nonNullable HasqlDecoders. int4)
830+ <*> HasqlDecoders. column (HasqlDecoders. nonNullable HasqlDecoders. text)
831+ <*> HasqlDecoders. column (HasqlDecoders. nonNullable HasqlDecoders. int4)
832+ <*> HasqlDecoders. column (HasqlDecoders. nonNullable HasqlDecoders. text)
833+ <*> HasqlDecoders. column (HasqlDecoders. nonNullable HasqlDecoders. int4)
834+ <*> HasqlDecoders. column (HasqlDecoders. nonNullable HasqlDecoders. bool)
835+ ))
836+
837+ primaryKeysStatement :: HasqlStatement. Statement [Int32 ] [(Int32 , Int32 )]
838+ primaryKeysStatement =
839+ HasqlStatement. preparable
840+ " SELECT conrelid::int4, unnest(conkey)::int4 as attnum FROM pg_constraint WHERE contype = 'p' AND conrelid = ANY($1)"
841+ oidArrayParamsEncoder
842+ (HasqlDecoders. rowList
843+ ((,)
844+ <$> HasqlDecoders. column (HasqlDecoders. nonNullable HasqlDecoders. int4)
845+ <*> HasqlDecoders. column (HasqlDecoders. nonNullable HasqlDecoders. int4)
846+ ))
847+
848+ foreignKeysStatement :: HasqlStatement. Statement [Int32 ] [(Int32 , Int32 , Int32 )]
849+ foreignKeysStatement =
850+ HasqlStatement. preparable
806851 (mconcat
807- [ " SELECT conrelid, conkey[1] as attnum, confrelid "
852+ [ " SELECT conrelid::int4 , conkey[1]::int4 as attnum, confrelid::int4 "
808853 , " FROM pg_constraint "
809- , " WHERE contype = 'f' AND array_length(conkey,1) = 1 AND conrelid = ANY(?)"
854+ , " WHERE contype = 'f' AND array_length(conkey,1) = 1 AND conrelid = ANY($1)"
855+ ])
856+ oidArrayParamsEncoder
857+ (HasqlDecoders. rowList
858+ ((,,)
859+ <$> HasqlDecoders. column (HasqlDecoders. nonNullable HasqlDecoders. int4)
860+ <*> HasqlDecoders. column (HasqlDecoders. nonNullable HasqlDecoders. int4)
861+ <*> HasqlDecoders. column (HasqlDecoders. nonNullable HasqlDecoders. int4)
862+ ))
863+
864+ typeInfoStatement :: HasqlStatement. Statement [Int32 ] [(Int32 , Text , Int32 , Maybe Text , Maybe Text )]
865+ typeInfoStatement =
866+ HasqlStatement. preparable
867+ (mconcat
868+ [ " SELECT oid::int4, typname::text, typelem::int4, typtype::text, typnamespace::regnamespace::text "
869+ , " FROM pg_type "
870+ , " WHERE oid = ANY($1)"
810871 ])
811- (PG. Only (PG. PGArray (map toPGOid tableOids)) :: PG. Only (PG. PGArray PG. Oid )) -- parameterize the OID list
872+ oidArrayParamsEncoder
873+ (HasqlDecoders. rowList
874+ ((,,,,)
875+ <$> HasqlDecoders. column (HasqlDecoders. nonNullable HasqlDecoders. int4)
876+ <*> HasqlDecoders. column (HasqlDecoders. nonNullable HasqlDecoders. text)
877+ <*> HasqlDecoders. column (HasqlDecoders. nonNullable HasqlDecoders. int4)
878+ <*> HasqlDecoders. column (HasqlDecoders. nullable HasqlDecoders. text)
879+ <*> HasqlDecoders. column (HasqlDecoders. nullable HasqlDecoders. text)
880+ ))
881+
882+ -- | Load table metadata for all referenced tables.
883+ -- High-level: read pg_catalog to map table/column info.
884+ loadTableMeta :: BS. ByteString -> [PQ. Oid ] -> IO (Map. Map PQ. Oid TableMeta )
885+ loadTableMeta _ [] = pure mempty -- no tables requested
886+ loadTableMeta dbUrl tableOids = do
887+ let tableOidParams = map toOidInt32 tableOids
888+ rows <- runHasqlMetadataSession dbUrl (HasqlSession. statement tableOidParams tableColumnsStatement)
889+ primaryKeys <- runHasqlMetadataSession dbUrl (HasqlSession. statement tableOidParams primaryKeysStatement)
890+ foreignKeys <- runHasqlMetadataSession dbUrl (HasqlSession. statement tableOidParams foreignKeysStatement)
812891
813892 let pkMap = primaryKeys
814- |> foldl' (\ acc (relid :: PG. Oid , att :: Int ) ->
815- Map. insertWith Set. union (toPQOid relid) (Set. singleton att ) acc
893+ |> foldl' (\ acc (relid, attnum ) ->
894+ Map. insertWith Set. union (fromOidInt32 relid) (Set. singleton ( fromIntegral attnum) ) acc
816895 ) mempty -- build map of table -> primary key attnums
817896
818897 fkMap = foreignKeys
819- |> foldl' (\ acc (relid :: PG. Oid , att :: Int , ref :: PG. Oid ) ->
820- Map. insertWith Map. union (toPQOid relid) (Map. singleton att (toPQOid ref)) acc
898+ |> foldl' (\ acc (relid, attnum , ref) ->
899+ Map. insertWith Map. union (fromOidInt32 relid) (Map. singleton ( fromIntegral attnum) (fromOidInt32 ref)) acc
821900 ) mempty -- build map of table -> foreign key attnum -> referenced table
822901
823902 tableGroups =
824903 rows
825- |> map (\ (relid :: PG. Oid , name :: Text , attnum :: Int , attname :: Text , atttypid :: PG. Oid , attnotnull :: Bool ) ->
826- (toPQOid relid, ColumnMeta { cmAttnum = attnum, cmName = attname, cmTypeOid = toPQOid atttypid, cmNotNull = attnotnull }, name)
904+ |> map (\ (relid, name, attnum, attname, atttypid, attnotnull) ->
905+ ( fromOidInt32 relid
906+ , ColumnMeta
907+ { cmAttnum = fromIntegral attnum
908+ , cmName = attname
909+ , cmTypeOid = fromOidInt32 atttypid
910+ , cmNotNull = attnotnull
911+ }
912+ , name
913+ )
827914 ) -- annotate each column row with its table
828915 |> List. groupBy (\ (l, _, _) (r, _, _) -> l == r) -- group by table OID
829916
@@ -853,23 +940,17 @@ loadTableMeta conn tableOids = do
853940
854941-- | Load type information for the given OIDs.
855942-- High-level: fetch pg_type metadata recursively for arrays.
856- loadTypeInfo :: PG. Connection -> [PQ. Oid ] -> IO (Map. Map PQ. Oid PgTypeInfo )
943+ loadTypeInfo :: BS. ByteString -> [PQ. Oid ] -> IO (Map. Map PQ. Oid PgTypeInfo )
857944loadTypeInfo _ [] = pure mempty -- no types requested
858- loadTypeInfo conn typeOids = do
945+ loadTypeInfo dbUrl typeOids = do
859946 let requested = Set. fromList typeOids -- track requested OIDs
860- rows <- PG. query conn -- fetch pg_type rows for requested OIDs
861- (mconcat
862- [ " SELECT oid, typname, typelem, typtype, typnamespace::regnamespace::text "
863- , " FROM pg_type "
864- , " WHERE oid = ANY(?)"
865- ])
866- (PG. Only (PG. PGArray (map toPGOid typeOids)) :: PG. Only (PG. PGArray PG. Oid )) -- parameterize the OID list
947+ rows <- runHasqlMetadataSession dbUrl (HasqlSession. statement (map toOidInt32 typeOids) typeInfoStatement)
867948 let (typeMap, missing) =
868949 rows
869950 |> foldl'
870- (\ (acc, missingAcc) (oid :: PG. Oid , name :: Text , elemOid :: PG. Oid , typtype :: Maybe Text , nsp :: Maybe Text ) ->
871- let thisOid = toPQOid oid -- convert to libpq Oid type
872- elemOid' = if elemOid == PG. Oid 0 then Nothing else Just (toPQOid elemOid) -- ignore 0 elem
951+ (\ (acc, missingAcc) (oid, name, elemOid, typtype, nsp) ->
952+ let thisOid = fromOidInt32 oid -- convert to libpq Oid type
953+ elemOid' = if elemOid == 0 then Nothing else Just (fromOidInt32 elemOid) -- ignore 0 elem
873954 nextMissing = case elemOid' of
874955 Just o | o `Set.notMember` requested -> o : missingAcc -- queue missing element types
875956 _ -> missingAcc
@@ -885,7 +966,7 @@ loadTypeInfo conn typeOids = do
885966 )
886967 )
887968 (mempty , [] ) -- start with empty map and missing list
888- extras <- loadTypeInfo conn missing -- recursively load missing element types
969+ extras <- loadTypeInfo dbUrl missing -- recursively load missing element types
889970 pure (typeMap <> extras) -- merge base and extra type info
890971
891972-- | Build the Haskell type for a parameter, based on its OID.
0 commit comments