|
| 1 | +import Dependencies |
| 2 | +import Foundation |
| 3 | +import InlineSnapshotTesting |
| 4 | +import StructuredQueries |
| 5 | +import StructuredQueriesSQLite |
| 6 | +import StructuredQueriesTestSupport |
| 7 | +import Testing |
| 8 | + |
| 9 | +extension SnapshotTests { |
| 10 | + @Suite struct CustomFunctionTests { |
| 11 | + @Test func customDate() { |
| 12 | + @Dependency(\.defaultDatabase) var database |
| 13 | + __$dateTime.install(database.handle) |
| 14 | + assertQuery( |
| 15 | + Values(_$dateTime()) |
| 16 | + ) { |
| 17 | + """ |
| 18 | + SELECT "dateTime"(NULL) |
| 19 | + """ |
| 20 | + } results: { |
| 21 | + """ |
| 22 | + ┌────────────────────────────────┐ |
| 23 | + │ Date(1970-01-01T00:00:00.000Z) │ |
| 24 | + └────────────────────────────────┘ |
| 25 | + """ |
| 26 | + } |
| 27 | + } |
| 28 | + } |
| 29 | +} |
| 30 | + |
| 31 | +// --- |
| 32 | +import Foundation |
| 33 | +import SQLite3 |
| 34 | + |
| 35 | +// @CustomFunction // @DatabaseFunction ? @ScalarFunction (_vs._ @AggregateFunction?) |
| 36 | +func dateTime(_ format: String? = nil) -> Date { |
| 37 | + Date(timeIntervalSince1970: 0) |
| 38 | +} |
| 39 | + |
| 40 | +// Macro expansion: |
| 41 | +@available(macOS 14, *) |
| 42 | +func _$dateTime( |
| 43 | + _ format: some QueryExpression<String?> = String?.none |
| 44 | +) -> some QueryExpression<Date> { |
| 45 | + __$dateTime(format) |
| 46 | +} |
| 47 | + |
| 48 | +@available(macOS 14, *) |
| 49 | +var __$dateTime: CustomFunction<String?, Date, Never> { |
| 50 | + CustomFunction("dateTime", isDeterministic: false, body: dateTime(_:)) |
| 51 | +} |
| 52 | +// --- |
| 53 | +// Library code: |
| 54 | +@available(macOS 14, *) |
| 55 | +struct CustomFunction<each Input: QueryBindable, Output: QueryBindable, Failure: Error> { |
| 56 | + let name: String |
| 57 | + let isDeterministic: Bool |
| 58 | + let body: (repeat each Input) throws(Failure) -> Output |
| 59 | + |
| 60 | + init( |
| 61 | + _ name: String, |
| 62 | + isDeterministic: Bool, |
| 63 | + body: @escaping (repeat each Input) throws(Failure) -> Output |
| 64 | + ) { |
| 65 | + self.name = name |
| 66 | + self.isDeterministic = isDeterministic |
| 67 | + self.body = body |
| 68 | + } |
| 69 | + |
| 70 | + func callAsFunction<each T>(_ input: repeat each T) -> SQLQueryExpression<Output> |
| 71 | + where repeat each T: QueryExpression<each Input> { |
| 72 | + var arguments: [QueryFragment] = [] |
| 73 | + for input in repeat each input { |
| 74 | + arguments.append(input.queryFragment) |
| 75 | + } |
| 76 | + return SQLQueryExpression("\(quote: name)(\(arguments.joined(separator: ", ")))") |
| 77 | + } |
| 78 | + |
| 79 | + fileprivate var anyBody: AnyBody { |
| 80 | + AnyBody { argv in |
| 81 | + var iterator = argv.makeIterator() |
| 82 | + func next<Element: QueryBindable>() throws -> Element { |
| 83 | + guard let queryBinding = iterator.next(), let element = Element(queryBinding: queryBinding) |
| 84 | + else { |
| 85 | + throw QueryDecodingError.missingRequiredColumn // FIXME: New error case |
| 86 | + } |
| 87 | + return element |
| 88 | + } |
| 89 | + return try body(repeat { _ in try next() }((each Input).self)).queryBinding |
| 90 | + } |
| 91 | + } |
| 92 | + |
| 93 | + func install(_ db: OpaquePointer) { |
| 94 | + // TODO: Should this be `-1`? |
| 95 | + var count: Int32 = 0 |
| 96 | + for _ in repeat (each Input).self { |
| 97 | + count += 1 |
| 98 | + } |
| 99 | + let body = Unmanaged.passRetained(anyBody).toOpaque() |
| 100 | + sqlite3_create_function_v2( |
| 101 | + db, |
| 102 | + name, |
| 103 | + count, |
| 104 | + SQLITE_UTF8 | (isDeterministic ? SQLITE_DETERMINISTIC : 0), |
| 105 | + body, |
| 106 | + { ctx, argc, argv in |
| 107 | + do { |
| 108 | + let body = Unmanaged<AnyBody> |
| 109 | + .fromOpaque(sqlite3_user_data(ctx)) |
| 110 | + .takeUnretainedValue() |
| 111 | + let arguments: [QueryBinding] = try (0..<argc).map { idx in |
| 112 | + let value = argv?[Int(idx)] |
| 113 | + switch sqlite3_value_type(value) { |
| 114 | + case SQLITE_BLOB: |
| 115 | + if let blob = sqlite3_value_blob(value) { |
| 116 | + let count = Int(sqlite3_value_bytes(value)) |
| 117 | + let buffer = UnsafeRawBufferPointer(start: blob, count: count) |
| 118 | + return .blob(Array(buffer)) |
| 119 | + } else { |
| 120 | + return .blob([]) |
| 121 | + } |
| 122 | + case SQLITE_FLOAT: |
| 123 | + return .double(sqlite3_value_double(value)) |
| 124 | + case SQLITE_INTEGER: |
| 125 | + return .int(sqlite3_value_int64(value)) |
| 126 | + case SQLITE_NULL: |
| 127 | + return .null |
| 128 | + case SQLITE_TEXT: |
| 129 | + return .text(String(cString: UnsafePointer(sqlite3_value_text(value)))) |
| 130 | + default: |
| 131 | + throw UnknownType() |
| 132 | + } |
| 133 | + } |
| 134 | + let output = try body(arguments) |
| 135 | + try output.result(db: ctx) |
| 136 | + } catch { |
| 137 | + // TODO: Debug description? Localized? |
| 138 | + sqlite3_result_error(ctx, "\(error)", -1) |
| 139 | + } |
| 140 | + }, |
| 141 | + nil, |
| 142 | + nil, |
| 143 | + { ctx in |
| 144 | + guard let ctx else { return } |
| 145 | + Unmanaged<AnyObject>.fromOpaque(ctx).release() |
| 146 | + } |
| 147 | + ) |
| 148 | + } |
| 149 | +} |
| 150 | + |
| 151 | +private let SQLITE_TRANSIENT = unsafeBitCast(-1, to: sqlite3_destructor_type.self) |
| 152 | + |
| 153 | +private struct UnknownType: Error {} |
| 154 | + |
| 155 | +private final class AnyBody { |
| 156 | + let body: ([QueryBinding]) throws -> QueryBinding |
| 157 | + init(body: @escaping ([QueryBinding]) throws -> QueryBinding) { |
| 158 | + self.body = body |
| 159 | + } |
| 160 | + func callAsFunction(_ arguments: [QueryBinding]) throws -> QueryBinding { |
| 161 | + try body(arguments) |
| 162 | + } |
| 163 | +} |
| 164 | + |
| 165 | +private extension QueryBinding { |
| 166 | + func result(db: OpaquePointer?) throws { |
| 167 | + switch self { |
| 168 | + case .blob(let value): |
| 169 | + sqlite3_result_blob(db, Array(value), Int32(value.count), SQLITE_TRANSIENT) |
| 170 | + case .double(let value): |
| 171 | + sqlite3_result_double(db, value) |
| 172 | + case .date(let value): |
| 173 | + sqlite3_result_text(db, value.iso8601String, -1, SQLITE_TRANSIENT) |
| 174 | + case .int(let value): |
| 175 | + sqlite3_result_int64(db, value) |
| 176 | + case .null: |
| 177 | + sqlite3_result_null(db) |
| 178 | + case .text(let value): |
| 179 | + sqlite3_result_text(db, value, -1, SQLITE_TRANSIENT) |
| 180 | + case .uuid(let value): |
| 181 | + sqlite3_result_text(db, value.uuidString.lowercased(), -1, SQLITE_TRANSIENT) |
| 182 | + case .invalid(let error): |
| 183 | + throw error |
| 184 | + } |
| 185 | + } |
| 186 | +} |
0 commit comments