Skip to content

Commit e325781

Browse files
committed
wip
1 parent f68e4a0 commit e325781

File tree

5 files changed

+230
-87
lines changed

5 files changed

+230
-87
lines changed
Lines changed: 0 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,88 +1 @@
11
@_exported import StructuredQueriesCore
2-
3-
// ---
4-
import Foundation
5-
6-
// @CustomFunction // @DatabaseFunction ? @ScalarFunction (_vs._ @AggregateFunction?)
7-
func dateTime(_ format: String? = nil) -> Date {
8-
Date()
9-
}
10-
11-
// Macro expansion:
12-
@available(macOS 14, *)
13-
@_disfavoredOverload // Or can/should this be applied the the above?
14-
func dateTime(
15-
_ format: some QueryExpression<String?> = String?.none
16-
) -> some QueryExpression<Date> {
17-
_$dateTime(format)
18-
}
19-
20-
@available(macOS 14, *)
21-
var _$dateTime: CustomFunction<String?, Date> {
22-
CustomFunction("dateTime", isDeterministic: false, body: dateTime(_:))
23-
}
24-
// ---
25-
26-
import SQLite3
27-
28-
@available(macOS 14, *)
29-
struct CustomFunction<each Input, Output: QueryBindable> {
30-
let name: String
31-
let isDeterministic: Bool
32-
// private let body: Body
33-
34-
init(
35-
_ name: String,
36-
isDeterministic: Bool,
37-
body: @escaping (repeat each Input) -> Output
38-
) {
39-
self.name = name
40-
self.isDeterministic = isDeterministic
41-
// self.body = Body(body)
42-
}
43-
44-
func callAsFunction<each T>(_ input: repeat each T) -> SQLQueryExpression<Output>
45-
where repeat each T: QueryExpression<each Input> {
46-
var arguments: [QueryFragment] = []
47-
for input in repeat each input {
48-
arguments.append(input.queryFragment)
49-
}
50-
return SQLQueryExpression("\(quote: name)(\(arguments.joined(separator: ", "))")
51-
}
52-
53-
// func install(_ db: OpaquePointer) {
54-
// // TODO: Should this be `-1`?
55-
// var count: Int32 = 0
56-
// for _ in repeat (each Input).self {
57-
// count += 1
58-
// }
59-
// let body = Unmanaged.passRetained(body).toOpaque()
60-
// sqlite3_create_function_v2(
61-
// db,
62-
// name,
63-
// count,
64-
// SQLITE_UTF8 | (isDeterministic ? SQLITE_DETERMINISTIC : 0),
65-
// body,
66-
// { ctx, argc, argv in
67-
//// let body = Unmanaged<Body>
68-
//// .fromOpaque(sqlite3_user_data(ctx))
69-
//// .takeUnretainedValue()
70-
// },
71-
// nil,
72-
// nil,
73-
// { ctx in
74-
//// Unmanaged<AnyObject>.fromOpaque(body).release()
75-
// }
76-
// )
77-
// }
78-
}
79-
80-
81-
private final class Body {
82-
let body: ([Any]) -> Any
83-
init<each Input, Output: QueryBindable>(_ body: @escaping (repeat each Input) -> Output) {
84-
fatalError()
85-
}
86-
}
87-
88-

Sources/StructuredQueriesCore/Optional.swift

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@ extension Optional: QueryBindable where Wrapped: QueryBindable {
2727
public var queryBinding: QueryBinding {
2828
self?.queryBinding ?? .null
2929
}
30+
31+
public init?(queryBinding: QueryBinding) {
32+
self = Wrapped(queryBinding: queryBinding)
33+
}
3034
}
3135

3236
extension Optional: QueryDecodable where Wrapped: QueryDecodable {

Sources/StructuredQueriesCore/QueryBindable.swift

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,25 @@ public protocol QueryBindable: QueryRepresentable, QueryExpression where QueryVa
99

1010
/// A value that can be bound to a parameter of a SQL statement.
1111
var queryBinding: QueryBinding { get }
12+
13+
init?(queryBinding: QueryBinding)
1214
}
1315

1416
extension QueryBindable {
1517
public var queryFragment: QueryFragment { "\(queryBinding)" }
18+
19+
public init?(queryBinding: QueryBinding) {
20+
guard let queryValue = QueryValue(queryBinding: queryBinding) else { return nil }
21+
self.init(queryBinding: queryValue.queryBinding)
22+
}
1623
}
1724

1825
extension [UInt8]: QueryBindable, QueryExpression {
1926
public var queryBinding: QueryBinding { .blob(self) }
27+
public init?(queryBinding: QueryBinding) {
28+
guard case .blob = queryBinding else { return nil }
29+
self.init(queryBinding: queryBinding)
30+
}
2031
}
2132

