Skip to content

Commit 74ec4d1

Browse files
committed
wip
1 parent 0092915 commit 74ec4d1

File tree

6 files changed

+101
-37
lines changed

6 files changed

+101
-37
lines changed

Sources/StructuredQueriesSQLite/Macros.swift

Lines changed: 21 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -55,26 +55,24 @@ public macro DatabaseFunction<each T: QueryRepresentable & QueryExpression>(
5555
type: "DatabaseFunctionMacro"
5656
)
5757

58-
// TODO:
59-
// @attached(peer, names: overloaded, prefixed(`$`))
60-
// public macro DatabaseFunction<each T: QueryRepresentable & QueryExpression, R: QueryBindable>(
61-
// _ name: String = "",
62-
// as representableFunctionType: ((any Sequence<(repeat each T)>) -> R).Type,
63-
// isDeterministic: Bool = false
64-
// ) =
65-
// #externalMacro(
66-
// module: "StructuredQueriesSQLiteMacros",
67-
// type: "DatabaseFunctionMacro"
68-
// )
69-
//
70-
// @attached(peer, names: overloaded, prefixed(`$`))
71-
// public macro DatabaseFunction<each T: QueryRepresentable & QueryExpression>(
72-
// _ name: String = "",
73-
// as representableFunctionType: ((any Sequence<(repeat each T)>) -> Void).Type,
74-
// isDeterministic: Bool = false
75-
// ) =
76-
// #externalMacro(
77-
// module: "StructuredQueriesSQLiteMacros",
78-
// type: "DatabaseFunctionMacro"
79-
// )
80-
//
58+
@attached(peer, names: overloaded, prefixed(`$`))
59+
public macro DatabaseFunction<each T: QueryRepresentable & QueryExpression, R: QueryBindable>(
60+
_ name: String = "",
61+
as representableFunctionType: ((any Sequence<(repeat each T)>) -> R).Type,
62+
isDeterministic: Bool = false
63+
) =
64+
#externalMacro(
65+
module: "StructuredQueriesSQLiteMacros",
66+
type: "DatabaseFunctionMacro"
67+
)
68+
69+
@attached(peer, names: overloaded, prefixed(`$`))
70+
public macro DatabaseFunction<each T: QueryRepresentable & QueryExpression>(
71+
_ name: String = "",
72+
as representableFunctionType: ((any Sequence<(repeat each T)>) -> Void).Type,
73+
isDeterministic: Bool = false
74+
) =
75+
#externalMacro(
76+
module: "StructuredQueriesSQLiteMacros",
77+
type: "DatabaseFunctionMacro"
78+
)

