diff --git a/ihp/IHP/TypedSql.hs b/ihp/IHP/TypedSql.hs new file mode 100644 index 000000000..a5e8cb4b3 --- /dev/null +++ b/ihp/IHP/TypedSql.hs @@ -0,0 +1,35 @@ +{-# LANGUAGE ImplicitParams #-} +{-# LANGUAGE RecordWildCards #-} + +module IHP.TypedSql + ( typedSql + , TypedQuery (..) + , sqlQueryTyped + , sqlExecTyped + ) where + +import qualified Hasql.Decoders as HasqlDecoders +import qualified Hasql.DynamicStatements.Snippet as Snippet +import IHP.ModelSupport (ModelContext, sqlQueryHasql, withHasqlOrPgSimple) +import IHP.Prelude + +import IHP.TypedSql.Quoter (typedSql) +import IHP.TypedSql.Types (TypedQuery (..)) + +-- | Run a typed query and return all rows. +-- High-level: executes the generated hasql snippet with its decoder. +sqlQueryTyped :: (?modelContext :: ModelContext) => TypedQuery result -> IO [result] +sqlQueryTyped TypedQuery { tqSnippet, tqResultDecoder } = + runTypedSqlSession tqSnippet (HasqlDecoders.rowList tqResultDecoder) + +-- | Run a typed statement (INSERT/UPDATE/DELETE) and return affected row count. +-- High-level: executes the generated hasql snippet and decodes rows affected. +sqlExecTyped :: (?modelContext :: ModelContext) => TypedQuery result -> IO Int64 +sqlExecTyped TypedQuery { tqSnippet } = + runTypedSqlSession tqSnippet HasqlDecoders.rowsAffected + +runTypedSqlSession :: (?modelContext :: ModelContext) => Snippet.Snippet -> HasqlDecoders.Result result -> IO result +runTypedSqlSession snippet decoder = + withHasqlOrPgSimple + (\pool -> sqlQueryHasql pool snippet decoder) + (fail "typedSql: requires hasql pool and does not support pg-simple transactions or RLS contexts") diff --git a/ihp/IHP/TypedSql/Bootstrap.hs b/ihp/IHP/TypedSql/Bootstrap.hs new file mode 100644 index 000000000..2291ab56f --- /dev/null +++ b/ihp/IHP/TypedSql/Bootstrap.hs @@ -0,0 +1,185 @@ +module IHP.TypedSql.Bootstrap + ( describeUsingBootstrap + ) where + +import Control.Exception (bracket_) +import Control.Monad (when) +import qualified Data.ByteString as BS +import qualified Data.ByteString.Char8 as BS8 +import Data.Maybe (catMaybes) +import qualified Data.String.Conversions as CS +import System.Directory (canonicalizePath, createDirectoryIfMissing, + doesDirectoryExist, doesFileExist, + findExecutable, removeDirectoryRecursive) +import System.Environment (lookupEnv) +import System.FilePath (isRelative, takeDirectory, takeFileName, ()) +import System.IO (Handle, hIsEOF) +import System.IO.Temp (withSystemTempDirectory) +import qualified System.Process as Process + +import IHP.Prelude +import IHP.TypedSql.Metadata (DescribeResult, describeStatementWith) + +-- | Resolved schema inputs for bootstrap mode. +-- typedSql uses this to run a temporary DB from SQL schema files. +data BootstrapConfig = BootstrapConfig + { bcAppSchemaPath :: !FilePath + , bcIhpSchemaPath :: !(Maybe FilePath) + } + +-- | Paths to postgres tools needed for bootstrapping. +-- These are resolved from PATH to keep bootstrap hermetic. +data PgTools = PgTools + { pgInitdb :: !FilePath + , pgPostgres :: !FilePath + , pgCreatedb :: !FilePath + , pgPsql :: !FilePath + } + +-- | Describe a query by bootstrapping a temporary database from schema files. +-- This is used when IHP_TYPED_SQL_BOOTSTRAP is enabled. +describeUsingBootstrap :: FilePath -> String -> IO DescribeResult +describeUsingBootstrap sourcePath sqlText = do + config <- resolveBootstrapConfig sourcePath + withBootstrapDatabase config \dbUrl -> + describeStatementWith dbUrl (CS.cs sqlText) + +-- | Resolve schema paths relative to the source file that contains typedSql. +resolveBootstrapConfig :: FilePath -> IO BootstrapConfig +resolveBootstrapConfig sourcePath = do + sourceDir <- canonicalizePath (takeDirectory sourcePath) + appSchemaPath <- resolveSchemaPath sourceDir + ihpSchemaPath <- resolveIhpSchemaPath sourceDir + pure BootstrapConfig + { bcAppSchemaPath = appSchemaPath + , bcIhpSchemaPath = ihpSchemaPath + } + +-- | Locate the application schema (Application/Schema.sql) for bootstrapping. +resolveSchemaPath :: FilePath -> IO FilePath +resolveSchemaPath sourceDir = do + envSchema <- lookupEnv "IHP_TYPED_SQL_SCHEMA" + case envSchema of + Just path -> resolveRelativePath sourceDir path >>= ensureFileExists "IHP_TYPED_SQL_SCHEMA" + Nothing -> do + findUpwards sourceDir ("Application" "Schema.sql") >>= \case + Just found -> pure found + Nothing -> + fail "typedSql: could not find Application/Schema.sql. Set IHP_TYPED_SQL_SCHEMA to an absolute path." + +-- | Locate the IHP schema (IHPSchema.sql) for bootstrapping, if present. +resolveIhpSchemaPath :: FilePath -> IO (Maybe FilePath) +resolveIhpSchemaPath sourceDir = do + envSchema <- lookupEnv "IHP_TYPED_SQL_IHP_SCHEMA" + case envSchema of + Just path -> Just <$> (resolveRelativePath sourceDir path >>= ensureFileExists "IHP_TYPED_SQL_IHP_SCHEMA") + Nothing -> do + envLib <- lookupEnv "IHP_LIB" + fromLib <- case envLib of + Just libPath -> do + let candidate = libPath "IHPSchema.sql" + exists <- doesFileExist candidate + pure (if exists then Just candidate else Nothing) + Nothing -> pure Nothing + case fromLib of + Just _ -> pure fromLib + Nothing -> findUpwards sourceDir ("ihp-ide" "data" "IHPSchema.sql") + +-- | Resolve a possibly relative schema path to an absolute path. +resolveRelativePath :: FilePath -> FilePath -> IO FilePath +resolveRelativePath baseDir path = do + let resolved = if isRelative path then baseDir path else path + canonicalizePath resolved + +-- | Verify that a schema file exists; fail with a typedSql-specific message otherwise. +ensureFileExists :: String -> FilePath -> IO FilePath +ensureFileExists label path = do + exists <- doesFileExist path + if exists + then pure path + else fail ("typedSql: " <> label <> " points to missing file: " <> path) + +-- | Search upwards for a schema file starting from the given directory. +findUpwards :: FilePath -> FilePath -> IO (Maybe FilePath) +findUpwards startDir relativePath = go startDir + where + go current = do + let candidate = current relativePath + exists <- doesFileExist candidate + if exists + then Just <$> canonicalizePath candidate + else do + let parent = takeDirectory current + if parent == current + then pure Nothing + else go parent + +-- | Start a temporary postgres, load schemas, and run a metadata action. +withBootstrapDatabase :: BootstrapConfig -> (BS.ByteString -> IO a) -> IO a +withBootstrapDatabase BootstrapConfig { bcAppSchemaPath, bcIhpSchemaPath } action = do + PgTools { pgInitdb, pgPostgres, pgCreatedb, pgPsql } <- resolvePgTools + withSystemTempDirectory "ihp-typed-sql" \tempDir -> do + let dataDir = tempDir "state" + let socketDir = "/tmp" takeFileName tempDir + let cleanupSocket = do + exists <- doesDirectoryExist socketDir + when exists (removeDirectoryRecursive socketDir) + bracket_ (createDirectoryIfMissing True socketDir) cleanupSocket do + Process.callProcess pgInitdb [dataDir, "--no-locale", "--encoding", "UTF8"] + + let params = + (Process.proc pgPostgres ["-D", dataDir, "-k", socketDir, "-c", "listen_addresses="]) + { Process.std_in = Process.CreatePipe + , Process.std_out = Process.CreatePipe + , Process.std_err = Process.CreatePipe + } + Process.withCreateProcess params \_ _ stderrHandle processHandle -> do + errHandle <- maybe (fail "typedSql: unable to read postgres logs") pure stderrHandle + let stop = do + Process.terminateProcess processHandle + _ <- Process.waitForProcess processHandle + pure () + let start = do + waitUntilReady errHandle + Process.callProcess pgCreatedb ["app", "-h", socketDir] + let loadSchema file = Process.callProcess pgPsql ["-h", socketDir, "-d", "app", "-v", "ON_ERROR_STOP=1", "-f", file] + forM_ (catMaybes [bcIhpSchemaPath, Just bcAppSchemaPath]) loadSchema + bracket_ start stop do + let dbUrl = CS.cs ("postgresql:///app?host=" <> socketDir) + action dbUrl + +-- | Resolve postgres tool paths from PATH (or adjacent to postgres binary). +resolvePgTools :: IO PgTools +resolvePgTools = do + pgPostgres <- requireExecutable "postgres" + let binDir = takeDirectory pgPostgres + pgInitdb <- findInBinOrPath binDir "initdb" + pgCreatedb <- findInBinOrPath binDir "createdb" + pgPsql <- findInBinOrPath binDir "psql" + pure PgTools { pgInitdb, pgPostgres, pgCreatedb, pgPsql } + +-- | Prefer a tool in the same bin dir as postgres, fallback to PATH. +findInBinOrPath :: FilePath -> String -> IO FilePath +findInBinOrPath binDir name = do + let candidate = binDir name + exists <- doesFileExist candidate + if exists then pure candidate else requireExecutable name + +-- | Require a tool to exist in PATH, otherwise fail with a bootstrap-specific error. +requireExecutable :: String -> IO FilePath +requireExecutable name = + findExecutable name >>= \case + Just path -> pure path + Nothing -> fail ("typedSql: bootstrap requires '" <> name <> "' in PATH") + +-- | Block until postgres reports readiness in its stderr log. +waitUntilReady :: Handle -> IO () +waitUntilReady handle = do + done <- hIsEOF handle + if done + then fail "typedSql: postgres exited before it was ready" + else do + line <- BS8.hGetLine handle + if "database system is ready to accept connections" `BS8.isInfixOf` line + then pure () + else waitUntilReady handle diff --git a/ihp/IHP/TypedSql/Decoders.hs b/ihp/IHP/TypedSql/Decoders.hs new file mode 100644 index 000000000..db53eaade --- /dev/null +++ b/ihp/IHP/TypedSql/Decoders.hs @@ -0,0 +1,283 @@ +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE OverloadedStrings #-} + +module IHP.TypedSql.Decoders + ( resultDecoderForColumns + ) where + +import Control.Monad (guard) +import qualified Data.List as List +import qualified Data.Map.Strict as Map +import Data.Maybe (mapMaybe) +import qualified Data.Set as Set +import qualified Data.String.Conversions as CS +import qualified Database.PostgreSQL.LibPQ as PQ +import qualified Hasql.Decoders as HasqlDecoders +import qualified Language.Haskell.TH as TH +import IHP.Hasql.FromRow as HasqlFromRow +import IHP.ModelSupport.Types (Id' (..)) +import IHP.Prelude + +import IHP.TypedSql.Metadata (ColumnMeta (..), DescribeColumn (..), PgTypeInfo (..), TableMeta (..)) + +-- | Build a hasql result decoder for the described SQL columns. +-- For full-table selections we reuse FromRowHasql; otherwise we decode a scalar/tuple. +resultDecoderForColumns :: Map.Map PQ.Oid PgTypeInfo -> Map.Map PQ.Oid TableMeta -> [DescribeColumn] -> TH.ExpQ +resultDecoderForColumns typeInfo tables columns = do + case detectFullTable tables columns of + Just _ -> + pure (TH.VarE 'HasqlFromRow.hasqlRowDecoder) + Nothing -> do + rowDecoder <- case columns of + [] -> pure (TH.AppE (TH.VarE 'pure) (TH.ConE '())) + [column] -> rowDecoderForColumn typeInfo tables column + _ -> tupleRowDecoderForColumns typeInfo tables columns + pure rowDecoder + +-- | Detect whether the columns represent a full table selection (table.* in table column order). +-- When this is true, typedSql can decode via the generated FromRowHasql model instance. +detectFullTable :: Map.Map PQ.Oid TableMeta -> [DescribeColumn] -> Maybe Text +detectFullTable tables cols = do + guard (not (null cols)) + let grouped = + cols + |> List.groupBy (\a b -> dcTable a == dcTable b) + |> mapMaybe (\group -> case List.uncons group of + Just (first, _) -> Just (dcTable first, group) + Nothing -> Nothing + ) + case grouped of + [(tableOid, colGroup)] | tableOid /= PQ.Oid 0 -> do + TableMeta { tmColumnOrder } <- Map.lookup tableOid tables + let attnums = mapMaybe dcAttnum colGroup + guard (attnums == tmColumnOrder) + TableMeta { tmName } <- Map.lookup tableOid tables + pure tmName + _ -> Nothing + +tupleRowDecoderForColumns :: Map.Map PQ.Oid PgTypeInfo -> Map.Map PQ.Oid TableMeta -> [DescribeColumn] -> TH.ExpQ +tupleRowDecoderForColumns typeInfo tables columns = do + columnDecoders <- mapM (rowDecoderForColumn typeInfo tables) columns + case columnDecoders of + [] -> pure (TH.AppE (TH.VarE 'pure) (TH.ConE '())) + firstDecoder:restDecoders -> do + let tupleConstructor = TH.ConE (TH.tupleDataName (length columnDecoders)) + let withFirst = TH.AppE (TH.AppE (TH.VarE '(<$>)) tupleConstructor) firstDecoder + pure (foldl (\acc decoder -> TH.AppE (TH.AppE (TH.VarE '(<*>)) acc) decoder) withFirst restDecoders) + +rowDecoderForColumn :: Map.Map PQ.Oid PgTypeInfo -> Map.Map PQ.Oid TableMeta -> DescribeColumn -> TH.ExpQ +rowDecoderForColumn typeInfo tables DescribeColumn { dcType, dcTable, dcAttnum } = + case (Map.lookup dcTable tables, dcAttnum) of + (Just TableMeta { tmPrimaryKeys, tmForeignKeys, tmColumns }, Just attnum) + | attnum `Set.member` tmPrimaryKeys -> do + let nullable = maybe True (not . cmNotNull) (Map.lookup attnum tmColumns) + columnTypeOid <- maybe (failText (missingColumnType attnum dcTable)) (pure . cmTypeOid) (Map.lookup attnum tmColumns) + decodeIdColumn typeInfo nullable columnTypeOid + | attnum `Map.member` tmForeignKeys -> do + let nullable = maybe True (not . cmNotNull) (Map.lookup attnum tmColumns) + columnTypeOid <- maybe (failText (missingColumnType attnum dcTable)) (pure . cmTypeOid) (Map.lookup attnum tmColumns) + decodeIdColumn typeInfo nullable columnTypeOid + | otherwise -> do + let nullable = maybe True (not . cmNotNull) (Map.lookup attnum tmColumns) + columnTypeOid <- maybe (failText (missingColumnType attnum dcTable)) (pure . cmTypeOid) (Map.lookup attnum tmColumns) + decodeColumnByOid typeInfo nullable columnTypeOid + _ -> + decodeColumnByOid typeInfo True dcType + where + missingColumnType attnum tableOid = + "typedSql: missing column metadata for attnum " <> show attnum <> " on table oid " <> show tableOid + +decodeIdColumn :: Map.Map PQ.Oid PgTypeInfo -> Bool -> PQ.Oid -> TH.ExpQ +decodeIdColumn typeInfo nullable oid = do + baseDecoder <- decodeColumnByOid typeInfo nullable oid + if nullable + then pure (TH.AppE (TH.AppE (TH.VarE 'fmap) (TH.AppE (TH.VarE 'fmap) (TH.ConE 'Id))) baseDecoder) + else pure (TH.AppE (TH.AppE (TH.VarE 'fmap) (TH.ConE 'Id)) baseDecoder) + +decodeColumnByOid :: Map.Map PQ.Oid PgTypeInfo -> Bool -> PQ.Oid -> TH.ExpQ +decodeColumnByOid typeInfo nullable oid = + case Map.lookup oid typeInfo of + Nothing -> failText ("typedSql: missing type information for column oid " <> show oid) + Just pgTypeInfo -> decodeColumnByTypeInfo typeInfo nullable pgTypeInfo + +decodeColumnByTypeInfo :: Map.Map PQ.Oid PgTypeInfo -> Bool -> PgTypeInfo -> TH.ExpQ +decodeColumnByTypeInfo typeInfo nullable PgTypeInfo { ptiName, ptiElem } = + case ptiElem of + Just elementOid -> decodeArrayColumn typeInfo nullable elementOid + Nothing -> decodeScalarColumn nullable ptiName + +decodeArrayColumn :: Map.Map PQ.Oid PgTypeInfo -> Bool -> PQ.Oid -> TH.ExpQ +decodeArrayColumn typeInfo nullable elementOid = + case Map.lookup elementOid typeInfo of + Nothing -> failText ("typedSql: missing array element type for oid " <> show elementOid) + Just elementType -> + case ptiName elementType of + "int2" -> decodeIntArray nullable + "int4" -> decodeIntArray nullable + "int8" -> decodeIntegerArray nullable + "text" -> decodeSimpleArray nullable (TH.VarE 'HasqlDecoders.text) + "varchar" -> decodeSimpleArray nullable (TH.VarE 'HasqlDecoders.text) + "bpchar" -> decodeSimpleArray nullable (TH.VarE 'HasqlDecoders.text) + "citext" -> decodeSimpleArray nullable (TH.VarE 'HasqlDecoders.text) + "bool" -> decodeSimpleArray nullable (TH.VarE 'HasqlDecoders.bool) + "uuid" -> decodeSimpleArray nullable (TH.VarE 'HasqlDecoders.uuid) + "float4" -> decodeSimpleArray nullable (TH.VarE 'HasqlDecoders.float4) + "float8" -> decodeSimpleArray nullable (TH.VarE 'HasqlDecoders.float8) + "numeric" -> decodeSimpleArray nullable (TH.VarE 'HasqlDecoders.numeric) + "json" -> decodeSimpleArray nullable (TH.VarE 'HasqlDecoders.json) + "jsonb" -> decodeSimpleArray nullable (TH.VarE 'HasqlDecoders.jsonb) + "bytea" -> decodeByteaArray nullable + unsupported -> + failText ("typedSql: unsupported array element type for hasql decoder: " <> unsupported) + +decodeScalarColumn :: Bool -> Text -> TH.ExpQ +decodeScalarColumn nullable typeName = + case typeName of + "int2" -> decodeIntScalar nullable + "int4" -> decodeIntScalar nullable + "int8" -> decodeIntegerScalar nullable + "text" -> decodeSimpleScalar nullable (TH.VarE 'HasqlDecoders.text) + "varchar" -> decodeSimpleScalar nullable (TH.VarE 'HasqlDecoders.text) + "bpchar" -> decodeSimpleScalar nullable (TH.VarE 'HasqlDecoders.text) + "citext" -> decodeSimpleScalar nullable (TH.VarE 'HasqlDecoders.text) + "bool" -> decodeSimpleScalar nullable (TH.VarE 'HasqlDecoders.bool) + "uuid" -> decodeSimpleScalar nullable (TH.VarE 'HasqlDecoders.uuid) + "timestamptz" -> decodeSimpleScalar nullable (TH.VarE 'HasqlDecoders.timestamptz) + "timestamp" -> decodeSimpleScalar nullable (TH.VarE 'HasqlDecoders.timestamp) + "date" -> decodeSimpleScalar nullable (TH.VarE 'HasqlDecoders.date) + "time" -> decodeSimpleScalar nullable (TH.VarE 'HasqlDecoders.time) + "json" -> decodeSimpleScalar nullable (TH.VarE 'HasqlDecoders.json) + "jsonb" -> decodeSimpleScalar nullable (TH.VarE 'HasqlDecoders.jsonb) + "float4" -> decodeSimpleScalar nullable (TH.VarE 'HasqlDecoders.float4) + "float8" -> decodeSimpleScalar nullable (TH.VarE 'HasqlDecoders.float8) + "numeric" -> decodeSimpleScalar nullable (TH.VarE 'HasqlDecoders.numeric) + "bytea" -> decodeByteaScalar nullable + unsupported -> + failText ("typedSql: unsupported column type for hasql decoder: " <> unsupported) + +decodeSimpleScalar :: Bool -> TH.Exp -> TH.ExpQ +decodeSimpleScalar nullable valueDecoder = + pure (TH.AppE (TH.VarE 'HasqlDecoders.column) (nullabilityWrapper nullable valueDecoder)) + +decodeSimpleArray :: Bool -> TH.Exp -> TH.ExpQ +decodeSimpleArray nullable valueDecoder = + pure + ( TH.AppE + (TH.VarE 'HasqlDecoders.column) + ( nullabilityWrapper + nullable + ( TH.AppE + (TH.VarE 'HasqlDecoders.listArray) + (TH.AppE (TH.VarE 'HasqlDecoders.nonNullable) valueDecoder) + ) + ) + ) + +decodeIntScalar :: Bool -> TH.ExpQ +decodeIntScalar nullable = + if nullable + then pure + ( TH.AppE + (TH.AppE (TH.VarE 'fmap) (TH.AppE (TH.VarE 'fmap) (TH.VarE 'fromIntegral))) + (TH.AppE (TH.VarE 'HasqlDecoders.column) (nullabilityWrapper True (TH.VarE 'HasqlDecoders.int4))) + ) + else pure + ( TH.AppE + (TH.AppE (TH.VarE 'fmap) (TH.VarE 'fromIntegral)) + (TH.AppE (TH.VarE 'HasqlDecoders.column) (nullabilityWrapper False (TH.VarE 'HasqlDecoders.int4))) + ) + +decodeIntegerScalar :: Bool -> TH.ExpQ +decodeIntegerScalar nullable = + if nullable + then pure + ( TH.AppE + (TH.AppE (TH.VarE 'fmap) (TH.AppE (TH.VarE 'fmap) (TH.VarE 'fromIntegral))) + (TH.AppE (TH.VarE 'HasqlDecoders.column) (nullabilityWrapper True (TH.VarE 'HasqlDecoders.int8))) + ) + else pure + ( TH.AppE + (TH.AppE (TH.VarE 'fmap) (TH.VarE 'fromIntegral)) + (TH.AppE (TH.VarE 'HasqlDecoders.column) (nullabilityWrapper False (TH.VarE 'HasqlDecoders.int8))) + ) + +decodeIntArray :: Bool -> TH.ExpQ +decodeIntArray nullable = + if nullable + then pure + ( TH.AppE + (TH.AppE (TH.VarE 'fmap) (TH.AppE (TH.VarE 'fmap) (TH.AppE (TH.VarE 'map) (TH.VarE 'fromIntegral)))) + ( TH.AppE + (TH.VarE 'HasqlDecoders.column) + ( nullabilityWrapper + True + ( TH.AppE + (TH.VarE 'HasqlDecoders.listArray) + (TH.AppE (TH.VarE 'HasqlDecoders.nonNullable) (TH.VarE 'HasqlDecoders.int4)) + ) + ) + ) + ) + else pure + ( TH.AppE + (TH.AppE (TH.VarE 'fmap) (TH.AppE (TH.VarE 'map) (TH.VarE 'fromIntegral))) + ( TH.AppE + (TH.VarE 'HasqlDecoders.column) + ( nullabilityWrapper + False + ( TH.AppE + (TH.VarE 'HasqlDecoders.listArray) + (TH.AppE (TH.VarE 'HasqlDecoders.nonNullable) (TH.VarE 'HasqlDecoders.int4)) + ) + ) + ) + ) + +decodeIntegerArray :: Bool -> TH.ExpQ +decodeIntegerArray nullable = + if nullable + then pure + ( TH.AppE + (TH.AppE (TH.VarE 'fmap) (TH.AppE (TH.VarE 'fmap) (TH.AppE (TH.VarE 'map) (TH.VarE 'fromIntegral)))) + ( TH.AppE + (TH.VarE 'HasqlDecoders.column) + ( nullabilityWrapper + True + ( TH.AppE + (TH.VarE 'HasqlDecoders.listArray) + (TH.AppE (TH.VarE 'HasqlDecoders.nonNullable) (TH.VarE 'HasqlDecoders.int8)) + ) + ) + ) + ) + else pure + ( TH.AppE + (TH.AppE (TH.VarE 'fmap) (TH.AppE (TH.VarE 'map) (TH.VarE 'fromIntegral))) + ( TH.AppE + (TH.VarE 'HasqlDecoders.column) + ( nullabilityWrapper + False + ( TH.AppE + (TH.VarE 'HasqlDecoders.listArray) + (TH.AppE (TH.VarE 'HasqlDecoders.nonNullable) (TH.VarE 'HasqlDecoders.int8)) + ) + ) + ) + ) + +decodeByteaScalar :: Bool -> TH.ExpQ +decodeByteaScalar nullable = + decodeSimpleScalar nullable (TH.VarE 'HasqlDecoders.bytea) + +decodeByteaArray :: Bool -> TH.ExpQ +decodeByteaArray nullable = + decodeSimpleArray nullable (TH.VarE 'HasqlDecoders.bytea) + +nullabilityWrapper :: Bool -> TH.Exp -> TH.Exp +nullabilityWrapper nullable valueDecoder = + TH.AppE + (TH.VarE (if nullable then 'HasqlDecoders.nullable else 'HasqlDecoders.nonNullable)) + valueDecoder + +failText :: Text -> TH.Q a +failText = fail . CS.cs diff --git a/ihp/IHP/TypedSql/Metadata.hs b/ihp/IHP/TypedSql/Metadata.hs new file mode 100644 index 000000000..e56dae8c7 --- /dev/null +++ b/ihp/IHP/TypedSql/Metadata.hs @@ -0,0 +1,356 @@ +{-# LANGUAGE NamedFieldPuns #-} + +module IHP.TypedSql.Metadata + ( DescribeResult (..) + , DescribeColumn (..) + , ColumnMeta (..) + , TableMeta (..) + , PgTypeInfo (..) + , toOidInt32 + , fromOidInt32 + , describeStatement + , describeStatementWith + ) where + +import Control.Exception (bracket) +import Control.Monad (guard) +import Data.Int (Int32) +import qualified Data.ByteString as BS +import qualified Data.List as List +import qualified Data.Map.Strict as Map +import Data.Maybe (catMaybes, mapMaybe) +import qualified Data.Set as Set +import qualified Data.String.Conversions as CS +import qualified Data.Text as Text +import qualified Database.PostgreSQL.LibPQ as PQ +import qualified Hasql.Connection as HasqlConnection +import qualified Hasql.Connection.Settings as HasqlSettings +import qualified Hasql.Decoders as HasqlDecoders +import qualified Hasql.Encoders as HasqlEncoders +import qualified Hasql.Session as HasqlSession +import qualified Hasql.Statement as HasqlStatement +import IHP.FrameworkConfig (defaultDatabaseUrl) +import IHP.Prelude + +-- | Result of describing a statement. +-- High-level: this is the central metadata bundle for typedSql inference. +data DescribeResult = DescribeResult + { drParams :: ![PQ.Oid] + , drColumns :: ![DescribeColumn] + , drTables :: !(Map.Map PQ.Oid TableMeta) + , drTypes :: !(Map.Map PQ.Oid PgTypeInfo) + } + +-- | Metadata for a column in the result set. +-- This drives result type inference and row parser selection. +data DescribeColumn = DescribeColumn + { dcName :: !BS.ByteString + , dcType :: !PQ.Oid + , dcTable :: !PQ.Oid + , dcAttnum :: !(Maybe Int) + } + +-- | Column details extracted from pg_attribute. +-- These are used to map columns to IHP Id' and nullable types. +data ColumnMeta = ColumnMeta + { cmAttnum :: !Int + , cmName :: !Text + , cmTypeOid :: !PQ.Oid + , cmNotNull :: !Bool + } + +-- | Table metadata, including columns and key relationships. +-- This is used for table.* detection and key-aware typing. +data TableMeta = TableMeta + { tmOid :: !PQ.Oid + , tmName :: !Text + , tmColumns :: !(Map.Map Int ColumnMeta) + , tmColumnOrder :: ![Int] + , tmPrimaryKeys :: !(Set.Set Int) + , tmForeignKeys :: !(Map.Map Int PQ.Oid) + } + +-- | Postgres type metadata needed for Haskell mapping. +-- Used to convert OIDs to concrete Haskell types. +data PgTypeInfo = PgTypeInfo + { ptiOid :: !PQ.Oid + , ptiName :: !Text + , ptiElem :: !(Maybe PQ.Oid) + , ptiType :: !(Maybe Char) + , ptiNamespace :: !(Maybe Text) + } + +-- | Convert libpq Oid to Int32 for Hasql parameter encoding. +toOidInt32 :: PQ.Oid -> Int32 +toOidInt32 (PQ.Oid oid) = fromIntegral oid + +-- | Convert Hasql-decoded Oid value back to libpq Oid. +fromOidInt32 :: Int32 -> PQ.Oid +fromOidInt32 oid = PQ.Oid (fromIntegral oid) + +-- | Describe a statement by asking a real Postgres server. +-- typedSql uses this when not running in bootstrap mode. +describeStatement :: BS.ByteString -> IO DescribeResult +describeStatement sql = do + dbUrl <- defaultDatabaseUrl + describeStatementWith dbUrl sql + +-- | Describe a statement using an explicit database URL. +-- This is the core path for metadata lookup in typedSql. +describeStatementWith :: BS.ByteString -> BS.ByteString -> IO DescribeResult +describeStatementWith dbUrl sql = do + conn <- PQ.connectdb dbUrl + status <- PQ.status conn + unless (status == PQ.ConnectionOk) do + err <- PQ.errorMessage conn + fail ("typedSql: could not connect to database: " <> CS.cs (fromMaybe "" err)) + + let statementName = "ihp_typed_sql_stmt" + _ <- ensureOk "prepare" =<< PQ.prepare conn statementName sql Nothing + desc <- ensureOk "describe" =<< PQ.describePrepared conn statementName + + paramCount <- PQ.nparams desc + paramTypes <- mapM (PQ.paramtype desc) [0 .. paramCount - 1] + + columnCount <- PQ.nfields desc + let PQ.Col columnCountCInt = columnCount + let columnCountInt = fromIntegral columnCountCInt :: Int + columns <- mapM (\i -> do + let colIndex = PQ.Col (fromIntegral i) + name <- fromMaybe "" <$> PQ.fname desc colIndex + colType <- PQ.ftype desc colIndex + tableOid <- PQ.ftable desc colIndex + attnumRaw <- PQ.ftablecol desc colIndex + let PQ.Col attnumCInt = attnumRaw + let attnumInt = fromIntegral attnumCInt :: Int + let attnum = + if tableOid == PQ.Oid 0 || attnumInt <= 0 + then Nothing + else Just attnumInt + pure DescribeColumn { dcName = name, dcType = colType, dcTable = tableOid, dcAttnum = attnum } + ) [0 .. columnCountInt - 1] + + let tableOids = Set.fromList (map dcTable columns) |> Set.delete (PQ.Oid 0) + typeOids = Set.fromList paramTypes <> Set.fromList (map dcType columns) + + tables <- loadTableMeta dbUrl (Set.toList tableOids) + let referencedOids = + tables + |> Map.elems + |> foldl' + (\acc TableMeta { tmForeignKeys } -> + acc <> Set.fromList (Map.elems tmForeignKeys) + ) + mempty + let missingRefs = referencedOids `Set.difference` Map.keysSet tables + extraTables <- loadTableMeta dbUrl (Set.toList missingRefs) + let tables' = tables <> extraTables + types <- loadTypeInfo dbUrl (Set.toList typeOids) + + _ <- PQ.exec conn ("DEALLOCATE " <> statementName) + PQ.finish conn + + pure DescribeResult { drParams = paramTypes, drColumns = columns, drTables = tables', drTypes = types } + +-- | Ensure libpq returned a successful result. +-- Errors here surface as typedSql compile-time failures. +ensureOk :: String -> Maybe PQ.Result -> IO PQ.Result +ensureOk actionName = \case + Nothing -> fail ("typedSql: " <> actionName <> " returned no result") + Just res -> do + status <- PQ.resultStatus res + case status of + PQ.CommandOk -> pure res + PQ.TuplesOk -> pure res + _ -> do + msg <- PQ.resultErrorMessage res + fail ("typedSql: " <> actionName <> " failed: " <> CS.cs (fromMaybe "" msg)) + +-- | Run a Hasql session for metadata queries (pg_catalog). +-- This keeps metadata lookups on the same hasql stack as typedSql execution. +runHasqlMetadataSession :: BS.ByteString -> HasqlSession.Session a -> IO a +runHasqlMetadataSession dbUrl session = do + let settings = HasqlSettings.connectionString (CS.cs dbUrl) + hasqlConnection <- + HasqlConnection.acquire settings >>= \case + Left connectionError -> + fail (CS.cs ("typedSql: could not connect to database: " <> tshow connectionError)) + Right connection -> + pure connection + bracket (pure hasqlConnection) HasqlConnection.release \connection -> + HasqlConnection.use connection session >>= \case + Left sessionError -> + fail (CS.cs ("typedSql: metadata query failed: " <> tshow sessionError)) + Right result -> + pure result + +-- | Encoder for passing OID arrays into pg_catalog queries. +oidArrayParamsEncoder :: HasqlEncoders.Params [Int32] +oidArrayParamsEncoder = + HasqlEncoders.param + (HasqlEncoders.nonNullable + (HasqlEncoders.foldableArray + (HasqlEncoders.nonNullable HasqlEncoders.oid) + ) + ) + +-- | Query to load column metadata for a set of table OIDs. +tableColumnsStatement :: HasqlStatement.Statement [Int32] [(Int32, Text, Int32, Text, Int32, Bool)] +tableColumnsStatement = + HasqlStatement.preparable + (mconcat + [ "SELECT c.oid::int4, c.relname::text, a.attnum::int4, a.attname::text, a.atttypid::int4, a.attnotnull " + , "FROM pg_class c " + , "JOIN pg_namespace ns ON ns.oid = c.relnamespace " + , "JOIN pg_attribute a ON a.attrelid = c.oid " + , "WHERE c.oid = ANY($1) AND a.attnum > 0 AND NOT a.attisdropped " + , "ORDER BY c.oid, a.attnum" + ]) + oidArrayParamsEncoder + (HasqlDecoders.rowList + ((,,,,,) + <$> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4) + <*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.text) + <*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4) + <*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.text) + <*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4) + <*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.bool) + )) + +-- | Query to load primary key columns for a set of tables. +primaryKeysStatement :: HasqlStatement.Statement [Int32] [(Int32, Int32)] +primaryKeysStatement = + HasqlStatement.preparable + "SELECT conrelid::int4, unnest(conkey)::int4 as attnum FROM pg_constraint WHERE contype = 'p' AND conrelid = ANY($1)" + oidArrayParamsEncoder + (HasqlDecoders.rowList + ((,) + <$> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4) + <*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4) + )) + +-- | Query to load single-column foreign keys for a set of tables. +foreignKeysStatement :: HasqlStatement.Statement [Int32] [(Int32, Int32, Int32)] +foreignKeysStatement = + HasqlStatement.preparable + (mconcat + [ "SELECT conrelid::int4, conkey[1]::int4 as attnum, confrelid::int4 " + , "FROM pg_constraint " + , "WHERE contype = 'f' AND array_length(conkey,1) = 1 AND conrelid = ANY($1)" + ]) + oidArrayParamsEncoder + (HasqlDecoders.rowList + ((,,) + <$> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4) + <*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4) + <*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4) + )) + +-- | Query to load type information for a set of type OIDs. +typeInfoStatement :: HasqlStatement.Statement [Int32] [(Int32, Text, Int32, Maybe Text, Maybe Text)] +typeInfoStatement = + HasqlStatement.preparable + (mconcat + [ "SELECT oid::int4, typname::text, typelem::int4, typtype::text, typnamespace::regnamespace::text " + , "FROM pg_type " + , "WHERE oid = ANY($1)" + ]) + oidArrayParamsEncoder + (HasqlDecoders.rowList + ((,,,,) + <$> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4) + <*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.text) + <*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4) + <*> HasqlDecoders.column (HasqlDecoders.nullable HasqlDecoders.text) + <*> HasqlDecoders.column (HasqlDecoders.nullable HasqlDecoders.text) + )) + +-- | Load table metadata for all referenced tables. +-- High-level: read pg_catalog to map table/column info for typedSql inference. +loadTableMeta :: BS.ByteString -> [PQ.Oid] -> IO (Map.Map PQ.Oid TableMeta) +loadTableMeta _ [] = pure mempty +loadTableMeta dbUrl tableOids = do + let tableOidParams = map toOidInt32 tableOids + rows <- runHasqlMetadataSession dbUrl (HasqlSession.statement tableOidParams tableColumnsStatement) + primaryKeys <- runHasqlMetadataSession dbUrl (HasqlSession.statement tableOidParams primaryKeysStatement) + foreignKeys <- runHasqlMetadataSession dbUrl (HasqlSession.statement tableOidParams foreignKeysStatement) + + let pkMap = primaryKeys + |> foldl' (\acc (relid, attnum) -> + Map.insertWith Set.union (fromOidInt32 relid) (Set.singleton (fromIntegral attnum)) acc + ) mempty + + fkMap = foreignKeys + |> foldl' (\acc (relid, attnum, ref) -> + Map.insertWith Map.union (fromOidInt32 relid) (Map.singleton (fromIntegral attnum) (fromOidInt32 ref)) acc + ) mempty + + tableGroups = + rows + |> map (\(relid, name, attnum, attname, atttypid, attnotnull) -> + ( fromOidInt32 relid + , ColumnMeta + { cmAttnum = fromIntegral attnum + , cmName = attname + , cmTypeOid = fromOidInt32 atttypid + , cmNotNull = attnotnull + } + , name + ) + ) + |> List.groupBy (\(l, _, _) (r, _, _) -> l == r) + + pure $ tableGroups + |> foldl' + (\acc group -> + case group of + [] -> acc + (tableOid, _, tableName):_ -> + let cols = group + |> map (\(_, column, _) -> (cmAttnum column, column)) + |> Map.fromList + order = group |> map (\(_, column, _) -> cmAttnum column) + pks = Map.findWithDefault mempty tableOid pkMap + fks = Map.findWithDefault mempty tableOid fkMap + meta = TableMeta + { tmOid = tableOid + , tmName = tableName + , tmColumns = cols + , tmColumnOrder = order + , tmPrimaryKeys = pks + , tmForeignKeys = fks + } + in Map.insert tableOid meta acc + ) + mempty + +-- | Load type information for the given OIDs. +-- High-level: fetch pg_type metadata recursively for arrays. +loadTypeInfo :: BS.ByteString -> [PQ.Oid] -> IO (Map.Map PQ.Oid PgTypeInfo) +loadTypeInfo _ [] = pure mempty +loadTypeInfo dbUrl typeOids = do + let requested = Set.fromList typeOids + rows <- runHasqlMetadataSession dbUrl (HasqlSession.statement (map toOidInt32 typeOids) typeInfoStatement) + let (typeMap, missing) = + rows + |> foldl' + (\(acc, missingAcc) (oid, name, elemOid, typtype, nsp) -> + let thisOid = fromOidInt32 oid + elemOid' = if elemOid == 0 then Nothing else Just (fromOidInt32 elemOid) + nextMissing = case elemOid' of + Just o | o `Set.notMember` requested -> o : missingAcc + _ -> missingAcc + in ( Map.insert thisOid PgTypeInfo + { ptiOid = thisOid + , ptiName = name + , ptiElem = elemOid' + , ptiType = typtype >>= (listToMaybe . CS.cs) + , ptiNamespace = nsp + } + acc + , nextMissing + ) + ) + (mempty, []) + extras <- loadTypeInfo dbUrl missing + pure (typeMap <> extras) diff --git a/ihp/IHP/TypedSql/ParamHints.hs b/ihp/IHP/TypedSql/ParamHints.hs new file mode 100644 index 000000000..4740f7d2e --- /dev/null +++ b/ihp/IHP/TypedSql/ParamHints.hs @@ -0,0 +1,310 @@ +module IHP.TypedSql.ParamHints + ( SqlToken (..) + , ParamHint (..) + , tokenizeSql + , buildAliasMap + , collectParamHints + , resolveParamHintTypes + ) where + +import Control.Monad (guard) +import qualified Data.Char as Char +import qualified Data.List as List +import qualified Data.Map.Strict as Map +import Data.Maybe (catMaybes, mapMaybe) +import qualified Data.Set as Set +import qualified Data.Text as Text +import qualified Data.String.Conversions as CS +import qualified Database.PostgreSQL.LibPQ as PQ +import qualified Language.Haskell.TH as TH +import IHP.Prelude + +import IHP.TypedSql.Metadata (ColumnMeta (..), DescribeColumn (..), + PgTypeInfo, TableMeta (..)) +import IHP.TypedSql.TypeMapping (hsTypeForColumn) + +-- | Minimal SQL token used for placeholder context inspection. +-- This is a lightweight scanner, not a full SQL parser. +data SqlToken + = TokIdent !Text + | TokSymbol !Char + | TokParam !Int + deriving (Eq, Show) + +-- | A derived hint about the expected type of a placeholder. +-- The quasiquoter uses this to coerce ${...} to a column-compatible type. +data ParamHint = ParamHint + { phIndex :: !Int + , phTable :: !Text + , phColumn :: !Text + , phArray :: !Bool + } + deriving (Eq, Show) + +-- | Tokenize SQL just enough to locate placeholders and nearby identifiers. +-- This feeds alias detection and parameter hint extraction in typedSql. +tokenizeSql :: String -> [SqlToken] +tokenizeSql = go [] where + go acc [] = reverse acc + go acc ('-':'-':rest) = go acc (dropLineComment rest) + go acc ('/':'*':rest) = go acc (dropBlockComment rest) + go acc ('\'':rest) = go acc (dropStringLiteral rest) + go acc ('"':rest) = + let (ident, remaining) = parseQuotedIdent rest + in go (TokIdent ident : acc) remaining + go acc ('$':rest) = + let (digits, remaining) = span Char.isDigit rest + in if null digits + then go acc remaining + else go (TokParam (digitsToInt digits) : acc) remaining + go acc (c:rest) + | Char.isSpace c = go acc rest + | isIdentStart c = + let (identTail, remaining) = span isIdentChar rest + identText = Text.toLower (CS.cs (c : identTail)) + in go (TokIdent identText : acc) remaining + | isSymbolToken c = go (TokSymbol c : acc) rest + | otherwise = go acc rest + + isIdentStart ch = Char.isLetter ch || ch == '_' + isIdentChar ch = Char.isAlphaNum ch || ch == '_' || ch == '$' + isSymbolToken ch = ch `elem` ['.', '=', '(', ')', ','] + + dropLineComment = dropWhile (/= '\n') + dropBlockComment = dropUntil "*/" + dropStringLiteral = dropSingleQuoted + + dropUntil _ [] = [] + dropUntil pattern@(p1:p2:_) (x:y:rest) + | x == p1 && y == p2 = rest + | otherwise = dropUntil pattern (y:rest) + dropUntil _ rest = rest + + dropSingleQuoted [] = [] + dropSingleQuoted ('\'':'\'':xs) = dropSingleQuoted xs + dropSingleQuoted ('\'':xs) = xs + dropSingleQuoted (_:xs) = dropSingleQuoted xs + + parseQuotedIdent = go "" where + go acc [] = (Text.toLower (CS.cs (reverse acc)), []) + go acc ('"':'"':xs) = go ('"':acc) xs + go acc ('"':xs) = (Text.toLower (CS.cs (reverse acc)), xs) + go acc (x:xs) = go (x:acc) xs + +-- | Convert a list of digit chars to an Int for $1/$2 token indices. +digitsToInt :: String -> Int +digitsToInt = foldl' (\acc digit -> acc * 10 + Char.digitToInt digit) 0 + +-- | Safe indexing helper for the token stream. +tokenAtIndex :: [a] -> Int -> Maybe a +tokenAtIndex xs ix = + case List.drop ix xs of + (value:_) -> Just value + [] -> Nothing + +-- | Keywords that should not be treated as table aliases. +reservedKeywords :: Set.Set Text +reservedKeywords = + Set.fromList (map Text.pack + [ "as", "where", "join", "inner", "left", "right", "full", "cross" + , "on", "group", "order", "limit", "offset", "having", "union" + , "intersect", "except", "returning", "set", "values", "from", "update" + , "delete", "insert", "select" + ]) + +-- | Keywords that introduce a table reference. +clauseKeywords :: Set.Set Text +clauseKeywords = Set.fromList (map Text.pack ["from", "join", "update", "into"]) + +-- | Build a map of aliases to base table names from FROM/JOIN clauses. +-- Used to resolve qualified columns when inferring parameter types. +buildAliasMap :: [SqlToken] -> Map.Map Text Text +buildAliasMap tokens = go tokens Map.empty where + go [] acc = acc + go (TokIdent keyword : rest) acc + | keyword `Set.member` clauseKeywords = + case parseTable rest of + Nothing -> go rest acc + Just (tableName, afterTable) -> + let (alias, afterAlias) = parseAlias afterTable + acc' = Map.insert tableName tableName acc + acc'' = maybe acc' (\name -> Map.insert name tableName acc') alias + in go afterAlias acc'' + | otherwise = go rest acc + go (_:rest) acc = go rest acc + + parseTable (TokIdent _schemaName : TokSymbol '.' : TokIdent tableName : rest) = + Just (tableName, rest) + parseTable (TokIdent tableName : rest) = + Just (tableName, rest) + parseTable _ = Nothing + + parseAlias (TokIdent aliasKeyword : TokIdent alias : rest) + | aliasKeyword == Text.pack "as" = (Just alias, rest) + parseAlias (TokIdent alias : rest) + | alias `Set.notMember` reservedKeywords = (Just alias, rest) + parseAlias rest = (Nothing, rest) + +-- | Find placeholder sites that look like column comparisons and capture their types. +-- This is how typedSql infers a more precise parameter type than the DB-provided OID. +collectParamHints :: [SqlToken] -> Map.Map Text Text -> Map.Map Int ParamHint +collectParamHints tokens aliasMap = + let defaultTable = singleTable aliasMap + in tokens + |> zip [0..] + |> mapMaybe (hintForToken aliasMap defaultTable) + |> foldl' mergeHints Map.empty + |> Map.mapMaybe id + where + tokenAt ix + | ix < 0 = Nothing + | otherwise = tokenAtIndex tokens ix + + singleTable aliases = + case Set.toList (Set.fromList (Map.elems aliases)) of + [table] -> Just table + _ -> Nothing + + hasDotBefore ix = + case tokenAt (ix - 1) of + Just (TokSymbol '.') -> True + _ -> False + + hasDotAfter ix = + case tokenAt (ix + 1) of + Just (TokSymbol '.') -> True + _ -> False + + hintForToken aliases defaultTable (ix, TokParam index) = + let matches = catMaybes + [ matchEqRight aliases ix index + , matchEqLeft aliases ix index + , matchInRight aliases ix index + , matchAnyRight aliases ix index + , matchEqRightUnqualified defaultTable ix index + , matchEqLeftUnqualified defaultTable ix index + , matchInRightUnqualified defaultTable ix index + , matchAnyRightUnqualified defaultTable ix index + ] + in listToMaybe matches + hintForToken _ _ _ = Nothing + + matchEqRight aliases ix index = do + TokSymbol '=' <- tokenAt (ix - 1) + TokIdent column <- tokenAt (ix - 2) + TokSymbol '.' <- tokenAt (ix - 3) + TokIdent tableRef <- tokenAt (ix - 4) + tableName <- Map.lookup tableRef aliases + pure ParamHint { phIndex = index, phTable = tableName, phColumn = column, phArray = False } + + matchEqRightUnqualified defaultTable ix index = do + tableName <- defaultTable + TokSymbol '=' <- tokenAt (ix - 1) + TokIdent column <- tokenAt (ix - 2) + guard (not (hasDotBefore (ix - 2))) + pure ParamHint { phIndex = index, phTable = tableName, phColumn = column, phArray = False } + + matchEqLeft aliases ix index = do + TokSymbol '=' <- tokenAt (ix + 1) + TokIdent tableRef <- tokenAt (ix + 2) + TokSymbol '.' <- tokenAt (ix + 3) + TokIdent column <- tokenAt (ix + 4) + tableName <- Map.lookup tableRef aliases + pure ParamHint { phIndex = index, phTable = tableName, phColumn = column, phArray = False } + + matchEqLeftUnqualified defaultTable ix index = do + tableName <- defaultTable + TokSymbol '=' <- tokenAt (ix + 1) + TokIdent column <- tokenAt (ix + 2) + guard (not (hasDotAfter (ix + 2))) + pure ParamHint { phIndex = index, phTable = tableName, phColumn = column, phArray = False } + + matchInRight aliases ix index = do + TokSymbol '(' <- tokenAt (ix - 1) + TokIdent keyword <- tokenAt (ix - 2) + guard (keyword == Text.pack "in") + TokIdent column <- tokenAt (ix - 3) + TokSymbol '.' <- tokenAt (ix - 4) + TokIdent tableRef <- tokenAt (ix - 5) + tableName <- Map.lookup tableRef aliases + pure ParamHint { phIndex = index, phTable = tableName, phColumn = column, phArray = True } + + matchInRightUnqualified defaultTable ix index = do + tableName <- defaultTable + TokSymbol '(' <- tokenAt (ix - 1) + TokIdent keyword <- tokenAt (ix - 2) + guard (keyword == Text.pack "in") + TokIdent column <- tokenAt (ix - 3) + guard (not (hasDotBefore (ix - 3))) + pure ParamHint { phIndex = index, phTable = tableName, phColumn = column, phArray = True } + + matchAnyRight aliases ix index = do + TokSymbol '(' <- tokenAt (ix - 1) + TokIdent keyword <- tokenAt (ix - 2) + guard (keyword == Text.pack "any") + TokSymbol '=' <- tokenAt (ix - 3) + TokIdent column <- tokenAt (ix - 4) + TokSymbol '.' <- tokenAt (ix - 5) + TokIdent tableRef <- tokenAt (ix - 6) + tableName <- Map.lookup tableRef aliases + pure ParamHint { phIndex = index, phTable = tableName, phColumn = column, phArray = True } + + matchAnyRightUnqualified defaultTable ix index = do + tableName <- defaultTable + TokSymbol '(' <- tokenAt (ix - 1) + TokIdent keyword <- tokenAt (ix - 2) + guard (keyword == Text.pack "any") + TokSymbol '=' <- tokenAt (ix - 3) + TokIdent column <- tokenAt (ix - 4) + guard (not (hasDotBefore (ix - 4))) + pure ParamHint { phIndex = index, phTable = tableName, phColumn = column, phArray = True } + + mergeHints acc hint = + Map.alter (mergeHint hint) (phIndex hint) acc + + mergeHint hint Nothing = Just (Just hint) + mergeHint hint (Just Nothing) = Just Nothing + mergeHint hint (Just (Just existing)) + | existing == hint = Just (Just existing) + | otherwise = Just Nothing + +-- | Convert parameter hints into concrete Haskell types using table metadata. +-- The quasiquoter uses this to apply per-placeholder type annotations. +resolveParamHintTypes :: Map.Map PQ.Oid TableMeta -> Map.Map PQ.Oid PgTypeInfo -> Map.Map Int ParamHint -> TH.Q (Map.Map Int TH.Type) +resolveParamHintTypes tables typeInfo hints = do + let tablesByName = tables + |> Map.toList + |> mapMaybe (\(oid, table@TableMeta { tmName }) -> Just (tmName, (oid, table))) + |> Map.fromList + resolved <- mapM (resolveHint tablesByName) (Map.toList hints) + pure (Map.fromList (catMaybes resolved)) + where + resolveHint tablesByName (index, ParamHint { phTable, phColumn, phArray }) = do + case Map.lookup phTable tablesByName of + Nothing -> pure Nothing + Just (tableOid, table@TableMeta { tmColumns }) -> + case findColumn tmColumns phColumn of + Nothing -> pure Nothing + Just (attnum, ColumnMeta { cmTypeOid }) -> do + baseType <- hsTypeForColumn typeInfo tables DescribeColumn + { dcName = CS.cs phColumn + , dcType = cmTypeOid + , dcTable = tableOid + , dcAttnum = Just attnum + } + let stripped = stripMaybeType baseType + let hintedType = if phArray then TH.AppT TH.ListT stripped else stripped + pure (Just (index, hintedType)) + + findColumn columns columnName = + columns + |> Map.toList + |> List.find (\(_, ColumnMeta { cmName }) -> Text.toLower cmName == Text.toLower columnName) + |> fmap (\(attnum, column) -> (attnum, column)) + +-- | Strip a top-level Maybe wrapper to get the base column type. +-- Used when parameter hints should be non-nullable inputs. +stripMaybeType :: TH.Type -> TH.Type +stripMaybeType (TH.AppT (TH.ConT maybeName) inner) + | maybeName == ''Maybe = inner +stripMaybeType other = other diff --git a/ihp/IHP/TypedSql/Placeholders.hs b/ihp/IHP/TypedSql/Placeholders.hs new file mode 100644 index 000000000..c535180c3 --- /dev/null +++ b/ihp/IHP/TypedSql/Placeholders.hs @@ -0,0 +1,57 @@ +{-# LANGUAGE TemplateHaskell #-} + +module IHP.TypedSql.Placeholders + ( PlaceholderPlan (..) + , planPlaceholders + , parseExpr + ) where + +import qualified Data.String.Conversions as CS +import qualified Language.Haskell.Meta.Parse as HaskellMeta +import qualified Language.Haskell.TH as TH +import IHP.Prelude + +-- | Output of placeholder parsing used by the typedSql quasiquoter. +-- It carries the SQL variant for describe, the SQL variant for runtime execution, +-- and the original Haskell placeholder expressions to splice. +data PlaceholderPlan = PlaceholderPlan + { ppDescribeSql :: !String -- ^ SQL with $1/$2 placeholders for the describe step. + , ppRuntimeSql :: !String -- ^ SQL with ? placeholders for hasql snippet execution. + , ppExprs :: ![String] -- ^ Raw Haskell expressions from ${...}. + } + +-- | Replace ${expr} placeholders with PostgreSQL-style $1 for describe and ? for runtime. +-- High-level: turns a templated SQL string into SQL strings plus expr list. +planPlaceholders :: String -> PlaceholderPlan +planPlaceholders = go 1 "" "" [] where + go _ accDescribe accRuntime exprs [] = + PlaceholderPlan + { ppDescribeSql = reverse accDescribe + , ppRuntimeSql = reverse accRuntime + , ppExprs = reverse exprs + } + go n accDescribe accRuntime exprs ('$':'{':rest) = + let (expr, after) = breakOnClosing 0 "" rest -- parse until matching } + describeToken = reverse ('$' : CS.cs (show n)) + in go (n + 1) + (describeToken <> accDescribe) + ('?' : accRuntime) + (expr : exprs) + after + go n accDescribe accRuntime exprs (c:rest) = + go n (c : accDescribe) (c : accRuntime) exprs rest + + breakOnClosing depth acc [] = (reverse acc, []) -- no closing brace found + breakOnClosing depth acc ('{':xs) = breakOnClosing (depth + 1) ('{':acc) xs -- nested { increases depth + breakOnClosing depth acc ('}':xs) + | depth == 0 = (reverse acc, xs) -- close the current placeholder + | otherwise = breakOnClosing (depth - 1) ('}':acc) xs -- close a nested brace + breakOnClosing depth acc (x:xs) = breakOnClosing depth (x:acc) xs -- accumulate placeholder chars + +-- | Parse a placeholder expression into TH. +-- Used by typedSql to turn ${...} splices into typed expressions. +parseExpr :: String -> TH.ExpQ +parseExpr exprText = + case HaskellMeta.parseExp exprText of + Left err -> fail ("typedSql: failed to parse expression {" <> exprText <> "}: " <> err) -- parse error + Right expr -> pure expr -- success: return parsed TH expression diff --git a/ihp/IHP/TypedSql/Quoter.hs b/ihp/IHP/TypedSql/Quoter.hs new file mode 100644 index 000000000..227ce2a69 --- /dev/null +++ b/ihp/IHP/TypedSql/Quoter.hs @@ -0,0 +1,136 @@ +{-# LANGUAGE NamedFieldPuns #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE TemplateHaskell #-} + +module IHP.TypedSql.Quoter + ( typedSql + ) where + +import Control.Monad (when) +import Data.Coerce (coerce) +import qualified Data.Char as Char +import qualified Data.Map.Strict as Map +import qualified Data.String.Conversions as CS +import qualified Hasql.DynamicStatements.Snippet as Snippet +import qualified Language.Haskell.TH as TH +import qualified Language.Haskell.TH.Quote as TH +import System.Environment (lookupEnv) +import IHP.Prelude +import IHP.Hasql.Encoders () + +import IHP.TypedSql.Bootstrap (describeUsingBootstrap) +import IHP.TypedSql.Decoders (resultDecoderForColumns) +import IHP.TypedSql.Metadata (DescribeColumn (..), DescribeResult (..), PgTypeInfo (..), + describeStatement) +import IHP.TypedSql.ParamHints (buildAliasMap, collectParamHints, + resolveParamHintTypes, tokenizeSql) +import IHP.TypedSql.Placeholders (PlaceholderPlan (..), parseExpr, + planPlaceholders) +import IHP.TypedSql.TypeMapping (hsTypeForColumns, hsTypeForParam) +import IHP.TypedSql.Types (TypedQuery (..)) + +-- | QuasiQuoter entry point for typed SQL. +-- High-level: produces a TH expression that builds a TypedQuery at compile time. +typedSql :: TH.QuasiQuoter +typedSql = + TH.QuasiQuoter + { TH.quoteExp = typedSqlExp + , TH.quotePat = \_ -> fail "typedSql: not supported in patterns" + , TH.quoteType = \_ -> fail "typedSql: not supported in types" + , TH.quoteDec = \_ -> fail "typedSql: not supported at top-level" + } + +-- | Build the TH expression for a typed SQL quasiquote. +-- This is the heart of typedSql: parse placeholders, describe SQL, and assemble a TypedQuery. +typedSqlExp :: String -> TH.ExpQ +typedSqlExp rawSql = do + let PlaceholderPlan { ppDescribeSql, ppRuntimeSql, ppExprs } = planPlaceholders rawSql + parsedExprs <- mapM parseExpr ppExprs + + bootstrapEnv <- TH.runIO (lookupEnv "IHP_TYPED_SQL_BOOTSTRAP") + loc <- TH.location + let useBootstrap = isBootstrapEnabled bootstrapEnv + describeResult <- TH.runIO $ + if useBootstrap + then describeUsingBootstrap (TH.loc_filename loc) ppDescribeSql + else describeStatement (CS.cs ppDescribeSql) + + let DescribeResult { drParams, drColumns, drTables, drTypes } = describeResult + when (length drParams /= length parsedExprs) $ + fail (CS.cs ("typedSql: placeholder count mismatch. SQL expects " <> show (length drParams) <> " parameters but found " <> show (length parsedExprs) <> " ${..} expressions.")) + + paramTypes <- mapM (hsTypeForParam drTypes) drParams + + let sqlTokens = tokenizeSql ppDescribeSql + let aliasMap = buildAliasMap sqlTokens + let paramHints = collectParamHints sqlTokens aliasMap + paramHintTypes <- resolveParamHintTypes drTables drTypes paramHints + + let annotatedParams = + zipWith3 + (\index expr paramTy -> + let expectedType = fromMaybe paramTy (Map.lookup index paramHintTypes) + in TH.SigE (TH.AppE (TH.VarE 'coerce) expr) expectedType + ) + [1..] + parsedExprs + paramTypes + + resultType <- hsTypeForColumns drTypes drTables drColumns + + let isCompositeColumn = + case drColumns of + [DescribeColumn { dcType }] -> + case Map.lookup dcType drTypes of + Just PgTypeInfo { ptiType = Just 'c' } -> True + _ -> False + _ -> False + when (length drColumns == 1 && isCompositeColumn) $ + fail + ("typedSql: composite columns must be expanded (use SELECT table.* " + <> "or list columns explicitly)") + resultDecoder <- resultDecoderForColumns drTypes drTables drColumns + snippetExpr <- buildSnippetExpression ppRuntimeSql annotatedParams + let typedQueryExpr = + TH.AppE + (TH.AppE + (TH.ConE 'TypedQuery) + snippetExpr + ) + resultDecoder + + pure (TH.SigE typedQueryExpr (TH.AppT (TH.ConT ''TypedQuery) resultType)) + +buildSnippetExpression :: String -> [TH.Exp] -> TH.ExpQ +buildSnippetExpression sql params = do + let chunks = splitOnQuestion sql + when (length chunks /= length params + 1) do + fail "typedSql: internal error while building hasql snippet" + let sqlSnippets = map (TH.AppE (TH.VarE 'Snippet.sql) . TH.LitE . TH.StringL) chunks + let paramSnippets = map (TH.AppE (TH.VarE 'Snippet.param)) params + let pieces = interleave sqlSnippets paramSnippets + case pieces of + [] -> pure (TH.AppE (TH.VarE 'Snippet.sql) (TH.LitE (TH.StringL ""))) + firstPiece:restPieces -> + pure (foldl (\acc piece -> TH.InfixE (Just acc) (TH.VarE '(<>) ) (Just piece)) firstPiece restPieces) + +splitOnQuestion :: String -> [String] +splitOnQuestion input = go "" [] input + where + go current acc [] = reverse (reverse current : acc) + go current acc ('?':rest) = go "" (reverse current : acc) rest + go current acc (char:rest) = go (char:current) acc rest + +interleave :: [a] -> [a] -> [a] +interleave [] ys = ys +interleave xs [] = xs +interleave (x:xs) (y:ys) = x : y : interleave xs ys + +-- | Interpret IHP_TYPED_SQL_BOOTSTRAP to decide between live DB and bootstrap mode. +isBootstrapEnabled :: Maybe String -> Bool +isBootstrapEnabled = \case + Nothing -> False + Just raw -> + let value = map Char.toLower raw + in not (value `elem` ["", "0", "false", "no", "off"]) diff --git a/ihp/IHP/TypedSql/TypeMapping.hs b/ihp/IHP/TypedSql/TypeMapping.hs new file mode 100644 index 000000000..43edc54f2 --- /dev/null +++ b/ihp/IHP/TypedSql/TypeMapping.hs @@ -0,0 +1,140 @@ +{-# LANGUAGE NamedFieldPuns #-} + +module IHP.TypedSql.TypeMapping + ( hsTypeForParam + , hsTypeForColumns + , hsTypeForColumn + ) where + +import Control.Monad (guard) +import qualified Data.Aeson as Aeson +import qualified Data.ByteString as BS +import qualified Data.List as List +import qualified Data.Map.Strict as Map +import Data.Maybe (mapMaybe) +import Data.Scientific (Scientific) +import qualified Data.Set as Set +import qualified Data.String.Conversions as CS +import qualified Data.Text as Text +import Data.Time (LocalTime, TimeOfDay, UTCTime) +import Data.Time.Calendar (Day) +import Data.UUID (UUID) +import qualified Database.PostgreSQL.LibPQ as PQ +import qualified Language.Haskell.TH as TH +import Net.IP (IP) +import IHP.ModelSupport.Types (Id') +import IHP.NameSupport (tableNameToModelName) +import IHP.Prelude +import qualified IHP.Postgres.Point as PGPoint +import qualified IHP.Postgres.Polygon as PGPolygon +import qualified IHP.Postgres.TimeParser as PGTime +import qualified IHP.Postgres.TSVector as PGTs + +import IHP.TypedSql.Metadata (ColumnMeta (..), DescribeColumn (..), PgTypeInfo (..), TableMeta (..)) + +-- | Build the Haskell type for a parameter, based on its OID. +-- High-level: map a PG type OID into a TH Type. +hsTypeForParam :: Map.Map PQ.Oid PgTypeInfo -> PQ.Oid -> TH.TypeQ +hsTypeForParam typeInfo oid = maybe (fail (CS.cs unknown)) (hsTypeForPg typeInfo False) (Map.lookup oid typeInfo) + where + unknown = "typedSql: missing type information for parameter oid " <> show oid + +-- | Build the result type for the described columns. +-- High-level: pick a model type for table.* or a tuple type for ad-hoc select lists. +hsTypeForColumns :: Map.Map PQ.Oid PgTypeInfo -> Map.Map PQ.Oid TableMeta -> [DescribeColumn] -> TH.TypeQ +hsTypeForColumns typeInfo tables cols = do + case detectFullTable tables cols of + Just tableName -> + pure (TH.ConT (TH.mkName (CS.cs (tableNameToModelName tableName)))) + Nothing -> do + hsCols <- mapM (hsTypeForColumn typeInfo tables) cols + case hsCols of + [single] -> pure single + _ -> pure $ foldl TH.AppT (TH.TupleT (length hsCols)) hsCols + +-- | Detect whether the columns represent a full table selection (table.* with all columns in order). +-- High-level: if yes, we can return the model type directly. +detectFullTable :: Map.Map PQ.Oid TableMeta -> [DescribeColumn] -> Maybe Text +detectFullTable tables cols = do + guard (not (null cols)) + let grouped = + cols + |> List.groupBy (\a b -> dcTable a == dcTable b) + |> mapMaybe (\group -> case List.uncons group of + Just (first, _) -> Just (dcTable first, group) + Nothing -> Nothing + ) + case grouped of + [(tableOid, colGroup)] | tableOid /= PQ.Oid 0 -> do + TableMeta { tmColumnOrder } <- Map.lookup tableOid tables + let attnums = mapMaybe dcAttnum colGroup + guard (attnums == tmColumnOrder) + TableMeta { tmName } <- Map.lookup tableOid tables + pure tmName + _ -> Nothing + +-- | Map a single column into a Haskell type, with key-aware rules. +hsTypeForColumn :: Map.Map PQ.Oid PgTypeInfo -> Map.Map PQ.Oid TableMeta -> DescribeColumn -> TH.TypeQ +hsTypeForColumn typeInfo tables DescribeColumn { dcType, dcTable, dcAttnum } = + case (Map.lookup dcTable tables, dcAttnum) of + (Just TableMeta { tmName = tableName, tmPrimaryKeys, tmForeignKeys, tmColumns }, Just attnum) -> do + let baseType = Map.lookup attnum tmColumns >>= \ColumnMeta { cmTypeOid } -> Map.lookup cmTypeOid typeInfo + let nullable = maybe True (not . cmNotNull) (Map.lookup attnum tmColumns) + case () of + _ | attnum `Set.member` tmPrimaryKeys -> + pure (wrapNull nullable (idType tableName)) + | Just refTable <- Map.lookup attnum tmForeignKeys -> + case Map.lookup refTable tables of + Just TableMeta { tmName = refName } -> + pure (wrapNull nullable (idType refName)) + Nothing -> + maybe (fail (CS.cs missingType)) (hsTypeForPg typeInfo nullable) baseType + | otherwise -> + maybe (fail (CS.cs missingType)) (hsTypeForPg typeInfo nullable) baseType + where + missingType = "typedSql: missing type info for column " <> show attnum <> " of table " <> tableName + _ -> + maybe (fail (CS.cs ("typedSql: missing type info for column oid " <> show dcType))) (hsTypeForPg typeInfo True) (Map.lookup dcType typeInfo) + +-- | Wrap a type in Maybe when nullable. +wrapNull :: Bool -> TH.Type -> TH.Type +wrapNull nullable ty = if nullable then TH.AppT (TH.ConT ''Maybe) ty else ty + +-- | Build the Id' type for a table name. +idType :: Text -> TH.Type +idType tableName = TH.AppT (TH.ConT ''Id') (TH.LitT (TH.StrTyLit (CS.cs tableName))) + +-- | Map Postgres type metadata to a Haskell type. +-- This is the core mapping used for both parameters and results. +hsTypeForPg :: Map.Map PQ.Oid PgTypeInfo -> Bool -> PgTypeInfo -> TH.TypeQ +hsTypeForPg typeInfo nullable PgTypeInfo { ptiName, ptiElem, ptiType } = do + base <- case () of + _ | Just elemOid <- ptiElem -> do + elemInfo <- maybe (fail (CS.cs ("typedSql: missing array element type for " <> ptiName))) pure (Map.lookup elemOid typeInfo) + elemTy <- hsTypeForPg typeInfo False elemInfo + pure (TH.AppT TH.ListT elemTy) + _ | ptiName `elem` ["int2", "int4"] -> pure (TH.ConT ''Int) + _ | ptiName == "int8" -> pure (TH.ConT ''Integer) + _ | ptiName `elem` ["text", "varchar", "bpchar", "citext"] -> pure (TH.ConT ''Text) + _ | ptiName == "bool" -> pure (TH.ConT ''Bool) + _ | ptiName == "uuid" -> pure (TH.ConT ''UUID) + _ | ptiName == "timestamptz" -> pure (TH.ConT ''UTCTime) + _ | ptiName == "timestamp" -> pure (TH.ConT ''LocalTime) + _ | ptiName == "date" -> pure (TH.ConT ''Day) + _ | ptiName == "time" -> pure (TH.ConT ''TimeOfDay) + _ | ptiName `elem` ["json", "jsonb"] -> pure (TH.ConT ''Aeson.Value) + _ | ptiName == "bytea" -> pure (TH.ConT ''BS.ByteString) + _ | ptiName == "float4" -> pure (TH.ConT ''Float) + _ | ptiName == "float8" -> pure (TH.ConT ''Double) + _ | ptiName == "numeric" -> pure (TH.ConT ''Scientific) + _ | ptiName == "point" -> pure (TH.ConT ''PGPoint.Point) + _ | ptiName == "polygon" -> pure (TH.ConT ''PGPolygon.Polygon) + _ | ptiName == "inet" -> pure (TH.ConT ''IP) + _ | ptiName == "tsvector" -> pure (TH.ConT ''PGTs.TSVector) + _ | ptiName == "interval" -> pure (TH.ConT ''PGTime.PGInterval) + _ | ptiType == Just 'e' -> + pure (TH.ConT (TH.mkName (CS.cs (tableNameToModelName ptiName)))) + _ | ptiType == Just 'c' -> + pure (TH.ConT (TH.mkName (CS.cs (tableNameToModelName ptiName)))) + _ -> pure (TH.ConT (TH.mkName (CS.cs (tableNameToModelName ptiName)))) + pure (wrapNull nullable base) diff --git a/ihp/IHP/TypedSql/Types.hs b/ihp/IHP/TypedSql/Types.hs new file mode 100644 index 000000000..00692fe72 --- /dev/null +++ b/ihp/IHP/TypedSql/Types.hs @@ -0,0 +1,13 @@ +module IHP.TypedSql.Types + ( TypedQuery (..) + ) where + +import qualified Hasql.Decoders as HasqlDecoders +import qualified Hasql.DynamicStatements.Snippet as Snippet + +-- | Prepared query with a custom row parser. +-- High-level: this is the runtime value produced by the typed SQL quasiquoter. +data TypedQuery result = TypedQuery + { tqSnippet :: !Snippet.Snippet + , tqResultDecoder :: !(HasqlDecoders.Row result) + } diff --git a/ihp/Test/Test/Main.hs b/ihp/Test/Test/Main.hs index ca6c6ad13..90ce8aaf0 100644 --- a/ihp/Test/Test/Main.hs +++ b/ihp/Test/Test/Main.hs @@ -20,6 +20,7 @@ import qualified Test.ViewSupportSpec import qualified Test.FileStorage.ControllerFunctionsSpec import qualified Test.PGListenerSpec import qualified Test.MockingSpec +import qualified Test.TypedSqlSpec main :: IO () main = hspec do @@ -40,3 +41,4 @@ main = hspec do Test.Controller.CookieSpec.tests Test.PGListenerSpec.tests Test.MockingSpec.tests + Test.TypedSqlSpec.tests diff --git a/ihp/Test/Test/TypedSqlSpec.hs b/ihp/Test/Test/TypedSqlSpec.hs new file mode 100644 index 000000000..407c0a344 --- /dev/null +++ b/ihp/Test/Test/TypedSqlSpec.hs @@ -0,0 +1,1226 @@ +module Test.TypedSqlSpec where + +import qualified Control.Exception as Exception +import qualified Data.Text as Text +import qualified Data.Text.IO as Text +import IHP.Log.Types +import IHP.ModelSupport (Id', ModelContext, + createModelContext, + releaseModelContext, + sqlExec) +import IHP.Prelude +import System.Directory (doesFileExist, + getCurrentDirectory) +import System.Environment (getEnvironment, lookupEnv) +import System.FilePath (takeDirectory, ()) +import System.Process (CreateProcess (..), proc, + readCreateProcessWithExitCode) +import System.IO.Temp (withSystemTempDirectory) +import Test.Hspec +import qualified Prelude + +tests :: Spec +tests = do + describe "TypedSql macro compile-time checks" do + it "compiles valid typedSql queries with inferred types" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compilePassModule + assertGhciSuccess ghciOutput + + it "fails when a scalar parameter has the wrong type" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailWrongScalarParameter + assertGhciFailure ghciOutput [] + + it "fails when a foreign-key parameter has the wrong type" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailWrongForeignKeyParameter + assertGhciFailure ghciOutput [] + + it "fails when an IN parameter has the wrong element type" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailWrongInParameter + assertGhciFailure ghciOutput [] + + it "fails when a placeholder expression is invalid Haskell" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailInvalidPlaceholderExpression + assertGhciFailure ghciOutput ["failed to parse expression"] + + it "fails when SQL parameter count does not match ${...} placeholders" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailPlaceholderCountMismatch + assertGhciFailure ghciOutput ["placeholder count mismatch"] + + it "fails when selecting a single composite value without expansion" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailSingleCompositeColumn + assertGhciFailure ghciOutput ["composite columns must be expanded"] + + it "fails when SQL references an unknown column" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailUnknownColumn + assertGhciFailure ghciOutput ["does not exist"] + + it "fails when primary-key result type is annotated as UUID instead of Id" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailPrimaryKeyResultAnnotation + assertGhciFailure ghciOutput [] + + it "fails when nullable column result is annotated as non-Maybe" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailNullableResultAnnotation + assertGhciFailure ghciOutput [] + + it "fails when LEFT JOIN result is annotated as Maybe" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailLeftJoinMaybeAnnotation + assertGhciFailure ghciOutput [] + + it "fails when RIGHT JOIN result is annotated as Maybe" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailRightJoinMaybeAnnotation + assertGhciFailure ghciOutput [] + + it "fails when tuple arity does not match selected columns" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailTupleArityMismatch + assertGhciFailure ghciOutput [] + + it "fails when boolean expression result is annotated as Int" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailBooleanResultAnnotation + assertGhciFailure ghciOutput [] + + it "fails when boolean expression result is annotated as non-Maybe Bool" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailBooleanNonMaybeAnnotation + assertGhciFailure ghciOutput [] + + it "fails when COUNT(*) result is annotated as non-Maybe Integer" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailCountNonMaybeAnnotation + assertGhciFailure ghciOutput [] + + it "fails when COALESCE expression is annotated as non-Maybe Text" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailCoalesceNonMaybeAnnotation + assertGhciFailure ghciOutput [] + + it "fails when literal expression result is annotated as non-Maybe Int" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailLiteralNonMaybeAnnotation + assertGhciFailure ghciOutput [] + + it "fails when arithmetic expression result is annotated as non-Maybe Int" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailArithmeticNonMaybeAnnotation + assertGhciFailure ghciOutput [] + + it "fails when CASE expression result is annotated as non-Maybe Text" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailCaseNonMaybeAnnotation + assertGhciFailure ghciOutput [] + + it "fails when EXISTS expression result is annotated as non-Maybe Bool" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailExistsNonMaybeAnnotation + assertGhciFailure ghciOutput [] + + it "fails when NULL literal result is annotated as non-Maybe Text" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailNullLiteralNonMaybeAnnotation + assertGhciFailure ghciOutput [] + + it "fails when CTE result is annotated as Maybe Text" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailCteMaybeAnnotation + assertGhciFailure ghciOutput [] + + it "fails when subquery result is annotated as Maybe Text" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailSubqueryMaybeAnnotation + assertGhciFailure ghciOutput [] + + it "fails when UNION result is annotated as non-Maybe Text" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailUnionNonMaybeAnnotation + assertGhciFailure ghciOutput [] + + it "fails when window function result is annotated as non-Maybe Integer" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailWindowNonMaybeAnnotation + assertGhciFailure ghciOutput [] + + it "fails when grouped COUNT(*) result is annotated as non-Maybe Integer" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailGroupedCountNonMaybeAnnotation + assertGhciFailure ghciOutput [] + + it "fails when array literal result is annotated as non-Maybe [Text]" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailArrayLiteralNonMaybeAnnotation + assertGhciFailure ghciOutput [] + + it "fails when NULLIF expression result is annotated as non-Maybe Text" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciLoadModule compileFailNullIfNonMaybeAnnotation + assertGhciFailure ghciOutput [] + + describe "TypedSql macro runtime execution" do + it "executes typedSql queries end-to-end via ghci" do + requirePostgresTestHook + withTestModelContext do + setupSchema + ghciOutput <- ghciRunModule runtimeModule + assertGhciSuccess ghciOutput + ghciOutput `shouldContainText` "RUNTIME_OK" + +requirePostgresTestHook :: IO () +requirePostgresTestHook = do + maybePgHost <- lookupEnv "PGHOST" + when (isNothing maybePgHost) do + pendingWith "requires postgresqlTestHook / withTestPostgres (PGHOST is not set)" + +withTestModelContext :: ((?modelContext :: ModelContext) => IO a) -> IO a +withTestModelContext action = do + logger <- newLogger def { level = Warn } + modelContext <- createModelContext 10 1 "" logger + let ?modelContext = modelContext + action `Exception.finally` releaseModelContext modelContext + +setupSchema :: (?modelContext :: ModelContext) => IO () +setupSchema = do + sqlExec "DROP TABLE IF EXISTS typed_sql_test_items" () + sqlExec "DROP TABLE IF EXISTS typed_sql_test_authors" () + sqlExec "DROP TYPE IF EXISTS typed_sql_test_pair" () + + sqlExec "CREATE TYPE typed_sql_test_pair AS (name TEXT, views INT)" () + + sqlExec + "CREATE TABLE typed_sql_test_authors (id UUID PRIMARY KEY, name TEXT NOT NULL)" + () + + sqlExec + "CREATE TABLE typed_sql_test_items (id UUID PRIMARY KEY, author_id UUID REFERENCES typed_sql_test_authors(id), name TEXT NOT NULL, views INT NOT NULL, score DOUBLE PRECISION, tags TEXT[] NOT NULL DEFAULT '{}')" + () + + sqlExec + "INSERT INTO typed_sql_test_authors (id, name) VALUES ('00000000-0000-0000-0000-000000000001'::uuid, 'Alice')" + () + + sqlExec + "INSERT INTO typed_sql_test_authors (id, name) VALUES ('00000000-0000-0000-0000-000000000002'::uuid, 'Bob')" + () + + sqlExec + "INSERT INTO typed_sql_test_items (id, author_id, name, views, score, tags) VALUES ('10000000-0000-0000-0000-000000000001'::uuid, '00000000-0000-0000-0000-000000000001'::uuid, 'First', 5, 1.5, ARRAY['red', 'blue'])" + () + + sqlExec + "INSERT INTO typed_sql_test_items (id, author_id, name, views, score, tags) VALUES ('10000000-0000-0000-0000-000000000002'::uuid, '00000000-0000-0000-0000-000000000001'::uuid, 'Second', 8, NULL, ARRAY['green'])" + () + + pure () + +ghciLoadModule :: Text -> IO Text +ghciLoadModule source = + ghciRun source [":set -fno-code"] [] + +ghciRunModule :: Text -> IO Text +ghciRunModule source = + ghciRun source [] ["main"] + +ghciRun :: Text -> [Text] -> [Text] -> IO Text +ghciRun source preLoadCommands postLoadCommands = + withSystemTempDirectory "typed-sql-ghci" \tempDir -> do + packageRoot <- findIhpPackageRoot + let repoRoot = takeDirectory packageRoot + useRepoGhci <- doesFileExist (repoRoot ".ghci") + env <- ghciEnvironment + + let modulePath = tempDir "TypedSqlCase.hs" + Text.writeFile modulePath source + + let commands = + ghciDefaultExtensionCommands + <> preLoadCommands + <> [":l " <> tshow modulePath] + <> postLoadCommands + <> [":quit"] + + let ghciArgs = + if useRepoGhci + then ["-v0"] + else ["-ignore-dot-ghci", "-v0", "-i" <> packageRoot] + + let process = (proc "ghci" ghciArgs) + { cwd = Just (if useRepoGhci then repoRoot else packageRoot) + , env = Just env + } + + (_exitCode, stdOut, stdErr) <- readCreateProcessWithExitCode process (cs (Text.unlines commands)) + pure (cs stdOut <> cs stdErr) + +ghciDefaultExtensionCommands :: [Text] +ghciDefaultExtensionCommands = + map (":set " <>) + [ "-XGHC2021" + , "-XNoImplicitPrelude" + , "-XImplicitParams" + , "-XOverloadedStrings" + , "-XDisambiguateRecordFields" + , "-XDuplicateRecordFields" + , "-XOverloadedLabels" + , "-XDataKinds" + , "-XQuasiQuotes" + , "-XTypeFamilies" + , "-XPackageImports" + , "-XRecordWildCards" + , "-XDefaultSignatures" + , "-XFunctionalDependencies" + , "-XPartialTypeSignatures" + , "-XBlockArguments" + , "-XLambdaCase" + , "-XTemplateHaskell" + , "-XOverloadedRecordDot" + , "-XDeepSubsumption" + , "-XExplicitNamespaces" + ] + +findIhpPackageRoot :: IO FilePath +findIhpPackageRoot = do + currentDirectory <- getCurrentDirectory + + let inPackageRoot = currentDirectory "IHP" "TypedSql.hs" + inPackageExists <- doesFileExist inPackageRoot + if inPackageExists + then pure currentDirectory + else do + let fromRepoRoot = currentDirectory "ihp" "IHP" "TypedSql.hs" + fromRepoExists <- doesFileExist fromRepoRoot + if fromRepoExists + then pure (currentDirectory "ihp") + else fail "TypedSqlSpec: could not locate ihp package root" + +ghciEnvironment :: IO [(String, String)] +ghciEnvironment = do + baseEnvironment <- getEnvironment + + pgHost <- fromMaybe "" <$> lookupEnv "PGHOST" + pgDatabase <- fromMaybe "" <$> lookupEnv "PGDATABASE" + pgUser <- fromMaybe "" <$> lookupEnv "PGUSER" + pgPort <- lookupEnv "PGPORT" + + let databaseUrlParts :: [String] + databaseUrlParts = + [ "host=" <> pgHost + , "dbname=" <> pgDatabase + , "user=" <> pgUser + ] <> case pgPort of + Just port | not (null port) -> ["port=" <> port] + _ -> [] + + let overrides :: [(String, String)] + overrides = + [ ("DATABASE_URL", Prelude.unwords databaseUrlParts) + ] + + pure (applyEnvironmentOverrides overrides baseEnvironment) + +applyEnvironmentOverrides :: [(String, String)] -> [(String, String)] -> [(String, String)] +applyEnvironmentOverrides overrides base = + overrides <> filter (\(name, _) -> name `notElem` map fst overrides) base + +assertGhciSuccess :: Text -> IO () +assertGhciSuccess output = + when (containsCompileError output) do + expectationFailure ("expected ghci load/run to succeed, but got:\n" <> cs output) + +assertGhciFailure :: Text -> [Text] -> IO () +assertGhciFailure output expectedFragments = do + when (not (containsCompileError output)) do + expectationFailure ("expected ghci load to fail, but got:\n" <> cs output) + + forM_ expectedFragments \fragment -> + when (not (Text.toLower fragment `Text.isInfixOf` Text.toLower output)) do + expectationFailure + ( "expected ghci output to contain fragment: " + <> cs fragment + <> "\nactual output:\n" + <> cs output + ) + +containsCompileError :: Text -> Bool +containsCompileError output = + let lower = Text.toLower output + in " error:" `Text.isInfixOf` lower + || "\nerror:" `Text.isInfixOf` lower + +shouldContainText :: Text -> Text -> Expectation +shouldContainText haystack needle = + when (not (needle `Text.isInfixOf` haystack)) do + expectationFailure + ( "expected text output to contain: " + <> cs needle + <> "\nactual output:\n" + <> cs haystack + ) + +compilePassModule :: Text +compilePassModule = Text.unlines + [ "{-# LANGUAGE DataKinds #-}" + , "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "{-# LANGUAGE TypeApplications #-}" + , "{-# LANGUAGE TypeFamilies #-}" + , "module TypedSqlCompilePass where" + , "" + , "import IHP.Prelude" + , "import IHP.ModelSupport (Id'(..), PrimaryKey)" + , "import IHP.Hasql.FromRow (FromRowHasql (..))" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "import qualified Hasql.Decoders as HasqlDecoders" + , "" + , "type instance PrimaryKey \"typed_sql_test_items\" = UUID" + , "type instance PrimaryKey \"typed_sql_test_authors\" = UUID" + , "" + , "data TypedSqlTestItem = TypedSqlTestItem" + , " { typedSqlTestItemId :: Id' \"typed_sql_test_items\"" + , " , typedSqlTestItemAuthorId :: Maybe (Id' \"typed_sql_test_authors\")" + , " , typedSqlTestItemName :: Text" + , " , typedSqlTestItemViews :: Int" + , " , typedSqlTestItemScore :: Maybe Double" + , " , typedSqlTestItemTags :: [Text]" + , " } deriving (Eq, Show)" + , "" + , "instance FromRowHasql TypedSqlTestItem where" + , " hasqlRowDecoder =" + , " TypedSqlTestItem" + , " <$> (fmap Id (HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.uuid)))" + , " <*> (fmap (fmap Id) (HasqlDecoders.column (HasqlDecoders.nullable HasqlDecoders.uuid)))" + , " <*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.text)" + , " <*> (fmap fromIntegral (HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4)))" + , " <*> HasqlDecoders.column (HasqlDecoders.nullable HasqlDecoders.float8)" + , " <*> HasqlDecoders.column (HasqlDecoders.nonNullable (HasqlDecoders.listArray (HasqlDecoders.nonNullable HasqlDecoders.text)))" + , "" + , "qName :: TypedQuery Text" + , "qName = [typedSql| SELECT name FROM typed_sql_test_items LIMIT 1 |]" + , "" + , "qAllFields :: TypedQuery TypedSqlTestItem" + , "qAllFields = [typedSql| SELECT typed_sql_test_items.* FROM typed_sql_test_items LIMIT 1 |]" + , "" + , "qAllFieldsAlias :: TypedQuery TypedSqlTestItem" + , "qAllFieldsAlias = [typedSql| SELECT i.* FROM typed_sql_test_items i JOIN typed_sql_test_authors a ON a.id = i.author_id LIMIT 1 |]" + , "" + , "qPrimaryKey :: TypedQuery (Id' \"typed_sql_test_items\")" + , "qPrimaryKey = [typedSql| SELECT id FROM typed_sql_test_items LIMIT 1 |]" + , "" + , "qForeignKey :: TypedQuery (Maybe (Id' \"typed_sql_test_authors\"))" + , "qForeignKey = [typedSql| SELECT author_id FROM typed_sql_test_items LIMIT 1 |]" + , "" + , "qNullable :: TypedQuery (Maybe Double)" + , "qNullable = [typedSql| SELECT score FROM typed_sql_test_items LIMIT 1 |]" + , "" + , "qArray :: TypedQuery [Text]" + , "qArray = [typedSql| SELECT tags FROM typed_sql_test_items LIMIT 1 |]" + , "" + , "qTuple :: TypedQuery (Id' \"typed_sql_test_items\", Text, Int)" + , "qTuple = [typedSql| SELECT id, name, views FROM typed_sql_test_items LIMIT 1 |]" + , "" + , "qEqParam :: TypedQuery Text" + , "qEqParam = [typedSql| SELECT name FROM typed_sql_test_items WHERE views = ${5 :: Int} LIMIT 1 |]" + , "" + , "qForeignKeyParamHint :: TypedQuery Text" + , "qForeignKeyParamHint =" + , " let authorId = (\"00000000-0000-0000-0000-000000000001\" :: Id' \"typed_sql_test_authors\")" + , " in [typedSql| SELECT name FROM typed_sql_test_items WHERE author_id = ${authorId} LIMIT 1 |]" + , "" + , "qInParamHint :: TypedQuery Text" + , "qInParamHint =" + , " let authorIds = [ (\"00000000-0000-0000-0000-000000000001\" :: Id' \"typed_sql_test_authors\") ]" + , " in [typedSql| SELECT name FROM typed_sql_test_items WHERE author_id IN (${authorIds}) LIMIT 1 |]" + , "" + , "qAnyParamHint :: TypedQuery Text" + , "qAnyParamHint =" + , " let itemIds =" + , " [ (\"10000000-0000-0000-0000-000000000001\" :: Id' \"typed_sql_test_items\")" + , " , (\"10000000-0000-0000-0000-000000000002\" :: Id' \"typed_sql_test_items\")" + , " ]" + , " in [typedSql| SELECT name FROM typed_sql_test_items WHERE id = ANY(${itemIds}) ORDER BY name LIMIT 1 |]" + , "" + , "qCompositeExpanded :: TypedQuery (Maybe Text, Maybe Int)" + , "qCompositeExpanded = [typedSql| SELECT (ROW(name, views)::typed_sql_test_pair).* FROM typed_sql_test_items LIMIT 1 |]" + , "" + , "qBoolExpr :: TypedQuery (Maybe Bool)" + , "qBoolExpr = [typedSql| SELECT author_id IS NULL FROM typed_sql_test_items LIMIT 1 |]" + , "" + , "qCountExpr :: TypedQuery (Maybe Integer)" + , "qCountExpr = [typedSql| SELECT COUNT(*) FROM typed_sql_test_items |]" + , "" + , "qLiteralInt :: TypedQuery (Maybe Int)" + , "qLiteralInt = [typedSql| SELECT 1 |]" + , "" + , "qArithmeticExpr :: TypedQuery (Maybe Int)" + , "qArithmeticExpr = [typedSql| SELECT views + 1 FROM typed_sql_test_items LIMIT 1 |]" + , "" + , "qCaseExpr :: TypedQuery (Maybe Text)" + , "qCaseExpr = [typedSql| SELECT CASE WHEN views > 5 THEN name ELSE 'low' END FROM typed_sql_test_items LIMIT 1 |]" + , "" + , "qExistsExpr :: TypedQuery (Maybe Bool)" + , "qExistsExpr = [typedSql| SELECT EXISTS(SELECT 1 FROM typed_sql_test_items WHERE views > 7) |]" + , "" + , "qNullLiteral :: TypedQuery (Maybe Text)" + , "qNullLiteral = [typedSql| SELECT NULL::text |]" + , "" + , "qCte :: TypedQuery Text" + , "qCte = [typedSql| WITH item_names AS (SELECT name FROM typed_sql_test_items WHERE views > 6) SELECT name FROM item_names LIMIT 1 |]" + , "" + , "qSubquery :: TypedQuery Text" + , "qSubquery = [typedSql| SELECT name FROM (SELECT name FROM typed_sql_test_items WHERE views < 6) sub LIMIT 1 |]" + , "" + , "qUnion :: TypedQuery (Maybe Text)" + , "qUnion = [typedSql| SELECT name FROM typed_sql_test_items WHERE views > 6 UNION ALL SELECT name FROM typed_sql_test_items WHERE views < 6 |]" + , "" + , "qWindow :: TypedQuery (Maybe Integer)" + , "qWindow = [typedSql| SELECT row_number() OVER (ORDER BY name) FROM typed_sql_test_items LIMIT 1 |]" + , "" + , "qGroupedCount :: TypedQuery (Text, Maybe Integer)" + , "qGroupedCount = [typedSql| SELECT name, COUNT(*) FROM typed_sql_test_items GROUP BY name ORDER BY name LIMIT 1 |]" + , "" + , "qArrayLiteral :: TypedQuery (Maybe [Text])" + , "qArrayLiteral = [typedSql| SELECT ARRAY['x','y']::text[] |]" + , "" + , "qNullIfExpr :: TypedQuery (Maybe Text)" + , "qNullIfExpr = [typedSql| SELECT NULLIF(name, 'First') FROM typed_sql_test_items LIMIT 1 |]" + , "" + , "qSchemaQualified :: TypedQuery Text" + , "qSchemaQualified = [typedSql| SELECT name FROM public.typed_sql_test_items LIMIT 1 |]" + , "" + , "qQuotedIdentifiers :: TypedQuery Text" + , "qQuotedIdentifiers = [typedSql| SELECT \"name\" FROM \"typed_sql_test_items\" LIMIT 1 |]" + , "" + , "qInnerJoin :: TypedQuery (Text, Text)" + , "qInnerJoin = [typedSql| SELECT i.name, a.name FROM typed_sql_test_items i INNER JOIN typed_sql_test_authors a ON a.id = i.author_id LIMIT 1 |]" + , "" + , "qLeftJoin :: TypedQuery (Text, Text)" + , "qLeftJoin = [typedSql| SELECT i.name, a.name FROM typed_sql_test_items i LEFT JOIN typed_sql_test_authors a ON a.id = i.author_id LIMIT 1 |]" + , "" + , "qRightJoin :: TypedQuery (Text, Text)" + , "qRightJoin = [typedSql| SELECT i.name, a.name FROM typed_sql_test_items i RIGHT JOIN typed_sql_test_authors a ON a.id = i.author_id LIMIT 1 |]" + , "" + , "qRightJoinCoalesced :: TypedQuery (Maybe Text, Text)" + , "qRightJoinCoalesced = [typedSql| SELECT COALESCE(i.name, '(no-item)'), a.name FROM typed_sql_test_items i RIGHT JOIN typed_sql_test_authors a ON a.id = i.author_id LIMIT 1 |]" + ] + +compileFailWrongScalarParameter :: Text +compileFailWrongScalarParameter = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailWrongScalarParameter where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery Text" + , "bad = [typedSql| SELECT name FROM typed_sql_test_items WHERE views = ${(\"not an int\" :: Text)} LIMIT 1 |]" + ] + +compileFailWrongForeignKeyParameter :: Text +compileFailWrongForeignKeyParameter = Text.unlines + [ "{-# LANGUAGE DataKinds #-}" + , "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "{-# LANGUAGE TypeFamilies #-}" + , "module TypedSqlCompileFailWrongForeignKeyParameter where" + , "" + , "import IHP.Prelude" + , "import IHP.ModelSupport (PrimaryKey)" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "type instance PrimaryKey \"typed_sql_test_items\" = UUID" + , "type instance PrimaryKey \"typed_sql_test_authors\" = UUID" + , "" + , "bad :: TypedQuery Text" + , "bad = [typedSql| SELECT name FROM typed_sql_test_items WHERE author_id = ${(\"not-an-id\" :: Text)} LIMIT 1 |]" + ] + +compileFailWrongInParameter :: Text +compileFailWrongInParameter = Text.unlines + [ "{-# LANGUAGE DataKinds #-}" + , "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "{-# LANGUAGE TypeFamilies #-}" + , "module TypedSqlCompileFailWrongInParameter where" + , "" + , "import IHP.Prelude" + , "import IHP.ModelSupport (PrimaryKey)" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "type instance PrimaryKey \"typed_sql_test_items\" = UUID" + , "type instance PrimaryKey \"typed_sql_test_authors\" = UUID" + , "" + , "bad :: TypedQuery Text" + , "bad =" + , " let authorIds = [\"one\" :: Text, \"two\" :: Text]" + , " in [typedSql| SELECT name FROM typed_sql_test_items WHERE author_id IN (${authorIds}) LIMIT 1 |]" + ] + +compileFailInvalidPlaceholderExpression :: Text +compileFailInvalidPlaceholderExpression = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailInvalidPlaceholderExpression where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery Text" + , "bad = [typedSql| SELECT name FROM typed_sql_test_items WHERE views = ${(} LIMIT 1 |]" + ] + +compileFailPlaceholderCountMismatch :: Text +compileFailPlaceholderCountMismatch = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailPlaceholderCountMismatch where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery Text" + , "bad = [typedSql| SELECT name FROM typed_sql_test_items WHERE views = $1 LIMIT 1 |]" + ] + +compileFailSingleCompositeColumn :: Text +compileFailSingleCompositeColumn = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailSingleCompositeColumn where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery Text" + , "bad = [typedSql| SELECT ROW(name, views)::typed_sql_test_pair FROM typed_sql_test_items LIMIT 1 |]" + ] + +compileFailUnknownColumn :: Text +compileFailUnknownColumn = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailUnknownColumn where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery Text" + , "bad = [typedSql| SELECT no_such_column FROM typed_sql_test_items LIMIT 1 |]" + ] + +compileFailPrimaryKeyResultAnnotation :: Text +compileFailPrimaryKeyResultAnnotation = Text.unlines + [ "{-# LANGUAGE DataKinds #-}" + , "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "{-# LANGUAGE TypeFamilies #-}" + , "module TypedSqlCompileFailPrimaryKeyResultAnnotation where" + , "" + , "import IHP.Prelude" + , "import IHP.ModelSupport (PrimaryKey)" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "type instance PrimaryKey \"typed_sql_test_items\" = UUID" + , "" + , "bad :: TypedQuery UUID" + , "bad = [typedSql| SELECT id FROM typed_sql_test_items LIMIT 1 |]" + ] + +compileFailNullableResultAnnotation :: Text +compileFailNullableResultAnnotation = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailNullableResultAnnotation where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery Double" + , "bad = [typedSql| SELECT score FROM typed_sql_test_items LIMIT 1 |]" + ] + +compileFailLeftJoinMaybeAnnotation :: Text +compileFailLeftJoinMaybeAnnotation = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailLeftJoinMaybeAnnotation where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery (Text, Maybe Text)" + , "bad = [typedSql| SELECT i.name, a.name FROM typed_sql_test_items i LEFT JOIN typed_sql_test_authors a ON a.id = i.author_id LIMIT 1 |]" + ] + +compileFailRightJoinMaybeAnnotation :: Text +compileFailRightJoinMaybeAnnotation = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailRightJoinMaybeAnnotation where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery (Maybe Text, Text)" + , "bad = [typedSql| SELECT i.name, a.name FROM typed_sql_test_items i RIGHT JOIN typed_sql_test_authors a ON a.id = i.author_id ORDER BY a.name LIMIT 1 |]" + ] + +compileFailTupleArityMismatch :: Text +compileFailTupleArityMismatch = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailTupleArityMismatch where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery Text" + , "bad = [typedSql| SELECT name, views FROM typed_sql_test_items LIMIT 1 |]" + ] + +compileFailBooleanResultAnnotation :: Text +compileFailBooleanResultAnnotation = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailBooleanResultAnnotation where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery Int" + , "bad = [typedSql| SELECT author_id IS NULL FROM typed_sql_test_items LIMIT 1 |]" + ] + +compileFailBooleanNonMaybeAnnotation :: Text +compileFailBooleanNonMaybeAnnotation = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailBooleanNonMaybeAnnotation where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery Bool" + , "bad = [typedSql| SELECT author_id IS NULL FROM typed_sql_test_items LIMIT 1 |]" + ] + +compileFailCountNonMaybeAnnotation :: Text +compileFailCountNonMaybeAnnotation = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailCountNonMaybeAnnotation where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery Integer" + , "bad = [typedSql| SELECT COUNT(*) FROM typed_sql_test_items |]" + ] + +compileFailCoalesceNonMaybeAnnotation :: Text +compileFailCoalesceNonMaybeAnnotation = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailCoalesceNonMaybeAnnotation where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery (Text, Text)" + , "bad = [typedSql| SELECT COALESCE(i.name, '(no-item)'), a.name FROM typed_sql_test_items i RIGHT JOIN typed_sql_test_authors a ON a.id = i.author_id LIMIT 1 |]" + ] + +compileFailLiteralNonMaybeAnnotation :: Text +compileFailLiteralNonMaybeAnnotation = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailLiteralNonMaybeAnnotation where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery Int" + , "bad = [typedSql| SELECT 1 |]" + ] + +compileFailArithmeticNonMaybeAnnotation :: Text +compileFailArithmeticNonMaybeAnnotation = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailArithmeticNonMaybeAnnotation where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery Int" + , "bad = [typedSql| SELECT views + 1 FROM typed_sql_test_items LIMIT 1 |]" + ] + +compileFailCaseNonMaybeAnnotation :: Text +compileFailCaseNonMaybeAnnotation = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailCaseNonMaybeAnnotation where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery Text" + , "bad = [typedSql| SELECT CASE WHEN views > 5 THEN name ELSE 'low' END FROM typed_sql_test_items LIMIT 1 |]" + ] + +compileFailExistsNonMaybeAnnotation :: Text +compileFailExistsNonMaybeAnnotation = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailExistsNonMaybeAnnotation where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery Bool" + , "bad = [typedSql| SELECT EXISTS(SELECT 1 FROM typed_sql_test_items WHERE views > 7) |]" + ] + +compileFailNullLiteralNonMaybeAnnotation :: Text +compileFailNullLiteralNonMaybeAnnotation = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailNullLiteralNonMaybeAnnotation where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery Text" + , "bad = [typedSql| SELECT NULL::text |]" + ] + +compileFailCteMaybeAnnotation :: Text +compileFailCteMaybeAnnotation = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailCteMaybeAnnotation where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery (Maybe Text)" + , "bad = [typedSql| WITH item_names AS (SELECT name FROM typed_sql_test_items WHERE views > 6) SELECT name FROM item_names LIMIT 1 |]" + ] + +compileFailSubqueryMaybeAnnotation :: Text +compileFailSubqueryMaybeAnnotation = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailSubqueryMaybeAnnotation where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery (Maybe Text)" + , "bad = [typedSql| SELECT name FROM (SELECT name FROM typed_sql_test_items WHERE views < 6) sub LIMIT 1 |]" + ] + +compileFailUnionNonMaybeAnnotation :: Text +compileFailUnionNonMaybeAnnotation = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailUnionNonMaybeAnnotation where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery Text" + , "bad = [typedSql| SELECT name FROM typed_sql_test_items WHERE views > 6 UNION ALL SELECT name FROM typed_sql_test_items WHERE views < 6 |]" + ] + +compileFailWindowNonMaybeAnnotation :: Text +compileFailWindowNonMaybeAnnotation = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailWindowNonMaybeAnnotation where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery Integer" + , "bad = [typedSql| SELECT row_number() OVER (ORDER BY name) FROM typed_sql_test_items LIMIT 1 |]" + ] + +compileFailGroupedCountNonMaybeAnnotation :: Text +compileFailGroupedCountNonMaybeAnnotation = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailGroupedCountNonMaybeAnnotation where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery (Text, Integer)" + , "bad = [typedSql| SELECT name, COUNT(*) FROM typed_sql_test_items GROUP BY name ORDER BY name LIMIT 1 |]" + ] + +compileFailArrayLiteralNonMaybeAnnotation :: Text +compileFailArrayLiteralNonMaybeAnnotation = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailArrayLiteralNonMaybeAnnotation where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery [Text]" + , "bad = [typedSql| SELECT ARRAY['x','y']::text[] |]" + ] + +compileFailNullIfNonMaybeAnnotation :: Text +compileFailNullIfNonMaybeAnnotation = Text.unlines + [ "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "module TypedSqlCompileFailNullIfNonMaybeAnnotation where" + , "" + , "import IHP.Prelude" + , "import IHP.TypedSql (TypedQuery, typedSql)" + , "" + , "bad :: TypedQuery Text" + , "bad = [typedSql| SELECT NULLIF(name, 'First') FROM typed_sql_test_items LIMIT 1 |]" + ] + +runtimeModule :: Text +runtimeModule = Text.unlines + [ "{-# LANGUAGE DataKinds #-}" + , "{-# LANGUAGE ImplicitParams #-}" + , "{-# LANGUAGE NoImplicitPrelude #-}" + , "{-# LANGUAGE OverloadedStrings #-}" + , "{-# LANGUAGE QuasiQuotes #-}" + , "{-# LANGUAGE TypeFamilies #-}" + , "module Main where" + , "" + , "import qualified Control.Exception as Exception" + , "import IHP.Prelude" + , "import IHP.Log.Types" + , "import IHP.ModelSupport (Id'(..), ModelContext, PrimaryKey, createModelContext, releaseModelContext)" + , "import IHP.Hasql.FromRow (FromRowHasql (..))" + , "import IHP.TypedSql (sqlExecTyped, sqlQueryTyped, typedSql)" + , "import qualified Hasql.Decoders as HasqlDecoders" + , "" + , "type instance PrimaryKey \"typed_sql_test_items\" = UUID" + , "type instance PrimaryKey \"typed_sql_test_authors\" = UUID" + , "" + , "data TypedSqlTestItem = TypedSqlTestItem" + , " { typedSqlTestItemId :: Id' \"typed_sql_test_items\"" + , " , typedSqlTestItemAuthorId :: Maybe (Id' \"typed_sql_test_authors\")" + , " , typedSqlTestItemName :: Text" + , " , typedSqlTestItemViews :: Int" + , " , typedSqlTestItemScore :: Maybe Double" + , " , typedSqlTestItemTags :: [Text]" + , " } deriving (Eq, Show)" + , "" + , "instance FromRowHasql TypedSqlTestItem where" + , " hasqlRowDecoder =" + , " TypedSqlTestItem" + , " <$> (fmap Id (HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.uuid)))" + , " <*> (fmap (fmap Id) (HasqlDecoders.column (HasqlDecoders.nullable HasqlDecoders.uuid)))" + , " <*> HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.text)" + , " <*> (fmap fromIntegral (HasqlDecoders.column (HasqlDecoders.nonNullable HasqlDecoders.int4)))" + , " <*> HasqlDecoders.column (HasqlDecoders.nullable HasqlDecoders.float8)" + , " <*> HasqlDecoders.column (HasqlDecoders.nonNullable (HasqlDecoders.listArray (HasqlDecoders.nonNullable HasqlDecoders.text)))" + , "" + , "main :: IO ()" + , "main = do" + , " logger <- newLogger def { level = Warn }" + , " modelContext <- createModelContext 10 1 \"\" logger" + , " let ?modelContext = modelContext" + , " flip Exception.finally (releaseModelContext modelContext) do" + , " let authorId = (\"00000000-0000-0000-0000-000000000001\" :: UUID)" + , " let itemId1 = (\"10000000-0000-0000-0000-000000000001\" :: UUID)" + , " let itemId2 = (\"10000000-0000-0000-0000-000000000002\" :: UUID)" + , "" + , " _ <- sqlExecTyped [typedSql| DELETE FROM typed_sql_test_items |]" + , "" + , " _ <- sqlExecTyped [typedSql|" + , " INSERT INTO typed_sql_test_items (id, author_id, name, views, score, tags)" + , " VALUES (${itemId1}, ${authorId}, ${(\"First\" :: Text)}, ${5 :: Int}, ${(1.5 :: Double)}, ${([\"red\", \"blue\"] :: [Text])})" + , " |]" + , "" + , " _ <- sqlExecTyped [typedSql|" + , " INSERT INTO typed_sql_test_items (id, author_id, name, views, score, tags)" + , " VALUES (${itemId2}, ${authorId}, ${(\"Second\" :: Text)}, ${8 :: Int}, ${(2.0 :: Double)}, ${([\"green\"] :: [Text])})" + , " |]" + , "" + , " names <- sqlQueryTyped [typedSql|" + , " SELECT name FROM typed_sql_test_items" + , " WHERE views > ${3 :: Int}" + , " ORDER BY name" + , " |]" + , "" + , " when ((names :: [Text]) /= [\"First\", \"Second\"]) do" + , " error (\"unexpected names from typedSql: \" <> show names)" + , "" + , " namesViaTypedSql <- sqlQueryTyped [typedSql|" + , " SELECT name FROM typed_sql_test_items" + , " WHERE views >= ${5 :: Int}" + , " ORDER BY name" + , " |]" + , "" + , " when ((namesViaTypedSql :: [Text]) /= [\"First\", \"Second\"]) do" + , " error (\"unexpected names from typedSql second query: \" <> show namesViaTypedSql)" + , "" + , " allItems <- sqlQueryTyped [typedSql|" + , " SELECT typed_sql_test_items.*" + , " FROM typed_sql_test_items" + , " ORDER BY name" + , " |]" + , "" + , " let expectedItems =" + , " [ TypedSqlTestItem (Id itemId1) (Just (Id authorId)) \"First\" 5 (Just 1.5) [\"red\", \"blue\"]" + , " , TypedSqlTestItem (Id itemId2) (Just (Id authorId)) \"Second\" 8 (Just 2.0) [\"green\"]" + , " ]" + , " when ((allItems :: [TypedSqlTestItem]) /= expectedItems) do" + , " error (\"unexpected rows from table.* query: \" <> show allItems)" + , "" + , " boolExprRows <- sqlQueryTyped [typedSql|" + , " SELECT author_id IS NULL" + , " FROM typed_sql_test_items" + , " ORDER BY name" + , " |]" + , "" + , " when ((boolExprRows :: [Maybe Bool]) /= [Just False, Just False]) do" + , " error (\"unexpected rows from bool expression query: \" <> show boolExprRows)" + , "" + , " countRows <- sqlQueryTyped [typedSql| SELECT COUNT(*) FROM typed_sql_test_items |]" + , "" + , " when ((countRows :: [Maybe Integer]) /= [Just 2]) do" + , " error (\"unexpected rows from count query: \" <> show countRows)" + , "" + , " literalRows <- sqlQueryTyped [typedSql| SELECT 1 |]" + , "" + , " when ((literalRows :: [Maybe Int]) /= [Just 1]) do" + , " error (\"unexpected rows from literal query: \" <> show literalRows)" + , "" + , " arithmeticRows <- sqlQueryTyped [typedSql|" + , " SELECT views + 1 FROM typed_sql_test_items" + , " ORDER BY name" + , " |]" + , "" + , " when ((arithmeticRows :: [Maybe Int]) /= [Just 6, Just 9]) do" + , " error (\"unexpected rows from arithmetic query: \" <> show arithmeticRows)" + , "" + , " caseRows <- sqlQueryTyped [typedSql|" + , " SELECT CASE WHEN views > 5 THEN name ELSE 'low' END" + , " FROM typed_sql_test_items" + , " ORDER BY name" + , " |]" + , "" + , " when ((caseRows :: [Maybe Text]) /= [Just \"low\", Just \"Second\"]) do" + , " error (\"unexpected rows from CASE query: \" <> show caseRows)" + , "" + , " existsRows <- sqlQueryTyped [typedSql| SELECT EXISTS(SELECT 1 FROM typed_sql_test_items WHERE views > 7) |]" + , "" + , " when ((existsRows :: [Maybe Bool]) /= [Just True]) do" + , " error (\"unexpected rows from EXISTS query: \" <> show existsRows)" + , "" + , " nullLiteralRows <- sqlQueryTyped [typedSql| SELECT NULL::text |]" + , "" + , " when ((nullLiteralRows :: [Maybe Text]) /= [Nothing]) do" + , " error (\"unexpected rows from NULL literal query: \" <> show nullLiteralRows)" + , "" + , " cteRows <- sqlQueryTyped [typedSql|" + , " WITH item_names AS (SELECT name FROM typed_sql_test_items WHERE views > 6)" + , " SELECT name FROM item_names ORDER BY name" + , " |]" + , "" + , " when ((cteRows :: [Text]) /= [\"Second\"]) do" + , " error (\"unexpected rows from CTE query: \" <> show cteRows)" + , "" + , " subqueryRows <- sqlQueryTyped [typedSql|" + , " SELECT name FROM (SELECT name FROM typed_sql_test_items WHERE views < 6) sub" + , " ORDER BY name" + , " |]" + , "" + , " when ((subqueryRows :: [Text]) /= [\"First\"]) do" + , " error (\"unexpected rows from subquery: \" <> show subqueryRows)" + , "" + , " unionRows <- sqlQueryTyped [typedSql|" + , " SELECT name FROM typed_sql_test_items WHERE views > 6" + , " UNION ALL" + , " SELECT name FROM typed_sql_test_items WHERE views < 6" + , " ORDER BY name" + , " |]" + , "" + , " when ((unionRows :: [Maybe Text]) /= [Just \"First\", Just \"Second\"]) do" + , " error (\"unexpected rows from UNION: \" <> show unionRows)" + , "" + , " windowRows <- sqlQueryTyped [typedSql|" + , " SELECT row_number() OVER (ORDER BY name)" + , " FROM typed_sql_test_items" + , " ORDER BY name" + , " |]" + , "" + , " when ((windowRows :: [Maybe Integer]) /= [Just 1, Just 2]) do" + , " error (\"unexpected rows from window function: \" <> show windowRows)" + , "" + , " groupedCountRows <- sqlQueryTyped [typedSql|" + , " SELECT name, COUNT(*)" + , " FROM typed_sql_test_items" + , " GROUP BY name" + , " ORDER BY name" + , " |]" + , "" + , " when ((groupedCountRows :: [(Text, Maybe Integer)]) /= [(\"First\", Just 1), (\"Second\", Just 1)]) do" + , " error (\"unexpected rows from grouped count: \" <> show groupedCountRows)" + , "" + , " arrayLiteralRows <- sqlQueryTyped [typedSql| SELECT ARRAY['x','y']::text[] |]" + , "" + , " when ((arrayLiteralRows :: [Maybe [Text]]) /= [Just [\"x\", \"y\"]]) do" + , " error (\"unexpected rows from array literal: \" <> show arrayLiteralRows)" + , "" + , " nullIfRows <- sqlQueryTyped [typedSql|" + , " SELECT NULLIF(name, 'First')" + , " FROM typed_sql_test_items" + , " ORDER BY name" + , " |]" + , "" + , " when ((nullIfRows :: [Maybe Text]) /= [Nothing, Just \"Second\"]) do" + , " error (\"unexpected rows from NULLIF: \" <> show nullIfRows)" + , "" + , " innerJoinRows <- sqlQueryTyped [typedSql|" + , " SELECT i.name, a.name" + , " FROM typed_sql_test_items i" + , " INNER JOIN typed_sql_test_authors a ON a.id = i.author_id" + , " ORDER BY i.name" + , " |]" + , "" + , " when ((innerJoinRows :: [(Text, Text)]) /= [(\"First\", \"Alice\"), (\"Second\", \"Alice\")]) do" + , " error (\"unexpected rows from inner join: \" <> show innerJoinRows)" + , "" + , " leftJoinRows <- sqlQueryTyped [typedSql|" + , " SELECT i.name, a.name" + , " FROM typed_sql_test_items i" + , " LEFT JOIN typed_sql_test_authors a ON a.id = i.author_id" + , " ORDER BY i.name" + , " |]" + , "" + , " when ((leftJoinRows :: [(Text, Text)]) /= [(\"First\", \"Alice\"), (\"Second\", \"Alice\")]) do" + , " error (\"unexpected rows from left join: \" <> show leftJoinRows)" + , "" + , " rightJoinRows <- sqlQueryTyped [typedSql|" + , " SELECT i.name, a.name" + , " FROM typed_sql_test_items i" + , " RIGHT JOIN typed_sql_test_authors a ON a.id = i.author_id" + , " WHERE i.id IS NOT NULL" + , " ORDER BY a.name, i.name" + , " |]" + , "" + , " when ((rightJoinRows :: [(Text, Text)]) /= [(\"First\", \"Alice\"), (\"Second\", \"Alice\")]) do" + , " error (\"unexpected rows from right join: \" <> show rightJoinRows)" + , "" + , " rightJoinCoalescedRows <- sqlQueryTyped [typedSql|" + , " SELECT COALESCE(i.name, '(no-item)'), a.name" + , " FROM typed_sql_test_items i" + , " RIGHT JOIN typed_sql_test_authors a ON a.id = i.author_id" + , " ORDER BY a.name, i.name NULLS LAST" + , " |]" + , "" + , " when ((rightJoinCoalescedRows :: [(Maybe Text, Text)]) /= [(Just \"First\", \"Alice\"), (Just \"Second\", \"Alice\"), (Just \"(no-item)\", \"Bob\")]) do" + , " error (\"unexpected rows from right join with COALESCE: \" <> show rightJoinCoalescedRows)" + , "" + , " putStrLn \"RUNTIME_OK\"" + ] diff --git a/ihp/ihp.cabal b/ihp/ihp.cabal index 4ad09e23a..bf6c495d1 100644 --- a/ihp/ihp.cabal +++ b/ihp/ihp.cabal @@ -54,6 +54,7 @@ common shared-properties , text , postgresql-simple , ihp-pglistener + , postgresql-libpq , wai-app-static , wai-util , bytestring @@ -214,6 +215,7 @@ library , IHP.QueryBuilder.Union , IHP.QueryBuilder.HasqlCompiler , IHP.QueryBuilder.HasqlHelpers + , IHP.TypedSql , IHP.Fetch , IHP.RouterPrelude , IHP.Server @@ -259,7 +261,16 @@ library , IHP.PGVersion reexported-modules: IHP.PGListener autogen-modules: Paths_ihp - other-modules: Paths_ihp + other-modules: + Paths_ihp + IHP.TypedSql.Types + IHP.TypedSql.Quoter + IHP.TypedSql.Placeholders + IHP.TypedSql.ParamHints + IHP.TypedSql.Metadata + IHP.TypedSql.Bootstrap + IHP.TypedSql.TypeMapping + IHP.TypedSql.Decoders source-repository head type: git @@ -289,3 +300,4 @@ test-suite tests Test.Controller.CookieSpec Test.PGListenerSpec Test.MockingSpec + Test.TypedSqlSpec