Skip to content

Commit 699823b

Browse files
add tests and use hasql instead of postgres simple
1 parent 66f4714 commit 699823b

File tree

6 files changed

+955
-62
lines changed

6 files changed

+955
-62
lines changed

devenv-module.nix

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ that is defined in flake-module.nix
9595
http-types
9696
inflections
9797
text
98+
hasql
99+
postgresql-libpq
98100
postgresql-simple
99101
hasql
100102
hasql-notifications

ihp/IHP/Postgres/Typed.hs

Lines changed: 0 additions & 12 deletions
This file was deleted.

ihp/IHP/TypedSql.hs

Lines changed: 130 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@ module IHP.TypedSql
1212
, sqlExecTyped -- execute a typed statement and return affected rows
1313
) where
1414

15-
import Control.Exception (bracket_)
15+
import Control.Exception (bracket, bracket_)
1616
import Control.Monad (guard, when)
1717
import qualified Data.Aeson as Aeson
1818
import qualified Data.ByteString as BS
1919
import qualified Data.ByteString.Char8 as BS8
2020
import qualified Data.Char as Char
2121
import Data.Coerce (coerce)
22+
import Data.Int (Int32)
2223
import qualified Data.List as List
2324
import qualified Data.Map.Strict as Map
2425
import Data.Maybe (catMaybes, mapMaybe)
@@ -36,6 +37,12 @@ import qualified Database.PostgreSQL.Simple.FromRow as PGFR
3637
import qualified Database.PostgreSQL.Simple.ToField as PGTF
3738
import qualified Database.PostgreSQL.Simple.ToRow as PGTR
3839
import qualified Database.PostgreSQL.Simple.Types as PG
40+
import qualified Hasql.Connection as HasqlConnection
41+
import qualified Hasql.Connection.Settings as HasqlSettings
42+
import qualified Hasql.Decoders as HasqlDecoders
43+
import qualified Hasql.Encoders as HasqlEncoders
44+
import qualified Hasql.Session as HasqlSession
45+
import qualified Hasql.Statement as HasqlStatement
3946
import IHP.Prelude
4047
import qualified Language.Haskell.Meta.Parse as HaskellMeta
4148
import qualified Language.Haskell.TH as TH
@@ -165,13 +172,13 @@ data PgTypeInfo = PgTypeInfo
165172
, ptiNamespace :: !(Maybe Text) -- low-level: namespace name
166173
}
167174

168-
-- Convert libpq Oid to postgres-simple Oid.
169-
toPGOid :: PQ.Oid -> PG.Oid
170-
toPGOid (PQ.Oid w) = PG.Oid w -- low-level: wrap the same numeric value
175+
-- Convert libpq Oid to Int32 for Hasql parameter encoding.
176+
toOidInt32 :: PQ.Oid -> Int32
177+
toOidInt32 (PQ.Oid oid) = fromIntegral oid
171178

172-
-- Convert postgres-simple Oid to libpq Oid.
173-
toPQOid :: PG.Oid -> PQ.Oid
174-
toPQOid (PG.Oid w) = PQ.Oid w -- low-level: wrap the same numeric value
179+
-- Convert Hasql-decoded Oid value back to libpq Oid.
180+
fromOidInt32 :: Int32 -> PQ.Oid
181+
fromOidInt32 oid = PQ.Oid (fromIntegral oid)
175182

176183
-- Build the TH expression for a typed SQL quasiquote.
177184
typedSqlExp :: String -> TH.ExpQ
@@ -748,8 +755,7 @@ describeStatementWith dbUrl sql = do
748755
let tableOids = Set.fromList (map dcTable columns) |> Set.delete (PQ.Oid 0) -- collect referenced table OIDs
749756
typeOids = Set.fromList paramTypes <> Set.fromList (map dcType columns) -- collect referenced type OIDs
750757

751-
pgConn <- PG.connectPostgreSQL dbUrl -- open postgres-simple connection for catalog queries
752-
tables <- loadTableMeta pgConn (Set.toList tableOids) -- load table metadata
758+
tables <- loadTableMeta dbUrl (Set.toList tableOids) -- load table metadata
753759
let referencedOids =
754760
tables
755761
|> Map.elems
@@ -759,10 +765,9 @@ describeStatementWith dbUrl sql = do
759765
)
760766
mempty
761767
let missingRefs = referencedOids `Set.difference` Map.keysSet tables
762-
extraTables <- loadTableMeta pgConn (Set.toList missingRefs)
768+
extraTables <- loadTableMeta dbUrl (Set.toList missingRefs)
763769
let tables' = tables <> extraTables
764-
types <- loadTypeInfo pgConn (Set.toList typeOids) -- load type metadata
765-
PG.close pgConn -- close postgres-simple connection
770+
types <- loadTypeInfo dbUrl (Set.toList typeOids) -- load type metadata
766771

