diff --git a/Sources/StructuredQueriesCore/PrimaryKeyed.swift b/Sources/StructuredQueriesCore/PrimaryKeyed.swift index 4e3f5b65..766025a2 100644 --- a/Sources/StructuredQueriesCore/PrimaryKeyed.swift +++ b/Sources/StructuredQueriesCore/PrimaryKeyed.swift @@ -88,7 +88,17 @@ extension PrimaryKeyedTable { public static func find( _ primaryKey: some QueryExpression ) -> Where { - Self.where { $0.primaryKey.eq(primaryKey) } + Self.find([primaryKey]) + } + + /// A where clause filtered by primary keys. + /// + /// - Parameter primaryKey: Primary keys identifying table rows. + /// - Returns: A `WHERE` clause. + public static func find( + _ primaryKeys: some Sequence> + ) -> Where { + Self.where { $0.primaryKey.in(primaryKeys) } } public var primaryKey: PrimaryKey.QueryOutput { @@ -104,9 +114,17 @@ extension TableDraft { public static func find( _ primaryKey: some QueryExpression ) -> Where { - Self.where { _ in - PrimaryTable.columns.primaryKey.eq(primaryKey) - } + Self.find([primaryKey]) + } + + /// A where clause filtered by primary keys. + /// + /// - Parameter primaryKeys: Primary keys identifying table rows. + /// - Returns: A `WHERE` clause. + public static func find( + _ primaryKeys: some Sequence> + ) -> Where { + Self.where { $0.primaryKey.in(primaryKeys) } } } @@ -116,7 +134,17 @@ extension Where where From: PrimaryKeyedTable { /// - Parameter primaryKey: A primary key. /// - Returns: A where clause with the added primary key. public func find(_ primaryKey: some QueryExpression) -> Self { - self.where { $0.primaryKey.eq(primaryKey) } + self.find([primaryKey]) + } + + /// Adds a primary key condition to a where clause. + /// + /// - Parameter primaryKeys: A sequence of primary keys. + /// - Returns: A where clause with the added primary keys condition. + public func find( + _ primaryKeys: some Sequence> + ) -> Self { + Self.where { $0.primaryKey.in(primaryKeys) } } } @@ -125,12 +153,18 @@ extension Where where From: TableDraft { /// /// - Parameter primaryKey: A primary key. /// - Returns: A where clause with the added primary key. - public func find(_ primaryKey: From.PrimaryTable.TableColumns.PrimaryKey.QueryOutput) -> Self { - self.where { _ in - From.PrimaryTable.columns.primaryKey.eq( - From.PrimaryTable.TableColumns.PrimaryKey(queryOutput: primaryKey) - ) - } + public func find(_ primaryKey: some QueryExpression) -> Self { + self.find([primaryKey]) + } + + /// Adds a primary key condition to a where clause. + /// + /// - Parameter primaryKeys: A sequence of primary keys. + /// - Returns: A where clause with the added primary keys condition. + public func find( + _ primaryKeys: some Sequence> + ) -> Self { + Self.where { $0.primaryKey.in(primaryKeys) } } } @@ -142,6 +176,16 @@ extension Select where From: PrimaryKeyedTable { public func find(_ primaryKey: some QueryExpression) -> Self { self.and(From.find(primaryKey)) } + + /// A select statement filtered by a sequence of primary keys. + /// + /// - Parameter primaryKeys: A sequence of primary keys. + /// - Returns: A select statement filtered by the given keys. + public func find( + _ primaryKeys: some Sequence> + ) -> Self { + self.and(From.find(primaryKeys)) + } } extension Select where From: TableDraft { @@ -154,6 +198,16 @@ extension Select where From: TableDraft { ) -> Self { self.and(From.find(primaryKey)) } + + /// A select statement filtered by a sequence of primary keys. + /// + /// - Parameter primaryKeys: A sequence of primary keys. + /// - Returns: A select statement filtered by the given keys. + public func find( + _ primaryKeys: some Sequence> + ) -> Self { + self.and(From.find(primaryKeys)) + } } extension Update where From: PrimaryKeyedTable { @@ -162,7 +216,17 @@ extension Update where From: PrimaryKeyedTable { /// - Parameter primaryKey: A primary key identifying a table row. /// - Returns: An update statement filtered by the given key. public func find(_ primaryKey: some QueryExpression) -> Self { - self.where { $0.primaryKey.eq(primaryKey) } + self.find([primaryKey]) + } + + /// An update statement filtered by a sequence of primary keys. + /// + /// - Parameter primaryKeys: A sequence of primary keys. + /// - Returns: An update statement filtered by the given keys. + public func find( + _ primaryKeys: some Sequence> + ) -> Self { + self.where { $0.primaryKey.in(primaryKeys) } } } @@ -171,12 +235,18 @@ extension Update where From: TableDraft { /// /// - Parameter primaryKey: A primary key identifying a table row. /// - Returns: An update statement filtered by the given key. - public func find(_ primaryKey: From.PrimaryTable.TableColumns.PrimaryKey.QueryOutput) -> Self { - self.where { _ in - From.PrimaryTable.columns.primaryKey.eq( - From.PrimaryTable.TableColumns.PrimaryKey(queryOutput: primaryKey) - ) - } + public func find(_ primaryKey: some QueryExpression) -> Self { + self.find([primaryKey]) + } + + /// An update statement filtered by a sequence of primary keys. + /// + /// - Parameter primaryKeys: A sequence of primary keys. + /// - Returns: An update statement filtered by the given keys. + public func find( + _ primaryKeys: some Sequence> + ) -> Self { + self.where { $0.primaryKey.in(primaryKeys) } } } @@ -186,7 +256,17 @@ extension Delete where From: PrimaryKeyedTable { /// - Parameter primaryKey: A primary key identifying a table row. /// - Returns: A delete statement filtered by the given key. public func find(_ primaryKey: some QueryExpression) -> Self { - self.where { $0.primaryKey.eq(primaryKey) } + self.find([primaryKey]) + } + + /// A delete statement filtered by a sequence of primary keys. + /// + /// - Parameter primaryKeys: A sequence of primary keys. + /// - Returns: A delete statement filtered by the given keys. + public func find( + _ primaryKeys: some Sequence> + ) -> Self { + self.where { $0.primaryKey.in(primaryKeys) } } } @@ -195,11 +275,17 @@ extension Delete where From: TableDraft { /// /// - Parameter primaryKey: A primary key identifying a table row. /// - Returns: A delete statement filtered by the given key. - public func find(_ primaryKey: From.PrimaryTable.TableColumns.PrimaryKey.QueryOutput) -> Self { - self.where { _ in - From.PrimaryTable.columns.primaryKey.eq( - From.PrimaryTable.TableColumns.PrimaryKey(queryOutput: primaryKey) - ) - } + public func find(_ primaryKey: some QueryExpression) -> Self { + self.find([primaryKey]) + } + + /// A delete statement filtered by a sequence of primary keys. + /// + /// - Parameter primaryKeys: A sequence of primary keys. + /// - Returns: A delete statement filtered by the given keys. + public func find( + _ primaryKeys: some Sequence> + ) -> Self { + self.where { $0.primaryKey.in(primaryKeys) } } } diff --git a/Tests/StructuredQueriesTests/PrimaryKeyedTableTests.swift b/Tests/StructuredQueriesTests/PrimaryKeyedTableTests.swift index e5615c38..5c7a9fd6 100644 --- a/Tests/StructuredQueriesTests/PrimaryKeyedTableTests.swift +++ b/Tests/StructuredQueriesTests/PrimaryKeyedTableTests.swift @@ -32,7 +32,7 @@ extension SnapshotTests { """ UPDATE "reminders" SET "title" = ("reminders"."title" || '!!!') - WHERE ("reminders"."id" = 1) + WHERE ("reminders"."id" IN (1)) RETURNING "title" """ } results: { @@ -50,7 +50,7 @@ extension SnapshotTests { """ UPDATE "reminders" SET "title" = ("reminders"."title" || '???') - WHERE ("reminders"."id" = 1) + WHERE ("reminders"."id" IN (1)) RETURNING "title" """ } results: { @@ -69,7 +69,7 @@ extension SnapshotTests { ) { """ DELETE FROM "reminders" - WHERE ("reminders"."id" = 1) + WHERE ("reminders"."id" IN (1)) RETURNING "reminders"."id" """ } results: { @@ -86,7 +86,7 @@ extension SnapshotTests { ) { """ DELETE FROM "reminders" - WHERE ("reminders"."id" = 2) + WHERE ("reminders"."id" IN (2)) RETURNING "reminders"."id" """ } results: { @@ -105,7 +105,7 @@ extension SnapshotTests { """ SELECT "reminders"."id", "reminders"."title" FROM "reminders" - WHERE ("reminders"."id" = 1) + WHERE ("reminders"."id" IN (1)) """ } results: { """ @@ -121,7 +121,7 @@ extension SnapshotTests { """ SELECT "reminders"."id", "reminders"."title" FROM "reminders" - WHERE ("reminders"."id" = 1) + WHERE ("reminders"."id" IN (1)) """ } results: { """ @@ -137,7 +137,7 @@ extension SnapshotTests { """ SELECT "reminders"."id", "reminders"."title" FROM "reminders" - WHERE ("reminders"."id" = 2) + WHERE ("reminders"."id" IN (2)) """ } results: { """ @@ -147,13 +147,59 @@ extension SnapshotTests { """ } + assertQuery( + Reminder.select { ($0.id, $0.title) }.find([2, 4, 6]) + ) { + """ + SELECT "reminders"."id", "reminders"."title" + FROM "reminders" + WHERE ("reminders"."id" IN (2, 4, 6)) + """ + } results: { + """ + ┌───┬────────────────────────────┐ + │ 2 │ "Haircut" │ + │ 4 │ "Take a walk" │ + │ 6 │ "Pick up kids from school" │ + └───┴────────────────────────────┘ + """ + } + + assertQuery( + Reminder.select { ($0.id, $0.title) }.find(Reminder.select(\.id)) + ) { + """ + SELECT "reminders"."id", "reminders"."title" + FROM "reminders" + WHERE ("reminders"."id" IN (( + SELECT "reminders"."id" + FROM "reminders" + ))) + """ + } results: { + """ + ┌────┬────────────────────────────┐ + │ 1 │ "Groceries" │ + │ 2 │ "Haircut" │ + │ 3 │ "Doctor appointment" │ + │ 4 │ "Take a walk" │ + │ 5 │ "Buy concert tickets" │ + │ 6 │ "Pick up kids from school" │ + │ 7 │ "Get laundry" │ + │ 8 │ "Take out trash" │ + │ 9 │ "Call accountant" │ + │ 10 │ "Send weekly emails" │ + └────┴────────────────────────────┘ + """ + } + assertQuery( Reminder.Draft.select { ($0.id, $0.title) }.find(2) ) { """ SELECT "reminders"."id", "reminders"."title" FROM "reminders" - WHERE ("reminders"."id" = 2) + WHERE ("reminders"."id" IN (2)) """ } results: { """ @@ -175,7 +221,7 @@ extension SnapshotTests { SELECT "reminders"."title", "remindersLists"."title" FROM "reminders" JOIN "remindersLists" ON ("reminders"."remindersListID" = "remindersLists"."id") - WHERE ("reminders"."id" = 2) + WHERE ("reminders"."id" IN (2)) """ } results: { """ @@ -214,7 +260,7 @@ extension SnapshotTests { """ SELECT "rows"."id", "rows"."isDeleted", "rows"."isNotDeleted" FROM "rows" - WHERE ("rows"."id" = '00000000-0000-0000-0000-000000000001') + WHERE ("rows"."id" IN ('00000000-0000-0000-0000-000000000001')) """ } results: { """ diff --git a/Tests/StructuredQueriesTests/UpdateTests.swift b/Tests/StructuredQueriesTests/UpdateTests.swift index cd943a9d..061c6cdc 100644 --- a/Tests/StructuredQueriesTests/UpdateTests.swift +++ b/Tests/StructuredQueriesTests/UpdateTests.swift @@ -365,7 +365,7 @@ extension SnapshotTests { """ UPDATE "reminders" SET "dueDate" = CASE WHEN ("reminders"."dueDate" IS NULL) THEN '2018-01-29 00:08:00.000' END - WHERE ("reminders"."id" = 1) + WHERE ("reminders"."id" IN (1)) RETURNING "dueDate" """ } results: { @@ -383,7 +383,7 @@ extension SnapshotTests { """ UPDATE "reminders" SET "dueDate" = CASE WHEN ("reminders"."dueDate" IS NULL) THEN '2018-01-29 00:08:00.000' END - WHERE ("reminders"."id" = 1) + WHERE ("reminders"."id" IN (1)) RETURNING "dueDate" """ } results: { diff --git a/Tests/StructuredQueriesTests/WhereTests.swift b/Tests/StructuredQueriesTests/WhereTests.swift index 9425d6bf..74051cd9 100644 --- a/Tests/StructuredQueriesTests/WhereTests.swift +++ b/Tests/StructuredQueriesTests/WhereTests.swift @@ -225,7 +225,7 @@ extension SnapshotTests { SELECT "remindersLists"."id", "remindersLists"."color", "remindersLists"."title", "remindersLists"."position", "reminders"."id", "reminders"."assignedUserID", "reminders"."dueDate", "reminders"."isCompleted", "reminders"."isFlagged", "reminders"."notes", "reminders"."priority", "reminders"."remindersListID", "reminders"."title", "reminders"."updatedAt" FROM "remindersLists" LEFT JOIN "reminders" ON ("remindersLists"."id" = "reminders"."remindersListID") - WHERE ("remindersLists"."id" = 4) AND "reminders"."isCompleted" + WHERE ("remindersLists"."id" IN (4)) AND "reminders"."isCompleted" """ } results: { """