2233
extension Bool: QueryBindable {
@@ -25,10 +36,18 @@ extension Bool: QueryBindable {
2536

2637
extension Double: QueryBindable {
2738
public var queryBinding: QueryBinding { .double(self) }
39+
public init?(queryBinding: QueryBinding) {
40+
guard case .double = queryBinding else { return nil }
41+
self.init(queryBinding: queryBinding)
42+
}
2843
}
2944

3045
extension Date: QueryBindable {
3146
public var queryBinding: QueryBinding { .date(self) }
47+
public init?(queryBinding: QueryBinding) {
48+
guard case .date = queryBinding else { return nil }
49+
self.init(queryBinding: queryBinding)
50+
}
3251
}
3352

3453
extension Float: QueryBindable {
@@ -53,10 +72,18 @@ extension Int32: QueryBindable {
5372

5473
extension Int64: QueryBindable {
5574
public var queryBinding: QueryBinding { .int(self) }
75+
public init?(queryBinding: QueryBinding) {
76+
guard case .int = queryBinding else { return nil }
77+
self.init(queryBinding: queryBinding)
78+
}
5679
}
5780

5881
extension String: QueryBindable {
5982
public var queryBinding: QueryBinding { .text(self) }
83+
public init?(queryBinding: QueryBinding) {
84+
guard case .text = queryBinding else { return nil }
85+
self.init(queryBinding: queryBinding)
86+
}
6087
}
6188

6289
extension UInt8: QueryBindable {
@@ -83,6 +110,10 @@ extension UInt64: QueryBindable {
83110

84111
extension UUID: QueryBindable {
85112
public var queryBinding: QueryBinding { .uuid(self) }
113+
public init?(queryBinding: QueryBinding) {
114+
guard case .uuid = queryBinding else { return nil }
115+
self.init(queryBinding: queryBinding)
116+
}
86117
}
87118

88119
extension DefaultStringInterpolation {

Sources/StructuredQueriesSQLite/Database.swift

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,15 @@ public struct Database {
1111
@usableFromInline
1212
let storage: Storage
1313

14+
public var handle: OpaquePointer {
15+
switch storage {
16+
case .owned(let storage):
17+
return storage.handle
18+
case .unowned(let handle):
19+
return handle
20+
}
21+
}
22+
1423
public init(_ ptr: OpaquePointer) {
1524
self.storage = .unowned(ptr)
1625
}
Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import Dependencies
2+
import Foundation
3+
import InlineSnapshotTesting
4+
import StructuredQueries
5+
import StructuredQueriesSQLite
6+
import StructuredQueriesTestSupport
7+
import Testing
8+
9+
extension SnapshotTests {
10+
@Suite struct CustomFunctionTests {
11+
@Test func customDate() {
12+
@Dependency(\.defaultDatabase) var database
13+
__$dateTime.install(database.handle)
14+
assertQuery(
15+
Values(_$dateTime())
16+
) {
17+
"""
18+
SELECT "dateTime"(NULL)
19+
"""
20+
} results: {
21+
"""
22+
┌────────────────────────────────┐
23+
│ Date(1970-01-01T00:00:00.000Z) │
24+
└────────────────────────────────┘
25+
"""
26+
}
27+
}
28+
}
29+
}
30+
31+
// ---
32+
import Foundation
33+
import SQLite3
34+
35+
// @CustomFunction // @DatabaseFunction ? @ScalarFunction (_vs._ @AggregateFunction?)
36+
func dateTime(_ format: String? = nil) -> Date {
37+
Date(timeIntervalSince1970: 0)
38+
}
39+
40+
// Macro expansion:
41+
@available(macOS 14, *)
42+
func _$dateTime(
43+
_ format: some QueryExpression<String?> = String?.none
44+
) -> some QueryExpression<Date> {
45+
__$dateTime(format)
46+
}
47+
48+
@available(macOS 14, *)
49+
var __$dateTime: CustomFunction<String?, Date, Never> {
50+
CustomFunction("dateTime", isDeterministic: false, body: dateTime(_:))
51+
}
52+
// ---
53+
// Library code:
54+
@available(macOS 14, *)
55+
struct CustomFunction<each Input: QueryBindable, Output: QueryBindable, Failure: Error> {
56+
let name: String
57+
let isDeterministic: Bool
58+
let body: (repeat each Input) throws(Failure) -> Output
59+
60+
init(
61+
_ name: String,
62+
isDeterministic: Bool,
63+
body: @escaping (repeat each Input) throws(Failure) -> Output
64+
) {
65+
self.name = name
66+
self.isDeterministic = isDeterministic
67+
self.body = body
68+
}
69+
70+
func callAsFunction<each T>(_ input: repeat each T) -> SQLQueryExpression<Output>
71+
where repeat each T: QueryExpression<each Input> {
72+
var arguments: [QueryFragment] = []
73+
for input in repeat each input {
74+
arguments.append(input.queryFragment)
75+
}
76+
return SQLQueryExpression("\(quote: name)(\(arguments.joined(separator: ", ")))")
77+
}
78+
79+
fileprivate var anyBody: AnyBody {
80+
AnyBody { argv in
81+
var iterator = argv.makeIterator()
82+
func next<Element: QueryBindable>() throws -> Element {
83+
guard let queryBinding = iterator.next(), let element = Element(queryBinding: queryBinding)
84+
else {
85+
throw QueryDecodingError.missingRequiredColumn // FIXME: New error case
86+
}
87+
return element
88+
}
89+
return try body(repeat { _ in try next() }((each Input).self)).queryBinding
90+
}
91+
}
92+
93+
func install(_ db: OpaquePointer) {
94+
// TODO: Should this be `-1`?
95+
var count: Int32 = 0
96+
for _ in repeat (each Input).self {
97+
count += 1
98+
}
99+
let body = Unmanaged.passRetained(anyBody).toOpaque()
100+
sqlite3_create_function_v2(
101+
db,
102+
name,
103+
count,
104+
SQLITE_UTF8 | (isDeterministic ? SQLITE_DETERMINISTIC : 0),
105+
body,
106+
{ ctx, argc, argv in
107+
do {
108+
let body = Unmanaged<AnyBody>
109+
.fromOpaque(sqlite3_user_data(ctx))
110+
.takeUnretainedValue()
111+
let arguments: [QueryBinding] = try (0..<argc).map { idx in
112+
let value = argv?[Int(idx)]
113+
switch sqlite3_value_type(value) {
114+
case SQLITE_BLOB:
115+
if let blob = sqlite3_value_blob(value) {
116+
let count = Int(sqlite3_value_bytes(value))
117+
let buffer = UnsafeRawBufferPointer(start: blob, count: count)
118+
return .blob(Array(buffer))
119+
} else {
120+
return .blob([])
121+
}
122+
case SQLITE_FLOAT:
123+
return .double(sqlite3_value_double(value))
124+
case SQLITE_INTEGER:
125+
return .int(sqlite3_value_int64(value))
126+
case SQLITE_NULL:
127+
return .null
128+
case SQLITE_TEXT:
129+
return .text(String(cString: UnsafePointer(sqlite3_value_text(value))))
130+
default:
131+
throw UnknownType()
132+
}
133+
}
134+
let output = try body(arguments)
135+
try output.result(db: ctx)
136+
} catch {
137+
// TODO: Debug description? Localized?
138+
sqlite3_result_error(ctx, "\(error)", -1)
139+
}
140+
},
141+
nil,
142+
nil,
143+
{ ctx in
144+
guard let ctx else { return }
145+
Unmanaged<AnyObject>.fromOpaque(ctx).release()
146+
}
147+
)
148+
}
149+
}
150+
151+
private let SQLITE_TRANSIENT = unsafeBitCast(-1, to: sqlite3_destructor_type.self)
152+
153+
private struct UnknownType: Error {}
154+
155+
private final class AnyBody {
156+
let body: ([QueryBinding]) throws -> QueryBinding
157+
init(body: @escaping ([QueryBinding]) throws -> QueryBinding) {
158+
self.body = body
159+
}
160+
func callAsFunction(_ arguments: [QueryBinding]) throws -> QueryBinding {
161+
try body(arguments)
162+
}
163+
}
164+
165+
private extension QueryBinding {
166+
func result(db: OpaquePointer?) throws {
167+
switch self {
168+
case .blob(let value):
169+
sqlite3_result_blob(db, Array(value), Int32(value.count), SQLITE_TRANSIENT)
170+
case .double(let value):
171+
sqlite3_result_double(db, value)
172+
case .date(let value):
173+
sqlite3_result_text(db, value.iso8601String, -1, SQLITE_TRANSIENT)
174+
case .int(let value):
175+
sqlite3_result_int64(db, value)
176+
case .null:
177+
sqlite3_result_null(db)
178+
case .text(let value):
179+
sqlite3_result_text(db, value, -1, SQLITE_TRANSIENT)
180+
case .uuid(let value):
181+
sqlite3_result_text(db, value.uuidString.lowercased(), -1, SQLITE_TRANSIENT)
182+
case .invalid(let error):
183+
throw error
184+
}
185+
}
186+
}

0 commit comments

Comments
 (0)