Sources/StructuredQueriesSQLiteCore/DatabaseFunction.swift

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,11 @@ extension ScalarDatabaseFunction {
5656
/// Don't conform to this protocol directly. Instead, use the `@DatabaseFunction` macro to generate
5757
/// a conformance.
5858
public protocol AggregateDatabaseFunction<Input, Output>: DatabaseFunction {
59-
func step(_ decoder: inout some QueryDecoder) throws -> Input
59+
associatedtype Row
6060

61-
func invoke(_ arguments: some Sequence<Input>) throws -> QueryBinding
61+
func step(_ decoder: inout some QueryDecoder) throws -> Row
62+
63+
func invoke(_ arguments: some Sequence<Row>) throws -> QueryBinding
6264
}
6365

6466
extension AggregateDatabaseFunction {

Sources/StructuredQueriesSQLiteMacros/DatabaseFunctionMacro.swift

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,14 +110,14 @@ extension DatabaseFunctionMacro: PeerMacro {
110110
var invocationArgumentTypes: [TypeSyntax] = []
111111
var parameters: [String] = []
112112
var argumentBindings: [String] = []
113-
var functionRepresentationIterator = functionRepresentation?.parameters.makeIterator()
114113

115114
var decodings: [String] = []
116115
var decodingUnwrappings: [String] = []
117116
var canThrowInvalidInvocation = false
118117

119118
let isAggregate: Bool
120119
var representableInputType: String
120+
var rowType = ""
121121
let projectedCallSyntax: ExprSyntax
122122

123123
if signature.parameterClause.parameters.count == 1,
@@ -131,7 +131,6 @@ extension DatabaseFunctionMacro: PeerMacro {
131131
let genericArgument = genericArgumentClause.arguments.first
132132
{
133133
isAggregate = true
134-
representableInputType = "\(genericArgument)"
135134

136135
someOrAnyParameterType.someOrAnySpecifier.tokenKind = .keyword(.any)
137136
let bodySignature =
@@ -159,10 +158,26 @@ extension DatabaseFunctionMacro: PeerMacro {
159158
]
160159
)
161160

161+
let representableInputGeneric = functionRepresentation?
162+
.parameters.first?
163+
.type.as(SomeOrAnyTypeSyntax.self)?
164+
.constraint.as(IdentifierTypeSyntax.self)?
165+
.genericArgumentClause?
166+
.arguments.first
167+
let representableInputGenericArgument = representableInputGeneric?.argument
168+
169+
representableInputType = "\(representableInputGeneric ?? genericArgument)"
170+
rowType = "\(genericArgument)"
171+
172+
let representableInputArguments =
173+
representableInputGenericArgument?.as(TupleTypeSyntax.self)?.elements.map(\.type)
174+
?? (representableInputGenericArgument?.cast(TypeSyntax.self)).map { [$0] }
175+
var representableInputArgumentsIterator = representableInputArguments?.makeIterator()
176+
162177
var offset = 0
163178
for var element in tupleType.elements {
164179
defer { offset += 1 }
165-
var type = (functionRepresentationIterator?.next()?.type ?? element.type)
180+
var type = representableInputArgumentsIterator?.next() ?? element.type
166181
element.type = type.asQueryExpression()
167182
type = type.trimmed
168183
representableInputTypes.append(type.description)
@@ -225,6 +240,7 @@ extension DatabaseFunctionMacro: PeerMacro {
225240
"""
226241
} else {
227242
isAggregate = false
243+
var functionRepresentationIterator = functionRepresentation?.parameters.makeIterator()
228244

229245
for index in signature.parameterClause.parameters.indices {
230246
var parameter = signature.parameterClause.parameters[index]
@@ -343,7 +359,7 @@ extension DatabaseFunctionMacro: PeerMacro {
343359
"""
344360
public func step(
345361
_ decoder: inout some QueryDecoder
346-
) throws -> \(raw: representableInputType) {
362+
) throws -> \(raw: rowType) {
347363
\(raw: (decodings + decodingUnwrappings).map { "\($0)\n" }.joined())\
348364
\(raw: stepReturnClause)\
349365
}

Sources/_StructuredQueriesSQLite/DatabaseFunction.swift

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ private protocol AggregateDatabaseFunctionIteratorProtocol<Body> {
114114
associatedtype Body: AggregateDatabaseFunction
115115

116116
var body: Body { get }
117-
var stream: Stream<Body.Input> { get }
117+
var stream: Stream<Body.Row> { get }
118118
func start()
119119
func step(_ decoder: inout some QueryDecoder) throws
120120
func finish()
@@ -125,7 +125,7 @@ private final class AggregateDatabaseFunctionIterator<
125125
Body: AggregateDatabaseFunction
126126
>: AggregateDatabaseFunctionIteratorProtocol {
127127
let body: Body
128-
let stream = Stream<Body.Input>()
128+
let stream = Stream<Body.Row>()
129129
let queue = DispatchQueue.global(qos: .userInitiated)
130130
var _result: QueryBinding?
131131
init(_ body: Body) {

Tests/StructuredQueriesMacrosTests/DatabaseFunctionMacroTests.swift

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1409,7 +1409,6 @@ extension SnapshotTests {
14091409
}
14101410
}
14111411

1412-
// TODO
14131412
@Test func customRepresentations() {
14141413
assertMacro {
14151414
#"""
@@ -1433,20 +1432,20 @@ extension SnapshotTests {
14331432
}
14341433
14351434
struct __macro_local_6joinedfMu_: StructuredQueriesSQLiteCore.AggregateDatabaseFunction {
1436-
public typealias Input = [String]
1435+
public typealias Input = [String].JSONRepresentation
14371436
public typealias Output = [String].JSONRepresentation
14381437
public let name = "joined"
14391438
public var argumentCount: Int? {
14401439
var argumentCount = 0
1441-
argumentCount += any Sequence<[String].JSONRepresentation>._columnWidth
1440+
argumentCount += [String].JSONRepresentation._columnWidth
14421441
return argumentCount
14431442
}
14441443
public let isDeterministic = false
14451444
public let body: (_ arrays: any Sequence<[String]>) -> [String]
14461445
public init(_ body: @escaping (_ arrays: any Sequence<[String]>) -> [String]) {
14471446
self.body = body
14481447
}
1449-
public func callAsFunction(_ arrays: some StructuredQueriesCore.QueryExpression<any Sequence<[String].JSONRepresentation>>, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression<Bool>)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression<[String].JSONRepresentation> {
1448+
public func callAsFunction(_ arrays: some StructuredQueriesCore.QueryExpression<[String].JSONRepresentation>, order: (some QueryExpression)? = Bool?.none, filter: (some QueryExpression<Bool>)? = Bool?.none) -> some StructuredQueriesCore.QueryExpression<[String].JSONRepresentation> {
14501449
StructuredQueriesCore.$_isSelecting.withValue(false) {
14511450
StructuredQueriesCore.AggregateFunction(
14521451
self.name, arrays, order: order, filter: filter
@@ -1456,7 +1455,7 @@ extension SnapshotTests {
14561455
public func step(
14571456
_ decoder: inout some QueryDecoder
14581457
) throws -> [String] {
1459-
let arrays = try decoder.decode(any Sequence<[String].JSONRepresentation>.self)
1458+
let arrays = try decoder.decode([String].JSONRepresentation.self)
14601459
guard let arrays else {
14611460
throw InvalidInvocation()
14621461
}

Tests/StructuredQueriesTests/DatabaseFunctionTests.swift

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ extension SnapshotTests {
477477
}
478478

479479
@DatabaseFunction
480-
func joined(_ arguments: some Sequence<(String, separator: String)>) throws -> String? {
480+
func joined(_ arguments: some Sequence<(String, separator: String)>) -> String? {
481481
var iterator = arguments.makeIterator()
482482
guard var (result, _) = iterator.next() else { return nil }
483483
while let (string, separator) = iterator.next() {
@@ -505,5 +505,54 @@ extension SnapshotTests {
505505
"""
506506
}
507507
}
508+
509+
@DatabaseFunction(
510+
as: ((any Sequence<[String].JSONRepresentation>) -> [String].JSONRepresentation).self
511+
)
512+
func jsonJoined(_ arrays: some Sequence<[String]>) -> [String] {
513+
arrays.flatMap(\.self)
514+
}
515+
516+
@Test func aggregateRepresentation() {
517+
$jsonJoined.install(database.handle)
518+
519+
assertQuery(
520+
Reminder.select {
521+
$jsonJoined(#sql("json_array(\($0.title.lower()), \($0.title.upper()))"))
522+
}
523+
) {
524+
"""
525+
SELECT "jsonJoined"(json_array(lower("reminders"."title"), upper("reminders"."title")))
526+
FROM "reminders"
527+
"""
528+
} results: {
529+
"""
530+
┌─────────────────────────────────────┐
531+
│ [ │
532+
│ [0]: "groceries", │
533+
│ [1]: "GROCERIES", │
534+
│ [2]: "haircut", │
535+
│ [3]: "HAIRCUT", │
536+
│ [4]: "doctor appointment", │
537+
│ [5]: "DOCTOR APPOINTMENT", │
538+
│ [6]: "take a walk", │
539+
│ [7]: "TAKE A WALK", │
540+
│ [8]: "buy concert tickets", │
541+
│ [9]: "BUY CONCERT TICKETS", │
542+
│ [10]: "pick up kids from school", │
543+
│ [11]: "PICK UP KIDS FROM SCHOOL", │
544+
│ [12]: "get laundry", │
545+
│ [13]: "GET LAUNDRY", │
546+
│ [14]: "take out trash", │
547+
│ [15]: "TAKE OUT TRASH", │
548+
│ [16]: "call accountant", │
549+
│ [17]: "CALL ACCOUNTANT", │
550+
│ [18]: "send weekly emails", │
551+
│ [19]: "SEND WEEKLY EMAILS"
552+
│ ] │
553+
└─────────────────────────────────────┘
554+
"""
555+
}
556+
}
508557
}
509558
}

0 commit comments

Comments
 (0)