767772
_ <- PQ.exec conn ("DEALLOCATE " <> statementName) -- release prepared statement
768773
PQ.finish conn -- close libpq connection
@@ -782,48 +787,130 @@ ensureOk actionName = \case
782787
msg <- PQ.resultErrorMessage res -- read error message
783788
fail ("typedSql: " <> actionName <> " failed: " <> CS.cs (fromMaybe "" msg)) -- abort
784789

785-
-- | Load table metadata for all referenced tables.
786-
-- High-level: read pg_catalog to map table/column info.
787-
loadTableMeta :: PG.Connection -> [PQ.Oid] -> IO (Map.Map PQ.Oid TableMeta)
788-
loadTableMeta _ [] = pure mempty -- no tables requested
789-
loadTableMeta conn tableOids = do
790-
rows <- PG.query conn -- fetch column info for each requested table
790+
runHasqlMetadataSession :: BS.ByteString -> HasqlSession.Session a -> IO a
791+
runHasqlMetadataSession dbUrl session = do
792+
let settings = HasqlSettings.connectionString (CS.cs dbUrl)
793+
hasqlConnection <-
794+
HasqlConnection.acquire settings >>= \case
795+
Left connectionError ->
796+
fail (CS.cs ("typedSql: could not connect to database: " <> tshow connectionError))
797+
Right connection ->
798+
pure connection
799+
bracket (pure hasqlConnection) HasqlConnection.release \connection ->
800+
HasqlConnection.use connection session >>= \case
801+
Left sessionError ->
802+
fail (CS.cs ("typedSql: metadata query failed: " <> tshow sessionError))
803+
Right result ->
804+
pure result
805+
806+
oidArrayParamsEncoder :: HasqlEncoders.Params [Int32]
807+
oidArrayParamsEncoder =
808+
HasqlEncoders.param
809+
(HasqlEncoders.nonNullable
810+
(HasqlEncoders.foldableArray
811+
(HasqlEncoders.nonNullable HasqlEncoders.oid)
812+
)
813+
)
814+
815+
tableColumnsStatement :: HasqlStatement.Statement [Int32] [(Int32, Text, Int32, Text, Int32, Bool)]
816+
tableColumnsStatement =
817+
HasqlStatement.preparable
791818
(mconcat
792-
[ "SELECT c.oid, c.relname, a.attnum, a.attname, a.atttypid, a.attnotnull "
819+
[ "SELECT c.oid::int4, c.relname::text, a.attnum::int4, a.attname::text, a.atttypid::int4, a.attnotnull "
793820
, "FROM pg_class c "
794821
, "JOIN pg_namespace ns ON ns.oid = c.relnamespace "
795822
, "JOIN pg_attribute a ON a.attrelid = c.oid "
796-
, "WHERE c.oid = ANY(?) AND a.attnum > 0 AND NOT a.attisdropped "
823+
, "WHERE c.oid = ANY($1) AND a.attnum > 0 AND NOT a.attisdropped "
797824
, "ORDER BY c.oid, a.attnum"
798825
])
799-
(PG.Only (PG.PGArray (map toPGOid tableOids)) :: PG.Only (PG.PGArray PG.Oid)) -- parameterize the OID list
800-
801-
primaryKeys <- PG.query conn -- fetch primary key columns for each table
802-
"SELECT conrelid, unnest(conkey) as attnum FROM pg_constraint WHERE contype = 'p' AND conrelid = ANY(?)"
803-
(PG.Only (PG.PGArray (map toPGOid tableOids)) :: PG.Only (PG.PGArray PG.Oid)) -- parameterize the OID list
804-
805-
foreignKeys <- PG.query conn -- fetch simple (single-column) foreign keys
826+
oidArrayParamsEncoder
827+
(HasqlDecoders.rowList
828+
((,,,,,)
829+
<$> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4)
830+
<*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.text)
831+
<*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4)
832+
<*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.text)
833+
<*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4)
834+
<*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.bool)
835+
))
836+
837+
primaryKeysStatement :: HasqlStatement.Statement [Int32] [(Int32, Int32)]
838+
primaryKeysStatement =
839+
HasqlStatement.preparable
840+
"SELECT conrelid::int4, unnest(conkey)::int4 as attnum FROM pg_constraint WHERE contype = 'p' AND conrelid = ANY($1)"
841+
oidArrayParamsEncoder
842+
(HasqlDecoders.rowList
843+
((,)
844+
<$> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4)
845+
<*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4)
846+
))
847+
848+
foreignKeysStatement :: HasqlStatement.Statement [Int32] [(Int32, Int32, Int32)]
849+
foreignKeysStatement =
850+
HasqlStatement.preparable
806851
(mconcat
807-
[ "SELECT conrelid, conkey[1] as attnum, confrelid "
852+
[ "SELECT conrelid::int4, conkey[1]::int4 as attnum, confrelid::int4 "
808853
, "FROM pg_constraint "
809-
, "WHERE contype = 'f' AND array_length(conkey,1) = 1 AND conrelid = ANY(?)"
854+
, "WHERE contype = 'f' AND array_length(conkey,1) = 1 AND conrelid = ANY($1)"
855+
])
856+
oidArrayParamsEncoder
857+
(HasqlDecoders.rowList
858+
((,,)
859+
<$> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4)
860+
<*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4)
861+
<*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4)
862+
))
863+
864+
typeInfoStatement :: HasqlStatement.Statement [Int32] [(Int32, Text, Int32, Maybe Text, Maybe Text)]
865+
typeInfoStatement =
866+
HasqlStatement.preparable
867+
(mconcat
868+
[ "SELECT oid::int4, typname::text, typelem::int4, typtype::text, typnamespace::regnamespace::text "
869+
, "FROM pg_type "
870+
, "WHERE oid = ANY($1)"
810871
])
811-
(PG.Only (PG.PGArray (map toPGOid tableOids)) :: PG.Only (PG.PGArray PG.Oid)) -- parameterize the OID list
872+
oidArrayParamsEncoder
873+
(HasqlDecoders.rowList
874+
((,,,,)
875+
<$> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4)
876+
<*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.text)
877+
<*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4)
878+
<*> HasqlDecoders.column (HasqlDecoders.nullable HasqlDecoders.text)
879+
<*> HasqlDecoders.column (HasqlDecoders.nullable HasqlDecoders.text)
880+
))
881+
882+
-- | Load table metadata for all referenced tables.
883+
-- High-level: read pg_catalog to map table/column info.
884+
loadTableMeta :: BS.ByteString -> [PQ.Oid] -> IO (Map.Map PQ.Oid TableMeta)
885+
loadTableMeta _ [] = pure mempty -- no tables requested
886+
loadTableMeta dbUrl tableOids = do
887+
let tableOidParams = map toOidInt32 tableOids
888+
rows <- runHasqlMetadataSession dbUrl (HasqlSession.statement tableOidParams tableColumnsStatement)
889+
primaryKeys <- runHasqlMetadataSession dbUrl (HasqlSession.statement tableOidParams primaryKeysStatement)
890+
foreignKeys <- runHasqlMetadataSession dbUrl (HasqlSession.statement tableOidParams foreignKeysStatement)
812891

