diff --git a/Sources/StructuredQueries/Macros.swift b/Sources/StructuredQueries/Macros.swift index 80eac032..e905ff93 100644 --- a/Sources/StructuredQueries/Macros.swift +++ b/Sources/StructuredQueries/Macros.swift @@ -91,7 +91,7 @@ public macro Ephemeral() = /// or common table expression. @attached( extension, - conformances: QueryRepresentable, + conformances: _Selection, names: named(Columns), named(init(decoder:)) ) diff --git a/Sources/StructuredQueriesCore/Internal/Deprecations.swift b/Sources/StructuredQueriesCore/Internal/Deprecations.swift index b9920caf..334e7ff7 100644 --- a/Sources/StructuredQueriesCore/Internal/Deprecations.swift +++ b/Sources/StructuredQueriesCore/Internal/Deprecations.swift @@ -65,7 +65,7 @@ extension Table { public static func insert( or conflictResolution: ConflictResolution? = nil, _ columns: (TableColumns) -> TableColumns = { $0 }, - @InsertValuesBuilder values: () -> [Self], + @InsertValuesBuilder values: () -> [[QueryFragment]], onConflict updates: ((inout Updates) -> Void)? ) -> InsertOf { insert(or: conflictResolution, columns, values: values, onConflictDoUpdate: updates) @@ -75,8 +75,8 @@ extension Table { public static func insert( or conflictResolution: ConflictResolution? = nil, _ columns: (TableColumns) -> (TableColumn, repeat TableColumn), - @InsertValuesBuilder<(V1.QueryOutput, repeat (each V2).QueryOutput)> - values: () -> [(V1.QueryOutput, repeat (each V2).QueryOutput)], + @InsertValuesBuilder<(V1, repeat each V2)> + values: () -> [[QueryFragment]], onConflict updates: ((inout Updates) -> Void)? ) -> InsertOf { insert(or: conflictResolution, columns, values: values, onConflictDoUpdate: updates) diff --git a/Sources/StructuredQueriesCore/Seeds.swift b/Sources/StructuredQueriesCore/Seeds.swift index e5929bb7..131a79b7 100644 --- a/Sources/StructuredQueriesCore/Seeds.swift +++ b/Sources/StructuredQueriesCore/Seeds.swift @@ -67,7 +67,7 @@ public struct Seeds: Sequence { /// [SharingGRDB]: https://github.com/pointfreeco/sharing-grdb /// /// - Parameter build: A result builder closure that prepares statements to insert every built row. - public init(@InsertValuesBuilder _ build: () -> [any Table]) { + public init(@SeedsBuilder _ build: () -> [any Table]) { self.seeds = build() } @@ -103,3 +103,46 @@ public struct Seeds: Sequence { } } } + +@resultBuilder +public enum SeedsBuilder { + public static func buildArray(_ components: [[any Table]]) -> [any Table] { + components.flatMap(\.self) + } + + public static func buildBlock(_ components: [any Table]) -> [any Table] { + components + } + + public static func buildEither(first component: [any Table]) -> [any Table] { + component + } + + public static func buildEither(second component: [any Table]) -> [any Table] { + component + } + + public static func buildExpression(_ expression: some Table) -> [any Table] { + [expression] + } + + public static func buildExpression(_ expression: [any Table]) -> [any Table] { + expression + } + + public static func buildLimitedAvailability(_ component: [any Table]) -> [any Table] { + component + } + + public static func buildOptional(_ component: [any Table]?) -> [any Table] { + component ?? [] + } + + public static func buildPartialBlock(first: [any Table]) -> [any Table] { + first + } + + public static func buildPartialBlock(accumulated: [any Table], next: [any Table]) -> [any Table] { + accumulated + next + } +} diff --git a/Sources/StructuredQueriesCore/Statements/Insert.swift b/Sources/StructuredQueriesCore/Statements/Insert.swift index bf77265a..7377912b 100644 --- a/Sources/StructuredQueriesCore/Statements/Insert.swift +++ b/Sources/StructuredQueriesCore/Statements/Insert.swift @@ -57,14 +57,15 @@ extension Table { public static func insert( or conflictResolution: ConflictResolution? = nil, _ columns: (TableColumns) -> TableColumns = { $0 }, - @InsertValuesBuilder values: () -> [Self], + @InsertValuesBuilder values: () -> [[QueryFragment]], onConflictDoUpdate updates: ((inout Updates, Excluded) -> Void)? = nil, @QueryFragmentBuilder where updateFilter: (TableColumns) -> [QueryFragment] = { _ in [] } ) -> InsertOf { _insert( or: conflictResolution, - values: values, + columnNames: TableColumns.writableColumns.map(\.name), + values: .values(values()), onConflict: { _ -> ()? in nil }, where: { _ in return [] }, doUpdate: updates, @@ -124,7 +125,7 @@ extension Table { public static func insert( or conflictResolution: ConflictResolution? = nil, _ columns: (TableColumns) -> TableColumns = { $0 }, - @InsertValuesBuilder values: () -> [Self], + @InsertValuesBuilder values: () -> [[QueryFragment]], onConflictDoUpdate updates: ((inout Updates) -> Void)?, @QueryFragmentBuilder where updateFilter: (TableColumns) -> [QueryFragment] = { _ in [] } @@ -153,7 +154,7 @@ extension Table { public static func insert( or conflictResolution: ConflictResolution? = nil, _ columns: (TableColumns) -> TableColumns = { $0 }, - @InsertValuesBuilder values: () -> [Self], + @InsertValuesBuilder values: () -> [[QueryFragment]], onConflict conflictTargets: (TableColumns) -> ( TableColumn, repeat TableColumn ), @@ -166,7 +167,8 @@ extension Table { withoutActuallyEscaping(updates) { updates in _insert( or: conflictResolution, - values: values, + columnNames: TableColumns.writableColumns.map(\.name), + values: .values(values()), onConflict: conflictTargets, where: targetFilter, doUpdate: updates, @@ -190,7 +192,7 @@ extension Table { public static func insert( or conflictResolution: ConflictResolution? = nil, _ columns: (TableColumns) -> TableColumns = { $0 }, - @InsertValuesBuilder values: () -> [Self], + @InsertValuesBuilder values: () -> [[QueryFragment]], onConflict conflictTargets: (TableColumns) -> ( TableColumn, repeat TableColumn ), @@ -211,40 +213,6 @@ extension Table { ) } - private static func _insert( - or conflictResolution: ConflictResolution?, - @InsertValuesBuilder values: () -> [Self], - onConflict conflictTargets: (TableColumns) -> (repeat TableColumn)?, - @QueryFragmentBuilder - where targetFilter: (TableColumns) -> [QueryFragment] = { _ in [] }, - doUpdate updates: ((inout Updates, Excluded) -> Void)?, - @QueryFragmentBuilder - where updateFilter: (TableColumns) -> [QueryFragment] = { _ in [] } - ) -> InsertOf { - var valueFragments: [[QueryFragment]] = [] - for value in values() { - var valueFragment: [QueryFragment] = [] - for column in TableColumns.writableColumns { - func open( - _ column: some WritableTableColumnExpression - ) -> QueryFragment { - Value(queryOutput: (value as! Root)[keyPath: column.keyPath]).queryFragment - } - valueFragment.append(open(column)) - } - valueFragments.append(valueFragment) - } - return _insert( - or: conflictResolution, - columnNames: TableColumns.writableColumns.map(\.name), - values: .values(valueFragments), - onConflict: conflictTargets, - where: targetFilter, - doUpdate: updates, - where: updateFilter - ) - } - /// An insert statement for one or more table rows. /// /// This function can be used to create an insert statement for a specified set of columns. @@ -303,8 +271,8 @@ extension Table { public static func insert( or conflictResolution: ConflictResolution? = nil, _ columns: (TableColumns) -> (TableColumn, repeat TableColumn), - @InsertValuesBuilder<(V1.QueryOutput, repeat (each V2).QueryOutput)> - values: () -> [(V1.QueryOutput, repeat (each V2).QueryOutput)], + @InsertValuesBuilder<(V1, repeat each V2)> + values: () -> [[QueryFragment]], onConflictDoUpdate updates: ((inout Updates, Excluded) -> Void)? = nil, @QueryFragmentBuilder where updateFilter: (TableColumns) -> [QueryFragment] = { _ in [] } @@ -333,8 +301,8 @@ extension Table { public static func insert( or conflictResolution: ConflictResolution? = nil, _ columns: (TableColumns) -> (TableColumn, repeat TableColumn), - @InsertValuesBuilder<(V1.QueryOutput, repeat (each V2).QueryOutput)> - values: () -> [(V1.QueryOutput, repeat (each V2).QueryOutput)], + @InsertValuesBuilder<(V1, repeat each V2)> + values: () -> [[QueryFragment]], onConflictDoUpdate updates: ((inout Updates) -> Void)?, @QueryFragmentBuilder where updateFilter: (TableColumns) -> [QueryFragment] = { _ in [] } @@ -363,8 +331,8 @@ extension Table { public static func insert( or conflictResolution: ConflictResolution? = nil, _ columns: (TableColumns) -> (TableColumn, repeat TableColumn), - @InsertValuesBuilder<(V1.QueryOutput, repeat (each V2).QueryOutput)> - values: () -> [(V1.QueryOutput, repeat (each V2).QueryOutput)], + @InsertValuesBuilder<(V1, repeat each V2)> + values: () -> [[QueryFragment]], onConflict conflictTargets: (TableColumns) -> ( TableColumn, repeat TableColumn ), @@ -402,8 +370,8 @@ extension Table { public static func insert( or conflictResolution: ConflictResolution? = nil, _ columns: (TableColumns) -> (TableColumn, repeat TableColumn), - @InsertValuesBuilder<(V1.QueryOutput, repeat (each V2).QueryOutput)> - values: () -> [(V1.QueryOutput, repeat (each V2).QueryOutput)], + @InsertValuesBuilder<(V1, repeat each V2)> + values: () -> [[QueryFragment]], onConflict conflictTargets: (TableColumns) -> ( TableColumn, repeat TableColumn ), @@ -427,8 +395,8 @@ extension Table { private static func _insert( or conflictResolution: ConflictResolution?, _ columns: (TableColumns) -> (repeat TableColumn), - @InsertValuesBuilder<(repeat (each Value).QueryOutput)> - values: () -> [(repeat (each Value).QueryOutput)], + @InsertValuesBuilder<(repeat each Value)> + values: () -> [[QueryFragment]], onConflict conflictTargets: (TableColumns) -> (repeat TableColumn)?, @QueryFragmentBuilder where targetFilter: (TableColumns) -> [QueryFragment] = { _ in [] }, @@ -440,18 +408,10 @@ extension Table { for column in repeat each columns(Self.columns) { columnNames.append(column.name) } - var valueFragments: [[QueryFragment]] = [] - for value in values() { - var valueFragment: [QueryFragment] = [] - for (columnType, column) in repeat ((each Value).self, each value) { - valueFragment.append("\(columnType.init(queryOutput: column).queryFragment)") - } - valueFragments.append(valueFragment) - } return _insert( or: conflictResolution, columnNames: columnNames, - values: .values(valueFragments), + values: .values(values()), onConflict: conflictTargets, where: targetFilter, doUpdate: updates, @@ -702,139 +662,6 @@ extension Table { } extension PrimaryKeyedTable { - /// An insert statement for one or more table rows. - /// - /// This function can be used to create an insert statement from a ``Draft`` value. - /// - /// - Parameters: - /// - conflictResolution: A conflict resolution algorithm. - /// - columns: Columns to insert. - /// - values: A builder of row values for the given columns. - /// - updates: Updates to perform in an upsert clause should the insert conflict with an - /// existing row. - /// - updateFilter: A filter to apply to the update clause. - /// - Returns: An insert statement. - public static func insert( - or conflictResolution: ConflictResolution? = nil, - _ columns: (Draft.TableColumns) -> Draft.TableColumns = { $0 }, - @InsertValuesBuilder values: () -> [Draft], - onConflictDoUpdate updates: ((inout Updates, Excluded) -> Void)? = nil, - @QueryFragmentBuilder - where updateFilter: (TableColumns) -> [QueryFragment] = { _ in [] } - ) -> InsertOf { - _insert( - or: conflictResolution, - values: values, - onConflict: { _ -> ()? in nil }, - where: { _ in return [] }, - doUpdate: updates, - where: updateFilter - ) - } - - /// An insert statement for one or more table rows. - /// - /// This function can be used to create an insert statement from a ``Draft`` value. - /// - /// - Parameters: - /// - conflictResolution: A conflict resolution algorithm. - /// - columns: Columns to insert. - /// - values: A builder of row values for the given columns. - /// - updates: Updates to perform in an upsert clause should the insert conflict with an - /// existing row. - /// - updateFilter: A filter to apply to the update clause. - /// - Returns: An insert statement. - public static func insert( - or conflictResolution: ConflictResolution? = nil, - _ columns: (Draft.TableColumns) -> Draft.TableColumns = { $0 }, - @InsertValuesBuilder values: () -> [Draft], - onConflictDoUpdate updates: ((inout Updates) -> Void)?, - @QueryFragmentBuilder - where updateFilter: (TableColumns) -> [QueryFragment] = { _ in [] } - ) -> InsertOf { - insert( - or: conflictResolution, - columns, - values: values, - onConflictDoUpdate: updates.map { updates in { row, _ in updates(&row) } }, - where: updateFilter - ) - } - - /// An insert statement for one or more table rows. - /// - /// This function can be used to create an insert statement from a ``Draft`` value. - /// - /// - Parameters: - /// - conflictResolution: A conflict resolution algorithm. - /// - columns: Columns to insert. - /// - values: A builder of row values for the given columns. - /// - conflictTargets: Indexed columns to target for conflict resolution. - /// - targetFilter: A filter to apply to conflict target columns. - /// - updates: Updates to perform in an upsert clause should the insert conflict with an - /// existing row. - /// - updateFilter: A filter to apply to the update clause. - public static func insert( - or conflictResolution: ConflictResolution? = nil, - _ columns: (Draft.TableColumns) -> Draft.TableColumns = { $0 }, - @InsertValuesBuilder values: () -> [Draft], - onConflict conflictTargets: (TableColumns) -> ( - TableColumn, repeat TableColumn - ), - @QueryFragmentBuilder - where targetFilter: (TableColumns) -> [QueryFragment] = { _ in [] }, - doUpdate updates: (inout Updates, Excluded) -> Void = { _, _ in }, - @QueryFragmentBuilder - where updateFilter: (TableColumns) -> [QueryFragment] = { _ in [] } - ) -> InsertOf { - withoutActuallyEscaping(updates) { updates in - _insert( - or: conflictResolution, - values: values, - onConflict: conflictTargets, - where: targetFilter, - doUpdate: updates, - where: updateFilter - ) - } - } - - /// An insert statement for one or more table rows. - /// - /// This function can be used to create an insert statement from a ``Draft`` value. - /// - /// - Parameters: - /// - conflictResolution: A conflict resolution algorithm. - /// - columns: Columns to insert. - /// - values: A builder of row values for the given columns. - /// - conflictTargets: Indexed columns to target for conflict resolution. - /// - targetFilter: A filter to apply to conflict target columns. - /// - updates: Updates to perform in an upsert clause should the insert conflict with an - /// existing row. - /// - updateFilter: A filter to apply to the update clause. - public static func insert( - or conflictResolution: ConflictResolution? = nil, - _ columns: (Draft.TableColumns) -> Draft.TableColumns = { $0 }, - @InsertValuesBuilder values: () -> [Draft], - onConflict conflictTargets: (TableColumns) -> ( - TableColumn, repeat TableColumn - ), - @QueryFragmentBuilder - where targetFilter: (TableColumns) -> [QueryFragment] = { _ in [] }, - doUpdate updates: (inout Updates) -> Void, - @QueryFragmentBuilder - where updateFilter: (TableColumns) -> [QueryFragment] = { _ in [] } - ) -> InsertOf { - insert( - or: conflictResolution, - values: values, - onConflict: conflictTargets, - where: targetFilter, - doUpdate: { row, _ in updates(&row) }, - where: updateFilter - ) - } - /// An upsert statement for given drafts. /// /// Generates an insert statement with an upsert clause. Useful for building forms that can both @@ -853,7 +680,7 @@ extension PrimaryKeyedTable { /// - Returns: An insert statement with an upsert clause. public static func upsert( or conflictResolution: ConflictResolution? = nil, - @InsertValuesBuilder values: () -> [Draft] + @InsertValuesBuilder values: () -> [[QueryFragment]] ) -> InsertOf { insert( or: conflictResolution, @@ -867,40 +694,6 @@ extension PrimaryKeyedTable { } ) } - - private static func _insert( - or conflictResolution: ConflictResolution?, - @InsertValuesBuilder values: () -> [Draft], - onConflict conflictTargets: (TableColumns) -> (repeat TableColumn)?, - @QueryFragmentBuilder - where targetFilter: (TableColumns) -> [QueryFragment] = { _ in [] }, - doUpdate updates: ((inout Updates, Excluded) -> Void)?, - @QueryFragmentBuilder - where updateFilter: (TableColumns) -> [QueryFragment] = { _ in [] } - ) -> InsertOf { - var valueFragments: [[QueryFragment]] = [] - for value in values() { - var valueFragment: [QueryFragment] = [] - for column in Draft.TableColumns.writableColumns { - func open( - _ column: some WritableTableColumnExpression - ) -> QueryFragment { - Value(queryOutput: (value as! Root)[keyPath: column.keyPath]).queryFragment - } - valueFragment.append(open(column)) - } - valueFragments.append(valueFragment) - } - return _insert( - or: conflictResolution, - columnNames: Draft.TableColumns.writableColumns.map(\.name), - values: .values(valueFragments), - onConflict: conflictTargets, - where: targetFilter, - doUpdate: updates, - where: updateFilter - ) - } } private enum InsertValues { @@ -1085,43 +878,155 @@ public typealias InsertOf = Insert /// insert any number of rows into a table. @resultBuilder public enum InsertValuesBuilder { - public static func buildArray(_ components: [[Value]]) -> [Value] { - components.flatMap(\.self) + public static func buildExpression(_ expression: [Value]) -> [[QueryFragment]] + where Value: Table { + var valueFragments: [[QueryFragment]] = [] + for value in expression { + var valueFragment: [QueryFragment] = [] + for column in Value.TableColumns.writableColumns { + func open( + _ column: some WritableTableColumnExpression + ) -> QueryFragment { + Member(queryOutput: (value as! Root)[keyPath: column.keyPath]).queryFragment + } + valueFragment.append(open(column)) + } + valueFragments.append(valueFragment) + } + return valueFragments } - public static func buildBlock(_ components: [Value]) -> [Value] { - components + @_disfavoredOverload + public static func buildExpression(_ expression: [Value.Draft]) -> [[QueryFragment]] + where Value: PrimaryKeyedTable { + var valueFragments: [[QueryFragment]] = [] + for value in expression { + var valueFragment: [QueryFragment] = [] + for column in Value.Draft.TableColumns.writableColumns { + func open( + _ column: some WritableTableColumnExpression + ) -> QueryFragment { + Member(queryOutput: (value as! Root)[keyPath: column.keyPath]).queryFragment + } + valueFragment.append(open(column)) + } + valueFragments.append(valueFragment) + } + return valueFragments } - public static func buildEither(first component: [Value]) -> [Value] { - component + @_disfavoredOverload + public static func buildExpression( + _ expression: [V] + ) -> [[QueryFragment]] + where + Value == V.QueryValue, + V.QueryValue: QueryRepresentable & QueryBindable + { + [expression.map(\.queryFragment)] } - public static func buildEither(second component: [Value]) -> [Value] { - component + @_disfavoredOverload + public static func buildExpression( + _ expression: [Value.QueryOutput] + ) -> [[QueryFragment]] + where Value: QueryRepresentable & QueryBindable { + [expression.map { Value(queryOutput: $0).queryFragment }] + } + + public static func buildExpression(_ expression: Value) -> [[QueryFragment]] + where Value: Table { + buildExpression([expression]) + } + + public static func buildExpression(_ expression: Value.Draft) -> [[QueryFragment]] + where Value: PrimaryKeyedTable { + buildExpression([expression]) + } + + @_disfavoredOverload + public static func buildExpression( + _ expression: V + ) -> [[QueryFragment]] + where + Value == V.QueryValue, + V.QueryValue: QueryRepresentable & QueryBindable + { + buildExpression([expression]) + } + + public static func buildExpression( + _ expression: Value.QueryOutput + ) -> [[QueryFragment]] + where Value: QueryRepresentable & QueryBindable { + buildExpression([expression]) + } + + @_disfavoredOverload + public static func buildExpression( + _ expression: (repeat each V) + ) -> [[QueryFragment]] + where + Value == (repeat (each V).QueryValue), + repeat (each V).QueryValue: QueryRepresentable & QueryBindable + { + var valueFragment: [QueryFragment] = [] + for column in repeat each expression { + valueFragment.append(column.queryFragment) + } + return [valueFragment] + } + + public static func buildExpression( + _ expression: (repeat (each V).QueryOutput) + ) -> [[QueryFragment]] + where Value == (repeat each V) { + var valueFragment: [QueryFragment] = [] + for (columnType, column) in repeat ((each V).self, each expression) { + valueFragment.append(columnType.init(queryOutput: column).queryFragment) + } + return [valueFragment] + } + + public static func buildExpression( + _ expression: Value.Columns + ) -> [[QueryFragment]] + where Value: _Selection { + [expression.selection.map(\.expression)] + } + + public static func buildArray(_ components: [[[QueryFragment]]]) -> [[QueryFragment]] { + components.flatMap(\.self) } - public static func buildExpression(_ expression: Value) -> [Value] { - [expression] + public static func buildBlock(_ components: [[QueryFragment]]) -> [[QueryFragment]] { + components } - public static func buildExpression(_ expression: [Value]) -> [Value] { - expression + public static func buildEither(first component: [[QueryFragment]]) -> [[QueryFragment]] { + component + } + + public static func buildEither(second component: [[QueryFragment]]) -> [[QueryFragment]] { + component } - public static func buildLimitedAvailability(_ component: [Value]) -> [Value] { + public static func buildLimitedAvailability(_ component: [[QueryFragment]]) -> [[QueryFragment]] { component } - public static func buildOptional(_ component: [Value]?) -> [Value] { + public static func buildOptional(_ component: [[QueryFragment]]?) -> [[QueryFragment]] { component ?? [] } - public static func buildPartialBlock(first: [Value]) -> [Value] { + public static func buildPartialBlock(first: [[QueryFragment]]) -> [[QueryFragment]] { first } - public static func buildPartialBlock(accumulated: [Value], next: [Value]) -> [Value] { + public static func buildPartialBlock( + accumulated: [[QueryFragment]], + next: [[QueryFragment]] + ) -> [[QueryFragment]] { accumulated + next } } diff --git a/Sources/StructuredQueriesCore/Statements/Select.swift b/Sources/StructuredQueriesCore/Statements/Select.swift index d1704888..7d9d445c 100644 --- a/Sources/StructuredQueriesCore/Statements/Select.swift +++ b/Sources/StructuredQueriesCore/Statements/Select.swift @@ -643,6 +643,7 @@ extension Select { /// - Returns: A new select statement that joins the given table and combines their clauses /// together. @_documentation(visibility: private) + @_disfavoredOverload public func join< each C1: QueryRepresentable, each C2: QueryRepresentable, F: Table, each J: Table >( @@ -681,6 +682,7 @@ extension Select { /// - constraint: The constraint describing the join. /// - Returns: A new select statement that joins the given table and combines their clauses /// together. + @_disfavoredOverload @_documentation(visibility: private) public func join( // TODO: Report issue to Swift team. Using 'some' crashes the compiler. @@ -801,6 +803,7 @@ extension Select { /// - constraint: The constraint describing the join. /// - Returns: A new select statement that left-joins the given table and combines their clauses /// together. + @_disfavoredOverload @_documentation(visibility: private) public func leftJoin< each C1: QueryRepresentable, each C2: QueryRepresentable, F: Table, each J: Table @@ -849,6 +852,7 @@ extension Select { /// - constraint: The constraint describing the join. /// - Returns: A new select statement that left-joins the given table and combines their clauses /// together. + @_disfavoredOverload @_documentation(visibility: private) public func leftJoin( // TODO: Report issue to Swift team. Using 'some' crashes the compiler. @@ -971,6 +975,7 @@ extension Select { /// - constraint: The constraint describing the join. /// - Returns: A new select statement that right-joins the given table and combines their clauses /// together. + @_disfavoredOverload @_documentation(visibility: private) public func rightJoin< each C1: QueryRepresentable, each C2: QueryRepresentable, F: Table, each J: Table @@ -1019,6 +1024,7 @@ extension Select { /// - constraint: The constraint describing the join. /// - Returns: A new select statement that right-joins the given table and combines their clauses /// together. + @_disfavoredOverload @_documentation(visibility: private) public func rightJoin( // TODO: Report issue to Swift team. Using 'some' crashes the compiler. @@ -1141,6 +1147,7 @@ extension Select { /// - constraint: The constraint describing the join. /// - Returns: A new select statement that full-joins the given table and combines their clauses /// together. + @_disfavoredOverload @_documentation(visibility: private) public func fullJoin< each C1: QueryRepresentable, each C2: QueryRepresentable, F: Table, each J: Table @@ -1189,6 +1196,7 @@ extension Select { /// - constraint: The constraint describing the join. /// - Returns: A new select statement that full-joins the given table and combines their clauses /// together. + @_disfavoredOverload @_documentation(visibility: private) public func fullJoin( // TODO: Report issue to Swift team. Using 'some' crashes the compiler. diff --git a/Sources/StructuredQueriesCore/_Selection.swift b/Sources/StructuredQueriesCore/_Selection.swift new file mode 100644 index 00000000..1d0678fb --- /dev/null +++ b/Sources/StructuredQueriesCore/_Selection.swift @@ -0,0 +1,13 @@ +public protocol _Selection: QueryRepresentable { + associatedtype Columns: _SelectedColumns +} + +public protocol _SelectedColumns: QueryExpression { + var selection: [(aliasName: String, expression: QueryFragment)] { get } +} + +extension _SelectedColumns { + public var queryFragment: QueryFragment { + selection.map { "\($1) AS \(quote: $0)" as QueryFragment }.joined(separator: ", ") + } +} diff --git a/Sources/StructuredQueriesMacros/SelectionMacro.swift b/Sources/StructuredQueriesMacros/SelectionMacro.swift index f57844e7..97f33a7d 100644 --- a/Sources/StructuredQueriesMacros/SelectionMacro.swift +++ b/Sources/StructuredQueriesMacros/SelectionMacro.swift @@ -155,7 +155,7 @@ extension SelectionMacro: ExtensionMacro { } var conformances: [TypeSyntax] = [] - let protocolNames: [TokenSyntax] = ["QueryRepresentable"] + let protocolNames: [TokenSyntax] = ["_Selection"] if let inheritanceClause = declaration.inheritanceClause { for type in protocolNames { if !inheritanceClause.inheritedTypes.contains(where: { @@ -300,8 +300,8 @@ extension SelectionMacro: MemberMacro { } var conformances: [TypeSyntax] = [] - let protocolNames: [TokenSyntax] = ["QueryRepresentable"] - let schemaConformances: [ExprSyntax] = ["\(moduleName).QueryExpression"] + let protocolNames: [TokenSyntax] = ["_Selection"] + let schemaConformances: [ExprSyntax] = ["\(moduleName)._SelectedColumns"] if let inheritanceClause = declaration.inheritanceClause { for type in protocolNames { if !inheritanceClause.inheritedTypes.contains(where: { @@ -324,19 +324,17 @@ extension SelectionMacro: MemberMacro { .joined(separator: ",\n") let initAssignment: [String] = allColumns - .map { #"\(\#($0.name).queryFragment) AS \#($0.name.text.quoted())"# } + .map { #"(\#($0.name.text.quoted()), \#($0.name).queryFragment)"# } return [ """ public struct Columns: \(schemaConformances, separator: ", ") { public typealias QueryValue = \(type.trimmed) - public let queryFragment: \(moduleName).QueryFragment + public let selection: [(aliasName: String, expression: \(moduleName).QueryFragment)] public init( \(raw: initArguments) ) { - self.queryFragment = \"\"\" - \(raw: initAssignment.joined(separator: ", ")) - \"\"\" + self.selection = [\(raw: initAssignment.joined(separator: ", "))] } } """ diff --git a/Tests/StructuredQueriesMacrosTests/SelectionMacroTests.swift b/Tests/StructuredQueriesMacrosTests/SelectionMacroTests.swift index d16bf9bb..0553b615 100644 --- a/Tests/StructuredQueriesMacrosTests/SelectionMacroTests.swift +++ b/Tests/StructuredQueriesMacrosTests/SelectionMacroTests.swift @@ -14,26 +14,24 @@ extension SnapshotTests { } """ } expansion: { - #""" + """ struct PlayerAndTeam { let player: Player let team: Team - public struct Columns: StructuredQueriesCore.QueryExpression { + public struct Columns: StructuredQueriesCore._SelectedColumns { public typealias QueryValue = PlayerAndTeam - public let queryFragment: StructuredQueriesCore.QueryFragment + public let selection: [(aliasName: String, expression: StructuredQueriesCore.QueryFragment)] public init( player: some StructuredQueriesCore.QueryExpression, team: some StructuredQueriesCore.QueryExpression ) { - self.queryFragment = """ - \(player.queryFragment) AS "player", \(team.queryFragment) AS "team" - """ + self.selection = [("player", player.queryFragment), ("team", team.queryFragment)] } } } - extension PlayerAndTeam: StructuredQueriesCore.QueryRepresentable { + extension PlayerAndTeam: StructuredQueriesCore._Selection { public init(decoder: inout some StructuredQueriesCore.QueryDecoder) throws { let player = try decoder.decode(Player.self) let team = try decoder.decode(Team.self) @@ -47,7 +45,7 @@ extension SnapshotTests { self.team = team } } - """# + """ } } @@ -77,26 +75,24 @@ extension SnapshotTests { } """ } expansion: { - #""" + """ struct ReminderTitleAndListTitle { var reminderTitle: String var listTitle: String? - public struct Columns: StructuredQueriesCore.QueryExpression { + public struct Columns: StructuredQueriesCore._SelectedColumns { public typealias QueryValue = ReminderTitleAndListTitle - public let queryFragment: StructuredQueriesCore.QueryFragment + public let selection: [(aliasName: String, expression: StructuredQueriesCore.QueryFragment)] public init( reminderTitle: some StructuredQueriesCore.QueryExpression, listTitle: some StructuredQueriesCore.QueryExpression ) { - self.queryFragment = """ - \(reminderTitle.queryFragment) AS "reminderTitle", \(listTitle.queryFragment) AS "listTitle" - """ + self.selection = [("reminderTitle", reminderTitle.queryFragment), ("listTitle", listTitle.queryFragment)] } } } - extension ReminderTitleAndListTitle: StructuredQueriesCore.QueryRepresentable { + extension ReminderTitleAndListTitle: StructuredQueriesCore._Selection { public init(decoder: inout some StructuredQueriesCore.QueryDecoder) throws { let reminderTitle = try decoder.decode(String.self) let listTitle = try decoder.decode(String.self) @@ -107,7 +103,7 @@ extension SnapshotTests { self.listTitle = listTitle } } - """# + """ } } @@ -120,24 +116,22 @@ extension SnapshotTests { } """ } expansion: { - #""" + """ struct ReminderDate { var date: Date - public struct Columns: StructuredQueriesCore.QueryExpression { + public struct Columns: StructuredQueriesCore._SelectedColumns { public typealias QueryValue = ReminderDate - public let queryFragment: StructuredQueriesCore.QueryFragment + public let selection: [(aliasName: String, expression: StructuredQueriesCore.QueryFragment)] public init( date: some StructuredQueriesCore.QueryExpression ) { - self.queryFragment = """ - \(date.queryFragment) AS "date" - """ + self.selection = [("date", date.queryFragment)] } } } - extension ReminderDate: StructuredQueriesCore.QueryRepresentable { + extension ReminderDate: StructuredQueriesCore._Selection { public init(decoder: inout some StructuredQueriesCore.QueryDecoder) throws { let date = try decoder.decode(Date.UnixTimeRepresentation.self) guard let date else { @@ -146,7 +140,7 @@ extension SnapshotTests { self.date = date } } - """# + """ } } } diff --git a/Tests/StructuredQueriesMacrosTests/TableSelectionMacroTests.swift b/Tests/StructuredQueriesMacrosTests/TableSelectionMacroTests.swift index 60d947e9..cd27cbf1 100644 --- a/Tests/StructuredQueriesMacrosTests/TableSelectionMacroTests.swift +++ b/Tests/StructuredQueriesMacrosTests/TableSelectionMacroTests.swift @@ -35,16 +35,14 @@ extension SnapshotTests { } } - public struct Columns: StructuredQueriesCore.QueryExpression { + public struct Columns: StructuredQueriesCore._SelectedColumns { public typealias QueryValue = ReminderListWithCount - public let queryFragment: StructuredQueriesCore.QueryFragment + public let selection: [(aliasName: String, expression: StructuredQueriesCore.QueryFragment)] public init( reminderList: some StructuredQueriesCore.QueryExpression, remindersCount: some StructuredQueriesCore.QueryExpression ) { - self.queryFragment = """ - \(reminderList.queryFragment) AS "reminderList", \(remindersCount.queryFragment) AS "remindersCount" - """ + self.selection = [("reminderList", reminderList.queryFragment), ("remindersCount", remindersCount.queryFragment)] } } } @@ -60,7 +58,7 @@ extension SnapshotTests { } } - extension ReminderListWithCount: StructuredQueriesCore.QueryRepresentable { + extension ReminderListWithCount: StructuredQueriesCore._Selection { public init(decoder: inout some StructuredQueriesCore.QueryDecoder) throws { let reminderList = try decoder.decode(ReminderList.self) let remindersCount = try decoder.decode(Int.self) diff --git a/Tests/StructuredQueriesTests/InsertTests.swift b/Tests/StructuredQueriesTests/InsertTests.swift index d092cbb0..9bfec1d5 100644 --- a/Tests/StructuredQueriesTests/InsertTests.swift +++ b/Tests/StructuredQueriesTests/InsertTests.swift @@ -13,6 +13,7 @@ extension SnapshotTests { } values: { (1, "Groceries", true, Date(timeIntervalSinceReferenceDate: 0), .high) (2, "Haircut", false, Date(timeIntervalSince1970: 0), .low) + (#sql("3"), #sql("'Schedule doctor appointment'"), #sql("0"), #sql("NULL"), #sql("2")) } onConflictDoUpdate: { $0.title += " Copy" } @@ -22,7 +23,7 @@ extension SnapshotTests { INSERT INTO "reminders" ("remindersListID", "title", "isCompleted", "dueDate", "priority") VALUES - (1, 'Groceries', 1, '2001-01-01 00:00:00.000', 3), (2, 'Haircut', 0, '1970-01-01 00:00:00.000', 1) + (1, 'Groceries', 1, '2001-01-01 00:00:00.000', 3), (2, 'Haircut', 0, '1970-01-01 00:00:00.000', 1), (3, 'Schedule doctor appointment', 0, NULL, 2) ON CONFLICT DO UPDATE SET "title" = ("reminders"."title" || ' Copy') RETURNING "id" """ @@ -31,6 +32,7 @@ extension SnapshotTests { ┌────┐ │ 11 │ │ 12 │ + │ 13 │ └────┘ """ } @@ -247,7 +249,8 @@ extension SnapshotTests { assertQuery( Reminder.insert { - Reminder.Draft(remindersListID: 1, title: "Check voicemail") + Reminder(id: 12, remindersListID: 1, title: "Check voicemail") + Reminder.Draft(remindersListID: 1, title: "Check pager") } .returning(\.id) ) { @@ -255,13 +258,14 @@ extension SnapshotTests { INSERT INTO "reminders" ("id", "assignedUserID", "dueDate", "isCompleted", "isFlagged", "notes", "priority", "remindersListID", "title", "updatedAt") VALUES - (NULL, NULL, NULL, 0, 0, '', NULL, 1, 'Check voicemail', '2040-02-14 23:31:30.000') + (12, NULL, NULL, 0, 0, '', NULL, 1, 'Check voicemail', '2040-02-14 23:31:30.000'), (NULL, NULL, NULL, 0, 0, '', NULL, 1, 'Check pager', '2040-02-14 23:31:30.000') RETURNING "id" """ } results: { """ ┌────┐ │ 12 │ + │ 13 │ └────┘ """ } @@ -285,8 +289,8 @@ extension SnapshotTests { } results: { """ ┌────┐ - │ 13 │ │ 14 │ + │ 15 │ └────┘ """ } @@ -576,6 +580,26 @@ extension SnapshotTests { } } + @Test func selectedColumns() { + assertInlineSnapshot( + of: Item.insert { + Item.Columns( + title: #sql("'Foo'"), + quantity: #sql("42"), + notes: #sql("[]") + ) + }, + as: .sql + ) { + """ + INSERT INTO "items" + ("title", "quantity", "notes") + VALUES + ('Foo', 42, []) + """ + } + } + @Test func onConflictWhereDoUpdateWhere() { assertQuery( Reminder.insert { @@ -671,7 +695,7 @@ extension SnapshotTests { } } -@Table private struct Item { +@Table @Selection private struct Item { var title = "" var quantity = 0 @Column(as: [String].JSONRepresentation.self)