Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions Sources/StructuredQueriesSQLite/Macros.swift
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,22 @@ public macro DatabaseFunction<each T: QueryBindable, R: QueryBindable>(
module: "StructuredQueriesSQLiteMacros",
type: "DatabaseFunctionMacro"
)

/// Defines and implements a conformance to the ``/StructuredQueriesSQLiteCore/DatabaseFunction``
/// protocol.
///
/// - Parameters
/// - name: The function's name. Defaults to the name of the function the macro is applied to.
/// - representableFunctionType: The function as represented in a query.
/// - isDeterministic: Whether or not the function is deterministic (or "pure" or "referentially
/// transparent"), _i.e._ given an input it will always return the same output.
@attached(peer, names: overloaded, prefixed(`$`))
public macro DatabaseFunction<each T: QueryBindable>(
_ name: String = "",
as representableFunctionType: ((repeat each T) -> Void).Type,
isDeterministic: Bool = false
) =
#externalMacro(
module: "StructuredQueriesSQLiteMacros",
type: "DatabaseFunctionMacro"
)
68 changes: 25 additions & 43 deletions Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,31 +25,11 @@ extension DatabaseFunctionMacro: PeerMacro {
return []
}

guard declaration.signature.returnClause != nil else {
context.diagnose(
Diagnostic(
node: declaration.signature,
position: declaration.signature.endPositionBeforeTrailingTrivia,
message: MacroExpansionErrorMessage(
"Missing required return type"
),
fixIt: .replace(
message: MacroExpansionFixItMessage("Insert '-> <#QueryBindable#>'"),
oldNode: declaration.signature,
newNode: declaration.signature.with(
\.returnClause,
ReturnClauseSyntax(
type: IdentifierTypeSyntax(name: "<#QueryBindable#>")
.with(\.leadingTrivia, .space)
.with(\.trailingTrivia, .space)
)
)
)
)
let returnClause =
declaration.signature.returnClause
?? ReturnClauseSyntax(
type: "Swift.Void" as TypeSyntax
)
return []
}

let declarationName = declaration.name.trimmedDescription.trimmingBackticks()
var functionName = declarationName
var functionRepresentation: FunctionTypeSyntax?
Expand Down Expand Up @@ -158,43 +138,45 @@ extension DatabaseFunctionMacro: PeerMacro {
argumentBindings.append((parameterName, "\(type)(queryBinding: arguments[\(offset)])"))
}
var inputType = bodyArguments.joined(separator: ", ")
let bodyReturnClause: String
let outputType: TypeSyntax
if let returnClause = signature.returnClause {
outputType = returnClause.type.trimmed
signature.returnClause?.type = (functionRepresentation?.returnClause ?? returnClause).type
.asQueryExpression()
bodyReturnClause = " \(returnClause.trimmedDescription)"
} else {
outputType = "Void"
bodyReturnClause = " -> Void"
}
let isVoidReturning = signature.returnClause == nil
let outputType = returnClause.type.trimmed
signature.returnClause = returnClause
signature.returnClause?.type = (functionRepresentation?.returnClause ?? returnClause).type
.asQueryExpression()
let bodyReturnClause = " \(returnClause.trimmedDescription)"
let bodyType = """
(\(inputType))\
\(declaration.signature.effectSpecifiers?.trimmedDescription ?? "")\
\(bodyReturnClause)
"""
let bodyInvocation = """
\(declaration.signature.effectSpecifiers?.throwsClause != nil ? "try " : "")self.body(\
\(argumentBindings.map { name, _ in "\(name).queryOutput" }.joined(separator: ", "))\
)
"""
// TODO: Diagnose 'asyncClause'?
signature.effectSpecifiers?.throwsClause = nil

var invocationBody = """
\(functionRepresentation?.returnClause.type ?? outputType)(
queryOutput: self.body(\
\(argumentBindings.map { name, _ in "\(name).queryOutput" }.joined(separator: ", "))\
)
var invocationBody =
isVoidReturning
? """
\(bodyInvocation)
return .null
"""
: """
return \(functionRepresentation?.returnClause.type ?? outputType)(
queryOutput: \(bodyInvocation)
)
.queryBinding
"""
if declaration.signature.effectSpecifiers?.throwsClause != nil {
invocationBody = """
do {
return try \(invocationBody)
\(invocationBody)
} catch {
return .invalid(error)
}
"""
} else {
invocationBody = "return \(invocationBody)"
}

var attributes = declaration.attributes
Expand Down
81 changes: 56 additions & 25 deletions Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -572,8 +572,8 @@ extension SnapshotTests {
return .invalid(InvalidInvocation())
}
do {
return try Date(
queryOutput: self.body()
return Date(
queryOutput: try self.body()
)
.queryBinding
} catch {
Expand Down Expand Up @@ -627,8 +627,8 @@ extension SnapshotTests {
return .invalid(InvalidInvocation())
}
do {
return try Date(
queryOutput: self.body()
return Date(
queryOutput: try self.body()
)
.queryBinding
} catch {
Expand Down Expand Up @@ -869,35 +869,64 @@ extension SnapshotTests {
}
}

@Test func returnTypeDiagnostic() {
@Test func voidReturnType() {
assertMacro {
"""
@DatabaseFunction
public func void() {
print("...")
}
"""
} diagnostics: {
"""
@DatabaseFunction
} expansion: {
#"""
public func void() {
──┬
╰─ 🛑 Missing required return type
✏️ Insert '-> <#QueryBindable#>'
print("...")
}
"""
} fixes: {

public var $void: __macro_local_4voidfMu_ {
__macro_local_4voidfMu_(void)
}

public struct __macro_local_4voidfMu_: StructuredQueriesSQLiteCore.ScalarDatabaseFunction {
public typealias Input = ()
public typealias Output = Swift.Void
public let name = "void"
public let argumentCount: Int? = 0
public let isDeterministic = false
public let body: () -> Swift.Void
public init(_ body: @escaping () -> Swift.Void) {
self.body = body
}
public func callAsFunction() -> some StructuredQueriesCore.QueryExpression<Swift.Void> {
StructuredQueriesCore.SQLQueryExpression(
"\(quote: self.name)()"
)
}
public func invoke(
_ arguments: [StructuredQueriesCore.QueryBinding]
) -> StructuredQueriesCore.QueryBinding {
guard self.argumentCount == nil || self.argumentCount == arguments.count else {
return .invalid(InvalidInvocation())
}
self.body()
return .null
}
private struct InvalidInvocation: Error {
}
}
"""#
}
assertMacro {
"""
@DatabaseFunction
public func void() -> <#QueryBindable#> {
print("...")
public func void() throws {
throw Failure()
}
"""
} expansion: {
#"""
public func void() -> <#QueryBindable#> {
print("...")
public func void() throws {
throw Failure()
}

public var $void: __macro_local_4voidfMu_ {
Expand All @@ -906,15 +935,15 @@ extension SnapshotTests {

public struct __macro_local_4voidfMu_: StructuredQueriesSQLiteCore.ScalarDatabaseFunction {
public typealias Input = ()
public typealias Output = <#QueryBindable#>
public typealias Output = Swift.Void
public let name = "void"
public let argumentCount: Int? = 0
public let isDeterministic = false
public let body: () -> <#QueryBindable#>
public init(_ body: @escaping () -> <#QueryBindable#>) {
public let body: () throws -> Swift.Void
public init(_ body: @escaping () throws -> Swift.Void) {
self.body = body
}
public func callAsFunction() -> some StructuredQueriesCore.QueryExpression<<#QueryBindable#>> {
public func callAsFunction() -> some StructuredQueriesCore.QueryExpression<Swift.Void> {
StructuredQueriesCore.SQLQueryExpression(
"\(quote: self.name)()"
)
Expand All @@ -925,10 +954,12 @@ extension SnapshotTests {
guard self.argumentCount == nil || self.argumentCount == arguments.count else {
return .invalid(InvalidInvocation())
}
return <#QueryBindable#>(
queryOutput: self.body()
)
.queryBinding
do {
try self.body()
return .null
} catch {
return .invalid(error)
}
}
private struct InvalidInvocation: Error {
}
Expand Down
40 changes: 31 additions & 9 deletions Tests/StructuredQueriesTests/DatabaseFunctionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ import _StructuredQueriesSQLite

extension SnapshotTests {
@Suite struct DatabaseFunctionTests {
@Dependency(\.defaultDatabase) var database

@DatabaseFunction
func isEnabled() -> Bool {
true
}
@Test func customIsEnabled() {
@Dependency(\.defaultDatabase) var database
$isEnabled.install(database.handle)
assertQuery(
Values($isEnabled())
Expand All @@ -37,7 +38,6 @@ extension SnapshotTests {
Date(timeIntervalSince1970: 0)
}
@Test func customDateTime() {
@Dependency(\.defaultDatabase) var database
$dateTime.install(database.handle)
assertQuery(
Values($dateTime())
Expand All @@ -59,7 +59,6 @@ extension SnapshotTests {
first + second
}
@Test func customConcat() {
@Dependency(\.defaultDatabase) var database
$concat.install(database.handle)
assertQuery(
Values($concat(first: "foo", second: "bar"))
Expand All @@ -77,7 +76,6 @@ extension SnapshotTests {
}

@Test func erasedConcat() {
@Dependency(\.defaultDatabase) var database
$concat.install(database.handle)
assertQuery(
Values($concat("foo", "bar"))
Expand All @@ -104,7 +102,6 @@ extension SnapshotTests {
throw Failure()
}
@Test func customThrowing() {
@Dependency(\.defaultDatabase) var database
$throwing.install(database.handle)
assertQuery(
Values($throwing())
Expand Down Expand Up @@ -132,7 +129,6 @@ extension SnapshotTests {
completion == .incomplete ? .completing : .incomplete
}
@Test func customToggle() {
@Dependency(\.defaultDatabase) var database
$toggle.install(database.handle)
assertQuery(
Values($toggle(Completion.incomplete))
Expand All @@ -155,7 +151,6 @@ extension SnapshotTests {
}

@Test func customRepresentation() {
@Dependency(\.defaultDatabase) var database
$jsonCapitalize.install(database.handle)
assertQuery(
Values($jsonCapitalize(#bind(["hello", "world"])))
Expand Down Expand Up @@ -184,7 +179,6 @@ extension SnapshotTests {
}

@Test func customMixedRepresentation() {
@Dependency(\.defaultDatabase) var database
$jsonDropFirst.install(database.handle)
assertQuery(
Values($jsonDropFirst(#bind(["hello", "world", "goodnight", "moon"]), 2))
Expand Down Expand Up @@ -215,7 +209,6 @@ extension SnapshotTests {
}

@Test func customNilRepresentation() {
@Dependency(\.defaultDatabase) var database
$jsonCount.install(database.handle)
assertQuery(
Values($jsonCount(#bind(["hello", "world", "goodnight", "moon"])))
Expand Down Expand Up @@ -249,5 +242,34 @@ extension SnapshotTests {
"""
}
}

final class Logger {
var messages: [String] = []

@DatabaseFunction
func log(_ message: String) {
messages.append(message)
}
}

@Test func voidState() {
let logger = Logger()
logger.$log.install(database.handle)

assertQuery(
Values(logger.$log("Hello, world!"))
) {
"""
SELECT "log"('Hello, world!')
"""
} results: {
"""
┌──┐
└──┘
"""
}

#expect(logger.messages == ["Hello, world!"])
}
}
}