diff --git a/Sources/StructuredQueriesSQLite/Macros.swift b/Sources/StructuredQueriesSQLite/Macros.swift index 79166b6c..019d0b62 100644 --- a/Sources/StructuredQueriesSQLite/Macros.swift +++ b/Sources/StructuredQueriesSQLite/Macros.swift @@ -35,3 +35,22 @@ public macro DatabaseFunction( module: "StructuredQueriesSQLiteMacros", type: "DatabaseFunctionMacro" ) + +/// Defines and implements a conformance to the ``/StructuredQueriesSQLiteCore/DatabaseFunction`` +/// protocol. +/// +/// - Parameters +/// - name: The function's name. Defaults to the name of the function the macro is applied to. +/// - representableFunctionType: The function as represented in a query. +/// - isDeterministic: Whether or not the function is deterministic (or "pure" or "referentially +/// transparent"), _i.e._ given an input it will always return the same output. +@attached(peer, names: overloaded, prefixed(`$`)) +public macro DatabaseFunction( + _ name: String = "", + as representableFunctionType: ((repeat each T) -> Void).Type, + isDeterministic: Bool = false +) = +#externalMacro( + module: "StructuredQueriesSQLiteMacros", + type: "DatabaseFunctionMacro" +) diff --git a/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift b/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift index 31f18bbd..91c1304f 100644 --- a/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift +++ b/Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift @@ -25,31 +25,11 @@ extension DatabaseFunctionMacro: PeerMacro { return [] } - guard declaration.signature.returnClause != nil else { - context.diagnose( - Diagnostic( - node: declaration.signature, - position: declaration.signature.endPositionBeforeTrailingTrivia, - message: MacroExpansionErrorMessage( - "Missing required return type" - ), - fixIt: .replace( - message: MacroExpansionFixItMessage("Insert '-> <#QueryBindable#>'"), - oldNode: declaration.signature, - newNode: declaration.signature.with( - \.returnClause, - ReturnClauseSyntax( - type: IdentifierTypeSyntax(name: "<#QueryBindable#>") - .with(\.leadingTrivia, .space) - .with(\.trailingTrivia, .space) - ) - ) - ) - ) + let returnClause = + declaration.signature.returnClause + ?? ReturnClauseSyntax( + type: "Swift.Void" as TypeSyntax ) - return [] - } - let declarationName = declaration.name.trimmedDescription.trimmingBackticks() var functionName = declarationName var functionRepresentation: FunctionTypeSyntax? @@ -158,43 +138,45 @@ extension DatabaseFunctionMacro: PeerMacro { argumentBindings.append((parameterName, "\(type)(queryBinding: arguments[\(offset)])")) } var inputType = bodyArguments.joined(separator: ", ") - let bodyReturnClause: String - let outputType: TypeSyntax - if let returnClause = signature.returnClause { - outputType = returnClause.type.trimmed - signature.returnClause?.type = (functionRepresentation?.returnClause ?? returnClause).type - .asQueryExpression() - bodyReturnClause = " \(returnClause.trimmedDescription)" - } else { - outputType = "Void" - bodyReturnClause = " -> Void" - } + let isVoidReturning = signature.returnClause == nil + let outputType = returnClause.type.trimmed + signature.returnClause = returnClause + signature.returnClause?.type = (functionRepresentation?.returnClause ?? returnClause).type + .asQueryExpression() + let bodyReturnClause = " \(returnClause.trimmedDescription)" let bodyType = """ (\(inputType))\ \(declaration.signature.effectSpecifiers?.trimmedDescription ?? "")\ \(bodyReturnClause) """ + let bodyInvocation = """ + \(declaration.signature.effectSpecifiers?.throwsClause != nil ? "try " : "")self.body(\ + \(argumentBindings.map { name, _ in "\(name).queryOutput" }.joined(separator: ", "))\ + ) + """ // TODO: Diagnose 'asyncClause'? signature.effectSpecifiers?.throwsClause = nil - var invocationBody = """ - \(functionRepresentation?.returnClause.type ?? outputType)( - queryOutput: self.body(\ - \(argumentBindings.map { name, _ in "\(name).queryOutput" }.joined(separator: ", "))\ - ) + var invocationBody = + isVoidReturning + ? """ + \(bodyInvocation) + return .null + """ + : """ + return \(functionRepresentation?.returnClause.type ?? outputType)( + queryOutput: \(bodyInvocation) ) .queryBinding """ if declaration.signature.effectSpecifiers?.throwsClause != nil { invocationBody = """ do { - return try \(invocationBody) + \(invocationBody) } catch { return .invalid(error) } """ - } else { - invocationBody = "return \(invocationBody)" } var attributes = declaration.attributes diff --git a/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift b/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift index 634f06aa..e2f1a070 100644 --- a/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift +++ b/Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift @@ -572,8 +572,8 @@ extension SnapshotTests { return .invalid(InvalidInvocation()) } do { - return try Date( - queryOutput: self.body() + return Date( + queryOutput: try self.body() ) .queryBinding } catch { @@ -627,8 +627,8 @@ extension SnapshotTests { return .invalid(InvalidInvocation()) } do { - return try Date( - queryOutput: self.body() + return Date( + queryOutput: try self.body() ) .queryBinding } catch { @@ -869,7 +869,7 @@ extension SnapshotTests { } } - @Test func returnTypeDiagnostic() { + @Test func voidReturnType() { assertMacro { """ @DatabaseFunction @@ -877,27 +877,56 @@ extension SnapshotTests { print("...") } """ - } diagnostics: { - """ - @DatabaseFunction + } expansion: { + #""" public func void() { - ──┬ - ╰─ 🛑 Missing required return type - ✏️ Insert '-> <#QueryBindable#>' print("...") } - """ - } fixes: { + + public var $void: __macro_local_4voidfMu_ { + __macro_local_4voidfMu_(void) + } + + public struct __macro_local_4voidfMu_: StructuredQueriesSQLiteCore.ScalarDatabaseFunction { + public typealias Input = () + public typealias Output = Swift.Void + public let name = "void" + public let argumentCount: Int? = 0 + public let isDeterministic = false + public let body: () -> Swift.Void + public init(_ body: @escaping () -> Swift.Void) { + self.body = body + } + public func callAsFunction() -> some StructuredQueriesCore.QueryExpression { + StructuredQueriesCore.SQLQueryExpression( + "\(quote: self.name)()" + ) + } + public func invoke( + _ arguments: [StructuredQueriesCore.QueryBinding] + ) -> StructuredQueriesCore.QueryBinding { + guard self.argumentCount == nil || self.argumentCount == arguments.count else { + return .invalid(InvalidInvocation()) + } + self.body() + return .null + } + private struct InvalidInvocation: Error { + } + } + """# + } + assertMacro { """ @DatabaseFunction - public func void() -> <#QueryBindable#> { - print("...") + public func void() throws { + throw Failure() } """ } expansion: { #""" - public func void() -> <#QueryBindable#> { - print("...") + public func void() throws { + throw Failure() } public var $void: __macro_local_4voidfMu_ { @@ -906,15 +935,15 @@ extension SnapshotTests { public struct __macro_local_4voidfMu_: StructuredQueriesSQLiteCore.ScalarDatabaseFunction { public typealias Input = () - public typealias Output = <#QueryBindable#> + public typealias Output = Swift.Void public let name = "void" public let argumentCount: Int? = 0 public let isDeterministic = false - public let body: () -> <#QueryBindable#> - public init(_ body: @escaping () -> <#QueryBindable#>) { + public let body: () throws -> Swift.Void + public init(_ body: @escaping () throws -> Swift.Void) { self.body = body } - public func callAsFunction() -> some StructuredQueriesCore.QueryExpression<<#QueryBindable#>> { + public func callAsFunction() -> some StructuredQueriesCore.QueryExpression { StructuredQueriesCore.SQLQueryExpression( "\(quote: self.name)()" ) @@ -925,10 +954,12 @@ extension SnapshotTests { guard self.argumentCount == nil || self.argumentCount == arguments.count else { return .invalid(InvalidInvocation()) } - return <#QueryBindable#>( - queryOutput: self.body() - ) - .queryBinding + do { + try self.body() + return .null + } catch { + return .invalid(error) + } } private struct InvalidInvocation: Error { } diff --git a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift index 526f721c..d88c8fad 100644 --- a/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift +++ b/Tests/StructuredQueriesTests/DatabaseFunctionTests.swift @@ -10,12 +10,13 @@ import _StructuredQueriesSQLite extension SnapshotTests { @Suite struct DatabaseFunctionTests { + @Dependency(\.defaultDatabase) var database + @DatabaseFunction func isEnabled() -> Bool { true } @Test func customIsEnabled() { - @Dependency(\.defaultDatabase) var database $isEnabled.install(database.handle) assertQuery( Values($isEnabled()) @@ -37,7 +38,6 @@ extension SnapshotTests { Date(timeIntervalSince1970: 0) } @Test func customDateTime() { - @Dependency(\.defaultDatabase) var database $dateTime.install(database.handle) assertQuery( Values($dateTime()) @@ -59,7 +59,6 @@ extension SnapshotTests { first + second } @Test func customConcat() { - @Dependency(\.defaultDatabase) var database $concat.install(database.handle) assertQuery( Values($concat(first: "foo", second: "bar")) @@ -77,7 +76,6 @@ extension SnapshotTests { } @Test func erasedConcat() { - @Dependency(\.defaultDatabase) var database $concat.install(database.handle) assertQuery( Values($concat("foo", "bar")) @@ -104,7 +102,6 @@ extension SnapshotTests { throw Failure() } @Test func customThrowing() { - @Dependency(\.defaultDatabase) var database $throwing.install(database.handle) assertQuery( Values($throwing()) @@ -132,7 +129,6 @@ extension SnapshotTests { completion == .incomplete ? .completing : .incomplete } @Test func customToggle() { - @Dependency(\.defaultDatabase) var database $toggle.install(database.handle) assertQuery( Values($toggle(Completion.incomplete)) @@ -155,7 +151,6 @@ extension SnapshotTests { } @Test func customRepresentation() { - @Dependency(\.defaultDatabase) var database $jsonCapitalize.install(database.handle) assertQuery( Values($jsonCapitalize(#bind(["hello", "world"]))) @@ -184,7 +179,6 @@ extension SnapshotTests { } @Test func customMixedRepresentation() { - @Dependency(\.defaultDatabase) var database $jsonDropFirst.install(database.handle) assertQuery( Values($jsonDropFirst(#bind(["hello", "world", "goodnight", "moon"]), 2)) @@ -215,7 +209,6 @@ extension SnapshotTests { } @Test func customNilRepresentation() { - @Dependency(\.defaultDatabase) var database $jsonCount.install(database.handle) assertQuery( Values($jsonCount(#bind(["hello", "world", "goodnight", "moon"]))) @@ -249,5 +242,34 @@ extension SnapshotTests { """ } } + + final class Logger { + var messages: [String] = [] + + @DatabaseFunction + func log(_ message: String) { + messages.append(message) + } + } + + @Test func voidState() { + let logger = Logger() + logger.$log.install(database.handle) + + assertQuery( + Values(logger.$log("Hello, world!")) + ) { + """ + SELECT "log"('Hello, world!') + """ + } results: { + """ + ┌──┐ + └──┘ + """ + } + + #expect(logger.messages == ["Hello, world!"]) + } } }