diff --git a/Sources/StructuredQueriesCore/AggregateFunctions.swift b/Sources/StructuredQueriesCore/AggregateFunctions.swift index e8c4bf61..ea44bdb8 100644 --- a/Sources/StructuredQueriesCore/AggregateFunctions.swift +++ b/Sources/StructuredQueriesCore/AggregateFunctions.swift @@ -20,7 +20,7 @@ extension QueryExpression where QueryValue: QueryBindable { distinct isDistinct: Bool = false, filter: (some QueryExpression)? = Bool?.none ) -> some QueryExpression { - AggregateFunction( + AggregateFunctionExpression( "count", isDistinct: isDistinct, [queryFragment], @@ -51,7 +51,7 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped == Strin order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none ) -> some QueryExpression { - AggregateFunction( + AggregateFunctionExpression( "group_concat", separator.map { [queryFragment, $0.queryFragment] } ?? [queryFragment], order: order?.queryFragment, @@ -74,7 +74,7 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped == Strin order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none ) -> some QueryExpression { - AggregateFunction( + AggregateFunctionExpression( "group_concat", isDistinct: isDistinct, [queryFragment], @@ -97,7 +97,7 @@ extension QueryExpression where QueryValue: QueryBindable & _OptionalPromotable public func max( filter: (some QueryExpression)? = Bool?.none ) -> some QueryExpression { - AggregateFunction("max", [queryFragment], filter: filter?.queryFragment) + AggregateFunctionExpression("max", [queryFragment], filter: filter?.queryFragment) } /// A minimum aggregate of this expression. @@ -112,7 +112,7 @@ extension QueryExpression where QueryValue: QueryBindable & _OptionalPromotable public func min( filter: (some QueryExpression)? = Bool?.none ) -> some QueryExpression { - AggregateFunction("min", [queryFragment], filter: filter?.queryFragment) + AggregateFunctionExpression("min", [queryFragment], filter: filter?.queryFragment) } } @@ -134,7 +134,7 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric distinct isDistinct: Bool = false, filter: (some QueryExpression)? = Bool?.none ) -> some QueryExpression { - AggregateFunction("avg", isDistinct: isDistinct, [queryFragment], filter: filter?.queryFragment) + AggregateFunctionExpression("avg", isDistinct: isDistinct, [queryFragment], filter: filter?.queryFragment) } /// An sum aggregate of this expression. @@ -156,7 +156,7 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric // NB: We must explicitly erase here to avoid a runtime crash with opaque return types // TODO: Report issue to Swift team. SQLQueryExpression( - AggregateFunction( + AggregateFunctionExpression( "sum", isDistinct: isDistinct, [queryFragment], @@ -182,7 +182,7 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric distinct isDistinct: Bool = false, filter: (some QueryExpression)? = Bool?.none ) -> some QueryExpression { - AggregateFunction( + AggregateFunctionExpression( "total", isDistinct: isDistinct, [queryFragment], @@ -191,7 +191,7 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric } } -extension QueryExpression where Self == AggregateFunction { +extension QueryExpression where Self == AggregateFunctionExpression { /// A `count(*)` aggregate. /// /// ```swift @@ -204,18 +204,34 @@ extension QueryExpression where Self == AggregateFunction { public static func count( filter: (any QueryExpression)? = nil ) -> Self { - AggregateFunction("count", ["*"], filter: filter?.queryFragment) + AggregateFunctionExpression("count", ["*"], filter: filter?.queryFragment) } } /// A query expression of an aggregate function. -public struct AggregateFunction: QueryExpression, Sendable { +public struct AggregateFunctionExpression: QueryExpression, Sendable { var name: QueryFragment var isDistinct: Bool var arguments: [QueryFragment] var order: QueryFragment? var filter: QueryFragment? + public init( + _ name: String, + distinct isDistinct: Bool = false, + _ arguments: repeat each Argument, + order: (some QueryExpression)? = Bool?.none, + filter: (some QueryExpression)? = Bool?.none + ) { + self.init( + QueryFragment(quote: name), + isDistinct: isDistinct, + Array(repeat each arguments), + order: order?.queryFragment, + filter: filter?.queryFragment + ) + } + package init( _ name: QueryFragment, isDistinct: Bool = false, diff --git a/Sources/StructuredQueriesCore/ScalarFunctions.swift b/Sources/StructuredQueriesCore/ScalarFunctions.swift index 460d05cb..7a2b2db2 100644 --- a/Sources/StructuredQueriesCore/ScalarFunctions.swift +++ b/Sources/StructuredQueriesCore/ScalarFunctions.swift @@ -319,8 +319,7 @@ extension QueryExpression where QueryValue == [UInt8] { } } -/// A query expression of a generalized query function. -public struct QueryFunction: QueryExpression { +package struct QueryFunction: QueryExpression { let name: QueryFragment let arguments: [QueryFragment] diff --git a/Sources/StructuredQueriesSQLite/Macros.swift b/Sources/StructuredQueriesSQLite/Macros.swift index 33d0deec..c2f96755 100644 --- a/Sources/StructuredQueriesSQLite/Macros.swift +++ b/Sources/StructuredQueriesSQLite/Macros.swift @@ -54,3 +54,41 @@ public macro DatabaseFunction( module: "StructuredQueriesSQLiteMacros", type: "DatabaseFunctionMacro" ) + +/// Defines and implements a conformance to the ``/StructuredQueriesSQLiteCore/DatabaseFunction`` +/// protocol. +/// +/// - Parameters +/// - name: The function's name. Defaults to the name of the function the macro is applied to. +/// - representableFunctionType: The function as represented in a query. +/// - isDeterministic: Whether or not the function is deterministic (or "pure" or "referentially +/// transparent"), _i.e._ given an input it will always return the same output. +@attached(peer, names: overloaded, prefixed(`$`)) +public macro DatabaseFunction( + _ name: String = "", + as representableFunctionType: ((any Sequence<(repeat each T)>) -> R).Type, + isDeterministic: Bool = false +) = + #externalMacro( + module: "StructuredQueriesSQLiteMacros", + type: "DatabaseFunctionMacro" + ) + +/// Defines and implements a conformance to the ``/StructuredQueriesSQLiteCore/DatabaseFunction`` +/// protocol. +/// +/// - Parameters +/// - name: The function's name. Defaults to the name of the function the macro is applied to. +/// - representableFunctionType: The function as represented in a query. +/// - isDeterministic: Whether or not the function is deterministic (or "pure" or "referentially +/// transparent"), _i.e._ given an input it will always return the same output. +@attached(peer, names: overloaded, prefixed(`$`)) +public macro DatabaseFunction( + _ name: String = "", + as representableFunctionType: ((any Sequence<(repeat each T)>) -> Void).Type, + isDeterministic: Bool = false +) = + #externalMacro( + module: "StructuredQueriesSQLiteMacros", + type: "DatabaseFunctionMacro" + ) diff --git a/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift b/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift index a7b774a0..71584ab0 100644 --- a/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift +++ b/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift @@ -1,7 +1,7 @@ /// A type representing a database function. /// -/// Don't conform to this protocol directly. Instead, use the `@DatabaseFunction` macro to generate -/// a conformance. +/// Don't conform to this protocol directly. Instead, use the +/// [`@DatabaseFunction`]() macro to generate a conformance. public protocol DatabaseFunction { /// A type representing the function's arguments. associatedtype Input @@ -22,8 +22,8 @@ public protocol DatabaseFunction { /// A type representing a scalar database function. /// -/// Don't conform to this protocol directly. Instead, use the `@DatabaseFunction` macro to generate -/// a conformance. +/// Don't conform to this protocol directly. Instead, use the +/// [`@DatabaseFunction`]() macro to generate a conformance. public protocol ScalarDatabaseFunction: DatabaseFunction { /// The function body. Uses a query decoder to process the input of a database function into a /// bindable value. @@ -50,3 +50,68 @@ extension ScalarDatabaseFunction { } } } + +/// A type representing an aggregate database function. +/// +/// Don't conform to this protocol directly. Instead, use the +/// [`@DatabaseFunction`]() macro to generate a +/// conformance. +public protocol AggregateDatabaseFunction: DatabaseFunction { + /// A type representing one row of input to the aggregate function. + associatedtype Element = Input + + /// Decodes a row into an element to aggregate a result from. + /// + /// - Parameter decoder: A query decoder. + /// - Returns: An element to append to the sequence sent to the aggregate function. + func step(_ decoder: inout some QueryDecoder) throws -> Element + + /// Aggregates elements into a bindable value. + /// + /// - Parameter arguments: A sequence of elements to aggregate from. + /// - Returns: A binding returned from the aggregate function. + func invoke(_ arguments: some Sequence) throws -> QueryBinding +} + +extension AggregateDatabaseFunction { + /// An aggregate function call expression. + /// + /// - Parameters + /// - input: Expressions representing the arguments of the function. + /// - isDistinct: Whether or not to include a `DISTINCT` clause, which filters duplicates from + /// the aggregation. + /// - order: An `ORDER BY` clause to apply to the aggregation. + /// - filter: A `FILTER` clause to apply to the aggregation. + /// - Returns: An expression representing the function call. + @_disfavoredOverload + public func callAsFunction( + _ input: some QueryExpression, + distinct isDistinct: Bool = false, + order: (some QueryExpression)? = Bool?.none, + filter: (some QueryExpression)? = Bool?.none + ) -> some QueryExpression + where Input: QueryBindable { + $_isSelecting.withValue(false) { + AggregateFunctionExpression(name, distinct: isDistinct, input, order: order, filter: filter) + } + } + + /// An aggregate function call expression. + /// + /// - Parameters + /// - input: Expressions representing the arguments of the function. + /// - order: An `ORDER BY` clause to apply to the aggregation. + /// - filter: A `FILTER` clause to apply to the aggregation. + /// - Returns: An expression representing the function call. + @_disfavoredOverload + public func callAsFunction( + _ input: repeat each T, + order: (some QueryExpression)? = Bool?.none, + filter: (some QueryExpression)? = Bool?.none + ) -> some QueryExpression + where Input == (repeat (each T).QueryValue) { + $_isSelecting.withValue(false) { + AggregateFunctionExpression(name, repeat each input, order: order, filter: filter) + } + } +} diff --git a/Sources/StructuredQueriesSQLiteCore/Documentation.docc/Articles/CustomFunctions.md b/Sources/StructuredQueriesSQLiteCore/Documentation.docc/Articles/CustomFunctions.md index 7d4b76b8..edb05ab5 100644 --- a/Sources/StructuredQueriesSQLiteCore/Documentation.docc/Articles/CustomFunctions.md +++ b/Sources/StructuredQueriesSQLiteCore/Documentation.docc/Articles/CustomFunctions.md @@ -5,6 +5,8 @@ from SQLite. ## Overview +### Scalar functions + StructuredQueries defines a macro specifically for defining Swift functions that can be called from a query. It's called `@DatabaseFunction`, and can annotate any function that works with query-representable types. @@ -18,11 +20,14 @@ func exclaim(_ string: String) -> String { } ``` +This defines a "scalar" function, which is called on a value for each row in a query, returning its +result. + > Note: If your project is using [default main actor isolation] then you further need to annotate > your function as `nonisolated`. [default main actor isolation]: https://github.com/swiftlang/swift-evolution/blob/main/proposals/0466-control-default-actor-isolation.md -And will be immediately callable in a query by prefixing the function with `$`: +Once defined, the function is immediately callable in a query by prefixing the function with `$`: ```swift Reminder.select { $exclaim($0.title) } @@ -52,9 +57,52 @@ configuration.prepareDatabase { db in > } > ``` +### Aggregate functions + +It is also possible to define a Swift function that builds a single result from multiple rows of a +query. The function must simply take a _sequence_ of query-representable types. + +For example, suppose you want to compute the most common priority used across all reminders. This +computation is called the "mode" in statistics, and unfortunately SQLite does not supply such +a function. But it is quite easy to write this function in plain Swift: + +```swift +@DatabaseFunction +func mode(priority priorities: some Sequence) -> Priority? { + var occurrences: [Priority: Int] = [:] + for priority in priorities { + guard let priority + else { continue } + occurrences[priority, default: 0] += 1 + } + return occurrences.max { $0.value < $1.value }?.key +} +``` + +This defines an "aggregate" function, and the sequence `priorities` that is passed to it represents +all of the data from the database passed to it while aggregating. It is now straightforward +to compute the mode of priorities across all reminders: + +```swift +Reminder + .select { $mode(priority: $0.priority) } +``` + +> Tip: Be sure to install the function in the database connection as discussed in +> above. + +You can also compute the mode of priorities inside each reminders list: + +```swift +RemindersList + .group(by: \.id) + .leftJoin(Reminder.all) { $0.id.eq($1.remindersListID) } + .select { ($0.title, $mode(priority: $1.priority)) } +``` + ### Custom representations -To define a type that works with a custom representation, i.e. anytime you use `@Column(as:)` in +To define a type that works with a custom representation, _i.e._ anytime you use `@Column(as:)` in your data type, you can use the `as` parameter of the macro to specify those types. For example, if your model holds onto a date and you want to store that date as a [unix timestamp]() (i.e. double), @@ -93,9 +141,22 @@ func jsonArrayExclaim(_ strings: [String]) -> [String] { } ``` +It is also possible to do this with aggregate functions, too, but you must describe the sequence as +an `any Sequence` instead of a `some Sequence`: + +```swift +@DatabaseFunction( + as: ((any Sequence<[String].JSONRepresentation>) -> [String].JSONRepresentation).self +) +func jsonJoined(_ arrays: some Sequence<[String]>) -> [String] { + arrays.flatMap(\.self) +} +``` + ## Topics ### Custom functions - ``DatabaseFunction`` - ``ScalarDatabaseFunction`` +- ``AggregateDatabaseFunction`` diff --git a/Sources/StructuredQueriesSQLiteCore/JSONFunctions.swift b/Sources/StructuredQueriesSQLiteCore/JSONFunctions.swift index 5e62a42b..d842250b 100644 --- a/Sources/StructuredQueriesSQLiteCore/JSONFunctions.swift +++ b/Sources/StructuredQueriesSQLiteCore/JSONFunctions.swift @@ -46,7 +46,7 @@ extension QueryExpression where QueryValue: Codable & QueryBindable { order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none ) -> some QueryExpression<[QueryValue].JSONRepresentation> { - AggregateFunction( + AggregateFunctionExpression( "json_group_array", isDistinct: isDistinct, [queryFragment], @@ -112,7 +112,7 @@ extension PrimaryKeyedTableDefinition where QueryValue: Codable { order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none ) -> some QueryExpression<[QueryValue].JSONRepresentation> { - AggregateFunction( + AggregateFunctionExpression( "json_group_array", isDistinct: isDistinct, [jsonObject().queryFragment], @@ -200,7 +200,7 @@ where } else { primaryKeyFilter.queryFragment } - return AggregateFunction( + return AggregateFunctionExpression( "json_group_array", isDistinct: isDistinct, [QueryValue.columns.jsonObject().queryFragment], diff --git a/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift b/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift index ff0f7bb7..c25887c5 100644 --- a/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift +++ b/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift @@ -110,45 +110,177 @@ extension DatabaseFunctionMacro: PeerMacro { var invocationArgumentTypes: [TypeSyntax] = [] var parameters: [String] = [] var argumentBindings: [String] = [] - var offset = 0 - var functionRepresentationIterator = functionRepresentation?.parameters.makeIterator() var decodings: [String] = [] var decodingUnwrappings: [String] = [] + var canThrowInvalidInvocation = false - for index in signature.parameterClause.parameters.indices { - defer { offset += 1 } - var parameter = signature.parameterClause.parameters[index] - if let ellipsis = parameter.ellipsis { - context.diagnose( - Diagnostic( - node: ellipsis, - message: MacroExpansionErrorMessage("Variadic arguments are not supported") + let isAggregate: Bool + var representableInputType: String + var rowType = "" + let projectedCallSyntax: ExprSyntax + + if signature.parameterClause.parameters.count == 1, + let parameter = signature.parameterClause.parameters.first, + var someOrAnyParameterType = parameter.type.as(SomeOrAnyTypeSyntax.self), + someOrAnyParameterType.someOrAnySpecifier.tokenKind == .keyword(.some), + let parameterType = someOrAnyParameterType.constraint.as(IdentifierTypeSyntax.self), + ["Sequence", "Swift.Sequence"].contains(parameterType.name.text), + let genericArgumentClause = parameterType.genericArgumentClause, + genericArgumentClause.arguments.count == 1, + let genericArgument = genericArgumentClause.arguments.first + { + isAggregate = true + + someOrAnyParameterType.someOrAnySpecifier.tokenKind = .keyword(.any) + let bodySignature = + signature + .with( + \.parameterClause.parameters[signature.parameterClause.parameters.startIndex], + parameter + .with(\.firstName, .wildcardToken(trailingTrivia: .space)) + .with(\.type, TypeSyntax(someOrAnyParameterType)) + ) + bodyArguments.append("\(bodySignature.parameterClause.parameters)") + + var parameterClause = signature.parameterClause.with(\.parameters, []) + let firstName = parameter.firstName.tokenKind == .wildcard ? nil : parameter.firstName + + let tupleType = + genericArgument.argument.as(TupleTypeSyntax.self) + ?? TupleTypeSyntax( + elements: [ + TupleTypeElementSyntax( + firstName: firstName, + secondName: parameter.secondName, + type: genericArgument.argument.cast(TypeSyntax.self) + ) + ] + ) + + let representableInputGeneric = functionRepresentation? + .parameters.first? + .type.as(SomeOrAnyTypeSyntax.self)? + .constraint.as(IdentifierTypeSyntax.self)? + .genericArgumentClause? + .arguments.first + let representableInputGenericArgument = representableInputGeneric?.argument + + representableInputType = "\(representableInputGeneric ?? genericArgument)" + rowType = "\(genericArgument)" + + let representableInputArguments = + representableInputGenericArgument?.as(TupleTypeSyntax.self)?.elements.map(\.type) + ?? (representableInputGenericArgument?.cast(TypeSyntax.self)).map { [$0] } + var representableInputArgumentsIterator = representableInputArguments?.makeIterator() + + var offset = 0 + for var element in tupleType.elements { + defer { offset += 1 } + var type = representableInputArgumentsIterator?.next() ?? element.type + element.type = type.asQueryExpression() + type = type.trimmed + representableInputTypes.append(type.description) + invocationArgumentTypes.append(type) + let firstName = element.firstName?.trimmedDescription + let secondName = element.secondName?.trimmedDescription ?? firstName ?? "p\(offset)" + parameters.append(secondName) + argumentBindings.append(secondName) + + argumentCounts.append("\(type)") + decodings.append("let \(secondName) = try decoder.decode(\(type).self)") + decodingUnwrappings.append( + "guard let \(secondName) else { throw InvalidInvocation() }" + ) + canThrowInvalidInvocation = true + + parameterClause.parameters.append( + FunctionParameterSyntax( + firstName: firstName.map { .identifier($0) } ?? .wildcardToken(), + secondName: firstName == secondName + ? nil + : .identifier(secondName, leadingTrivia: .space), + colon: .colonToken(), + type: element.type, + trailingComma: .commaToken(), + trailingTrivia: .space ) ) - return [] } - bodyArguments.append("\(parameter.type.trimmed)") - var type = (functionRepresentationIterator?.next()?.type ?? parameter.type) - parameter.type = type.asQueryExpression() - type = type.trimmed - representableInputTypes.append(type.description) - if let defaultValue = parameter.defaultValue, - defaultValue.value.is(NilLiteralExprSyntax.self) - { - parameter.defaultValue?.value = "\(type).none" + parameterClause.parameters.append( + FunctionParameterSyntax( + firstName: "order", + colon: .colonToken(), + type: "(some QueryExpression)?" as TypeSyntax, + defaultValue: InitializerClauseSyntax( + equal: .equalToken(leadingTrivia: .space, trailingTrivia: .space), + value: "Bool?.none" as ExprSyntax + ), + trailingComma: .commaToken(), + trailingTrivia: .space + ) + ) + parameterClause.parameters.append( + FunctionParameterSyntax( + firstName: "filter", + colon: .colonToken(trailingTrivia: .space), + type: "(some QueryExpression)?" as TypeSyntax, + defaultValue: InitializerClauseSyntax( + equal: .equalToken(leadingTrivia: .space, trailingTrivia: .space), + value: "Bool?.none" as ExprSyntax + ) + ) + ) + signature.parameterClause = parameterClause + projectedCallSyntax = """ + \(functionTypeName) { + \(raw: declaration.signature.effectSpecifiers?.throwsClause != nil ? "try " : "")\ + \(declaration.name.trimmed)(\(raw: firstName.map { "\($0.trimmedDescription): " } ?? "")$0) + } + """ + } else { + isAggregate = false + var functionRepresentationIterator = functionRepresentation?.parameters.makeIterator() + + for index in signature.parameterClause.parameters.indices { + var parameter = signature.parameterClause.parameters[index] + if let ellipsis = parameter.ellipsis { + context.diagnose( + Diagnostic( + node: ellipsis, + message: MacroExpansionErrorMessage("Variadic arguments are not supported") + ) + ) + return [] + } + bodyArguments.append("\(parameter.type.trimmed)") + var type = (functionRepresentationIterator?.next()?.type ?? parameter.type) + parameter.type = type.asQueryExpression() + type = type.trimmed + representableInputTypes.append(type.description) + if let defaultValue = parameter.defaultValue, + defaultValue.value.is(NilLiteralExprSyntax.self) + { + parameter.defaultValue?.value = "\(type).none" + } + signature.parameterClause.parameters[index] = parameter + invocationArgumentTypes.append(type) + let parameterName = (parameter.secondName ?? parameter.firstName).trimmedDescription + parameters.append(parameterName) + argumentBindings.append(parameterName) + + argumentCounts.append("\(type)") + decodings.append("let \(parameterName) = try decoder.decode(\(type).self)") + decodingUnwrappings.append("guard let \(parameterName) else { throw InvalidInvocation() }") + canThrowInvalidInvocation = true } - signature.parameterClause.parameters[index] = parameter - invocationArgumentTypes.append(type) - let parameterName = (parameter.secondName ?? parameter.firstName).trimmedDescription - parameters.append(parameterName) - argumentBindings.append(parameterName) - - argumentCounts.append("\(type)") - decodings.append("let \(parameterName) = try decoder.decode(\(type).self)") - decodingUnwrappings.append("guard let \(parameterName) else { throw InvalidInvocation() }") + representableInputType = representableInputTypes.joined(separator: ", ") + representableInputType = + representableInputTypes.count == 1 + ? representableInputType + : "(\(representableInputType))" + projectedCallSyntax = "\(functionTypeName)(\(declaration.name.trimmed))" } - var representableInputType = representableInputTypes.joined(separator: ", ") let isVoidReturning = signature.returnClause == nil let outputType = returnClause.type.trimmed signature.returnClause = returnClause @@ -161,36 +293,9 @@ extension DatabaseFunctionMacro: PeerMacro { \(declaration.signature.effectSpecifiers?.trimmedDescription ?? "")\ \(bodyReturnClause) """ - let bodyInvocation = """ - \(declaration.signature.effectSpecifiers?.throwsClause != nil ? "try " : "")self.body(\ - \(argumentBindings.joined(separator: ", "))\ - ) - """ // TODO: Diagnose 'asyncClause'? signature.effectSpecifiers?.throwsClause = nil - var invocationBody = - isVoidReturning - ? """ - \(bodyInvocation) - return .null - """ - : """ - return \(functionRepresentation?.returnClause.type ?? outputType)( - queryOutput: \(bodyInvocation) - ) - .queryBinding - """ - if declaration.signature.effectSpecifiers?.throwsClause != nil { - invocationBody = """ - do { - \(invocationBody) - } catch { - return .invalid(error) - } - """ - } - var attributes = declaration.attributes if let index = attributes.firstIndex(where: { $0.as(AttributeSyntax.self)?.attributeName.as(IdentifierTypeSyntax.self)?.name.text @@ -210,10 +315,6 @@ extension DatabaseFunctionMacro: PeerMacro { continue } } - representableInputType = - representableInputTypes.count == 1 - ? representableInputType - : "(\(representableInputType))" let argumentCount = argumentCounts.isEmpty @@ -224,15 +325,136 @@ extension DatabaseFunctionMacro: PeerMacro { return argumentCount """ + var methods: [DeclSyntax] = [] + if isAggregate { + var parameter = declaration.signature.parameterClause.parameters[ + declaration.signature.parameterClause.parameters.startIndex + ] + parameter.firstName = .wildcardToken(trailingTrivia: .space) + parameter.secondName = "arguments" + + methods.append( + """ + public func callAsFunction\(signature.trimmed) { + StructuredQueriesCore.$_isSelecting.withValue(false) { + StructuredQueriesCore.AggregateFunctionExpression( + self.name, \ + \(raw: parameters.joined(separator: ", ")), \ + order: order, \ + filter: filter + ) + } + } + """ + ) + + let stepReturnClause: String + switch parameters.count { + case 0: stepReturnClause = "" + case 1: stepReturnClause = "return \(parameters[0])\n" + default: stepReturnClause = "return (\(parameters.joined(separator: ", ")))\n" + } + + methods.append( + """ + public func step( + _ decoder: inout some QueryDecoder + ) throws -> \(raw: rowType) { + \(raw: (decodings + decodingUnwrappings).map { "\($0)\n" }.joined())\ + \(raw: stepReturnClause)\ + } + """ + ) + + let bodyInvocation = """ + \(declaration.signature.effectSpecifiers?.throwsClause != nil ? "try " : "")\ + self.body(arguments) + """ + var invocationBody = + isVoidReturning + ? """ + \(bodyInvocation) + return .null + """ + : "return \(representableOutputType)(queryOutput: \(bodyInvocation)).queryBinding" + if declaration.signature.effectSpecifiers?.throwsClause != nil { + invocationBody = """ + do { + \(invocationBody) + } catch { + return .invalid(error) + } + """ + } + methods.append( + """ + public func invoke(\(parameter)) -> QueryBinding { + \(raw: invocationBody) + } + """ + ) + } else { + methods.append( + """ + public func callAsFunction\(signature.trimmed) { + StructuredQueriesCore.$_isSelecting.withValue(false) { + StructuredQueriesCore.SQLQueryExpression( + "\\(quote: self.name)(\(raw: parameters.map { "\\(\($0))" }.joined(separator: ", ")))" + ) + } + } + """ + ) + + let bodyInvocation = """ + \(declaration.signature.effectSpecifiers?.throwsClause != nil ? "try " : "")self.body(\ + \(argumentBindings.joined(separator: ", "))\ + ) + """ + var invocationBody = + isVoidReturning + ? """ + \(bodyInvocation) + return .null + """ + : """ + return \(functionRepresentation?.returnClause.type ?? outputType)( + queryOutput: \(bodyInvocation) + ) + .queryBinding + """ + if declaration.signature.effectSpecifiers?.throwsClause != nil { + invocationBody = """ + do { + \(invocationBody) + } catch { + return .invalid(error) + } + """ + } + + methods.append( + """ + public func invoke( + _ decoder: inout some QueryDecoder + ) throws -> StructuredQueriesCore.QueryBinding { + \(raw: (decodings + decodingUnwrappings).map { "\($0)\n" }.joined())\ + \(raw: invocationBody) + } + """ + ) + } + return [ """ - \(attributes)\(access)\(`static`)\(nonisolated)var $\(raw: declarationName): \(functionTypeName) { - \(functionTypeName)(\(declaration.name.trimmed)) + \(attributes)\(access)\(`static`)\(nonisolated)var $\(raw: declarationName): \ + \(functionTypeName) { + \(projectedCallSyntax) } """, """ \(attributes)\(access)\(nonisolated)struct \(functionTypeName): \ - StructuredQueriesSQLiteCore.ScalarDatabaseFunction { + StructuredQueriesSQLiteCore.\(raw: isAggregate ? "Aggregate" : "Scalar")DatabaseFunction { public typealias Input = \(raw: representableInputType) public typealias Output = \(representableOutputType) public let name = \(databaseFunctionName) @@ -244,20 +466,8 @@ extension DatabaseFunctionMacro: PeerMacro { public init(_ body: @escaping \(raw: bodyType)) { self.body = body } - public func callAsFunction\(signature.trimmed) { - StructuredQueriesCore.$_isSelecting.withValue(false) { - StructuredQueriesCore.SQLQueryExpression( - "\\(quote: self.name)(\(raw: parameters.map { "\\(\($0))" }.joined(separator: ", ")))" - ) - } - } - public func invoke( - _ decoder: inout some QueryDecoder - ) throws -> StructuredQueriesCore.QueryBinding { - \(raw: (decodings + decodingUnwrappings).map { "\($0)\n" }.joined())\ - \(raw: invocationBody) - } - private struct InvalidInvocation: Error {} + \(raw: methods.map(\.description).joined(separator: "\n"))\ + \(raw: canThrowInvalidInvocation ? "\nprivate struct InvalidInvocation: Error {}" : "") } """, ] diff --git a/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift b/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift index 5fcf840e..ea5cb9fb 100644 --- a/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift +++ b/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift @@ -2,17 +2,17 @@ import Foundation extension ScalarDatabaseFunction { public func install(_ db: OpaquePointer) { - let box = Unmanaged.passRetained(ScalarDatabaseFunctionBox(self)).toOpaque() + let body = Unmanaged.passRetained(ScalarDatabaseFunctionDefinition(self)).toOpaque() sqlite3_create_function_v2( db, name, Int32(argumentCount ?? -1), SQLITE_UTF8 | (isDeterministic ? SQLITE_DETERMINISTIC : 0), - box, + body, { context, argumentCount, arguments in do { var decoder = SQLiteFunctionDecoder(argumentCount: argumentCount, arguments: arguments) - try Unmanaged + try Unmanaged .fromOpaque(sqlite3_user_data(context)) .takeUnretainedValue() .function @@ -26,19 +26,178 @@ extension ScalarDatabaseFunction { nil, { context in guard let context else { return } - Unmanaged.fromOpaque(context).release() + Unmanaged.fromOpaque(context).release() } ) } } -private final class ScalarDatabaseFunctionBox { +private final class ScalarDatabaseFunctionDefinition { let function: any ScalarDatabaseFunction init(_ function: some ScalarDatabaseFunction) { self.function = function } } +extension AggregateDatabaseFunction { + public func install(_ db: OpaquePointer) { + let body = Unmanaged.passRetained(AggregateDatabaseFunctionDefinition(self)).toOpaque() + sqlite3_create_function_v2( + db, + name, + Int32(argumentCount ?? -1), + SQLITE_UTF8 | (isDeterministic ? SQLITE_DETERMINISTIC : 0), + body, + nil, + { context, argumentCount, arguments in + var decoder = SQLiteFunctionDecoder(argumentCount: argumentCount, arguments: arguments) + let function = AggregateDatabaseFunctionContext[context].takeUnretainedValue() + do { + try function.iterator.step(&decoder) + } catch { + sqlite3_result_error(context, error.localizedDescription, -1) + } + }, + { context in + let unmanagedFunction = AggregateDatabaseFunctionContext[context] + let function = unmanagedFunction.takeUnretainedValue() + unmanagedFunction.release() + function.iterator.finish() + do { + try function.iterator.result.result(db: context) + } catch { + sqlite3_result_error(context, error.localizedDescription, -1) + } + }, + { context in + guard let context else { return } + Unmanaged.fromOpaque(context).release() + } + ) + } +} + +private final class AggregateDatabaseFunctionDefinition { + let function: any AggregateDatabaseFunction + init(_ function: some AggregateDatabaseFunction) { + self.function = function + } +} + +private final class AggregateDatabaseFunctionContext { + static subscript(context: OpaquePointer?) -> Unmanaged { + let size = MemoryLayout>.size + let pointer = sqlite3_aggregate_context(context, Int32(size))! + if pointer.load(as: Int.self) == 0 { + let definition = Unmanaged + .fromOpaque(sqlite3_user_data(context)) + .takeUnretainedValue() + let context = AggregateDatabaseFunctionContext(definition.function) + let unmanagedContext = Unmanaged.passRetained(context) + pointer + .assumingMemoryBound(to: Unmanaged.self) + .pointee = unmanagedContext + return unmanagedContext + } else { + return + pointer + .assumingMemoryBound(to: Unmanaged.self) + .pointee + } + } + let iterator: any AggregateDatabaseFunctionIteratorProtocol + init(_ body: some AggregateDatabaseFunction) { + self.iterator = AggregateDatabaseFunctionIterator(body) + } +} + +private protocol AggregateDatabaseFunctionIteratorProtocol { + associatedtype Body: AggregateDatabaseFunction + + var body: Body { get } + var stream: Stream { get } + func start() + func step(_ decoder: inout some QueryDecoder) throws + func finish() + var result: QueryBinding { get throws } +} + +private final class AggregateDatabaseFunctionIterator< + Body: AggregateDatabaseFunction +>: AggregateDatabaseFunctionIteratorProtocol { + let body: Body + let stream = Stream() + let queue: DispatchQueue + var _result: QueryBinding? + init(_ body: Body) { + self.body = body + self.queue = DispatchQueue( + label: "co.pointfree.StructuredQueriesSQLite.AggregateDatabaseFunction.\(body.name)" + ) + nonisolated(unsafe) let iterator: any AggregateDatabaseFunctionIteratorProtocol = self + queue.async { + iterator.start() + } + } + func start() { + do { + _result = try body.invoke(stream) + } catch { + _result = .invalid(error) + } + } + func step(_ decoder: inout some QueryDecoder) throws { + try stream.send(body.step(&decoder)) + } + func finish() { + stream.finish() + } + var result: QueryBinding { + get throws { + while true { + if let result = queue.sync(execute: { _result }) { + return result + } + } + } + } +} + +private final class Stream: Sequence { + let condition = NSCondition() + private var buffer: [Element] = [] + private var isFinished = false + + func send(_ element: Element) { + condition.withLock { + buffer.append(element) + condition.signal() + } + } + + func finish() { + condition.withLock { + isFinished = true + condition.broadcast() + } + } + + func makeIterator() -> Iterator { Iterator(base: self) } + + struct Iterator: IteratorProtocol { + fileprivate let base: Stream + mutating func next() -> Element? { + base.condition.withLock { + while base.buffer.isEmpty && !base.isFinished { + base.condition.wait() + } + guard !base.buffer.isEmpty else { return nil } + return base.buffer.removeFirst() + } + } + } +} + extension QueryBinding { fileprivate func result(db: OpaquePointer?) { switch self { diff --git a/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift b/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift index 09f50bde..c7f1f883 100644 --- a/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift +++ b/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift @@ -50,8 +50,6 @@ extension SnapshotTests { ) .queryBinding } - private struct InvalidInvocation: Error { - } } """# } @@ -102,8 +100,6 @@ extension SnapshotTests { ) .queryBinding } - private struct InvalidInvocation: Error { - } } """# } @@ -212,8 +208,6 @@ extension SnapshotTests { ) .queryBinding } - private struct InvalidInvocation: Error { - } } """# } @@ -638,8 +632,6 @@ extension SnapshotTests { return .invalid(error) } } - private struct InvalidInvocation: Error { - } } """# } @@ -694,8 +686,6 @@ extension SnapshotTests { return .invalid(error) } } - private struct InvalidInvocation: Error { - } } """# } @@ -746,8 +736,6 @@ extension SnapshotTests { ) .queryBinding } - private struct InvalidInvocation: Error { - } } """# } @@ -798,8 +786,6 @@ extension SnapshotTests { ) .queryBinding } - private struct InvalidInvocation: Error { - } } """# } @@ -873,8 +859,6 @@ extension SnapshotTests { ) .queryBinding } - private struct InvalidInvocation: Error { - } } """# } @@ -925,8 +909,6 @@ extension SnapshotTests { ) .queryBinding } - private struct InvalidInvocation: Error { - } } """# } @@ -975,8 +957,6 @@ extension SnapshotTests { self.body() return .null } - private struct InvalidInvocation: Error { - } } """# } @@ -1026,8 +1006,6 @@ extension SnapshotTests { return .invalid(error) } } - private struct InvalidInvocation: Error { - } } """# } @@ -1232,5 +1210,404 @@ extension SnapshotTests { """# } } + + @Suite struct AggregateTests { + @Test func basics() { + assertMacro { + """ + @DatabaseFunction + func sum(_ xs: some Sequence) -> Int { + xs.reduce(into: 0, +=) + } + """ + } expansion: { + """ + func sum(_ xs: some Sequence) -> Int { + xs.reduce(into: 0, +=) + } + + nonisolated var $sum: __macro_local_3sumfMu_ { + __macro_local_3sumfMu_ { + sum($0) + } + } + + nonisolated struct __macro_local_3sumfMu_: StructuredQueriesSQLiteCore.AggregateDatabaseFunction { + public typealias Input = Int + public typealias Output = Int + public let name = "sum" + public var argumentCount: Int? { + var argumentCount = 0 + argumentCount += Int._columnWidth + return argumentCount + } + public let isDeterministic = false + public let body: (_ xs: any Sequence) -> Int + public init(_ body: @escaping (_ xs: any Sequence) -> Int) { + self.body = body + } + public func callAsFunction(_ xs: some StructuredQueriesCore.QueryExpression, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression { + StructuredQueriesCore.$_isSelecting.withValue(false) { + StructuredQueriesCore.AggregateFunctionExpression( + self.name, xs, order: order, filter: filter + ) + } + } + public func step( + _ decoder: inout some QueryDecoder + ) throws -> Int { + let xs = try decoder.decode(Int.self) + guard let xs else { + throw InvalidInvocation() + } + return xs + } + public func invoke(_ arguments: some Sequence) -> QueryBinding { + return Int(queryOutput: self.body(arguments)).queryBinding + } + private struct InvalidInvocation: Error { + } + } + """ + } + } + + @Test func namedArgument() { + assertMacro { + """ + @DatabaseFunction + func sum(of xs: some Sequence) -> Int { + xs.reduce(into: 0, +=) + } + """ + } expansion: { + """ + func sum(of xs: some Sequence) -> Int { + xs.reduce(into: 0, +=) + } + + nonisolated var $sum: __macro_local_3sumfMu_ { + __macro_local_3sumfMu_ { + sum(of: $0) + } + } + + nonisolated struct __macro_local_3sumfMu_: StructuredQueriesSQLiteCore.AggregateDatabaseFunction { + public typealias Input = Int + public typealias Output = Int + public let name = "sum" + public var argumentCount: Int? { + var argumentCount = 0 + argumentCount += Int._columnWidth + return argumentCount + } + public let isDeterministic = false + public let body: (_ xs: any Sequence) -> Int + public init(_ body: @escaping (_ xs: any Sequence) -> Int) { + self.body = body + } + public func callAsFunction(of xs: some StructuredQueriesCore.QueryExpression, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression { + StructuredQueriesCore.$_isSelecting.withValue(false) { + StructuredQueriesCore.AggregateFunctionExpression( + self.name, xs, order: order, filter: filter + ) + } + } + public func step( + _ decoder: inout some QueryDecoder + ) throws -> Int { + let xs = try decoder.decode(Int.self) + guard let xs else { + throw InvalidInvocation() + } + return xs + } + public func invoke(_ arguments: some Sequence) -> QueryBinding { + return Int(queryOutput: self.body(arguments)).queryBinding + } + private struct InvalidInvocation: Error { + } + } + """ + } + } + + @Test func multipleArguments() { + assertMacro { + """ + @DatabaseFunction + func joined(_ arguments: some Sequence<(String, separator: String)>) -> String? { + var iterator = arguments.makeIterator() + guard var (result, _) = iterator.next() else { return nil } + while let (string, separator) = iterator.next() { + result.append(separator) + result.append(string) + } + return result + } + """ + } expansion: { + """ + func joined(_ arguments: some Sequence<(String, separator: String)>) -> String? { + var iterator = arguments.makeIterator() + guard var (result, _) = iterator.next() else { return nil } + while let (string, separator) = iterator.next() { + result.append(separator) + result.append(string) + } + return result + } + + nonisolated var $joined: __macro_local_6joinedfMu_ { + __macro_local_6joinedfMu_ { + joined($0) + } + } + + nonisolated struct __macro_local_6joinedfMu_: StructuredQueriesSQLiteCore.AggregateDatabaseFunction { + public typealias Input = (String, separator: String) + public typealias Output = String? + public let name = "joined" + public var argumentCount: Int? { + var argumentCount = 0 + argumentCount += String._columnWidth + argumentCount += String._columnWidth + return argumentCount + } + public let isDeterministic = false + public let body: (_ arguments: any Sequence<(String, separator: String)>) -> String? + public init(_ body: @escaping (_ arguments: any Sequence<(String, separator: String)>) -> String?) { + self.body = body + } + public func callAsFunction(_ p0: some StructuredQueriesCore.QueryExpression, separator: some StructuredQueriesCore.QueryExpression, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression { + StructuredQueriesCore.$_isSelecting.withValue(false) { + StructuredQueriesCore.AggregateFunctionExpression( + self.name, p0, separator, order: order, filter: filter + ) + } + } + public func step( + _ decoder: inout some QueryDecoder + ) throws -> (String, separator: String) { + let p0 = try decoder.decode(String.self) + let separator = try decoder.decode(String.self) + guard let p0 else { + throw InvalidInvocation() + } + guard let separator else { + throw InvalidInvocation() + } + return (p0, separator) + } + public func invoke(_ arguments: some Sequence<(String, separator: String)>) -> QueryBinding { + return String?(queryOutput: self.body(arguments)).queryBinding + } + private struct InvalidInvocation: Error { + } + } + """ + } + } + + @Test func customRepresentations() { + assertMacro { + #""" + @DatabaseFunction( + as: ((any Sequence<[String].JSONRepresentation>) -> [String].JSONRepresentation).self + ) + func joined(_ arrays: some Sequence<[String]>) -> [String] { + arrays.flatMap(\.self) + } + """# + } expansion: { + #""" + func joined(_ arrays: some Sequence<[String]>) -> [String] { + arrays.flatMap(\.self) + } + + nonisolated var $joined: __macro_local_6joinedfMu_ { + __macro_local_6joinedfMu_ { + joined($0) + } + } + + nonisolated struct __macro_local_6joinedfMu_: StructuredQueriesSQLiteCore.AggregateDatabaseFunction { + public typealias Input = [String].JSONRepresentation + public typealias Output = [String].JSONRepresentation + public let name = "joined" + public var argumentCount: Int? { + var argumentCount = 0 + argumentCount += [String].JSONRepresentation._columnWidth + return argumentCount + } + public let isDeterministic = false + public let body: (_ arrays: any Sequence<[String]>) -> [String] + public init(_ body: @escaping (_ arrays: any Sequence<[String]>) -> [String]) { + self.body = body + } + public func callAsFunction(_ arrays: some StructuredQueriesCore.QueryExpression<[String].JSONRepresentation>, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression<[String].JSONRepresentation> { + StructuredQueriesCore.$_isSelecting.withValue(false) { + StructuredQueriesCore.AggregateFunctionExpression( + self.name, arrays, order: order, filter: filter + ) + } + } + public func step( + _ decoder: inout some QueryDecoder + ) throws -> [String] { + let arrays = try decoder.decode([String].JSONRepresentation.self) + guard let arrays else { + throw InvalidInvocation() + } + return arrays + } + public func invoke(_ arguments: some Sequence<[String]>) -> QueryBinding { + return [String].JSONRepresentation(queryOutput: self.body(arguments)).queryBinding + } + private struct InvalidInvocation: Error { + } + } + """# + } + } + + @Test func voidReturning() { + assertMacro { + """ + @DatabaseFunction + func print(_ xs: some Sequence) { + for x in xs { + Swift.print(x) + } + } + """ + } expansion: { + """ + func print(_ xs: some Sequence) { + for x in xs { + Swift.print(x) + } + } + + nonisolated var $print: __macro_local_5printfMu_ { + __macro_local_5printfMu_ { + print($0) + } + } + + nonisolated struct __macro_local_5printfMu_: StructuredQueriesSQLiteCore.AggregateDatabaseFunction { + public typealias Input = Int + public typealias Output = Swift.Void + public let name = "print" + public var argumentCount: Int? { + var argumentCount = 0 + argumentCount += Int._columnWidth + return argumentCount + } + public let isDeterministic = false + public let body: (_ xs: any Sequence) -> Swift.Void + public init(_ body: @escaping (_ xs: any Sequence) -> Swift.Void) { + self.body = body + } + public func callAsFunction(_ xs: some StructuredQueriesCore.QueryExpression, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression { + StructuredQueriesCore.$_isSelecting.withValue(false) { + StructuredQueriesCore.AggregateFunctionExpression( + self.name, xs, order: order, filter: filter + ) + } + } + public func step( + _ decoder: inout some QueryDecoder + ) throws -> Int { + let xs = try decoder.decode(Int.self) + guard let xs else { + throw InvalidInvocation() + } + return xs + } + public func invoke(_ arguments: some Sequence) -> QueryBinding { + self.body(arguments) + return .null + } + private struct InvalidInvocation: Error { + } + } + """ + } + } + + @Test func throwing() { + assertMacro { + """ + @DatabaseFunction + func validatePositive(_ xs: some Sequence) throws { + for x in xs { + guard x.sign == .plus else { + throw NegativeError() + } + } + } + """ + } expansion: { + """ + func validatePositive(_ xs: some Sequence) throws { + for x in xs { + guard x.sign == .plus else { + throw NegativeError() + } + } + } + + nonisolated var $validatePositive: __macro_local_16validatePositivefMu_ { + __macro_local_16validatePositivefMu_ { + try validatePositive($0) + } + } + + nonisolated struct __macro_local_16validatePositivefMu_: StructuredQueriesSQLiteCore.AggregateDatabaseFunction { + public typealias Input = Int + public typealias Output = Swift.Void + public let name = "validatePositive" + public var argumentCount: Int? { + var argumentCount = 0 + argumentCount += Int._columnWidth + return argumentCount + } + public let isDeterministic = false + public let body: (_ xs: any Sequence) throws -> Swift.Void + public init(_ body: @escaping (_ xs: any Sequence) throws -> Swift.Void) { + self.body = body + } + public func callAsFunction(_ xs: some StructuredQueriesCore.QueryExpression, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression { + StructuredQueriesCore.$_isSelecting.withValue(false) { + StructuredQueriesCore.AggregateFunctionExpression( + self.name, xs, order: order, filter: filter + ) + } + } + public func step( + _ decoder: inout some QueryDecoder + ) throws -> Int { + let xs = try decoder.decode(Int.self) + guard let xs else { + throw InvalidInvocation() + } + return xs + } + public func invoke(_ arguments: some Sequence) -> QueryBinding { + do { + try self.body(arguments) + return .null + } catch { + return .invalid(error) + } + } + private struct InvalidInvocation: Error { + } + } + """ + } + } + } } } diff --git a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift index c13c0a9d..d57ffe24 100644 --- a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift +++ b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift @@ -451,5 +451,200 @@ extension SnapshotTests { """ } } + + @DatabaseFunction + func sum(of xs: some Sequence) -> Int { + xs.reduce(into: 0, +=) + } + + @Test func aggregate() { + $sum.install(database.handle) + + assertQuery( + Reminder.select { $sum(of: $0.id) } + ) { + """ + SELECT "sum"("reminders"."id") + FROM "reminders" + """ + } results: { + """ + ┌────┐ + │ 55 │ + └────┘ + """ + } + assertQuery( + Reminder.select { $sum($0.id, distinct: true) } + ) { + """ + SELECT "sum"(DISTINCT "reminders"."id") + FROM "reminders" + """ + } results: { + """ + ┌────┐ + │ 55 │ + └────┘ + """ + } + } + + @DatabaseFunction + func joined(_ arguments: some Sequence<(String, separator: String)>) -> String? { + var isFirst = true + var result = "" + for (string, separator) in arguments { + defer { isFirst = false } + if !isFirst { + result.append(separator) + } + result.append(string) + } + return result + } + + @Test func multiAggregate() { + $joined.install(database.handle) + + assertQuery( + Tag.select { $joined($0.title, separator: ", ") } + ) { + """ + SELECT "joined"("tags"."title", ', ') + FROM "tags" + """ + } results: { + """ + ┌────────────────────────────────┐ + │ "car, kids, someday, optional" │ + └────────────────────────────────┘ + """ + } + } + + @DatabaseFunction( + as: ((any Sequence<[String].JSONRepresentation>) -> [String].JSONRepresentation).self + ) + func jsonJoined(_ arrays: some Sequence<[String]>) -> [String] { + arrays.flatMap(\.self) + } + + @Test func aggregateRepresentation() { + $jsonJoined.install(database.handle) + + assertQuery( + Reminder.select { + $jsonJoined(#sql("json_array(\($0.title.lower()), \($0.title.upper()))")) + } + ) { + """ + SELECT "jsonJoined"(json_array(lower("reminders"."title"), upper("reminders"."title"))) + FROM "reminders" + """ + } results: { + """ + ┌─────────────────────────────────────┐ + │ [ │ + │ [0]: "groceries", │ + │ [1]: "GROCERIES", │ + │ [2]: "haircut", │ + │ [3]: "HAIRCUT", │ + │ [4]: "doctor appointment", │ + │ [5]: "DOCTOR APPOINTMENT", │ + │ [6]: "take a walk", │ + │ [7]: "TAKE A WALK", │ + │ [8]: "buy concert tickets", │ + │ [9]: "BUY CONCERT TICKETS", │ + │ [10]: "pick up kids from school", │ + │ [11]: "PICK UP KIDS FROM SCHOOL", │ + │ [12]: "get laundry", │ + │ [13]: "GET LAUNDRY", │ + │ [14]: "take out trash", │ + │ [15]: "TAKE OUT TRASH", │ + │ [16]: "call accountant", │ + │ [17]: "CALL ACCOUNTANT", │ + │ [18]: "send weekly emails", │ + │ [19]: "SEND WEEKLY EMAILS" │ + │ ] │ + └─────────────────────────────────────┘ + """ + } + } + + @DatabaseFunction + func tagged(_ tags: some Sequence) -> String { + tags.map { "#\($0.title)" }.joined(separator: " ") + } + + @Test func selectionTableAggregate() { + $tagged.install(database.handle) + + assertQuery( + Tag.select { $tagged($0) } + ) { + """ + SELECT "tagged"("tags"."id", "tags"."title") + FROM "tags" + """ + } results: { + """ + ┌─────────────────────────────────┐ + │ "#car #kids #someday #optional" │ + └─────────────────────────────────┘ + """ + } + } + + @DatabaseFunction + func mode(priority priorities: some Sequence) -> Priority? { + var occurences: [Priority: Int] = [:] + for priority in priorities { + guard let priority + else { continue } + occurences[priority, default: 0] += 1 + } + return occurences.max { $0.value < $1.value }?.key + } + @Test func modePriorityAggregate() { + $mode.install(database.handle) + + assertQuery( + Reminder + .select { $mode(priority: $0.priority) } + ) { + """ + SELECT "mode"("reminders"."priority") + FROM "reminders" + """ + } results: { + """ + ┌───────┐ + │ .high │ + └───────┘ + """ + } + assertQuery( + RemindersList + .group(by: \.id) + .leftJoin(Reminder.all) { $0.id.eq($1.remindersListID) } + .select { ($0.title, $mode(priority: $1.priority)) } + ) { + """ + SELECT "remindersLists"."title", "mode"("reminders"."priority") + FROM "remindersLists" + LEFT JOIN "reminders" ON ("remindersLists"."id") = ("reminders"."remindersListID") + GROUP BY "remindersLists"."id" + """ + } results: { + """ + ┌────────────┬─────────┐ + │ "Personal" │ .high │ + │ "Family" │ .high │ + │ "Business" │ .medium │ + └────────────┴─────────┘ + """ + } + } } }