813892
let pkMap = primaryKeys
814-
|> foldl' (\acc (relid :: PG.Oid, att :: Int) ->
815-
Map.insertWith Set.union (toPQOid relid) (Set.singleton att) acc
893+
|> foldl' (\acc (relid, attnum) ->
894+
Map.insertWith Set.union (fromOidInt32 relid) (Set.singleton (fromIntegral attnum)) acc
816895
) mempty -- build map of table -> primary key attnums
817896

818897
fkMap = foreignKeys
819-
|> foldl' (\acc (relid :: PG.Oid, att :: Int, ref :: PG.Oid) ->
820-
Map.insertWith Map.union (toPQOid relid) (Map.singleton att (toPQOid ref)) acc
898+
|> foldl' (\acc (relid, attnum, ref) ->
899+
Map.insertWith Map.union (fromOidInt32 relid) (Map.singleton (fromIntegral attnum) (fromOidInt32 ref)) acc
821900
) mempty -- build map of table -> foreign key attnum -> referenced table
822901

823902
tableGroups =
824903
rows
825-
|> map (\(relid :: PG.Oid, name :: Text, attnum :: Int, attname :: Text, atttypid :: PG.Oid, attnotnull :: Bool) ->
826-
(toPQOid relid, ColumnMeta { cmAttnum = attnum, cmName = attname, cmTypeOid = toPQOid atttypid, cmNotNull = attnotnull }, name)
904+
|> map (\(relid, name, attnum, attname, atttypid, attnotnull) ->
905+
( fromOidInt32 relid
906+
, ColumnMeta
907+
{ cmAttnum = fromIntegral attnum
908+
, cmName = attname
909+
, cmTypeOid = fromOidInt32 atttypid
910+
, cmNotNull = attnotnull
911+
}
912+
, name
913+
)
827914
) -- annotate each column row with its table
828915
|> List.groupBy (\(l, _, _) (r, _, _) -> l == r) -- group by table OID
829916

@@ -853,23 +940,17 @@ loadTableMeta conn tableOids = do
853940

854941
-- | Load type information for the given OIDs.
855942
-- High-level: fetch pg_type metadata recursively for arrays.
856-
loadTypeInfo :: PG.Connection -> [PQ.Oid] -> IO (Map.Map PQ.Oid PgTypeInfo)
943+
loadTypeInfo :: BS.ByteString -> [PQ.Oid] -> IO (Map.Map PQ.Oid PgTypeInfo)
857944
loadTypeInfo _ [] = pure mempty -- no types requested
858-
loadTypeInfo conn typeOids = do
945+
loadTypeInfo dbUrl typeOids = do
859946
let requested = Set.fromList typeOids -- track requested OIDs
860-
rows <- PG.query conn -- fetch pg_type rows for requested OIDs
861-
(mconcat
862-
[ "SELECT oid, typname, typelem, typtype, typnamespace::regnamespace::text "
863-
, "FROM pg_type "
864-
, "WHERE oid = ANY(?)"
865-
])
866-
(PG.Only (PG.PGArray (map toPGOid typeOids)) :: PG.Only (PG.PGArray PG.Oid)) -- parameterize the OID list
947+
rows <- runHasqlMetadataSession dbUrl (HasqlSession.statement (map toOidInt32 typeOids) typeInfoStatement)
867948
let (typeMap, missing) =
868949
rows
869950
|> foldl'
870-
(\(acc, missingAcc) (oid :: PG.Oid, name :: Text, elemOid :: PG.Oid, typtype :: Maybe Text, nsp :: Maybe Text) ->
871-
let thisOid = toPQOid oid -- convert to libpq Oid type
872-
elemOid' = if elemOid == PG.Oid 0 then Nothing else Just (toPQOid elemOid) -- ignore 0 elem
951+
(\(acc, missingAcc) (oid, name, elemOid, typtype, nsp) ->
952+
let thisOid = fromOidInt32 oid -- convert to libpq Oid type
953+
elemOid' = if elemOid == 0 then Nothing else Just (fromOidInt32 elemOid) -- ignore 0 elem
873954
nextMissing = case elemOid' of
874955
Just o | o `Set.notMember` requested -> o : missingAcc -- queue missing element types
875956
_ -> missingAcc
@@ -885,7 +966,7 @@ loadTypeInfo conn typeOids = do
885966
)
886967
)
887968
(mempty, []) -- start with empty map and missing list
888-
extras <- loadTypeInfo conn missing -- recursively load missing element types
969+
extras <- loadTypeInfo dbUrl missing -- recursively load missing element types
889970
pure (typeMap <> extras) -- merge base and extra type info
890971

891972
-- | Build the Haskell type for a parameter, based on its OID.

ihp/Test/Test/Main.hs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import qualified Test.ViewSupportSpec
2020
import qualified Test.FileStorage.ControllerFunctionsSpec
2121
import qualified Test.PGListenerSpec
2222
import qualified Test.MockingSpec
23+
import qualified Test.TypedSqlSpec
2324

2425
main :: IO ()
2526
main = hspec do
@@ -40,3 +41,4 @@ main = hspec do
4041
Test.Controller.CookieSpec.tests
4142
Test.PGListenerSpec.tests
4243
Test.MockingSpec.tests
44+
Test.TypedSqlSpec.tests

0 commit comments

Comments
 (0)