Skip to content

Commit e30ebd6

Browse files
authored
Functions should be generic over their query-representable types (#156)
* Functions should be generic over their query-representable types * wip
1 parent adad5c6 commit e30ebd6

File tree

6 files changed

+84
-19
lines changed

6 files changed

+84
-19
lines changed

Package.resolved

Lines changed: 1 addition & 10 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Sources/StructuredQueriesCore/QueryBindable.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ extension [UInt8]: QueryBindable, QueryExpression {
2828

2929
extension Bool: QueryBindable {
3030
public var queryBinding: QueryBinding { .int(self ? 1 : 0) }
31+
public init?(queryBinding: QueryBinding) {
32+
guard case .int(let value) = queryBinding else { return nil }
33+
self = value != 0
34+
}
3135
}
3236

3337
extension Double: QueryBindable {

Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ extension ScalarDatabaseFunction {
3838
///
3939
/// - Parameter input: Expressions representing the arguments of the function.
4040
/// - Returns: An expression representing the function call.
41+
@_disfavoredOverload
4142
public func callAsFunction<each T: QueryExpression>(
4243
_ input: repeat each T
4344
) -> some QueryExpression<Output>

Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ extension DatabaseFunctionMacro: PeerMacro {
105105
let argumentCount = declaration.signature.parameterClause.parameters.count
106106

107107
var bodyArguments: [String] = []
108+
var representableInputTypes: [String] = []
108109
var signature = declaration.signature
109110
var invocationArgumentTypes: [TypeSyntax] = []
110111
var parameters: [String] = []
@@ -125,6 +126,7 @@ extension DatabaseFunctionMacro: PeerMacro {
125126
}
126127
bodyArguments.append("\(parameter.type.trimmed)")
127128
let type = (functionRepresentationIterator?.next()?.type ?? parameter.type).trimmed
129+
representableInputTypes.append(type.trimmedDescription)
128130
parameter.type = type.asQueryExpression()
129131
if let defaultValue = parameter.defaultValue,
130132
defaultValue.value.is(NilLiteralExprSyntax.self)
@@ -137,15 +139,16 @@ extension DatabaseFunctionMacro: PeerMacro {
137139
parameters.append(parameterName)
138140
argumentBindings.append((parameterName, "\(type)(queryBinding: arguments[\(offset)])"))
139141
}
140-
var inputType = bodyArguments.joined(separator: ", ")
142+
var representableInputType = representableInputTypes.joined(separator: ", ")
141143
let isVoidReturning = signature.returnClause == nil
142144
let outputType = returnClause.type.trimmed
143145
signature.returnClause = returnClause
144-
signature.returnClause?.type = (functionRepresentation?.returnClause ?? returnClause).type
145-
.asQueryExpression()
146+
let representableOutputType = (functionRepresentation?.returnClause ?? returnClause).type
147+
.trimmed
148+
signature.returnClause?.type = representableOutputType.asQueryExpression()
146149
let bodyReturnClause = " \(returnClause.trimmedDescription)"
147150
let bodyType = """
148-
(\(inputType))\
151+
(\(bodyArguments.joined(separator: ", ")))\
149152
\(declaration.signature.effectSpecifiers?.trimmedDescription ?? "")\
150153
\(bodyReturnClause)
151154
"""
@@ -198,7 +201,9 @@ extension DatabaseFunctionMacro: PeerMacro {
198201
continue
199202
}
200203
}
201-
inputType = bodyArguments.count == 1 ? inputType : "(\(inputType))"
204+
representableInputType = representableInputTypes.count == 1
205+
? representableInputType
206+
: "(\(representableInputType))"
202207

203208
return [
204209
"""
@@ -209,8 +214,8 @@ extension DatabaseFunctionMacro: PeerMacro {
209214
"""
210215
\(attributes)\(access)struct \(functionTypeName): \
211216
StructuredQueriesSQLiteCore.ScalarDatabaseFunction {
212-
public typealias Input = \(raw: inputType)
213-
public typealias Output = \(outputType)
217+
public typealias Input = \(raw: representableInputType)
218+
public typealias Output = \(representableOutputType)
214219
public let name = \(databaseFunctionName)
215220
public let argumentCount: Int? = \(raw: argumentCount)
216221
public let isDeterministic = \(raw: isDeterministic)

Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,8 @@ extension SnapshotTests {
126126
}
127127
128128
struct __macro_local_14jsonCapitalizefMu_: StructuredQueriesSQLiteCore.ScalarDatabaseFunction {
129-
public typealias Input = [String]
130-
public typealias Output = [String]
129+
public typealias Input = [String].JSONRepresentation
130+
public typealias Output = [String].JSONRepresentation
131131
public let name = "jsonCapitalize"
132132
public let argumentCount: Int? = 1
133133
public let isDeterministic = false

Tests/StructuredQueriesTests/DatabaseFunctionTests.swift

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -271,5 +271,69 @@ extension SnapshotTests {
271271

272272
#expect(logger.messages == ["Hello, world!"])
273273
}
274+
275+
@DatabaseFunction(as: (([Tag].JSONRepresentation) -> String).self)
276+
func joinTags(_ tags: [Tag]) -> String {
277+
tags.map(\.title).joined(separator: ", ")
278+
}
279+
280+
@Test func jsonArray() {
281+
$joinTags.install(database.handle)
282+
283+
assertQuery(
284+
Reminder
285+
.group(by: \.id)
286+
.leftJoin(ReminderTag.all) { $0.id.eq($1.reminderID) }
287+
.leftJoin(Tag.all) { $1.tagID.eq($2.id) }
288+
.select { $joinTags($2.jsonGroupArray()) }
289+
) {
290+
"""
291+
SELECT "joinTags"(json_group_array(CASE WHEN ("tags"."rowid" IS NOT NULL) THEN json_object('id', json_quote("tags"."id"), 'title', json_quote("tags"."title")) END) FILTER (WHERE ("tags"."id" IS NOT NULL)))
292+
FROM "reminders"
293+
LEFT JOIN "remindersTags" ON ("reminders"."id" = "remindersTags"."reminderID")
294+
LEFT JOIN "tags" ON ("remindersTags"."tagID" = "tags"."id")
295+
GROUP BY "reminders"."id"
296+
"""
297+
} results: {
298+
"""
299+
┌─────────────────────┐
300+
"someday, optional"
301+
"someday, optional"
302+
""
303+
"car, kids"
304+
""
305+
""
306+
""
307+
""
308+
""
309+
""
310+
└─────────────────────┘
311+
"""
312+
}
313+
}
314+
315+
@DatabaseFunction(as: ((Reminder.JSONRepresentation, Bool) -> Bool).self)
316+
func isValid(_ reminder: Reminder, _ override: Bool = false) -> Bool {
317+
!reminder.title.isEmpty || override
318+
}
319+
@Test func jsonObject() {
320+
$isValid.install(database.handle)
321+
322+
assertQuery(
323+
Reminder.select { $isValid($0.jsonObject(), true) }.limit(1)
324+
) {
325+
"""
326+
SELECT "isValid"(json_object('id', json_quote("reminders"."id"), 'assignedUserID', json_quote("reminders"."assignedUserID"), 'dueDate', json_quote("reminders"."dueDate"), 'isCompleted', json(CASE "reminders"."isCompleted" WHEN 0 THEN 'false' WHEN 1 THEN 'true' END), 'isFlagged', json(CASE "reminders"."isFlagged" WHEN 0 THEN 'false' WHEN 1 THEN 'true' END), 'notes', json_quote("reminders"."notes"), 'priority', json_quote("reminders"."priority"), 'remindersListID', json_quote("reminders"."remindersListID"), 'title', json_quote("reminders"."title"), 'updatedAt', json_quote("reminders"."updatedAt")), 1)
327+
FROM "reminders"
328+
LIMIT 1
329+
"""
330+
} results: {
331+
"""
332+
┌──────┐
333+
│ true │
334+
└──────┘
335+
"""
336+
}
337+
}
274338
}
275339
}

0 commit comments

Comments
 (0)