Skip to content

Commit 21057aa

Browse files
committed
Query by custom type
However, this is not (yet) type-safe.
1 parent 168813c commit 21057aa

File tree

1 file changed

+60
-15
lines changed

1 file changed

+60
-15
lines changed

squeal-postgresql/exe/Example.hs

Lines changed: 60 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,15 @@
99
, TypeOperators
1010
#-}
1111

12-
{-# LANGUAGE ScopedTypeVariables #-}
1312
{-# LANGUAGE AllowAmbiguousTypes #-}
13+
{-# LANGUAGE FlexibleInstances #-}
14+
{-# LANGUAGE MultiParamTypeClasses #-}
15+
{-# LANGUAGE ScopedTypeVariables #-}
16+
{-# LANGUAGE TypeFamilies #-}
1417

1518
module Main (main, main2, upsertUser) where
1619

20+
import Control.Monad.Except (MonadError (throwError))
1721
import Control.Monad.IO.Class (MonadIO (..))
1822
import Data.Int (Int16, Int32)
1923
import Data.Text (Text)
@@ -49,6 +53,7 @@ type OrgSchema =
4953
'[ "pk_organizations" ::: 'PrimaryKey '["id"] ] :=>
5054
'[ "id" ::: 'Def :=> 'NotNull 'PGint4
5155
, "name" ::: 'NoDef :=> 'NotNull 'PGtext
56+
, "type" ::: 'NoDef :=> 'NotNull 'PGtext
5257
])
5358
, "members" ::: 'Table (
5459
'[ "fk_member" ::: 'ForeignKey '["member"] "user" "users" '["id"]
@@ -86,7 +91,8 @@ setup =
8691
>>>
8792
createTable (#org ! #organizations)
8893
( serial `as` #id :*
89-
(text & notNullable) `as` #name )
94+
(text & notNullable) `as` #name :*
95+
(text & notNullable) `as` #type )
9096
( primaryKey #id `as` #pk_organizations )
9197
>>>
9298
createTable (#org ! #members)
@@ -109,9 +115,9 @@ insertEmail :: Manipulation_ Schemas (Int32, Maybe Text) ()
109115
insertEmail = insertInto_ (#user ! #emails)
110116
(Values_ (Default `as` #id :* Set (param @1) `as` #user_id :* Set (param @2) `as` #email))
111117

112-
insertOrganization :: Manipulation_ Schemas (Only Text) (Only Int32)
118+
insertOrganization :: Manipulation_ Schemas (Text, OrganizationType) (Only Int32)
113119
insertOrganization = insertInto (#org ! #organizations)
114-
(Values_ (Default `as` #id :* Set (param @1) `as` #name))
120+
(Values_ (Default `as` #id :* Set (param @1) `as` #name :* Set (param @2) `as` #type))
115121
(OnConflict (OnConstraint #pk_organizations) DoNothing) (Returning_ (#id `as` #fromOnly))
116122

117123
getUsers :: Query_ Schemas () User
@@ -123,7 +129,10 @@ getUsers = select_
123129

124130
getOrganizations :: Query_ Schemas () Organization
125131
getOrganizations = select_
126-
(#o ! #id `as` #orgId :* #o ! #name `as` #orgName)
132+
( #o ! #id `as` #orgId :*
133+
#o ! #name `as` #orgName :*
134+
#o ! #type `as` #orgType
135+
)
127136
(from (table (#org ! #organizations `as` #o)))
128137

129138
getOrganizationsBy ::
@@ -135,11 +144,14 @@ getOrganizationsBy ::
135144
'[]
136145
Schemas
137146
'[NullPG hsty]
138-
'["o" ::: ["id" ::: NotNull PGint4, "name" ::: NotNull PGtext]] ->
147+
'["o" ::: ["id" ::: NotNull PGint4, "name" ::: NotNull PGtext, "type" ::: NotNull PGtext]] ->
139148
Query_ Schemas (Only hsty) Organization
140149
getOrganizationsBy condition =
141150
select_
142-
(#o ! #id `as` #orgId :* #o ! #name `as` #orgName)
151+
( #o ! #id `as` #orgId :*
152+
#o ! #name `as` #orgName :*
153+
#o ! #type `as` #orgType
154+
)
143155
(
144156
from (table (#org ! #organizations `as` #o))
145157
& where_ condition
@@ -173,14 +185,39 @@ data Organization
173185
= Organization
174186
{ orgId :: Int32
175187
, orgName :: Text
188+
, orgType :: OrganizationType
176189
} deriving (Show, GHC.Generic)
177190
instance SOP.Generic Organization
178191
instance SOP.HasDatatypeInfo Organization
179192

193+
data OrganizationType
194+
= ForProfit
195+
| NonProfit
196+
deriving (Show, GHC.Generic)
197+
instance SOP.Generic OrganizationType
198+
instance SOP.HasDatatypeInfo OrganizationType
199+
200+
instance IsPG OrganizationType where
201+
type PG OrganizationType = 'PGtext
202+
instance ToPG db OrganizationType where
203+
toPG = toPG . toText
204+
where
205+
toText ForProfit = "for-profit" :: Text
206+
toText NonProfit = "non-profit" :: Text
207+
208+
instance FromPG OrganizationType where
209+
fromPG = do
210+
value <- fromPG @Text
211+
fromText value
212+
where
213+
fromText "for-profit" = pure ForProfit
214+
fromText "non-profit" = pure NonProfit
215+
fromText value = throwError $ "Invalid organization type: \"" <> value <> "\""
216+
180217
organizations :: [Organization]
181218
organizations =
182-
[ Organization { orgId = 1, orgName = "ACME" }
183-
, Organization { orgId = 2, orgName = "Haskell Foundation" }
219+
[ Organization { orgId = 1, orgName = "ACME", orgType = ForProfit }
220+
, Organization { orgId = 2, orgName = "Haskell Foundation", orgType = NonProfit }
184221
]
185222

186223
session :: (MonadIO pq, MonadPQ Schemas pq) => pq ()
@@ -192,29 +229,37 @@ session = do
192229

193230
orgIdResults <- traversePrepared
194231
insertOrganization
195-
[Only (orgName organization) | organization <- organizations]
232+
[(orgName organization, orgType organization) | organization <- organizations]
196233
_ <- traverse (fmap fromOnly . getRow 0) (orgIdResults :: [Result (Only Int32)])
197234

198235
liftIO $ Char8.putStrLn "===> querying: users"
199236
usersResult <- runQuery getUsers
200237
usersRows <- getRows usersResult
201238
liftIO $ print (usersRows :: [User])
202239

203-
liftIO $ Char8.putStrLn "===> querying: organizations"
204-
organizationsResult <- runQuery getOrganizations
205-
organizationRows <- getRows organizationsResult
206-
liftIO $ print (organizationRows :: [Organization])
240+
liftIO $ Char8.putStrLn "===> querying: organizations: all"
241+
organizationsResult1 <- runQuery getOrganizations
242+
organizationRows1 <- getRows organizationsResult1
243+
liftIO $ print (organizationRows1 :: [Organization])
207244

245+
liftIO $ Char8.putStrLn "===> querying: organizations: by ID (2)"
208246
organizationsResult2 <- runQueryParams
209-
(getOrganizationsBy @Int32 ((#o ! #id) .== param @1)) (Only (1 :: Int32))
247+
(getOrganizationsBy @Int32 ((#o ! #id) .== param @1)) (Only (2 :: Int32))
210248
organizationRows2 <- getRows organizationsResult2
211249
liftIO $ print (organizationRows2 :: [Organization])
212250

251+
liftIO $ Char8.putStrLn "===> querying: organizations: by name (ACME)"
213252
organizationsResult3 <- runQueryParams
214253
(getOrganizationsBy @Text ((#o ! #name) .== param @1)) (Only ("ACME" :: Text))
215254
organizationRows3 <- getRows organizationsResult3
216255
liftIO $ print (organizationRows3 :: [Organization])
217256

257+
liftIO $ Char8.putStrLn "===> querying: organizations: by type (non-profit)"
258+
organizationsResult4 <- runQueryParams
259+
(getOrganizationsBy @Text ((#o ! #type) .== param @1)) (Only NonProfit)
260+
organizationRows4 <- getRows organizationsResult4
261+
liftIO $ print (organizationRows4 :: [Organization])
262+
218263
main :: IO ()
219264
main = do
220265
Char8.putStrLn "===> squeal"

0 commit comments

Comments
 (0)