From de31bffb2c3cc2337d70cc8f0e8feccbcf1ea2f6 Mon Sep 17 00:00:00 2001 From: Stephen Celis Date: Tue, 7 Oct 2025 14:06:19 -0700 Subject: [PATCH 01/19] wip --- .../DatabaseFunction.swift | 38 +++++++++ .../DatabaseFunction.swift | 80 +++++++++++++++++-- .../DatabaseFunctionTests.swift | 78 ++++++++++++++++++ 3 files changed, 191 insertions(+), 5 deletions(-) diff --git a/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift b/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift index a7b774a0..986ee091 100644 --- a/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift +++ b/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift @@ -50,3 +50,41 @@ extension ScalarDatabaseFunction { } } } + +/// A type representing an aggregate database function. +/// +/// Don't conform to this protocol directly. Instead, use the `@DatabaseFunction` macro to generate +/// a conformance. +public protocol AggregateDatabaseFunction: DatabaseFunction { + /// The function body. Uses a query decoder to process the input of a database function into a + /// bindable value. + /// + /// - Parameter decoder: A query decoder. + /// - Returns: A binding returned from the database function. + mutating func invoke(_ decoder: inout some QueryDecoder) throws + + var result: QueryBinding { get throws } +} + +extension AggregateDatabaseFunction { + /// A function call expression. + /// + /// - Parameter input: Expressions representing the arguments of the function. + /// - Returns: An expression representing the function call. + @_disfavoredOverload + public func callAsFunction( + _ input: repeat each T, + order: (some QueryExpression)? = Bool?.none, + filter: (some QueryExpression)? = Bool?.none + ) -> some QueryExpression + where Input == (repeat (each T).QueryValue) { + $_isSelecting.withValue(false) { + AggregateFunction( + QueryFragment(quote: name), + Array(repeat each input), + order: order?.queryFragment, + filter: filter?.queryFragment + ) + } + } +} diff --git a/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift b/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift index 5fcf840e..37f3859d 100644 --- a/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift +++ b/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift @@ -2,17 +2,17 @@ import Foundation extension ScalarDatabaseFunction { public func install(_ db: OpaquePointer) { - let box = Unmanaged.passRetained(ScalarDatabaseFunctionBox(self)).toOpaque() + let body = Unmanaged.passRetained(ScalarDatabaseFunctionDefinition(self)).toOpaque() sqlite3_create_function_v2( db, name, Int32(argumentCount ?? -1), SQLITE_UTF8 | (isDeterministic ? SQLITE_DETERMINISTIC : 0), - box, + body, { context, argumentCount, arguments in do { var decoder = SQLiteFunctionDecoder(argumentCount: argumentCount, arguments: arguments) - try Unmanaged + try Unmanaged .fromOpaque(sqlite3_user_data(context)) .takeUnretainedValue() .function @@ -26,19 +26,89 @@ extension ScalarDatabaseFunction { nil, { context in guard let context else { return } - Unmanaged.fromOpaque(context).release() + Unmanaged.fromOpaque(context).release() } ) } } -private final class ScalarDatabaseFunctionBox { +private final class ScalarDatabaseFunctionDefinition { let function: any ScalarDatabaseFunction init(_ function: some ScalarDatabaseFunction) { self.function = function } } +extension AggregateDatabaseFunction { + public func install(_ db: OpaquePointer) { + let body = Unmanaged.passRetained(AggregateDatabaseFunctionDefinition(self)).toOpaque() + sqlite3_create_function_v2( + db, + name, + Int32(argumentCount ?? -1), + SQLITE_UTF8 | (isDeterministic ? SQLITE_DETERMINISTIC : 0), + body, + nil, + { context, argumentCount, arguments in + var decoder = SQLiteFunctionDecoder(argumentCount: argumentCount, arguments: arguments) + let function = AggregateDatabaseFunctionContext[context].takeUnretainedValue() + do { + try function.body.invoke(&decoder) + } catch { + sqlite3_result_error(context, error.localizedDescription, -1) + } + }, + { context in + let unmanagedFunction = AggregateDatabaseFunctionContext[context] + let function = unmanagedFunction.takeUnretainedValue() + unmanagedFunction.release() + do { + try function.body.result.result(db: context) + } catch { + sqlite3_result_error(context, error.localizedDescription, -1) + } + }, + { context in + guard let context else { return } + Unmanaged.fromOpaque(context).release() + } + ) + } +} + +private final class AggregateDatabaseFunctionDefinition { + let function: any AggregateDatabaseFunction + init(_ function: some AggregateDatabaseFunction) { + self.function = function + } +} + +private final class AggregateDatabaseFunctionContext { + static subscript(context: OpaquePointer?) -> Unmanaged { + let size = MemoryLayout>.size + let pointer = sqlite3_aggregate_context(context, Int32(size))! + if pointer.load(as: Int.self) == 0 { + let definition = Unmanaged + .fromOpaque(sqlite3_user_data(context)) + .takeUnretainedValue() + let context = AggregateDatabaseFunctionContext(definition.function) + let unmanagedContext = Unmanaged.passRetained(context) + pointer + .assumingMemoryBound(to: Unmanaged.self) + .pointee = unmanagedContext + return unmanagedContext + } else { + return pointer + .assumingMemoryBound(to: Unmanaged.self) + .pointee + } + } + var body: any AggregateDatabaseFunction + init(_ body: some AggregateDatabaseFunction) { + self.body = body + } +} + extension QueryBinding { fileprivate func result(db: OpaquePointer?) { switch self { diff --git a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift index c13c0a9d..db74aa9f 100644 --- a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift +++ b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift @@ -451,5 +451,83 @@ extension SnapshotTests { """ } } + + // ... + + func joined(_ arguments: some Sequence<(String, separator: String)>) -> String? { + var iterator = arguments.makeIterator() + guard var (result, _) = iterator.next() else { return nil } + while let (string, separator) = iterator.next() { + result.append(separator) + result.append(string) + } + return result + } + + var _$joined: Joined { + Joined { joined($0) } + } + + struct Joined: StructuredQueriesSQLiteCore.AggregateDatabaseFunction { + public typealias Input = (String, separator: String) + public typealias Output = String? + public let name = "joined" + public let argumentCount: Int? = 2 + public let isDeterministic = true + public let body: (any Sequence<(String, separator: String)>) -> String? + public init(_ body: @escaping (any Sequence<(String, separator: String)>) -> String?) { + self.body = body + } + private var rows: [Input] = [] + public func callAsFunction( + _ n0: some StructuredQueriesCore.QueryExpression, + separator: some StructuredQueriesCore.QueryExpression, + order: (some QueryExpression)? = Bool?.none, + filter: (some QueryExpression)? = Bool?.none + ) -> some StructuredQueriesCore.QueryExpression { + $_isSelecting.withValue(false) { + AggregateFunction( + QueryFragment(quote: name), + [n0.queryFragment, separator.queryFragment], + order: order?.queryFragment, + filter: filter?.queryFragment + ) + } + } + public mutating func invoke(_ decoder: inout some QueryDecoder) throws { + var decoder = decoder + let p0 = try decoder.decode(String.self) + let separator = try decoder.decode(String.self) + guard let p0 else { throw InvalidInvocation() } + guard let separator else { throw InvalidInvocation() } + rows.append((p0, separator)) + } + public var result: QueryBinding { + get throws { + body(rows).queryBinding + } + } + private struct InvalidInvocation: Error { + } + } + + @Test func aggregate() { + _$joined.install(database.handle) + + assertQuery( + Tag.select { _$joined($0.title, separator: ", ") } + ) { + """ + SELECT "joined"("tags"."title", ', ') + FROM "tags" + """ + } results: { + """ + ┌────────────────────────────────┐ + │ "car, kids, someday, optional" │ + └────────────────────────────────┘ + """ + } + } } } From 20631e64a8d7e4ea80b934f1f4b986d2f22399cb Mon Sep 17 00:00:00 2001 From: Stephen Celis Date: Wed, 8 Oct 2025 10:37:31 -0700 Subject: [PATCH 02/19] wip --- Tests/StructuredQueriesTests/DatabaseFunctionTests.swift | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift index db74aa9f..3f7b964c 100644 --- a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift +++ b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift @@ -515,16 +515,16 @@ extension SnapshotTests { _$joined.install(database.handle) assertQuery( - Tag.select { _$joined($0.title, separator: ", ") } + Tag.select { _$joined($0.title, separator: ", ", order: $0.title) } ) { """ - SELECT "joined"("tags"."title", ', ') + SELECT "joined"("tags"."title", ', ' ORDER BY "tags"."title") FROM "tags" """ } results: { """ ┌────────────────────────────────┐ - │ "car, kids, someday, optional" │ + │ "car, kids, optional, someday" │ └────────────────────────────────┘ """ } From 9b4ed723ba6aee897e5089e8e8df2065b43ef414 Mon Sep 17 00:00:00 2001 From: Stephen Celis Date: Wed, 8 Oct 2025 17:13:22 -0700 Subject: [PATCH 03/19] wip --- .../DatabaseFunction.swift | 9 +- .../DatabaseFunction.swift | 89 ++++++++++++++++++- .../DatabaseFunctionTests.swift | 16 ++-- 3 files changed, 94 insertions(+), 20 deletions(-) diff --git a/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift b/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift index 986ee091..9f56a18e 100644 --- a/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift +++ b/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift @@ -56,14 +56,9 @@ extension ScalarDatabaseFunction { /// Don't conform to this protocol directly. Instead, use the `@DatabaseFunction` macro to generate /// a conformance. public protocol AggregateDatabaseFunction: DatabaseFunction { - /// The function body. Uses a query decoder to process the input of a database function into a - /// bindable value. - /// - /// - Parameter decoder: A query decoder. - /// - Returns: A binding returned from the database function. - mutating func invoke(_ decoder: inout some QueryDecoder) throws + func step(_ decoder: inout some QueryDecoder) throws -> Input - var result: QueryBinding { get throws } + func invoke(_ sequence: some Sequence) throws -> QueryBinding } extension AggregateDatabaseFunction { diff --git a/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift b/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift index 37f3859d..ec6df699 100644 --- a/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift +++ b/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift @@ -53,7 +53,7 @@ extension AggregateDatabaseFunction { var decoder = SQLiteFunctionDecoder(argumentCount: argumentCount, arguments: arguments) let function = AggregateDatabaseFunctionContext[context].takeUnretainedValue() do { - try function.body.invoke(&decoder) + try function.iterator.step(&decoder) } catch { sqlite3_result_error(context, error.localizedDescription, -1) } @@ -62,8 +62,9 @@ extension AggregateDatabaseFunction { let unmanagedFunction = AggregateDatabaseFunctionContext[context] let function = unmanagedFunction.takeUnretainedValue() unmanagedFunction.release() + function.iterator.finish() do { - try function.body.result.result(db: context) + try function.iterator.result.result(db: context) } catch { sqlite3_result_error(context, error.localizedDescription, -1) } @@ -103,9 +104,91 @@ private final class AggregateDatabaseFunctionContext { .pointee } } - var body: any AggregateDatabaseFunction + let iterator: any AggregateDatabaseFunctionIteratorProtocol init(_ body: some AggregateDatabaseFunction) { + self.iterator = AggregateDatabaseFunctionIterator(body) + } +} + +private protocol AggregateDatabaseFunctionIteratorProtocol { + associatedtype Body: AggregateDatabaseFunction + + var body: Body { get } + var stream: Stream { get } + func start() + func step(_ decoder: inout some QueryDecoder) throws + func finish() + var result: QueryBinding { get throws } +} + +private final class AggregateDatabaseFunctionIterator< + Body: AggregateDatabaseFunction +>: AggregateDatabaseFunctionIteratorProtocol { + let body: Body + let stream = Stream() + let queue = DispatchQueue.global() + var _result: QueryBinding? + init(_ body: Body) { self.body = body + nonisolated(unsafe) let iterator: any AggregateDatabaseFunctionIteratorProtocol = self + queue.async { + iterator.start() + } + } + func start() { + do { + _result = try body.invoke(stream) + } catch { + _result = .invalid(error) + } + } + func step(_ decoder: inout some QueryDecoder) throws { + try stream.send(body.step(&decoder)) + } + func finish() { + stream.finish() + } + var result: QueryBinding { + get throws { + while true { + if let _result { return _result } + } + } + } +} + +private final class Stream: Sequence { + let condition = NSCondition() + private var buffer: [Element] = [] + private var isFinished = false + + func send(_ element: Element) { + condition.withLock { + buffer.append(element) + condition.signal() + } + } + + func finish() { + condition.withLock { + isFinished = true + condition.broadcast() + } + } + + func makeIterator() -> Iterator { Iterator(base: self) } + + struct Iterator: IteratorProtocol { + fileprivate let base: Stream + mutating func next() -> Element? { + base.condition.withLock { + while base.buffer.isEmpty && !base.isFinished { + base.condition.wait() + } + guard !base.buffer.isEmpty else { return nil } + return base.buffer.removeFirst() + } + } } } diff --git a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift index 3f7b964c..3429bf52 100644 --- a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift +++ b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift @@ -474,11 +474,10 @@ extension SnapshotTests { public let name = "joined" public let argumentCount: Int? = 2 public let isDeterministic = true - public let body: (any Sequence<(String, separator: String)>) -> String? - public init(_ body: @escaping (any Sequence<(String, separator: String)>) -> String?) { + public let body: (any Sequence) -> String? + public init(_ body: @escaping (any Sequence<(Input)>) -> String?) { self.body = body } - private var rows: [Input] = [] public func callAsFunction( _ n0: some StructuredQueriesCore.QueryExpression, separator: some StructuredQueriesCore.QueryExpression, @@ -494,18 +493,15 @@ extension SnapshotTests { ) } } - public mutating func invoke(_ decoder: inout some QueryDecoder) throws { - var decoder = decoder + public func step(_ decoder: inout some QueryDecoder) throws -> Input { let p0 = try decoder.decode(String.self) let separator = try decoder.decode(String.self) guard let p0 else { throw InvalidInvocation() } guard let separator else { throw InvalidInvocation() } - rows.append((p0, separator)) + return (p0, separator) } - public var result: QueryBinding { - get throws { - body(rows).queryBinding - } + public func invoke(_ sequence: some Sequence) -> QueryBinding { + self.body(sequence).queryBinding } private struct InvalidInvocation: Error { } From b63fe53dd2dcea38ffd2829d804dc9f6288e0e19 Mon Sep 17 00:00:00 2001 From: Stephen Celis Date: Wed, 8 Oct 2025 22:35:37 -0700 Subject: [PATCH 04/19] wip --- Sources/_StructuredQueriesSQLite/DatabaseFunction.swift | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift b/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift index ec6df699..6b3160d5 100644 --- a/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift +++ b/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift @@ -126,7 +126,7 @@ private final class AggregateDatabaseFunctionIterator< >: AggregateDatabaseFunctionIteratorProtocol { let body: Body let stream = Stream() - let queue = DispatchQueue.global() + let queue = DispatchQueue.global(qos: .userInitiated) var _result: QueryBinding? init(_ body: Body) { self.body = body @@ -151,7 +151,9 @@ private final class AggregateDatabaseFunctionIterator< var result: QueryBinding { get throws { while true { - if let _result { return _result } + if let result = queue.sync(execute: { _result }) { + return result + } } } } From 9a7ca492c18d0070063a12979215607e8297b2d1 Mon Sep 17 00:00:00 2001 From: Stephen Celis Date: Thu, 9 Oct 2025 14:21:08 -0700 Subject: [PATCH 05/19] wip --- .../AggregateFunctions.swift | 16 + Sources/StructuredQueriesSQLite/Macros.swift | 24 + .../DatabaseFunction.swift | 2 +- .../DatabaseFunctionMacro.swift | 337 +++++++++++--- .../DatabaseFunctionMacroTests.swift | 422 +++++++++++++++++- .../TableMacroTests.swift | 40 +- .../DatabaseFunctionTests.swift | 76 ++-- 7 files changed, 762 insertions(+), 155 deletions(-) diff --git a/Sources/StructuredQueriesCore/AggregateFunctions.swift b/Sources/StructuredQueriesCore/AggregateFunctions.swift index e8c4bf61..e6023dde 100644 --- a/Sources/StructuredQueriesCore/AggregateFunctions.swift +++ b/Sources/StructuredQueriesCore/AggregateFunctions.swift @@ -216,6 +216,22 @@ public struct AggregateFunction: QueryExpression, Sendable { var order: QueryFragment? var filter: QueryFragment? + public init( + _ name: String, + distinct isDistinct: Bool = false, + _ arguments: repeat each Argument, + order: (some QueryExpression)? = Bool?.none, + filter: (some QueryExpression)? = Bool?.none + ) { + self.init( + QueryFragment(quote: name), + isDistinct: false, + Array(repeat each arguments), + order: order?.queryFragment, + filter: filter?.queryFragment + ) + } + package init( _ name: QueryFragment, isDistinct: Bool = false, diff --git a/Sources/StructuredQueriesSQLite/Macros.swift b/Sources/StructuredQueriesSQLite/Macros.swift index 33d0deec..c8b45dce 100644 --- a/Sources/StructuredQueriesSQLite/Macros.swift +++ b/Sources/StructuredQueriesSQLite/Macros.swift @@ -54,3 +54,27 @@ public macro DatabaseFunction( module: "StructuredQueriesSQLiteMacros", type: "DatabaseFunctionMacro" ) + +// TODO: +// @attached(peer, names: overloaded, prefixed(`$`)) +// public macro DatabaseFunction( +// _ name: String = "", +// as representableFunctionType: ((any Sequence<(repeat each T)>) -> R).Type, +// isDeterministic: Bool = false +// ) = +// #externalMacro( +// module: "StructuredQueriesSQLiteMacros", +// type: "DatabaseFunctionMacro" +// ) +// +// @attached(peer, names: overloaded, prefixed(`$`)) +// public macro DatabaseFunction( +// _ name: String = "", +// as representableFunctionType: ((any Sequence<(repeat each T)>) -> Void).Type, +// isDeterministic: Bool = false +// ) = +// #externalMacro( +// module: "StructuredQueriesSQLiteMacros", +// type: "DatabaseFunctionMacro" +// ) +// diff --git a/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift b/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift index 9f56a18e..6e8f51ab 100644 --- a/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift +++ b/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift @@ -58,7 +58,7 @@ extension ScalarDatabaseFunction { public protocol AggregateDatabaseFunction: DatabaseFunction { func step(_ decoder: inout some QueryDecoder) throws -> Input - func invoke(_ sequence: some Sequence) throws -> QueryBinding + func invoke(_ arguments: some Sequence) throws -> QueryBinding } extension AggregateDatabaseFunction { diff --git a/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift b/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift index 5fc5190a..aa827997 100644 --- a/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift +++ b/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift @@ -110,45 +110,157 @@ extension DatabaseFunctionMacro: PeerMacro { var invocationArgumentTypes: [TypeSyntax] = [] var parameters: [String] = [] var argumentBindings: [String] = [] - var offset = 0 var functionRepresentationIterator = functionRepresentation?.parameters.makeIterator() var decodings: [String] = [] var decodingUnwrappings: [String] = [] + var canThrowInvalidInvocation = false - for index in signature.parameterClause.parameters.indices { - defer { offset += 1 } - var parameter = signature.parameterClause.parameters[index] - if let ellipsis = parameter.ellipsis { - context.diagnose( - Diagnostic( - node: ellipsis, - message: MacroExpansionErrorMessage("Variadic arguments are not supported") + let isAggregate: Bool + var representableInputType: String + let projectedCallSyntax: ExprSyntax + + if signature.parameterClause.parameters.count == 1, + let parameter = signature.parameterClause.parameters.first, + var someOrAnyParameterType = parameter.type.as(SomeOrAnyTypeSyntax.self), + someOrAnyParameterType.someOrAnySpecifier.tokenKind == .keyword(.some), + let parameterType = someOrAnyParameterType.constraint.as(IdentifierTypeSyntax.self), + ["Sequence", "Swift.Sequence"].contains(parameterType.name.text), + let genericArgumentClause = parameterType.genericArgumentClause, + genericArgumentClause.arguments.count == 1, + let genericArgument = genericArgumentClause.arguments.first + { + isAggregate = true + representableInputType = "\(genericArgument)" + + someOrAnyParameterType.someOrAnySpecifier.tokenKind = .keyword(.any) + let bodySignature = + signature + .with( + \.parameterClause.parameters[signature.parameterClause.parameters.startIndex], + parameter + .with(\.firstName, .wildcardToken(trailingTrivia: .space)) + .with(\.type, TypeSyntax(someOrAnyParameterType)) + ) + bodyArguments.append("\(bodySignature.parameterClause.parameters)") + + var parameterClause = signature.parameterClause.with(\.parameters, []) + let firstName = parameter.firstName.tokenKind == .wildcard ? nil : parameter.firstName + + let tupleType = + genericArgument.argument.as(TupleTypeSyntax.self) + ?? TupleTypeSyntax( + elements: [ + TupleTypeElementSyntax( + firstName: firstName, + secondName: parameter.secondName, + type: genericArgument.argument.cast(TypeSyntax.self) + ) + ] + ) + + var offset = 0 + for var element in tupleType.elements { + defer { offset += 1 } + var type = (functionRepresentationIterator?.next()?.type ?? element.type) + element.type = type.asQueryExpression() + type = type.trimmed + representableInputTypes.append(type.description) + invocationArgumentTypes.append(type) + let firstName = element.firstName?.trimmedDescription + let secondName = element.secondName?.trimmedDescription ?? firstName ?? "p\(offset)" + parameters.append(secondName) + argumentBindings.append(secondName) + + argumentCounts.append("\(type)") + decodings.append("let \(secondName) = try decoder.decode(\(type).self)") + decodingUnwrappings.append( + "guard let \(secondName) else { throw InvalidInvocation() }" + ) + canThrowInvalidInvocation = true + + parameterClause.parameters.append( + FunctionParameterSyntax( + firstName: firstName.map { .identifier($0) } ?? .wildcardToken(), + secondName: firstName == secondName + ? nil + : .identifier(secondName, leadingTrivia: .space), + colon: .colonToken(), + type: element.type, + trailingComma: .commaToken(), + trailingTrivia: .space ) ) - return [] } - bodyArguments.append("\(parameter.type.trimmed)") - var type = (functionRepresentationIterator?.next()?.type ?? parameter.type) - parameter.type = type.asQueryExpression() - type = type.trimmed - representableInputTypes.append(type.description) - if let defaultValue = parameter.defaultValue, - defaultValue.value.is(NilLiteralExprSyntax.self) - { - parameter.defaultValue?.value = "\(type).none" + parameterClause.parameters.append( + FunctionParameterSyntax( + firstName: "order", + colon: .colonToken(), + type: "(some QueryExpression)?" as TypeSyntax, + defaultValue: InitializerClauseSyntax( + equal: .equalToken(leadingTrivia: .space, trailingTrivia: .space), + value: "Bool?.none" as ExprSyntax + ), + trailingComma: .commaToken(), + trailingTrivia: .space + ) + ) + parameterClause.parameters.append( + FunctionParameterSyntax( + firstName: "filter", + colon: .colonToken(trailingTrivia: .space), + type: "(some QueryExpression)?" as TypeSyntax, + defaultValue: InitializerClauseSyntax( + equal: .equalToken(leadingTrivia: .space, trailingTrivia: .space), + value: "Bool?.none" as ExprSyntax + ) + ) + ) + signature.parameterClause = parameterClause + projectedCallSyntax = """ + \(functionTypeName) { + \(raw: declaration.signature.effectSpecifiers?.throwsClause != nil ? "try " : "")\ + \(declaration.name.trimmed)(\(raw: firstName.map { "\($0.trimmedDescription): " } ?? "")$0) + } + """ + } else { + isAggregate = false + + for index in signature.parameterClause.parameters.indices { + var parameter = signature.parameterClause.parameters[index] + if let ellipsis = parameter.ellipsis { + context.diagnose( + Diagnostic( + node: ellipsis, + message: MacroExpansionErrorMessage("Variadic arguments are not supported") + ) + ) + return [] + } + bodyArguments.append("\(parameter.type.trimmed)") + var type = (functionRepresentationIterator?.next()?.type ?? parameter.type) + parameter.type = type.asQueryExpression() + type = type.trimmed + representableInputTypes.append(type.description) + if let defaultValue = parameter.defaultValue, + defaultValue.value.is(NilLiteralExprSyntax.self) + { + parameter.defaultValue?.value = "\(type).none" + } + signature.parameterClause.parameters[index] = parameter + invocationArgumentTypes.append(type) + let parameterName = (parameter.secondName ?? parameter.firstName).trimmedDescription + parameters.append(parameterName) + argumentBindings.append(parameterName) + + argumentCounts.append("\(type)") + decodings.append("let \(parameterName) = try decoder.decode(\(type).self)") + decodingUnwrappings.append("guard let \(parameterName) else { throw InvalidInvocation() }") + canThrowInvalidInvocation = true } - signature.parameterClause.parameters[index] = parameter - invocationArgumentTypes.append(type) - let parameterName = (parameter.secondName ?? parameter.firstName).trimmedDescription - parameters.append(parameterName) - argumentBindings.append(parameterName) - - argumentCounts.append("\(type)") - decodings.append("let \(parameterName) = try decoder.decode(\(type).self)") - decodingUnwrappings.append("guard let \(parameterName) else { throw InvalidInvocation() }") + representableInputType = representableInputTypes.joined(separator: ", ") + projectedCallSyntax = "\(functionTypeName)(\(declaration.name.trimmed))" } - var representableInputType = representableInputTypes.joined(separator: ", ") let isVoidReturning = signature.returnClause == nil let outputType = returnClause.type.trimmed signature.returnClause = returnClause @@ -161,36 +273,9 @@ extension DatabaseFunctionMacro: PeerMacro { \(declaration.signature.effectSpecifiers?.trimmedDescription ?? "")\ \(bodyReturnClause) """ - let bodyInvocation = """ - \(declaration.signature.effectSpecifiers?.throwsClause != nil ? "try " : "")self.body(\ - \(argumentBindings.joined(separator: ", "))\ - ) - """ // TODO: Diagnose 'asyncClause'? signature.effectSpecifiers?.throwsClause = nil - var invocationBody = - isVoidReturning - ? """ - \(bodyInvocation) - return .null - """ - : """ - return \(functionRepresentation?.returnClause.type ?? outputType)( - queryOutput: \(bodyInvocation) - ) - .queryBinding - """ - if declaration.signature.effectSpecifiers?.throwsClause != nil { - invocationBody = """ - do { - \(invocationBody) - } catch { - return .invalid(error) - } - """ - } - var attributes = declaration.attributes if let index = attributes.firstIndex(where: { $0.as(AttributeSyntax.self)?.attributeName.as(IdentifierTypeSyntax.self)?.name.text @@ -224,15 +309,135 @@ extension DatabaseFunctionMacro: PeerMacro { return argumentCount """ + var methods: [DeclSyntax] = [] + if isAggregate { + var parameter = declaration.signature.parameterClause.parameters[ + declaration.signature.parameterClause.parameters.startIndex + ] + parameter.firstName = .wildcardToken(trailingTrivia: .space) + parameter.secondName = "arguments" + + methods.append( + """ + public func callAsFunction\(signature.trimmed) { + StructuredQueriesCore.$_isSelecting.withValue(false) { + StructuredQueriesCore.AggregateFunction( + self.name, \ + \(raw: parameters.joined(separator: ", ")), \ + order: order, \ + filter: filter + ) + } + } + """ + ) + + let stepReturnClause: String + switch parameters.count { + case 0: stepReturnClause = "" + case 1: stepReturnClause = "return \(parameters[0])\n" + default: stepReturnClause = "return (\(parameters.joined(separator: ", ")))\n" + } + + methods.append( + """ + public func step( + _ decoder: inout some QueryDecoder + ) throws -> \(raw: representableInputType) { + \(raw: (decodings + decodingUnwrappings).map { "\($0)\n" }.joined())\ + \(raw: stepReturnClause)\ + } + """ + ) + + let bodyInvocation = """ + \(declaration.signature.effectSpecifiers?.throwsClause != nil ? "try " : "")\ + self.body(arguments) + """ + var invocationBody = + isVoidReturning + ? """ + \(bodyInvocation) + return .null + """ + : "return \(representableOutputType)(queryOutput: \(bodyInvocation)).queryBinding" + if declaration.signature.effectSpecifiers?.throwsClause != nil { + invocationBody = """ + do { + \(invocationBody) + } catch { + return .invalid(error) + } + """ + } + methods.append( + """ + public func invoke(\(parameter)) -> QueryBinding { + \(raw: invocationBody) + } + """ + ) + } else { + methods.append( + """ + public func callAsFunction\(signature.trimmed) { + StructuredQueriesCore.$_isSelecting.withValue(false) { + StructuredQueriesCore.SQLQueryExpression( + "\\(quote: self.name)(\(raw: parameters.map { "\\(\($0))" }.joined(separator: ", ")))" + ) + } + } + """ + ) + + let bodyInvocation = """ + \(declaration.signature.effectSpecifiers?.throwsClause != nil ? "try " : "")self.body(\ + \(argumentBindings.joined(separator: ", "))\ + ) + """ + var invocationBody = + isVoidReturning + ? """ + \(bodyInvocation) + return .null + """ + : """ + return \(functionRepresentation?.returnClause.type ?? outputType)( + queryOutput: \(bodyInvocation) + ) + .queryBinding + """ + if declaration.signature.effectSpecifiers?.throwsClause != nil { + invocationBody = """ + do { + \(invocationBody) + } catch { + return .invalid(error) + } + """ + } + + methods.append( + """ + public func invoke( + _ decoder: inout some QueryDecoder + ) throws -> StructuredQueriesCore.QueryBinding { + \(raw: (decodings + decodingUnwrappings).map { "\($0)\n" }.joined())\ + \(raw: invocationBody) + } + """ + ) + } + return [ """ \(attributes)\(access)\(`static`)var $\(raw: declarationName): \(functionTypeName) { - \(functionTypeName)(\(declaration.name.trimmed)) + \(projectedCallSyntax) } """, """ \(attributes)\(access)struct \(functionTypeName): \ - StructuredQueriesSQLiteCore.ScalarDatabaseFunction { + StructuredQueriesSQLiteCore.\(raw: isAggregate ? "Aggregate" : "Scalar")DatabaseFunction { public typealias Input = \(raw: representableInputType) public typealias Output = \(representableOutputType) public let name = \(databaseFunctionName) @@ -244,20 +449,8 @@ extension DatabaseFunctionMacro: PeerMacro { public init(_ body: @escaping \(raw: bodyType)) { self.body = body } - public func callAsFunction\(signature.trimmed) { - StructuredQueriesCore.$_isSelecting.withValue(false) { - StructuredQueriesCore.SQLQueryExpression( - "\\(quote: self.name)(\(raw: parameters.map { "\\(\($0))" }.joined(separator: ", ")))" - ) - } - } - public func invoke( - _ decoder: inout some QueryDecoder - ) throws -> StructuredQueriesCore.QueryBinding { - \(raw: (decodings + decodingUnwrappings).map { "\($0)\n" }.joined())\ - \(raw: invocationBody) - } - private struct InvalidInvocation: Error {} + \(raw: methods.map(\.description).joined(separator: "\n"))\ + \(raw: canThrowInvalidInvocation ? "\nprivate struct InvalidInvocation: Error {}" : "") } """, ] diff --git a/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift b/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift index 4ffdce05..f4677988 100644 --- a/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift +++ b/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift @@ -50,8 +50,6 @@ extension SnapshotTests { ) .queryBinding } - private struct InvalidInvocation: Error { - } } """# } @@ -102,8 +100,6 @@ extension SnapshotTests { ) .queryBinding } - private struct InvalidInvocation: Error { - } } """# } @@ -212,8 +208,6 @@ extension SnapshotTests { ) .queryBinding } - private struct InvalidInvocation: Error { - } } """# } @@ -638,8 +632,6 @@ extension SnapshotTests { return .invalid(error) } } - private struct InvalidInvocation: Error { - } } """# } @@ -694,8 +686,6 @@ extension SnapshotTests { return .invalid(error) } } - private struct InvalidInvocation: Error { - } } """# } @@ -746,8 +736,6 @@ extension SnapshotTests { ) .queryBinding } - private struct InvalidInvocation: Error { - } } """# } @@ -798,8 +786,6 @@ extension SnapshotTests { ) .queryBinding } - private struct InvalidInvocation: Error { - } } """# } @@ -873,8 +859,6 @@ extension SnapshotTests { ) .queryBinding } - private struct InvalidInvocation: Error { - } } """# } @@ -925,8 +909,6 @@ extension SnapshotTests { ) .queryBinding } - private struct InvalidInvocation: Error { - } } """# } @@ -975,8 +957,6 @@ extension SnapshotTests { self.body() return .null } - private struct InvalidInvocation: Error { - } } """# } @@ -1026,8 +1006,6 @@ extension SnapshotTests { return .invalid(error) } } - private struct InvalidInvocation: Error { - } } """# } @@ -1232,5 +1210,405 @@ extension SnapshotTests { """# } } + + @Suite struct AggregateTests { + @Test func basics() { + assertMacro { + """ + @DatabaseFunction + func sum(_ xs: some Sequence) -> Int { + xs.reduce(into: 0, +=) + } + """ + } expansion: { + """ + func sum(_ xs: some Sequence) -> Int { + xs.reduce(into: 0, +=) + } + + var $sum: __macro_local_3sumfMu_ { + __macro_local_3sumfMu_ { + sum($0) + } + } + + struct __macro_local_3sumfMu_: StructuredQueriesSQLiteCore.AggregateDatabaseFunction { + public typealias Input = Int + public typealias Output = Int + public let name = "sum" + public var argumentCount: Int? { + var argumentCount = 0 + argumentCount += Int._columnWidth + return argumentCount + } + public let isDeterministic = false + public let body: (_ xs: any Sequence) -> Int + public init(_ body: @escaping (_ xs: any Sequence) -> Int) { + self.body = body + } + public func callAsFunction(_ xs: some StructuredQueriesCore.QueryExpression, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression { + StructuredQueriesCore.$_isSelecting.withValue(false) { + StructuredQueriesCore.AggregateFunction( + self.name, xs, order: order, filter: filter + ) + } + } + public func step( + _ decoder: inout some QueryDecoder + ) throws -> Int { + let xs = try decoder.decode(Int.self) + guard let xs else { + throw InvalidInvocation() + } + return xs + } + public func invoke(_ arguments: some Sequence) -> QueryBinding { + Int(queryOutput: self.body(arguments)).queryBinding + } + private struct InvalidInvocation: Error { + } + } + """ + } + } + + @Test func namedArgument() { + assertMacro { + """ + @DatabaseFunction + func sum(of xs: some Sequence) -> Int { + xs.reduce(into: 0, +=) + } + """ + } expansion: { + """ + func sum(of xs: some Sequence) -> Int { + xs.reduce(into: 0, +=) + } + + var $sum: __macro_local_3sumfMu_ { + __macro_local_3sumfMu_ { + sum(of: $0) + } + } + + struct __macro_local_3sumfMu_: StructuredQueriesSQLiteCore.AggregateDatabaseFunction { + public typealias Input = Int + public typealias Output = Int + public let name = "sum" + public var argumentCount: Int? { + var argumentCount = 0 + argumentCount += Int._columnWidth + return argumentCount + } + public let isDeterministic = false + public let body: (_ xs: any Sequence) -> Int + public init(_ body: @escaping (_ xs: any Sequence) -> Int) { + self.body = body + } + public func callAsFunction(of xs: some StructuredQueriesCore.QueryExpression, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression { + StructuredQueriesCore.$_isSelecting.withValue(false) { + StructuredQueriesCore.AggregateFunction( + self.name, xs, order: order, filter: filter + ) + } + } + public func step( + _ decoder: inout some QueryDecoder + ) throws -> Int { + let xs = try decoder.decode(Int.self) + guard let xs else { + throw InvalidInvocation() + } + return xs + } + public func invoke(_ arguments: some Sequence) -> QueryBinding { + Int(queryOutput: self.body(arguments)).queryBinding + } + private struct InvalidInvocation: Error { + } + } + """ + } + } + + @Test func multipleArguments() { + assertMacro { + """ + @DatabaseFunction + func joined(_ arguments: some Sequence<(String, separator: String)>) -> String? { + var iterator = arguments.makeIterator() + guard var (result, _) = iterator.next() else { return nil } + while let (string, separator) = iterator.next() { + result.append(separator) + result.append(string) + } + return result + } + """ + } expansion: { + """ + func joined(_ arguments: some Sequence<(String, separator: String)>) -> String? { + var iterator = arguments.makeIterator() + guard var (result, _) = iterator.next() else { return nil } + while let (string, separator) = iterator.next() { + result.append(separator) + result.append(string) + } + return result + } + + var $joined: __macro_local_6joinedfMu_ { + __macro_local_6joinedfMu_ { + joined($0) + } + } + + struct __macro_local_6joinedfMu_: StructuredQueriesSQLiteCore.AggregateDatabaseFunction { + public typealias Input = ((String, separator: String)) + public typealias Output = String? + public let name = "joined" + public var argumentCount: Int? { + var argumentCount = 0 + argumentCount += String._columnWidth + argumentCount += String._columnWidth + return argumentCount + } + public let isDeterministic = false + public let body: (_ arguments: any Sequence<(String, separator: String)>) -> String? + public init(_ body: @escaping (_ arguments: any Sequence<(String, separator: String)>) -> String?) { + self.body = body + } + public func callAsFunction(_ p0: some StructuredQueriesCore.QueryExpression, separator separator: some StructuredQueriesCore.QueryExpression, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression { + StructuredQueriesCore.$_isSelecting.withValue(false) { + StructuredQueriesCore.AggregateFunction( + self.name, p0, separator, order: order, filter: filter + ) + } + } + public func step( + _ decoder: inout some QueryDecoder + ) throws -> ((String, separator: String)) { + let p0 = try decoder.decode(String.self) + let separator = try decoder.decode(String.self) + guard let p0 else { + throw InvalidInvocation() + } + guard let separator else { + throw InvalidInvocation() + } + return (p0, separator) + } + public func invoke(_ arguments: some Sequence<(String, separator: String)>) -> QueryBinding { + return String?(queryOutput: self.body(arguments)).queryBinding + } + private struct InvalidInvocation: Error { + } + } + """ + } + } + + // TODO + @Test func customRepresentations() { + assertMacro { + #""" + @DatabaseFunction( + as: ((any Sequence<[String].JSONRepresentation>) -> [String].JSONRepresentation).self + ) + func joined(_ arrays: some Sequence<[String]>) -> [String] { + arrays.flatMap(\.self) + } + """# + } expansion: { + #""" + func joined(_ arrays: some Sequence<[String]>) -> [String] { + arrays.flatMap(\.self) + } + + var $joined: __macro_local_6joinedfMu_ { + __macro_local_6joinedfMu_ { + joined($0) + } + } + + struct __macro_local_6joinedfMu_: StructuredQueriesSQLiteCore.AggregateDatabaseFunction { + public typealias Input = [String] + public typealias Output = [String].JSONRepresentation + public let name = "joined" + public var argumentCount: Int? { + var argumentCount = 0 + argumentCount += any Sequence<[String].JSONRepresentation>._columnWidth + return argumentCount + } + public let isDeterministic = false + public let body: (_ arrays: any Sequence<[String]>) -> [String] + public init(_ body: @escaping (_ arrays: any Sequence<[String]>) -> [String]) { + self.body = body + } + public func callAsFunction(_ arrays: some StructuredQueriesCore.QueryExpression>, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression<[String].JSONRepresentation> { + StructuredQueriesCore.$_isSelecting.withValue(false) { + StructuredQueriesCore.AggregateFunction( + self.name, arrays, order: order, filter: filter + ) + } + } + public func step( + _ decoder: inout some QueryDecoder + ) throws -> [String] { + let arrays = try decoder.decode(any Sequence<[String].JSONRepresentation>.self) + guard let arrays else { + throw InvalidInvocation() + } + return arrays + } + public func invoke(_ arguments: some Sequence<[String]>) -> QueryBinding { + [String].JSONRepresentation(queryOutput: self.body(arguments)).queryBinding + } + private struct InvalidInvocation: Error { + } + } + """# + } + } + + @Test func voidReturning() { + assertMacro { + """ + @DatabaseFunction + func print(_ xs: some Sequence) { + for x in xs { + Swift.print(x) + } + } + """ + } expansion: { + """ + func print(_ xs: some Sequence) { + for x in xs { + Swift.print(x) + } + } + + var $print: __macro_local_5printfMu_ { + __macro_local_5printfMu_ { + print($0) + } + } + + struct __macro_local_5printfMu_: StructuredQueriesSQLiteCore.AggregateDatabaseFunction { + public typealias Input = Int + public typealias Output = Swift.Void + public let name = "print" + public var argumentCount: Int? { + var argumentCount = 0 + argumentCount += Int._columnWidth + return argumentCount + } + public let isDeterministic = false + public let body: (_ xs: any Sequence) -> Swift.Void + public init(_ body: @escaping (_ xs: any Sequence) -> Swift.Void) { + self.body = body + } + public func callAsFunction(_ xs: some StructuredQueriesCore.QueryExpression, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression { + StructuredQueriesCore.$_isSelecting.withValue(false) { + StructuredQueriesCore.AggregateFunction( + self.name, xs, order: order, filter: filter + ) + } + } + public func step( + _ decoder: inout some QueryDecoder + ) throws -> Int { + let xs = try decoder.decode(Int.self) + guard let xs else { + throw InvalidInvocation() + } + return xs + } + public func invoke(_ arguments: some Sequence) -> QueryBinding { + self.body(arguments) + return .null + } + private struct InvalidInvocation: Error { + } + } + """ + } + } + + @Test func throwing() { + assertMacro { + """ + @DatabaseFunction + func validatePositive(_ xs: some Sequence) throws { + for x in xs { + guard x.sign == .plus else { + throw NegativeError() + } + } + } + """ + } expansion: { + """ + func validatePositive(_ xs: some Sequence) throws { + for x in xs { + guard x.sign == .plus else { + throw NegativeError() + } + } + } + + var $validatePositive: __macro_local_16validatePositivefMu_ { + __macro_local_16validatePositivefMu_ { + validatePositive($0) + } + } + + struct __macro_local_16validatePositivefMu_: StructuredQueriesSQLiteCore.AggregateDatabaseFunction { + public typealias Input = Int + public typealias Output = Swift.Void + public let name = "validatePositive" + public var argumentCount: Int? { + var argumentCount = 0 + argumentCount += Int._columnWidth + return argumentCount + } + public let isDeterministic = false + public let body: (_ xs: any Sequence) throws -> Swift.Void + public init(_ body: @escaping (_ xs: any Sequence) throws -> Swift.Void) { + self.body = body + } + public func callAsFunction(_ xs: some StructuredQueriesCore.QueryExpression, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression { + StructuredQueriesCore.$_isSelecting.withValue(false) { + StructuredQueriesCore.AggregateFunction( + self.name, xs, order: order, filter: filter + ) + } + } + public func step( + _ decoder: inout some QueryDecoder + ) throws -> Int { + let xs = try decoder.decode(Int.self) + guard let xs else { + throw InvalidInvocation() + } + return xs + } + public func invoke(_ arguments: some Sequence) -> QueryBinding { + do { + try self.body(arguments) + return .null + } catch { + return .invalid(error) + } + } + private struct InvalidInvocation: Error { + } + } + """ + } + } + } } } diff --git a/Tests/StructuredQueriesMacrosTests/TableMacroTests.swift b/Tests/StructuredQueriesMacrosTests/TableMacroTests.swift index b6466046..e3c31215 100644 --- a/Tests/StructuredQueriesMacrosTests/TableMacroTests.swift +++ b/Tests/StructuredQueriesMacrosTests/TableMacroTests.swift @@ -2453,7 +2453,7 @@ extension SnapshotTests { } } - public struct Selection: StructuredQueriesCore.TableExpression { + public nonisolated struct Selection: StructuredQueriesCore.TableExpression { public typealias QueryValue = Post public let allColumns: [any StructuredQueriesCore.QueryExpression] public static func photo( @@ -2482,7 +2482,10 @@ extension SnapshotTests { TableColumns() } public nonisolated static var _columnWidth: Int { - [Photo._columnWidth, String._columnWidth].reduce(0, +) + var columnWidth = 0 + columnWidth += Photo._columnWidth + columnWidth += String._columnWidth + return columnWidth } public nonisolated static var tableName: String { "posts" @@ -2557,7 +2560,7 @@ extension SnapshotTests { } } - public struct Selection: StructuredQueriesCore.TableExpression { + public nonisolated struct Selection: StructuredQueriesCore.TableExpression { public typealias QueryValue = Post public let allColumns: [any StructuredQueriesCore.QueryExpression] public static func photo( @@ -2586,7 +2589,10 @@ extension SnapshotTests { TableColumns() } public nonisolated static var _columnWidth: Int { - [Photo._columnWidth, String._columnWidth].reduce(0, +) + var columnWidth = 0 + columnWidth += Photo._columnWidth + columnWidth += String._columnWidth + return columnWidth } public nonisolated static var tableName: String { "posts" @@ -2657,7 +2663,7 @@ extension SnapshotTests { } } - public struct Selection: StructuredQueriesCore.TableExpression { + public nonisolated struct Selection: StructuredQueriesCore.TableExpression { public typealias QueryValue = Post public let allColumns: [any StructuredQueriesCore.QueryExpression] public static func photo( @@ -2686,7 +2692,10 @@ extension SnapshotTests { TableColumns() } public nonisolated static var _columnWidth: Int { - [Photo._columnWidth, String._columnWidth].reduce(0, +) + var columnWidth = 0 + columnWidth += Photo._columnWidth + columnWidth += String._columnWidth + return columnWidth } public nonisolated static var tableName: String { "posts" @@ -2742,7 +2751,7 @@ extension SnapshotTests { } } - public struct Selection: StructuredQueriesCore.TableExpression { + public nonisolated struct Selection: StructuredQueriesCore.TableExpression { public typealias QueryValue = Post public let allColumns: [any StructuredQueriesCore.QueryExpression] public static func photo( @@ -2771,7 +2780,10 @@ extension SnapshotTests { TableColumns() } public nonisolated static var _columnWidth: Int { - [Photo._columnWidth, String._columnWidth].reduce(0, +) + var columnWidth = 0 + columnWidth += Photo._columnWidth + columnWidth += String._columnWidth + return columnWidth } public nonisolated static var tableName: String { "posts" @@ -2823,7 +2835,7 @@ extension SnapshotTests { } } - public struct Selection: StructuredQueriesCore.TableExpression { + public nonisolated struct Selection: StructuredQueriesCore.TableExpression { public typealias QueryValue = Post public let allColumns: [any StructuredQueriesCore.QueryExpression] public static func note( @@ -2843,7 +2855,9 @@ extension SnapshotTests { TableColumns() } public nonisolated static var _columnWidth: Int { - [String._columnWidth].reduce(0, +) + var columnWidth = 0 + columnWidth += String._columnWidth + return columnWidth } public nonisolated static var tableName: String { "posts" @@ -2893,7 +2907,7 @@ extension SnapshotTests { } } - public struct Selection: StructuredQueriesCore.TableExpression { + public nonisolated struct Selection: StructuredQueriesCore.TableExpression { public typealias QueryValue = Post public let allColumns: [any StructuredQueriesCore.QueryExpression] public static func timestamp( @@ -2913,7 +2927,9 @@ extension SnapshotTests { TableColumns() } public nonisolated static var _columnWidth: Int { - [Date.UnixTimeRepresentation._columnWidth].reduce(0, +) + var columnWidth = 0 + columnWidth += Date.UnixTimeRepresentation._columnWidth + return columnWidth } public nonisolated static var tableName: String { "posts" diff --git a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift index 3429bf52..3eb48b42 100644 --- a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift +++ b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift @@ -452,9 +452,32 @@ extension SnapshotTests { } } - // ... + @DatabaseFunction + func sum(of xs: some Sequence) -> Int { + xs.reduce(into: 0, +=) + } + + @Test func aggregate() { + $sum.install(database.handle) - func joined(_ arguments: some Sequence<(String, separator: String)>) -> String? { + assertQuery( + Reminder.select { $sum(of: $0.id) } + ) { + """ + SELECT "sum"("reminders"."id") + FROM "reminders" + """ + } results: { + """ + ┌────┐ + │ 55 │ + └────┘ + """ + } + } + + @DatabaseFunction + func joined(_ arguments: some Sequence<(String, separator: String)>) throws -> String? { var iterator = arguments.makeIterator() guard var (result, _) = iterator.next() else { return nil } while let (string, separator) = iterator.next() { @@ -464,54 +487,11 @@ extension SnapshotTests { return result } - var _$joined: Joined { - Joined { joined($0) } - } - - struct Joined: StructuredQueriesSQLiteCore.AggregateDatabaseFunction { - public typealias Input = (String, separator: String) - public typealias Output = String? - public let name = "joined" - public let argumentCount: Int? = 2 - public let isDeterministic = true - public let body: (any Sequence) -> String? - public init(_ body: @escaping (any Sequence<(Input)>) -> String?) { - self.body = body - } - public func callAsFunction( - _ n0: some StructuredQueriesCore.QueryExpression, - separator: some StructuredQueriesCore.QueryExpression, - order: (some QueryExpression)? = Bool?.none, - filter: (some QueryExpression)? = Bool?.none - ) -> some StructuredQueriesCore.QueryExpression { - $_isSelecting.withValue(false) { - AggregateFunction( - QueryFragment(quote: name), - [n0.queryFragment, separator.queryFragment], - order: order?.queryFragment, - filter: filter?.queryFragment - ) - } - } - public func step(_ decoder: inout some QueryDecoder) throws -> Input { - let p0 = try decoder.decode(String.self) - let separator = try decoder.decode(String.self) - guard let p0 else { throw InvalidInvocation() } - guard let separator else { throw InvalidInvocation() } - return (p0, separator) - } - public func invoke(_ sequence: some Sequence) -> QueryBinding { - self.body(sequence).queryBinding - } - private struct InvalidInvocation: Error { - } - } - - @Test func aggregate() { - _$joined.install(database.handle) + @Test func multiAggregate() { + $joined.install(database.handle) assertQuery( - Tag.select { _$joined($0.title, separator: ", ", order: $0.title) } + Tag.select { $joined($0.title, separator: ", ", order: $0.title) } ) { """ SELECT "joined"("tags"."title", ', ' ORDER BY "tags"."title") From 0092915aad2c087e44e29f8469c926bb1ceadc2b Mon Sep 17 00:00:00 2001 From: Stephen Celis Date: Thu, 9 Oct 2025 15:14:51 -0700 Subject: [PATCH 06/19] wip --- .../DatabaseFunctionMacro.swift | 8 ++++---- .../DatabaseFunctionMacroTests.swift | 14 +++++++------- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift b/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift index aa827997..37ae484e 100644 --- a/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift +++ b/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift @@ -259,6 +259,10 @@ extension DatabaseFunctionMacro: PeerMacro { canThrowInvalidInvocation = true } representableInputType = representableInputTypes.joined(separator: ", ") + representableInputType = + representableInputTypes.count == 1 + ? representableInputType + : "(\(representableInputType))" projectedCallSyntax = "\(functionTypeName)(\(declaration.name.trimmed))" } let isVoidReturning = signature.returnClause == nil @@ -295,10 +299,6 @@ extension DatabaseFunctionMacro: PeerMacro { continue } } - representableInputType = - representableInputTypes.count == 1 - ? representableInputType - : "(\(representableInputType))" let argumentCount = argumentCounts.isEmpty diff --git a/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift b/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift index f4677988..48af2847 100644 --- a/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift +++ b/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift @@ -1263,7 +1263,7 @@ extension SnapshotTests { return xs } public func invoke(_ arguments: some Sequence) -> QueryBinding { - Int(queryOutput: self.body(arguments)).queryBinding + return Int(queryOutput: self.body(arguments)).queryBinding } private struct InvalidInvocation: Error { } @@ -1323,7 +1323,7 @@ extension SnapshotTests { return xs } public func invoke(_ arguments: some Sequence) -> QueryBinding { - Int(queryOutput: self.body(arguments)).queryBinding + return Int(queryOutput: self.body(arguments)).queryBinding } private struct InvalidInvocation: Error { } @@ -1365,7 +1365,7 @@ extension SnapshotTests { } struct __macro_local_6joinedfMu_: StructuredQueriesSQLiteCore.AggregateDatabaseFunction { - public typealias Input = ((String, separator: String)) + public typealias Input = (String, separator: String) public typealias Output = String? public let name = "joined" public var argumentCount: Int? { @@ -1379,7 +1379,7 @@ extension SnapshotTests { public init(_ body: @escaping (_ arguments: any Sequence<(String, separator: String)>) -> String?) { self.body = body } - public func callAsFunction(_ p0: some StructuredQueriesCore.QueryExpression, separator separator: some StructuredQueriesCore.QueryExpression, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression { + public func callAsFunction(_ p0: some StructuredQueriesCore.QueryExpression, separator: some StructuredQueriesCore.QueryExpression, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression { StructuredQueriesCore.$_isSelecting.withValue(false) { StructuredQueriesCore.AggregateFunction( self.name, p0, separator, order: order, filter: filter @@ -1388,7 +1388,7 @@ extension SnapshotTests { } public func step( _ decoder: inout some QueryDecoder - ) throws -> ((String, separator: String)) { + ) throws -> (String, separator: String) { let p0 = try decoder.decode(String.self) let separator = try decoder.decode(String.self) guard let p0 else { @@ -1463,7 +1463,7 @@ extension SnapshotTests { return arrays } public func invoke(_ arguments: some Sequence<[String]>) -> QueryBinding { - [String].JSONRepresentation(queryOutput: self.body(arguments)).queryBinding + return [String].JSONRepresentation(queryOutput: self.body(arguments)).queryBinding } private struct InvalidInvocation: Error { } @@ -1561,7 +1561,7 @@ extension SnapshotTests { var $validatePositive: __macro_local_16validatePositivefMu_ { __macro_local_16validatePositivefMu_ { - validatePositive($0) + try validatePositive($0) } } From 74ec4d1f49888c93240b81677ef978ee931ce8fc Mon Sep 17 00:00:00 2001 From: Stephen Celis Date: Thu, 9 Oct 2025 16:31:54 -0700 Subject: [PATCH 07/19] wip --- Sources/StructuredQueriesSQLite/Macros.swift | 44 ++++++++-------- .../DatabaseFunction.swift | 6 ++- .../DatabaseFunctionMacro.swift | 24 +++++++-- .../DatabaseFunction.swift | 4 +- .../DatabaseFunctionMacroTests.swift | 9 ++-- .../DatabaseFunctionTests.swift | 51 ++++++++++++++++++- 6 files changed, 101 insertions(+), 37 deletions(-) diff --git a/Sources/StructuredQueriesSQLite/Macros.swift b/Sources/StructuredQueriesSQLite/Macros.swift index c8b45dce..44137a08 100644 --- a/Sources/StructuredQueriesSQLite/Macros.swift +++ b/Sources/StructuredQueriesSQLite/Macros.swift @@ -55,26 +55,24 @@ public macro DatabaseFunction( type: "DatabaseFunctionMacro" ) -// TODO: -// @attached(peer, names: overloaded, prefixed(`$`)) -// public macro DatabaseFunction( -// _ name: String = "", -// as representableFunctionType: ((any Sequence<(repeat each T)>) -> R).Type, -// isDeterministic: Bool = false -// ) = -// #externalMacro( -// module: "StructuredQueriesSQLiteMacros", -// type: "DatabaseFunctionMacro" -// ) -// -// @attached(peer, names: overloaded, prefixed(`$`)) -// public macro DatabaseFunction( -// _ name: String = "", -// as representableFunctionType: ((any Sequence<(repeat each T)>) -> Void).Type, -// isDeterministic: Bool = false -// ) = -// #externalMacro( -// module: "StructuredQueriesSQLiteMacros", -// type: "DatabaseFunctionMacro" -// ) -// +@attached(peer, names: overloaded, prefixed(`$`)) +public macro DatabaseFunction( + _ name: String = "", + as representableFunctionType: ((any Sequence<(repeat each T)>) -> R).Type, + isDeterministic: Bool = false +) = + #externalMacro( + module: "StructuredQueriesSQLiteMacros", + type: "DatabaseFunctionMacro" + ) + +@attached(peer, names: overloaded, prefixed(`$`)) +public macro DatabaseFunction( + _ name: String = "", + as representableFunctionType: ((any Sequence<(repeat each T)>) -> Void).Type, + isDeterministic: Bool = false +) = + #externalMacro( + module: "StructuredQueriesSQLiteMacros", + type: "DatabaseFunctionMacro" + ) diff --git a/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift b/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift index 6e8f51ab..63baf167 100644 --- a/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift +++ b/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift @@ -56,9 +56,11 @@ extension ScalarDatabaseFunction { /// Don't conform to this protocol directly. Instead, use the `@DatabaseFunction` macro to generate /// a conformance. public protocol AggregateDatabaseFunction: DatabaseFunction { - func step(_ decoder: inout some QueryDecoder) throws -> Input + associatedtype Row - func invoke(_ arguments: some Sequence) throws -> QueryBinding + func step(_ decoder: inout some QueryDecoder) throws -> Row + + func invoke(_ arguments: some Sequence) throws -> QueryBinding } extension AggregateDatabaseFunction { diff --git a/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift b/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift index 37ae484e..a7d9573c 100644 --- a/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift +++ b/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift @@ -110,7 +110,6 @@ extension DatabaseFunctionMacro: PeerMacro { var invocationArgumentTypes: [TypeSyntax] = [] var parameters: [String] = [] var argumentBindings: [String] = [] - var functionRepresentationIterator = functionRepresentation?.parameters.makeIterator() var decodings: [String] = [] var decodingUnwrappings: [String] = [] @@ -118,6 +117,7 @@ extension DatabaseFunctionMacro: PeerMacro { let isAggregate: Bool var representableInputType: String + var rowType = "" let projectedCallSyntax: ExprSyntax if signature.parameterClause.parameters.count == 1, @@ -131,7 +131,6 @@ extension DatabaseFunctionMacro: PeerMacro { let genericArgument = genericArgumentClause.arguments.first { isAggregate = true - representableInputType = "\(genericArgument)" someOrAnyParameterType.someOrAnySpecifier.tokenKind = .keyword(.any) let bodySignature = @@ -159,10 +158,26 @@ extension DatabaseFunctionMacro: PeerMacro { ] ) + let representableInputGeneric = functionRepresentation? + .parameters.first? + .type.as(SomeOrAnyTypeSyntax.self)? + .constraint.as(IdentifierTypeSyntax.self)? + .genericArgumentClause? + .arguments.first + let representableInputGenericArgument = representableInputGeneric?.argument + + representableInputType = "\(representableInputGeneric ?? genericArgument)" + rowType = "\(genericArgument)" + + let representableInputArguments = + representableInputGenericArgument?.as(TupleTypeSyntax.self)?.elements.map(\.type) + ?? (representableInputGenericArgument?.cast(TypeSyntax.self)).map { [$0] } + var representableInputArgumentsIterator = representableInputArguments?.makeIterator() + var offset = 0 for var element in tupleType.elements { defer { offset += 1 } - var type = (functionRepresentationIterator?.next()?.type ?? element.type) + var type = representableInputArgumentsIterator?.next() ?? element.type element.type = type.asQueryExpression() type = type.trimmed representableInputTypes.append(type.description) @@ -225,6 +240,7 @@ extension DatabaseFunctionMacro: PeerMacro { """ } else { isAggregate = false + var functionRepresentationIterator = functionRepresentation?.parameters.makeIterator() for index in signature.parameterClause.parameters.indices { var parameter = signature.parameterClause.parameters[index] @@ -343,7 +359,7 @@ extension DatabaseFunctionMacro: PeerMacro { """ public func step( _ decoder: inout some QueryDecoder - ) throws -> \(raw: representableInputType) { + ) throws -> \(raw: rowType) { \(raw: (decodings + decodingUnwrappings).map { "\($0)\n" }.joined())\ \(raw: stepReturnClause)\ } diff --git a/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift b/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift index 6b3160d5..755d52d9 100644 --- a/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift +++ b/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift @@ -114,7 +114,7 @@ private protocol AggregateDatabaseFunctionIteratorProtocol { associatedtype Body: AggregateDatabaseFunction var body: Body { get } - var stream: Stream { get } + var stream: Stream { get } func start() func step(_ decoder: inout some QueryDecoder) throws func finish() @@ -125,7 +125,7 @@ private final class AggregateDatabaseFunctionIterator< Body: AggregateDatabaseFunction >: AggregateDatabaseFunctionIteratorProtocol { let body: Body - let stream = Stream() + let stream = Stream() let queue = DispatchQueue.global(qos: .userInitiated) var _result: QueryBinding? init(_ body: Body) { diff --git a/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift b/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift index 48af2847..fca5fc5a 100644 --- a/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift +++ b/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift @@ -1409,7 +1409,6 @@ extension SnapshotTests { } } - // TODO @Test func customRepresentations() { assertMacro { #""" @@ -1433,12 +1432,12 @@ extension SnapshotTests { } struct __macro_local_6joinedfMu_: StructuredQueriesSQLiteCore.AggregateDatabaseFunction { - public typealias Input = [String] + public typealias Input = [String].JSONRepresentation public typealias Output = [String].JSONRepresentation public let name = "joined" public var argumentCount: Int? { var argumentCount = 0 - argumentCount += any Sequence<[String].JSONRepresentation>._columnWidth + argumentCount += [String].JSONRepresentation._columnWidth return argumentCount } public let isDeterministic = false @@ -1446,7 +1445,7 @@ extension SnapshotTests { public init(_ body: @escaping (_ arrays: any Sequence<[String]>) -> [String]) { self.body = body } - public func callAsFunction(_ arrays: some StructuredQueriesCore.QueryExpression>, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression<[String].JSONRepresentation> { + public func callAsFunction(_ arrays: some StructuredQueriesCore.QueryExpression<[String].JSONRepresentation>, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression<[String].JSONRepresentation> { StructuredQueriesCore.$_isSelecting.withValue(false) { StructuredQueriesCore.AggregateFunction( self.name, arrays, order: order, filter: filter @@ -1456,7 +1455,7 @@ extension SnapshotTests { public func step( _ decoder: inout some QueryDecoder ) throws -> [String] { - let arrays = try decoder.decode(any Sequence<[String].JSONRepresentation>.self) + let arrays = try decoder.decode([String].JSONRepresentation.self) guard let arrays else { throw InvalidInvocation() } diff --git a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift index 3eb48b42..38abae36 100644 --- a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift +++ b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift @@ -477,7 +477,7 @@ extension SnapshotTests { } @DatabaseFunction - func joined(_ arguments: some Sequence<(String, separator: String)>) throws -> String? { + func joined(_ arguments: some Sequence<(String, separator: String)>) -> String? { var iterator = arguments.makeIterator() guard var (result, _) = iterator.next() else { return nil } while let (string, separator) = iterator.next() { @@ -505,5 +505,54 @@ extension SnapshotTests { """ } } + + @DatabaseFunction( + as: ((any Sequence<[String].JSONRepresentation>) -> [String].JSONRepresentation).self + ) + func jsonJoined(_ arrays: some Sequence<[String]>) -> [String] { + arrays.flatMap(\.self) + } + + @Test func aggregateRepresentation() { + $jsonJoined.install(database.handle) + + assertQuery( + Reminder.select { + $jsonJoined(#sql("json_array(\($0.title.lower()), \($0.title.upper()))")) + } + ) { + """ + SELECT "jsonJoined"(json_array(lower("reminders"."title"), upper("reminders"."title"))) + FROM "reminders" + """ + } results: { + """ + ┌─────────────────────────────────────┐ + │ [ │ + │ [0]: "groceries", │ + │ [1]: "GROCERIES", │ + │ [2]: "haircut", │ + │ [3]: "HAIRCUT", │ + │ [4]: "doctor appointment", │ + │ [5]: "DOCTOR APPOINTMENT", │ + │ [6]: "take a walk", │ + │ [7]: "TAKE A WALK", │ + │ [8]: "buy concert tickets", │ + │ [9]: "BUY CONCERT TICKETS", │ + │ [10]: "pick up kids from school", │ + │ [11]: "PICK UP KIDS FROM SCHOOL", │ + │ [12]: "get laundry", │ + │ [13]: "GET LAUNDRY", │ + │ [14]: "take out trash", │ + │ [15]: "TAKE OUT TRASH", │ + │ [16]: "call accountant", │ + │ [17]: "CALL ACCOUNTANT", │ + │ [18]: "send weekly emails", │ + │ [19]: "SEND WEEKLY EMAILS" │ + │ ] │ + └─────────────────────────────────────┘ + """ + } + } } } From 21d6cea75f12700d62e3c446ab3a6c8bbbf5cccb Mon Sep 17 00:00:00 2001 From: Stephen Celis Date: Fri, 10 Oct 2025 10:29:35 -0700 Subject: [PATCH 08/19] wip --- .../DatabaseFunction.swift | 4 +++- .../DatabaseFunctionTests.swift | 24 +++++++++++++++++++ 2 files changed, 27 insertions(+), 1 deletion(-) diff --git a/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift b/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift index 755d52d9..1e78b66a 100644 --- a/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift +++ b/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift @@ -126,7 +126,9 @@ private final class AggregateDatabaseFunctionIterator< >: AggregateDatabaseFunctionIteratorProtocol { let body: Body let stream = Stream() - let queue = DispatchQueue.global(qos: .userInitiated) + let queue = DispatchQueue( + label: "co.pointfree.StructuredQueriesSQLite.AggregateDatabaseFunction" + ) var _result: QueryBinding? init(_ body: Body) { self.body = body diff --git a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift index 38abae36..1980e414 100644 --- a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift +++ b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift @@ -554,5 +554,29 @@ extension SnapshotTests { """ } } + + @DatabaseFunction + func tagged(_ tags: some Sequence) -> String { + tags.map { "#\($0.title)" }.joined(separator: " ") + } + + @Test func selectionTableAggregate() { + $tagged.install(database.handle) + + assertQuery( + Tag.select { $tagged($0) } + ) { + """ + SELECT "tagged"("tags"."id", "tags"."title") + FROM "tags" + """ + } results: { + """ + ┌─────────────────────────────────┐ + │ "#car #kids #someday #optional" │ + └─────────────────────────────────┘ + """ + } + } } } From 5ed3d2d0ba2373e326443e2eb4bdf615c2d5c317 Mon Sep 17 00:00:00 2001 From: Stephen Celis Date: Fri, 10 Oct 2025 10:44:06 -0700 Subject: [PATCH 09/19] wip --- Tests/StructuredQueriesTests/DatabaseFunctionTests.swift | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift index 1980e414..cfff64fc 100644 --- a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift +++ b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift @@ -491,16 +491,16 @@ extension SnapshotTests { $joined.install(database.handle) assertQuery( - Tag.select { $joined($0.title, separator: ", ", order: $0.title) } + Tag.select { $joined($0.title, separator: ", ") } ) { """ - SELECT "joined"("tags"."title", ', ' ORDER BY "tags"."title") + SELECT "joined"("tags"."title", ', ') FROM "tags" """ } results: { """ ┌────────────────────────────────┐ - │ "car, kids, optional, someday" │ + │ "car, kids, someday, optional" │ └────────────────────────────────┘ """ } From 793c5faf01a1762fcd13f37c980411cafbf8afb3 Mon Sep 17 00:00:00 2001 From: Stephen Celis Date: Sat, 11 Oct 2025 11:42:21 -0700 Subject: [PATCH 10/19] wip --- .../_StructuredQueriesSQLite/DatabaseFunction.swift | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift b/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift index 1e78b66a..0fb8fbaf 100644 --- a/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift +++ b/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift @@ -99,7 +99,8 @@ private final class AggregateDatabaseFunctionContext { .pointee = unmanagedContext return unmanagedContext } else { - return pointer + return + pointer .assumingMemoryBound(to: Unmanaged.self) .pointee } @@ -126,12 +127,13 @@ private final class AggregateDatabaseFunctionIterator< >: AggregateDatabaseFunctionIteratorProtocol { let body: Body let stream = Stream() - let queue = DispatchQueue( - label: "co.pointfree.StructuredQueriesSQLite.AggregateDatabaseFunction" - ) + let queue: DispatchQueue var _result: QueryBinding? init(_ body: Body) { self.body = body + self.queue = DispatchQueue( + label: "co.pointfree.StructuredQueriesSQLite.AggregateDatabaseFunction.\(body.name)" + ) nonisolated(unsafe) let iterator: any AggregateDatabaseFunctionIteratorProtocol = self queue.async { iterator.start() From 71525233c7374777dee154e035cc0493e9131383 Mon Sep 17 00:00:00 2001 From: Stephen Celis Date: Sat, 11 Oct 2025 13:32:20 -0700 Subject: [PATCH 11/19] wip --- .../DatabaseFunction.swift | 15 ++++++++++++--- .../DatabaseFunction.swift | 4 ++-- 2 files changed, 14 insertions(+), 5 deletions(-) diff --git a/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift b/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift index 63baf167..f5f5318e 100644 --- a/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift +++ b/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift @@ -56,11 +56,20 @@ extension ScalarDatabaseFunction { /// Don't conform to this protocol directly. Instead, use the `@DatabaseFunction` macro to generate /// a conformance. public protocol AggregateDatabaseFunction: DatabaseFunction { - associatedtype Row + /// A type representing a row's input to the aggregate function. + associatedtype Element = Input - func step(_ decoder: inout some QueryDecoder) throws -> Row + /// Decodes rows into elements to aggregate a result from. + /// + /// - Parameter decoder: A query decoder. + /// - Returns: An element to append to the sequence sent to the aggregate function. + func step(_ decoder: inout some QueryDecoder) throws -> Element - func invoke(_ arguments: some Sequence) throws -> QueryBinding + /// Aggregates elements into a bindable value. + /// + /// - Parameter arguments: A sequence of elements to aggregate from. + /// - Returns: A binding returned from the aggregate function. + func invoke(_ arguments: some Sequence) throws -> QueryBinding } extension AggregateDatabaseFunction { diff --git a/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift b/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift index 0fb8fbaf..ea5cb9fb 100644 --- a/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift +++ b/Sources/_StructuredQueriesSQLite/DatabaseFunction.swift @@ -115,7 +115,7 @@ private protocol AggregateDatabaseFunctionIteratorProtocol { associatedtype Body: AggregateDatabaseFunction var body: Body { get } - var stream: Stream { get } + var stream: Stream { get } func start() func step(_ decoder: inout some QueryDecoder) throws func finish() @@ -126,7 +126,7 @@ private final class AggregateDatabaseFunctionIterator< Body: AggregateDatabaseFunction >: AggregateDatabaseFunctionIteratorProtocol { let body: Body - let stream = Stream() + let stream = Stream() let queue: DispatchQueue var _result: QueryBinding? init(_ body: Body) { From 8cede9d4a357b545639cd16a3b722705fe9ad83d Mon Sep 17 00:00:00 2001 From: Stephen Celis Date: Sat, 11 Oct 2025 13:36:03 -0700 Subject: [PATCH 12/19] wip --- Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift b/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift index f5f5318e..93dde21e 100644 --- a/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift +++ b/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift @@ -56,10 +56,10 @@ extension ScalarDatabaseFunction { /// Don't conform to this protocol directly. Instead, use the `@DatabaseFunction` macro to generate /// a conformance. public protocol AggregateDatabaseFunction: DatabaseFunction { - /// A type representing a row's input to the aggregate function. + /// A type representing one row of input to the aggregate function. associatedtype Element = Input - /// Decodes rows into elements to aggregate a result from. + /// Decodes a row into an element to aggregate a result from. /// /// - Parameter decoder: A query decoder. /// - Returns: An element to append to the sequence sent to the aggregate function. From 598f2d2422dd0ad1070e74f6ef6ea571c5d35cb9 Mon Sep 17 00:00:00 2001 From: Stephen Celis Date: Mon, 13 Oct 2025 11:54:27 -0700 Subject: [PATCH 13/19] wip --- .../AggregateFunctions.swift | 2 +- .../DatabaseFunction.swift | 36 ++++++++++++++----- .../DatabaseFunctionTests.swift | 14 ++++++++ 3 files changed, 43 insertions(+), 9 deletions(-) diff --git a/Sources/StructuredQueriesCore/AggregateFunctions.swift b/Sources/StructuredQueriesCore/AggregateFunctions.swift index e6023dde..3152d388 100644 --- a/Sources/StructuredQueriesCore/AggregateFunctions.swift +++ b/Sources/StructuredQueriesCore/AggregateFunctions.swift @@ -225,7 +225,7 @@ public struct AggregateFunction: QueryExpression, Sendable { ) { self.init( QueryFragment(quote: name), - isDistinct: false, + isDistinct: isDistinct, Array(repeat each arguments), order: order?.queryFragment, filter: filter?.queryFragment diff --git a/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift b/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift index 93dde21e..f47c69b5 100644 --- a/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift +++ b/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift @@ -73,9 +73,34 @@ public protocol AggregateDatabaseFunction: DatabaseFunction { } extension AggregateDatabaseFunction { - /// A function call expression. + /// An aggregate function call expression. /// - /// - Parameter input: Expressions representing the arguments of the function. + /// - Parameters + /// - input: Expressions representing the arguments of the function. + /// - isDistinct: Whether or not to include a `DISTINCT` clause, which filters duplicates from + /// the aggregation. + /// - order: An `ORDER BY` clause to apply to the aggregation. + /// - filter: A `FILTER` clause to apply to the aggregation. + /// - Returns: An expression representing the function call. + @_disfavoredOverload + public func callAsFunction( + _ input: some QueryExpression, + distinct isDistinct: Bool = false, + order: (some QueryExpression)? = Bool?.none, + filter: (some QueryExpression)? = Bool?.none + ) -> some QueryExpression + where Input: QueryBindable { + $_isSelecting.withValue(false) { + AggregateFunction(name, distinct: isDistinct, input, order: order, filter: filter) + } + } + + /// An aggregate function call expression. + /// + /// - Parameters + /// - input: Expressions representing the arguments of the function. + /// - order: An `ORDER BY` clause to apply to the aggregation. + /// - filter: A `FILTER` clause to apply to the aggregation. /// - Returns: An expression representing the function call. @_disfavoredOverload public func callAsFunction( @@ -85,12 +110,7 @@ extension AggregateDatabaseFunction { ) -> some QueryExpression where Input == (repeat (each T).QueryValue) { $_isSelecting.withValue(false) { - AggregateFunction( - QueryFragment(quote: name), - Array(repeat each input), - order: order?.queryFragment, - filter: filter?.queryFragment - ) + AggregateFunction(name, repeat each input, order: order, filter: filter) } } } diff --git a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift index cfff64fc..fa870d26 100644 --- a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift +++ b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift @@ -474,6 +474,20 @@ extension SnapshotTests { └────┘ """ } + assertQuery( + Reminder.select { $sum($0.id, distinct: true) } + ) { + """ + SELECT "sum"(DISTINCT "reminders"."id") + FROM "reminders" + """ + } results: { + """ + ┌────┐ + │ 55 │ + └────┘ + """ + } } @DatabaseFunction From 71bb30d995a758a754cb140da05133f0c2a09aeb Mon Sep 17 00:00:00 2001 From: Stephen Celis Date: Mon, 13 Oct 2025 11:57:41 -0700 Subject: [PATCH 14/19] wip --- Sources/StructuredQueriesSQLite/Macros.swift | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/Sources/StructuredQueriesSQLite/Macros.swift b/Sources/StructuredQueriesSQLite/Macros.swift index 44137a08..c2f96755 100644 --- a/Sources/StructuredQueriesSQLite/Macros.swift +++ b/Sources/StructuredQueriesSQLite/Macros.swift @@ -55,6 +55,14 @@ public macro DatabaseFunction( type: "DatabaseFunctionMacro" ) +/// Defines and implements a conformance to the ``/StructuredQueriesSQLiteCore/DatabaseFunction`` +/// protocol. +/// +/// - Parameters +/// - name: The function's name. Defaults to the name of the function the macro is applied to. +/// - representableFunctionType: The function as represented in a query. +/// - isDeterministic: Whether or not the function is deterministic (or "pure" or "referentially +/// transparent"), _i.e._ given an input it will always return the same output. @attached(peer, names: overloaded, prefixed(`$`)) public macro DatabaseFunction( _ name: String = "", @@ -66,6 +74,14 @@ public macro DatabaseFunction( _ name: String = "", From 4f23c4f76e0d003c02dce3f31dfb3b66eef29ee5 Mon Sep 17 00:00:00 2001 From: Stephen Celis Date: Mon, 13 Oct 2025 17:50:44 -0700 Subject: [PATCH 15/19] wip --- .../AggregateFunctions.swift | 22 +++++++++---------- .../ScalarFunctions.swift | 3 +-- .../DatabaseFunction.swift | 4 ++-- .../JSONFunctions.swift | 6 ++--- .../DatabaseFunctionMacro.swift | 2 +- .../DatabaseFunctionMacroTests.swift | 12 +++++----- 6 files changed, 24 insertions(+), 25 deletions(-) diff --git a/Sources/StructuredQueriesCore/AggregateFunctions.swift b/Sources/StructuredQueriesCore/AggregateFunctions.swift index 3152d388..ea44bdb8 100644 --- a/Sources/StructuredQueriesCore/AggregateFunctions.swift +++ b/Sources/StructuredQueriesCore/AggregateFunctions.swift @@ -20,7 +20,7 @@ extension QueryExpression where QueryValue: QueryBindable { distinct isDistinct: Bool = false, filter: (some QueryExpression)? = Bool?.none ) -> some QueryExpression { - AggregateFunction( + AggregateFunctionExpression( "count", isDistinct: isDistinct, [queryFragment], @@ -51,7 +51,7 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped == Strin order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none ) -> some QueryExpression { - AggregateFunction( + AggregateFunctionExpression( "group_concat", separator.map { [queryFragment, $0.queryFragment] } ?? [queryFragment], order: order?.queryFragment, @@ -74,7 +74,7 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped == Strin order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none ) -> some QueryExpression { - AggregateFunction( + AggregateFunctionExpression( "group_concat", isDistinct: isDistinct, [queryFragment], @@ -97,7 +97,7 @@ extension QueryExpression where QueryValue: QueryBindable & _OptionalPromotable public func max( filter: (some QueryExpression)? = Bool?.none ) -> some QueryExpression { - AggregateFunction("max", [queryFragment], filter: filter?.queryFragment) + AggregateFunctionExpression("max", [queryFragment], filter: filter?.queryFragment) } /// A minimum aggregate of this expression. @@ -112,7 +112,7 @@ extension QueryExpression where QueryValue: QueryBindable & _OptionalPromotable public func min( filter: (some QueryExpression)? = Bool?.none ) -> some QueryExpression { - AggregateFunction("min", [queryFragment], filter: filter?.queryFragment) + AggregateFunctionExpression("min", [queryFragment], filter: filter?.queryFragment) } } @@ -134,7 +134,7 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric distinct isDistinct: Bool = false, filter: (some QueryExpression)? = Bool?.none ) -> some QueryExpression { - AggregateFunction("avg", isDistinct: isDistinct, [queryFragment], filter: filter?.queryFragment) + AggregateFunctionExpression("avg", isDistinct: isDistinct, [queryFragment], filter: filter?.queryFragment) } /// An sum aggregate of this expression. @@ -156,7 +156,7 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric // NB: We must explicitly erase here to avoid a runtime crash with opaque return types // TODO: Report issue to Swift team. SQLQueryExpression( - AggregateFunction( + AggregateFunctionExpression( "sum", isDistinct: isDistinct, [queryFragment], @@ -182,7 +182,7 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric distinct isDistinct: Bool = false, filter: (some QueryExpression)? = Bool?.none ) -> some QueryExpression { - AggregateFunction( + AggregateFunctionExpression( "total", isDistinct: isDistinct, [queryFragment], @@ -191,7 +191,7 @@ where QueryValue: _OptionalPromotable, QueryValue._Optionalized.Wrapped: Numeric } } -extension QueryExpression where Self == AggregateFunction { +extension QueryExpression where Self == AggregateFunctionExpression { /// A `count(*)` aggregate. /// /// ```swift @@ -204,12 +204,12 @@ extension QueryExpression where Self == AggregateFunction { public static func count( filter: (any QueryExpression)? = nil ) -> Self { - AggregateFunction("count", ["*"], filter: filter?.queryFragment) + AggregateFunctionExpression("count", ["*"], filter: filter?.queryFragment) } } /// A query expression of an aggregate function. -public struct AggregateFunction: QueryExpression, Sendable { +public struct AggregateFunctionExpression: QueryExpression, Sendable { var name: QueryFragment var isDistinct: Bool var arguments: [QueryFragment] diff --git a/Sources/StructuredQueriesCore/ScalarFunctions.swift b/Sources/StructuredQueriesCore/ScalarFunctions.swift index 460d05cb..7a2b2db2 100644 --- a/Sources/StructuredQueriesCore/ScalarFunctions.swift +++ b/Sources/StructuredQueriesCore/ScalarFunctions.swift @@ -319,8 +319,7 @@ extension QueryExpression where QueryValue == [UInt8] { } } -/// A query expression of a generalized query function. -public struct QueryFunction: QueryExpression { +package struct QueryFunction: QueryExpression { let name: QueryFragment let arguments: [QueryFragment] diff --git a/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift b/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift index f47c69b5..aecc04ea 100644 --- a/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift +++ b/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift @@ -91,7 +91,7 @@ extension AggregateDatabaseFunction { ) -> some QueryExpression where Input: QueryBindable { $_isSelecting.withValue(false) { - AggregateFunction(name, distinct: isDistinct, input, order: order, filter: filter) + AggregateFunctionExpression(name, distinct: isDistinct, input, order: order, filter: filter) } } @@ -110,7 +110,7 @@ extension AggregateDatabaseFunction { ) -> some QueryExpression where Input == (repeat (each T).QueryValue) { $_isSelecting.withValue(false) { - AggregateFunction(name, repeat each input, order: order, filter: filter) + AggregateFunctionExpression(name, repeat each input, order: order, filter: filter) } } } diff --git a/Sources/StructuredQueriesSQLiteCore/JSONFunctions.swift b/Sources/StructuredQueriesSQLiteCore/JSONFunctions.swift index 5e62a42b..d842250b 100644 --- a/Sources/StructuredQueriesSQLiteCore/JSONFunctions.swift +++ b/Sources/StructuredQueriesSQLiteCore/JSONFunctions.swift @@ -46,7 +46,7 @@ extension QueryExpression where QueryValue: Codable & QueryBindable { order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none ) -> some QueryExpression<[QueryValue].JSONRepresentation> { - AggregateFunction( + AggregateFunctionExpression( "json_group_array", isDistinct: isDistinct, [queryFragment], @@ -112,7 +112,7 @@ extension PrimaryKeyedTableDefinition where QueryValue: Codable { order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none ) -> some QueryExpression<[QueryValue].JSONRepresentation> { - AggregateFunction( + AggregateFunctionExpression( "json_group_array", isDistinct: isDistinct, [jsonObject().queryFragment], @@ -200,7 +200,7 @@ where } else { primaryKeyFilter.queryFragment } - return AggregateFunction( + return AggregateFunctionExpression( "json_group_array", isDistinct: isDistinct, [QueryValue.columns.jsonObject().queryFragment], diff --git a/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift b/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift index 8d03eeaf..c25887c5 100644 --- a/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift +++ b/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift @@ -337,7 +337,7 @@ extension DatabaseFunctionMacro: PeerMacro { """ public func callAsFunction\(signature.trimmed) { StructuredQueriesCore.$_isSelecting.withValue(false) { - StructuredQueriesCore.AggregateFunction( + StructuredQueriesCore.AggregateFunctionExpression( self.name, \ \(raw: parameters.joined(separator: ", ")), \ order: order, \ diff --git a/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift b/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift index 6afe06ac..c7f1f883 100644 --- a/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift +++ b/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift @@ -1248,7 +1248,7 @@ extension SnapshotTests { } public func callAsFunction(_ xs: some StructuredQueriesCore.QueryExpression, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression { StructuredQueriesCore.$_isSelecting.withValue(false) { - StructuredQueriesCore.AggregateFunction( + StructuredQueriesCore.AggregateFunctionExpression( self.name, xs, order: order, filter: filter ) } @@ -1308,7 +1308,7 @@ extension SnapshotTests { } public func callAsFunction(of xs: some StructuredQueriesCore.QueryExpression, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression { StructuredQueriesCore.$_isSelecting.withValue(false) { - StructuredQueriesCore.AggregateFunction( + StructuredQueriesCore.AggregateFunctionExpression( self.name, xs, order: order, filter: filter ) } @@ -1381,7 +1381,7 @@ extension SnapshotTests { } public func callAsFunction(_ p0: some StructuredQueriesCore.QueryExpression, separator: some StructuredQueriesCore.QueryExpression, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression { StructuredQueriesCore.$_isSelecting.withValue(false) { - StructuredQueriesCore.AggregateFunction( + StructuredQueriesCore.AggregateFunctionExpression( self.name, p0, separator, order: order, filter: filter ) } @@ -1447,7 +1447,7 @@ extension SnapshotTests { } public func callAsFunction(_ arrays: some StructuredQueriesCore.QueryExpression<[String].JSONRepresentation>, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression<[String].JSONRepresentation> { StructuredQueriesCore.$_isSelecting.withValue(false) { - StructuredQueriesCore.AggregateFunction( + StructuredQueriesCore.AggregateFunctionExpression( self.name, arrays, order: order, filter: filter ) } @@ -1511,7 +1511,7 @@ extension SnapshotTests { } public func callAsFunction(_ xs: some StructuredQueriesCore.QueryExpression, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression { StructuredQueriesCore.$_isSelecting.withValue(false) { - StructuredQueriesCore.AggregateFunction( + StructuredQueriesCore.AggregateFunctionExpression( self.name, xs, order: order, filter: filter ) } @@ -1580,7 +1580,7 @@ extension SnapshotTests { } public func callAsFunction(_ xs: some StructuredQueriesCore.QueryExpression, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression { StructuredQueriesCore.$_isSelecting.withValue(false) { - StructuredQueriesCore.AggregateFunction( + StructuredQueriesCore.AggregateFunctionExpression( self.name, xs, order: order, filter: filter ) } From 6453cfd0c071c5d33c8a8a63434270ee07fb70b9 Mon Sep 17 00:00:00 2001 From: Stephen Celis Date: Mon, 13 Oct 2025 18:01:32 -0700 Subject: [PATCH 16/19] wip --- .../Articles/CustomFunctions.md | 27 +++++++++++++++++-- 1 file changed, 25 insertions(+), 2 deletions(-) diff --git a/Sources/StructuredQueriesSQLiteCore/Documentation.docc/Articles/CustomFunctions.md b/Sources/StructuredQueriesSQLiteCore/Documentation.docc/Articles/CustomFunctions.md index 7d4b76b8..ef410f3d 100644 --- a/Sources/StructuredQueriesSQLiteCore/Documentation.docc/Articles/CustomFunctions.md +++ b/Sources/StructuredQueriesSQLiteCore/Documentation.docc/Articles/CustomFunctions.md @@ -5,6 +5,8 @@ from SQLite. ## Overview +### Scalar functions + StructuredQueries defines a macro specifically for defining Swift functions that can be called from a query. It's called `@DatabaseFunction`, and can annotate any function that works with query-representable types. @@ -18,11 +20,14 @@ func exclaim(_ string: String) -> String { } ``` +This defines a "scalar" function, which is called on a value for each row in a query, returning its +result. + > Note: If your project is using [default main actor isolation] then you further need to annotate > your function as `nonisolated`. [default main actor isolation]: https://github.com/swiftlang/swift-evolution/blob/main/proposals/0466-control-default-actor-isolation.md -And will be immediately callable in a query by prefixing the function with `$`: +Once defined, the function is immediately callable in a query by prefixing the function with `$`: ```swift Reminder.select { $exclaim($0.title) } @@ -52,9 +57,26 @@ configuration.prepareDatabase { db in > } > ``` +### Aggregate functions + +It is also possible to define a Swift function that builds a single result from multiple rows of a +query. The function must simply take a _sequence_ of query-representable types. + +For example, a custom `sum` function could be defined like so: + +```swift +@DatabaseFunction +func sum(_ ints: some Sequence) -> Int { + ints.reduce(into: 0, +=) +} +``` + +This defines an "aggregate" function, where every element in `int` represents a row returned from +the base query. + ### Custom representations -To define a type that works with a custom representation, i.e. anytime you use `@Column(as:)` in +To define a type that works with a custom representation, _i.e._ anytime you use `@Column(as:)` in your data type, you can use the `as` parameter of the macro to specify those types. For example, if your model holds onto a date and you want to store that date as a [unix timestamp]() (i.e. double), @@ -99,3 +121,4 @@ func jsonArrayExclaim(_ strings: [String]) -> [String] { - ``DatabaseFunction`` - ``ScalarDatabaseFunction`` +- ``AggregateDatabaseFunction`` From ab2c4981ee02ab588b47d89695ebbae15efffd72 Mon Sep 17 00:00:00 2001 From: Brandon Williams Date: Wed, 15 Oct 2025 16:24:03 -0500 Subject: [PATCH 17/19] Added a test for mode aggregation. --- .../DatabaseFunction.swift | 13 ++-- .../DatabaseFunctionTests.swift | 62 +++++++++++++++++-- 2 files changed, 65 insertions(+), 10 deletions(-) diff --git a/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift b/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift index aecc04ea..71584ab0 100644 --- a/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift +++ b/Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift @@ -1,7 +1,7 @@ /// A type representing a database function. /// -/// Don't conform to this protocol directly. Instead, use the `@DatabaseFunction` macro to generate -/// a conformance. +/// Don't conform to this protocol directly. Instead, use the +/// [`@DatabaseFunction`]() macro to generate a conformance. public protocol DatabaseFunction { /// A type representing the function's arguments. associatedtype Input @@ -22,8 +22,8 @@ public protocol DatabaseFunction { /// A type representing a scalar database function. /// -/// Don't conform to this protocol directly. Instead, use the `@DatabaseFunction` macro to generate -/// a conformance. +/// Don't conform to this protocol directly. Instead, use the +/// [`@DatabaseFunction`]() macro to generate a conformance. public protocol ScalarDatabaseFunction: DatabaseFunction { /// The function body. Uses a query decoder to process the input of a database function into a /// bindable value. @@ -53,8 +53,9 @@ extension ScalarDatabaseFunction { /// A type representing an aggregate database function. /// -/// Don't conform to this protocol directly. Instead, use the `@DatabaseFunction` macro to generate -/// a conformance. +/// Don't conform to this protocol directly. Instead, use the +/// [`@DatabaseFunction`]() macro to generate a +/// conformance. public protocol AggregateDatabaseFunction: DatabaseFunction { /// A type representing one row of input to the aggregate function. associatedtype Element = Input diff --git a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift index fa870d26..302a47a2 100644 --- a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift +++ b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift @@ -492,10 +492,13 @@ extension SnapshotTests { @DatabaseFunction func joined(_ arguments: some Sequence<(String, separator: String)>) -> String? { - var iterator = arguments.makeIterator() - guard var (result, _) = iterator.next() else { return nil } - while let (string, separator) = iterator.next() { - result.append(separator) + var isFirst = true + var result = "" + for (string, separator) in arguments { + defer { isFirst = false } + if !isFirst { + result.append(separator) + } result.append(string) } return result @@ -592,5 +595,56 @@ extension SnapshotTests { """ } } + + @DatabaseFunction + func mode(priority priorities: some Sequence) -> Priority? { + var counts: [Priority: Int] = [:] + for priority in priorities { + guard let priority + else { continue } + counts[priority, default: 0] += 1 + } + return counts.max { $0.value < $1.value }?.key + } + @Test func modePriorityAggregate() { + $mode.install(database.handle) + + assertQuery( + Reminder + .select { $mode(priority: $0.priority) } + ) { + """ + SELECT "mode"("reminders"."priority") + FROM "reminders" + """ + } results: { + """ + ┌───────┐ + │ .high │ + └───────┘ + """ + } + assertQuery( + RemindersList + .group(by: \.id) + .leftJoin(Reminder.all) { $0.id.eq($1.remindersListID) } + .select { ($0.title, $mode(priority: $1.priority)) } + ) { + """ + SELECT "remindersLists"."title", "mode"("reminders"."priority") + FROM "remindersLists" + LEFT JOIN "reminders" ON ("remindersLists"."id") = ("reminders"."remindersListID") + GROUP BY "remindersLists"."id" + """ + } results: { + """ + ┌────────────┬─────────┐ + │ "Personal" │ .high │ + │ "Family" │ .high │ + │ "Business" │ .medium │ + └────────────┴─────────┘ + """ + } + } } } From f28c46c3a73462c8c998612e4cb64081a3ae13f2 Mon Sep 17 00:00:00 2001 From: Brandon Williams Date: Wed, 15 Oct 2025 16:56:49 -0500 Subject: [PATCH 18/19] more docs --- .../Articles/CustomFunctions.md | 48 +++++++++++++++++-- .../DatabaseFunctionTests.swift | 6 +-- 2 files changed, 46 insertions(+), 8 deletions(-) diff --git a/Sources/StructuredQueriesSQLiteCore/Documentation.docc/Articles/CustomFunctions.md b/Sources/StructuredQueriesSQLiteCore/Documentation.docc/Articles/CustomFunctions.md index ef410f3d..78d0128e 100644 --- a/Sources/StructuredQueriesSQLiteCore/Documentation.docc/Articles/CustomFunctions.md +++ b/Sources/StructuredQueriesSQLiteCore/Documentation.docc/Articles/CustomFunctions.md @@ -62,17 +62,43 @@ configuration.prepareDatabase { db in It is also possible to define a Swift function that builds a single result from multiple rows of a query. The function must simply take a _sequence_ of query-representable types. -For example, a custom `sum` function could be defined like so: +For example, suppose you want to compute the most common priority used across all reminders. This +computation is called the "mode" in statistics, and unfortunately SQLite does not supply such +a function. But it is quite easy to write this function in plain Swift: ```swift @DatabaseFunction -func sum(_ ints: some Sequence) -> Int { - ints.reduce(into: 0, +=) +func mode(priority priorities: some Sequence) -> Priority? { + var occurences: [Priority: Int] = [:] + for priority in priorities { + guard let priority + else { continue } + occurences[priority, default: 0] += 1 + } + return occurences.max { $0.value < $1.value }?.key } ``` -This defines an "aggregate" function, where every element in `int` represents a row returned from -the base query. +This defines an "aggregate" function, and the sequence `priorities` that is passed to it represents +all of the data from the database passed to it while aggregating. It is now straightfoward +to compute the mode of priorities across all reminders: + +```swift +Reminder + .select { $mode(priority: $0.priority) } +``` + +> Tip: Be sure to install the function in the database connection as discussed in +> above. + +You can also compute the mode of priorities inside each reminders list: + +```swift +RemindersList + .group(by: \.id) + .leftJoin(Reminder.all) { $0.id.eq($1.remindersListID) } + .select { ($0.title, $mode(priority: $1.priority)) } +``` ### Custom representations @@ -115,6 +141,18 @@ func jsonArrayExclaim(_ strings: [String]) -> [String] { } ``` +It is also possible to do this with aggregate functions, but you must describe the sequence as an +`any Sequence` instead of a `some Sequence`: + +```swift +@DatabaseFunction( + as: ((any Sequence<[String].JSONRepresentation>) -> [String].JSONRepresentation).self +) +func jsonJoined(_ arrays: some Sequence<[String]>) -> [String] { + arrays.flatMap(\.self) +} +``` + ## Topics ### Custom functions diff --git a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift index 302a47a2..d57ffe24 100644 --- a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift +++ b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift @@ -598,13 +598,13 @@ extension SnapshotTests { @DatabaseFunction func mode(priority priorities: some Sequence) -> Priority? { - var counts: [Priority: Int] = [:] + var occurences: [Priority: Int] = [:] for priority in priorities { guard let priority else { continue } - counts[priority, default: 0] += 1 + occurences[priority, default: 0] += 1 } - return counts.max { $0.value < $1.value }?.key + return occurences.max { $0.value < $1.value }?.key } @Test func modePriorityAggregate() { $mode.install(database.handle) From 2bcfaa712414bb2d4ba8ff693a8037afb177b8a3 Mon Sep 17 00:00:00 2001 From: Stephen Celis Date: Wed, 15 Oct 2025 17:54:22 -0700 Subject: [PATCH 19/19] wip --- .../Documentation.docc/Articles/CustomFunctions.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/Sources/StructuredQueriesSQLiteCore/Documentation.docc/Articles/CustomFunctions.md b/Sources/StructuredQueriesSQLiteCore/Documentation.docc/Articles/CustomFunctions.md index 78d0128e..edb05ab5 100644 --- a/Sources/StructuredQueriesSQLiteCore/Documentation.docc/Articles/CustomFunctions.md +++ b/Sources/StructuredQueriesSQLiteCore/Documentation.docc/Articles/CustomFunctions.md @@ -69,18 +69,18 @@ a function. But it is quite easy to write this function in plain Swift: ```swift @DatabaseFunction func mode(priority priorities: some Sequence) -> Priority? { - var occurences: [Priority: Int] = [:] + var occurrences: [Priority: Int] = [:] for priority in priorities { guard let priority else { continue } - occurences[priority, default: 0] += 1 + occurrences[priority, default: 0] += 1 } - return occurences.max { $0.value < $1.value }?.key + return occurrences.max { $0.value < $1.value }?.key } ``` This defines an "aggregate" function, and the sequence `priorities` that is passed to it represents -all of the data from the database passed to it while aggregating. It is now straightfoward +all of the data from the database passed to it while aggregating. It is now straightforward to compute the mode of priorities across all reminders: ```swift @@ -141,8 +141,8 @@ func jsonArrayExclaim(_ strings: [String]) -> [String] { } ``` -It is also possible to do this with aggregate functions, but you must describe the sequence as an -`any Sequence` instead of a `some Sequence`: +It is also possible to do this with aggregate functions, too, but you must describe the sequence as +an `any Sequence` instead of a `some Sequence`: ```swift @DatabaseFunction(