diff --git a/Playgrounds/README.playground/Contents.swift b/Playgrounds/README.playground/Contents.swift index 4d07064..9a13b6f 100644 --- a/Playgrounds/README.playground/Contents.swift +++ b/Playgrounds/README.playground/Contents.swift @@ -9,7 +9,7 @@ [Documentation](https://alexander-ignition.github.io/SQLyra/documentation/sqlyra/) - - Note: this readme file is available as Xcode playground in Playgrounds/README.playground + > this readme file is available as Xcode playground in Playgrounds/README.playground ## Open @@ -26,21 +26,21 @@ let database = try Database.open( Create table for contacts with fields `id` and `name`. */ -try database.execute( - """ +let sql = """ CREATE TABLE contacts( id INT PRIMARY KEY NOT NULL, name TEXT ); """ -) +try database.execute(sql) /*: ## Insert Insert new contacts Paul and John. */ -try database.execute("INSERT INTO contacts (id, name) VALUES (1, 'Paul');") -try database.execute("INSERT INTO contacts (id, name) VALUES (2, 'John');") +let insert = try database.prepare("INSERT INTO contacts (id, name) VALUES (?, ?);") +try insert.bind(parameters: 1, "Paul").execute().reset() +try insert.bind(parameters: 2, "John").execute() /*: ## Select @@ -51,4 +51,5 @@ struct Contact: Codable { let name: String } -let contacts = try database.prepare("SELECT * FROM contacts;").array(decoding: Contact.self) +let select = try database.prepare("SELECT * FROM contacts;") +let contacts = try select.array(Contact.self) diff --git a/README.md b/README.md index 4cf31df..391411a 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,7 @@ Swift SQLite wrapper. [Documentation](https://alexander-ignition.github.io/SQLyra/documentation/sqlyra/) -- Note: this readme file is available as Xcode playground in Playgrounds/README.playground +> this readme file is available as Xcode playground in Playgrounds/README.playground ## Open @@ -25,21 +25,21 @@ let database = try Database.open( Create table for contacts with fields `id` and `name`. ```swift -try database.execute( - """ +let sql = """ CREATE TABLE contacts( id INT PRIMARY KEY NOT NULL, name TEXT ); """ -) +try database.execute(sql) ``` ## Insert Insert new contacts Paul and John. ```swift -try database.execute("INSERT INTO contacts (id, name) VALUES (1, 'Paul');") -try database.execute("INSERT INTO contacts (id, name) VALUES (2, 'John');") +let insert = try database.prepare("INSERT INTO contacts (id, name) VALUES (?, ?);") +try insert.bind(parameters: 1, "Paul").execute().reset() +try insert.bind(parameters: 2, "John").execute() ``` ## Select @@ -50,5 +50,6 @@ struct Contact: Codable { let name: String } -let contacts = try database.prepare("SELECT * FROM contacts;").array(decoding: Contact.self) +let select = try database.prepare("SELECT * FROM contacts;") +let contacts = try select.array(Contact.self) ``` diff --git a/Sources/SQLyra/DatabaseError.swift b/Sources/SQLyra/DatabaseError.swift index cc03556..e6154b8 100644 --- a/Sources/SQLyra/DatabaseError.swift +++ b/Sources/SQLyra/DatabaseError.swift @@ -7,7 +7,7 @@ public struct DatabaseError: Error, Equatable, Hashable { public let code: Int32 /// A short error description. - public var message: String? + public let message: String? /// A complete sentence (or more) describing why the operation failed. public let details: String? diff --git a/Sources/SQLyra/PreparedStatement.swift b/Sources/SQLyra/PreparedStatement.swift index d62bd9c..7dba370 100644 --- a/Sources/SQLyra/PreparedStatement.swift +++ b/Sources/SQLyra/PreparedStatement.swift @@ -11,9 +11,9 @@ public final class PreparedStatement: DatabaseHandle { /// Find the database handle of a prepared statement. var db: OpaquePointer! { sqlite3_db_handle(stmt) } - private(set) lazy var columnIndexByName = [String: Int32]( + private(set) lazy var columnIndexByName = [String: Int]( uniqueKeysWithValues: (0.. Bool { - switch sqlite3_step(stmt) { - case SQLITE_DONE: - return false - case SQLITE_ROW: - return true - case let code: - throw DatabaseError(code: code, db: db) - } - } - /// Reset the prepared statement. /// /// The ``PreparedStatement/reset()`` function is called to reset a prepared statement object back to its initial state, ready to be re-executed. @@ -59,32 +45,6 @@ public final class PreparedStatement: DatabaseHandle { public func reset() throws -> PreparedStatement { try check(sqlite3_reset(stmt)) } - - /// Reset all bindings on a prepared statement. - /// - /// Contrary to the intuition of many, ``PreparedStatement/reset()`` does not reset the bindings on a prepared statement. - /// Use this routine to reset all host parameters to NULL. - /// - /// - Throws: ``DatabaseError`` - @discardableResult - public func clearBindings() throws -> PreparedStatement { - try check(sqlite3_clear_bindings(stmt)) - } - - // MARK: - Decodable - - public func array(decoding type: T.Type) throws -> [T] where T: Decodable { - var array: [T] = [] - while try step() { - let value = try decode(type) - array.append(value) - } - return array - } - - public func decode(_ type: T.Type) throws -> T where T: Decodable { - try StatementDecoder().decode(type, from: self) - } } // MARK: - Retrieving Statement SQL @@ -112,16 +72,16 @@ extension PreparedStatement { extension PreparedStatement { /// Number of SQL parameters. - public var parameterCount: Int32 { sqlite3_bind_parameter_count(stmt) } + public var parameterCount: Int { Int(sqlite3_bind_parameter_count(stmt)) } /// Name of a SQL parameter. - public func parameterName(at index: Int32) -> String? { - sqlite3_bind_parameter_name(stmt, index).map { String(cString: $0) } + public func parameterName(at index: Int) -> String? { + sqlite3_bind_parameter_name(stmt, Int32(index)).map { String(cString: $0) } } /// Index of a parameter with a given name. - public func parameterIndex(for name: String) -> Int32 { - sqlite3_bind_parameter_index(stmt, name) + public func parameterIndex(for name: String) -> Int { + Int(sqlite3_bind_parameter_index(stmt, name)) } } @@ -138,13 +98,14 @@ extension PreparedStatement { @discardableResult public func bind(parameters: SQLParameter...) throws -> PreparedStatement { for (index, parameter) in parameters.enumerated() { - try bind(index: Int32(index + 1), parameter: parameter) + try bind(index: index + 1, parameter: parameter) } return self } @discardableResult - public func bind(index: Int32, parameter: SQLParameter) throws -> PreparedStatement { + public func bind(index: Int, parameter: SQLParameter) throws -> PreparedStatement { + let index = Int32(index) let code = switch parameter { case .null: @@ -162,42 +123,110 @@ extension PreparedStatement { } return try check(code) } + + /// Reset all bindings on a prepared statement. + /// + /// Contrary to the intuition of many, ``PreparedStatement/reset()`` does not reset the bindings on a prepared statement. + /// Use this routine to reset all host parameters to NULL. + /// + /// - Throws: ``DatabaseError`` + @discardableResult + public func clearBindings() throws -> PreparedStatement { + try check(sqlite3_clear_bindings(stmt)) + } } -// MARK: - Result values from a Query +// MARK: - Columns extension PreparedStatement { /// Return the number of columns in the result set. - public var columnCount: Int32 { sqlite3_column_count(stmt) } + public var columnCount: Int { Int(sqlite3_column_count(stmt)) } - public func column(at index: Int32) -> Column { - Column(index: index, statement: self) + /// Returns the name assigned to a specific column in the result set of the SELECT statement. + /// + /// The name of a result column is the value of the "AS" clause for that column, if there is an AS clause. + /// If there is no AS clause then the name of the column is unspecified and may change from one release of SQLite to the next. + public func columnName(at index: Int) -> String? { + sqlite3_column_name(stmt, Int32(index)).string } +} - public func column(for name: String) -> Column? { - columnIndexByName[name].map { Column(index: $0, statement: self) } +// MARK: - Result values from a Query + +extension PreparedStatement { + /// The new row of data is ready for processing. + /// + /// - Throws: ``DatabaseError`` + public func row() throws -> Row? { + switch sqlite3_step(stmt) { + case SQLITE_DONE: nil + case SQLITE_ROW: Row(statement: self) + case let code: throw DatabaseError(code: code, db: db) + } } - /// Information about a single column of the current result row of a query. - public struct Column { - let index: Int32 + public func array(_ type: T.Type) throws -> [T] where T: Decodable { + try array(type, using: RowDecoder.default) + } + + public func array(_ type: T.Type, using decoder: RowDecoder) throws -> [T] where T: Decodable { + var array: [T] = [] + while let row = try row() { + let value = try row.decode(type, using: decoder) + array.append(value) + } + return array + } + + @dynamicMemberLookup + public struct Row { let statement: PreparedStatement - private var stmt: OpaquePointer { statement.stmt } - /// Returns the name assigned to a specific column in the result set of the SELECT statement. - /// - /// The name of a result column is the value of the "AS" clause for that column, if there is an AS clause. - /// If there is no AS clause then the name of the column is unspecified and may change from one release of SQLite to the next. - public var name: String? { sqlite3_column_name(stmt, index).string } + public subscript(dynamicMember name: String) -> Value? { + self[name] + } + + public subscript(name: String) -> Value? { + statement.columnIndexByName[name].flatMap { self[$0] } + } + + public subscript(index: Int) -> Value? { + if sqlite3_column_type(statement.stmt, Int32(index)) == SQLITE_NULL { + return nil + } + return Value(index: Int32(index), statement: statement) + } - public var isNull: Bool { sqlite3_column_type(stmt, index) == SQLITE_NULL } + public func decode(_ type: T.Type) throws -> T where T: Decodable { + try decode(type, using: RowDecoder.default) + } + + public func decode(_ type: T.Type, using decoder: RowDecoder) throws -> T where T: Decodable { + try decoder.decode(type, from: self) + } + } + + /// Result value from a query. + public struct Value { + let index: Int32 + let statement: PreparedStatement + private var stmt: OpaquePointer { statement.stmt } /// 64-bit INTEGER result. public var int64: Int64 { sqlite3_column_int64(stmt, index) } + /// 32-bit INTEGER result. + public var int32: Int32 { sqlite3_column_int(stmt, index) } + + /// A platform-specific integer. + public var int: Int { Int(int64) } + /// 64-bit IEEE floating point number. public var double: Double { sqlite3_column_double(stmt, index) } + /// Size of a BLOB or a UTF-8 TEXT result in bytes. + public var count: Int { Int(sqlite3_column_bytes(stmt, index)) } + /// UTF-8 TEXT result. public var string: String? { sqlite3_column_text(stmt, index).flatMap { String(cString: $0) } @@ -205,9 +234,7 @@ extension PreparedStatement { /// BLOB result. public var blob: Data? { - sqlite3_column_blob(stmt, index).map { bytes in - Data(bytes: bytes, count: Int(sqlite3_column_bytes(stmt, index))) - } + sqlite3_column_blob(stmt, index).map { Data(bytes: $0, count: count) } } } } diff --git a/Sources/SQLyra/RowDecoder.swift b/Sources/SQLyra/RowDecoder.swift new file mode 100644 index 0000000..4376dd0 --- /dev/null +++ b/Sources/SQLyra/RowDecoder.swift @@ -0,0 +1,184 @@ +import Foundation +import SQLite3 + +/// An object that decodes instances of a data type from ``PreparedStatement``. +public final class RowDecoder { + nonisolated(unsafe) static let `default` = RowDecoder() + + /// A dictionary you use to customize the decoding process by providing contextual information. + public var userInfo: [CodingUserInfoKey: Any] = [:] + + /// Creates a new, reusable row decoder. + public init() {} + + public func decode(_ type: T.Type, from row: PreparedStatement.Row) throws -> T where T: Decodable { + let decoder = _RowDecoder(row: row, userInfo: userInfo) + return try type.init(from: decoder) + } +} + +private struct _RowDecoder: Decoder { + let row: PreparedStatement.Row + + // MARK: - Decoder + + let userInfo: [CodingUserInfoKey: Any] + var codingPath: [any CodingKey] { [] } + + func container(keyedBy type: Key.Type) throws -> KeyedDecodingContainer where Key: CodingKey { + KeyedDecodingContainer(KeyedContainer(decoder: self)) + } + + func unkeyedContainer() throws -> any UnkeyedDecodingContainer { + throw DecodingError.typeMismatch(PreparedStatement.self, .context(codingPath, "")) + } + + func singleValueContainer() throws -> any SingleValueDecodingContainer { + throw DecodingError.typeMismatch(PreparedStatement.self, .context(codingPath, "")) + } + + // MARK: - KeyedDecodingContainer + + struct KeyedContainer: KeyedDecodingContainerProtocol { + let decoder: _RowDecoder + var codingPath: [any CodingKey] { decoder.codingPath } + var allKeys: [Key] { decoder.row.statement.columnIndexByName.keys.compactMap { Key(stringValue: $0) } } + + func contains(_ key: Key) -> Bool { decoder.row.statement.columnIndexByName.keys.contains(key.stringValue) } + func decodeNil(forKey key: Key) throws -> Bool { decoder.null(for: key) } + func decode(_ type: Bool.Type, forKey key: Key) throws -> Bool { try decoder.bool(forKey: key) } + func decode(_ type: String.Type, forKey key: Key) throws -> String { try decoder.string(forKey: key) } + func decode(_ type: Double.Type, forKey key: Key) throws -> Double { try decoder.floating(type, forKey: key) } + func decode(_ type: Float.Type, forKey key: Key) throws -> Float { try decoder.floating(type, forKey: key) } + func decode(_ type: Int.Type, forKey key: Key) throws -> Int { try decoder.integer(type, forKey: key) } + func decode(_ type: Int8.Type, forKey key: Key) throws -> Int8 { try decoder.integer(type, forKey: key) } + func decode(_ type: Int16.Type, forKey key: Key) throws -> Int16 { try decoder.integer(type, forKey: key) } + func decode(_ type: Int32.Type, forKey key: Key) throws -> Int32 { try decoder.integer(type, forKey: key) } + func decode(_ type: Int64.Type, forKey key: Key) throws -> Int64 { try decoder.integer(type, forKey: key) } + func decode(_ type: UInt.Type, forKey key: Key) throws -> UInt { try decoder.integer(type, forKey: key) } + func decode(_ type: UInt8.Type, forKey key: Key) throws -> UInt8 { try decoder.integer(type, forKey: key) } + func decode(_ type: UInt16.Type, forKey key: Key) throws -> UInt16 { try decoder.integer(type, forKey: key) } + func decode(_ type: UInt32.Type, forKey key: Key) throws -> UInt32 { try decoder.integer(type, forKey: key) } + func decode(_ type: UInt64.Type, forKey key: Key) throws -> UInt64 { try decoder.integer(type, forKey: key) } + func decode(_ type: T.Type, forKey key: Key) throws -> T where T: Decodable { + try decoder.decode(type, forKey: key) + } + + func superDecoder() throws -> any Decoder { fatalError() } + func superDecoder(forKey key: Key) throws -> any Decoder { fatalError() } + func nestedUnkeyedContainer(forKey key: Key) throws -> any UnkeyedDecodingContainer { fatalError() } + func nestedContainer( + keyedBy type: NestedKey.Type, + forKey key: Key + ) throws -> KeyedDecodingContainer where NestedKey: CodingKey { + fatalError() + } + } + + // MARK: - Decoding Values + + @inline(__always) + func null(for key: K) -> Bool where K: CodingKey { + row[key.stringValue] == nil + } + + @inline(__always) + func bool(forKey key: K) throws -> Bool where K: CodingKey { + try integer(Int64.self, forKey: key) != 0 + } + + @inline(__always) + func string(forKey key: K) throws -> String where K: CodingKey { + try columnValue(String.self, forKey: key).string ?? "" + } + + @inline(__always) + func integer(_ type: T.Type, forKey key: K) throws -> T where T: Numeric, K: CodingKey { + let value = try columnValue(type, forKey: key) + let int64 = value.int64 + guard let number = type.init(exactly: int64) else { + throw DecodingError.dataCorrupted(.context([key], "Parsed SQL int64 <\(int64)> does not fit in \(type).")) + } + return number + } + + @inline(__always) + func floating(_ type: T.Type, forKey key: K) throws -> T where T: BinaryFloatingPoint, K: CodingKey { + let value = try columnValue(type, forKey: key) + let double = value.double + guard let number = type.init(exactly: double) else { + throw DecodingError.dataCorrupted(.context([key], "Parsed SQL double <\(double)> does not fit in \(type).")) + } + return number + } + + @inline(__always) + func decode(_ type: T.Type, forKey key: K) throws -> T where T: Decodable, K: CodingKey { + if type == Data.self { + let value = try columnValue(type, forKey: key) + let data = value.blob ?? Data() + // swift-format-ignore: NeverForceUnwrap + return data as! T + } + let decoder = _ValueDecoder(key: key, decoder: self) + return try type.init(from: decoder) + } + + @inline(__always) + private func columnValue(_ type: T.Type, forKey key: K) throws -> PreparedStatement.Value where K: CodingKey { + guard let index = row.statement.columnIndexByName[key.stringValue] else { + throw DecodingError.keyNotFound(key, .context([key], "Column index not found for key: \(key)")) + } + guard let column = row[index] else { + throw DecodingError.valueNotFound(type, .context([key], "Column value not found for key: \(key)")) + } + return column + } +} + +private struct _ValueDecoder: Decoder, SingleValueDecodingContainer { + let key: any CodingKey + let decoder: _RowDecoder + + // MARK: - Decoder + + var userInfo: [CodingUserInfoKey: Any] { decoder.userInfo } + var codingPath: [any CodingKey] { [key] } + + func container(keyedBy type: Key.Type) throws -> KeyedDecodingContainer where Key: CodingKey { + throw DecodingError.typeMismatch(PreparedStatement.Value.self, .context(codingPath, "")) + } + + func unkeyedContainer() throws -> any UnkeyedDecodingContainer { + throw DecodingError.typeMismatch(PreparedStatement.Value.self, .context(codingPath, "")) + } + + func singleValueContainer() throws -> any SingleValueDecodingContainer { + self + } + + // MARK: - SingleValueDecodingContainer + + func decodeNil() -> Bool { decoder.null(for: key) } + func decode(_ type: Bool.Type) throws -> Bool { try decoder.bool(forKey: key) } + func decode(_ type: String.Type) throws -> String { try decoder.string(forKey: key) } + func decode(_ type: Double.Type) throws -> Double { try decoder.floating(type, forKey: key) } + func decode(_ type: Float.Type) throws -> Float { try decoder.floating(type, forKey: key) } + func decode(_ type: Int.Type) throws -> Int { try decoder.integer(type, forKey: key) } + func decode(_ type: Int8.Type) throws -> Int8 { try decoder.integer(type, forKey: key) } + func decode(_ type: Int16.Type) throws -> Int16 { try decoder.integer(type, forKey: key) } + func decode(_ type: Int32.Type) throws -> Int32 { try decoder.integer(type, forKey: key) } + func decode(_ type: Int64.Type) throws -> Int64 { try decoder.integer(type, forKey: key) } + func decode(_ type: UInt.Type) throws -> UInt { try decoder.integer(type, forKey: key) } + func decode(_ type: UInt8.Type) throws -> UInt8 { try decoder.integer(type, forKey: key) } + func decode(_ type: UInt16.Type) throws -> UInt16 { try decoder.integer(type, forKey: key) } + func decode(_ type: UInt32.Type) throws -> UInt32 { try decoder.integer(type, forKey: key) } + func decode(_ type: UInt64.Type) throws -> UInt64 { try decoder.integer(type, forKey: key) } + func decode(_ type: T.Type) throws -> T where T: Decodable { try decoder.decode(type, forKey: key) } +} + +private extension DecodingError.Context { + static func context(_ codingPath: [any CodingKey], _ message: String) -> DecodingError.Context { + DecodingError.Context(codingPath: codingPath, debugDescription: message) + } +} diff --git a/Sources/SQLyra/SQLParameter.swift b/Sources/SQLyra/SQLParameter.swift index 764a5f0..fc61051 100644 --- a/Sources/SQLyra/SQLParameter.swift +++ b/Sources/SQLyra/SQLParameter.swift @@ -1,7 +1,7 @@ import Foundation /// SQL parameters. -public enum SQLParameter: Equatable { +public enum SQLParameter: Equatable, Sendable { case null /// 64-bit signed integer. diff --git a/Sources/SQLyra/StatementDecoder.swift b/Sources/SQLyra/StatementDecoder.swift deleted file mode 100644 index f65a982..0000000 --- a/Sources/SQLyra/StatementDecoder.swift +++ /dev/null @@ -1,205 +0,0 @@ -import Foundation -import SQLite3 - -/// An object that decodes instances of a data type from ``PreparedStatement``. -public struct StatementDecoder { - /// A dictionary you use to customize the decoding process by providing contextual information. - public var userInfo: [CodingUserInfoKey: Any] = [:] - - /// Creates a new, reusable Statement decoder. - public init() {} - - public func decode(_ type: T.Type, from statement: PreparedStatement) throws -> T where T: Decodable { - let decoder = _StatementDecoder( - statement: statement, - userInfo: userInfo - ) - return try type.init(from: decoder) - } -} - -private final class _StatementDecoder { - let statement: PreparedStatement - let userInfo: [CodingUserInfoKey: Any] - private(set) var codingPath: [any CodingKey] = [] - - init(statement: PreparedStatement, userInfo: [CodingUserInfoKey: Any]) { - self.statement = statement - self.userInfo = userInfo - self.codingPath.reserveCapacity(3) - } - - @inline(__always) - func null(for key: K) -> Bool where K: CodingKey { - statement.column(for: key.stringValue)?.isNull ?? true - } - - @inline(__always) - func bool(forKey key: K) throws -> Bool where K: CodingKey { - try integer(Int64.self, forKey: key) != 0 - } - - @inline(__always) - func string(forKey key: K, single: Bool = false) throws -> String where K: CodingKey { - let index = try columnIndex(forKey: key, single: single) - guard let value = statement.column(at: index).string else { - throw DecodingError.valueNotFound(String.self, context(key, single, "")) - } - return value - } - - @inline(__always) - func floating( - _ type: T.Type, - forKey key: K, - single: Bool = false - ) throws -> T where T: BinaryFloatingPoint, K: CodingKey { - let index = try columnIndex(forKey: key, single: single) - let value = statement.column(at: index).double - guard let number = type.init(exactly: value) else { - throw DecodingError.dataCorrupted(context(key, single, numberNotFit(type, value: "\(value)"))) - } - return number - } - - @inline(__always) - func integer(_ type: T.Type, forKey key: K, single: Bool = false) throws -> T where T: Numeric, K: CodingKey { - let index = try columnIndex(forKey: key, single: single) - let value = statement.column(at: index).int64 - guard let number = type.init(exactly: value) else { - throw DecodingError.dataCorrupted(context(key, single, numberNotFit(type, value: "\(value)"))) - } - return number - } - - @inline(__always) - func decode( - _ type: T.Type, - forKey key: K, - single: Bool = false - ) throws -> T where T: Decodable, K: CodingKey { - if type == Data.self { - let index = try columnIndex(forKey: key, single: single) - guard let data = statement.column(at: index).blob else { - throw DecodingError.valueNotFound(Data.self, context(key, single, "")) - } - // swift-format-ignore: NeverForceUnwrap - return data as! T - } - if single { - return try type.init(from: self) - } - codingPath.append(key) - defer { - codingPath.removeLast() - } - return try type.init(from: self) - } - - private func columnIndex(forKey key: K, single: Bool) throws -> Int32 where K: CodingKey { - guard let index = statement.columnIndexByName[key.stringValue] else { - throw DecodingError.keyNotFound(key, context(key, single, "Column index not found for key: \(key)")) - } - return index - } - - private func context(_ key: any CodingKey, _ single: Bool, _ message: String) -> DecodingError.Context { - var path = codingPath - if !single { - path.append(key) - } - return DecodingError.Context(codingPath: path, debugDescription: message) - } -} - -private func numberNotFit(_ type: any Any.Type, value: String) -> String { - "Parsed SQL number <\(value)> does not fit in \(type)." -} - -// MARK: - Decoder - -extension _StatementDecoder: Decoder { - func container(keyedBy type: Key.Type) throws -> KeyedDecodingContainer where Key: CodingKey { - KeyedDecodingContainer(KeyedContainer(decoder: self)) - } - - func unkeyedContainer() throws -> any UnkeyedDecodingContainer { - let context = DecodingError.Context( - codingPath: codingPath, - debugDescription: "`unkeyedContainer()` not supported" - ) - throw DecodingError.dataCorrupted(context) - } - - func singleValueContainer() throws -> any SingleValueDecodingContainer { - if codingPath.isEmpty { - let context = DecodingError.Context(codingPath: codingPath, debugDescription: "key not found") - throw DecodingError.dataCorrupted(context) - } - return self - } -} - -// MARK: - SingleValueDecodingContainer - -extension _StatementDecoder: SingleValueDecodingContainer { - // swift-format-ignore: NeverForceUnwrap - private var key: any CodingKey { codingPath.last! } - - func decodeNil() -> Bool { null(for: key) } - func decode(_ type: Bool.Type) throws -> Bool { try bool(forKey: key) } - func decode(_ type: String.Type) throws -> String { try string(forKey: key, single: true) } - func decode(_ type: Double.Type) throws -> Double { try floating(type, forKey: key, single: true) } - func decode(_ type: Float.Type) throws -> Float { try floating(type, forKey: key, single: true) } - func decode(_ type: Int.Type) throws -> Int { try integer(type, forKey: key, single: true) } - func decode(_ type: Int8.Type) throws -> Int8 { try integer(type, forKey: key, single: true) } - func decode(_ type: Int16.Type) throws -> Int16 { try integer(type, forKey: key, single: true) } - func decode(_ type: Int32.Type) throws -> Int32 { try integer(type, forKey: key, single: true) } - func decode(_ type: Int64.Type) throws -> Int64 { try integer(type, forKey: key, single: true) } - func decode(_ type: UInt.Type) throws -> UInt { try integer(type, forKey: key, single: true) } - func decode(_ type: UInt8.Type) throws -> UInt8 { try integer(type, forKey: key, single: true) } - func decode(_ type: UInt16.Type) throws -> UInt16 { try integer(type, forKey: key, single: true) } - func decode(_ type: UInt32.Type) throws -> UInt32 { try integer(type, forKey: key, single: true) } - func decode(_ type: UInt64.Type) throws -> UInt64 { try integer(type, forKey: key, single: true) } - func decode(_ type: T.Type) throws -> T where T: Decodable { try decode(type, forKey: key, single: true) } -} - -// MARK: - KeyedDecodingContainer - -extension _StatementDecoder { - struct KeyedContainer: KeyedDecodingContainerProtocol { - let decoder: _StatementDecoder - var codingPath: [any CodingKey] { decoder.codingPath } - var allKeys: [Key] { decoder.statement.columnIndexByName.keys.compactMap { Key(stringValue: $0) } } - - func contains(_ key: Key) -> Bool { decoder.statement.columnIndexByName.keys.contains(key.stringValue) } - func decodeNil(forKey key: Key) throws -> Bool { decoder.null(for: key) } - func decode(_ type: Bool.Type, forKey key: Key) throws -> Bool { try decoder.bool(forKey: key) } - func decode(_ type: String.Type, forKey key: Key) throws -> String { try decoder.string(forKey: key) } - func decode(_ type: Double.Type, forKey key: Key) throws -> Double { try decoder.floating(type, forKey: key) } - func decode(_ type: Float.Type, forKey key: Key) throws -> Float { try decoder.floating(type, forKey: key) } - func decode(_ type: Int.Type, forKey key: Key) throws -> Int { try decoder.integer(type, forKey: key) } - func decode(_ type: Int8.Type, forKey key: Key) throws -> Int8 { try decoder.integer(type, forKey: key) } - func decode(_ type: Int16.Type, forKey key: Key) throws -> Int16 { try decoder.integer(type, forKey: key) } - func decode(_ type: Int32.Type, forKey key: Key) throws -> Int32 { try decoder.integer(type, forKey: key) } - func decode(_ type: Int64.Type, forKey key: Key) throws -> Int64 { try decoder.integer(type, forKey: key) } - func decode(_ type: UInt.Type, forKey key: Key) throws -> UInt { try decoder.integer(type, forKey: key) } - func decode(_ type: UInt8.Type, forKey key: Key) throws -> UInt8 { try decoder.integer(type, forKey: key) } - func decode(_ type: UInt16.Type, forKey key: Key) throws -> UInt16 { try decoder.integer(type, forKey: key) } - func decode(_ type: UInt32.Type, forKey key: Key) throws -> UInt32 { try decoder.integer(type, forKey: key) } - func decode(_ type: UInt64.Type, forKey key: Key) throws -> UInt64 { try decoder.integer(type, forKey: key) } - func decode(_ type: T.Type, forKey key: Key) throws -> T where T: Decodable { - try decoder.decode(type, forKey: key) - } - - func superDecoder() throws -> any Decoder { fatalError() } - func superDecoder(forKey key: Key) throws -> any Decoder { fatalError() } - func nestedUnkeyedContainer(forKey key: Key) throws -> any UnkeyedDecodingContainer { fatalError() } - func nestedContainer( - keyedBy type: NestedKey.Type, - forKey key: Key - ) throws -> KeyedDecodingContainer where NestedKey: CodingKey { - fatalError() - } - } -} diff --git a/Tests/SQLyraTests/DatabaseTests.swift b/Tests/SQLyraTests/DatabaseTests.swift index ba2d19b..0573dd8 100644 --- a/Tests/SQLyraTests/DatabaseTests.swift +++ b/Tests/SQLyraTests/DatabaseTests.swift @@ -78,12 +78,11 @@ struct DatabaseTests { let id: Int let name: String } - let contacts = try database.prepare("SELECT * FROM contacts;").array(decoding: Contact.self) + let contacts = try database.prepare("SELECT * FROM contacts;").array(Contact.self) let expected = [ Contact(id: 1, name: "Paul"), Contact(id: 2, name: "John"), ] #expect(contacts == expected) - // try database.execute("SELECT name FROM sqlite_master WHERE type ='table';") } } diff --git a/Tests/SQLyraTests/PreparedStatementTests.swift b/Tests/SQLyraTests/PreparedStatementTests.swift index eefceee..4872345 100644 --- a/Tests/SQLyraTests/PreparedStatementTests.swift +++ b/Tests/SQLyraTests/PreparedStatementTests.swift @@ -78,12 +78,12 @@ struct PreparedStatementTests { #expect(select.columnCount == 4) var contracts: [Contact] = [] - while try select.step() { + while let row = try select.row() { let contact = Contact( - id: Int(select.column(at: 0).int64), - name: select.column(at: 1).string ?? "", - rating: select.column(at: 2).double, - image: select.column(at: 3).blob + id: row.id?.int ?? 0, + name: row.name?.string ?? "", + rating: row.rating?.double ?? 0, + image: row.image?.blob ) contracts.append(contact) } diff --git a/Tests/SQLyraTests/RowDecoderTests.swift b/Tests/SQLyraTests/RowDecoderTests.swift new file mode 100644 index 0000000..7e3b7fd --- /dev/null +++ b/Tests/SQLyraTests/RowDecoderTests.swift @@ -0,0 +1,329 @@ +import Foundation +import SQLyra +import Testing + +struct RowDecoderTests { + struct SignedIntegers { + /// valid parameters for all signed integers + static let arguments: [(SQLParameter, Int)] = [ + (-1, -1), + (0, 0), + (1, 1), + (0.9, 0), + ("2", 2), + (.blob(Data("3".utf8)), 3), + ] + + struct IntTests: DecodableValueSuite { + let value: Int + + @Test(arguments: SignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Int) throws { + try _decode(parameter, Int(expected)) + } + } + + struct Int8Tests: DecodableValueSuite { + let value: Int8 + + @Test(arguments: SignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Int) throws { + try _decode(parameter, Int8(expected)) + } + } + + struct Int16Tests: DecodableValueSuite { + let value: Int16 + + @Test(arguments: SignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Int) throws { + try _decode(parameter, Int16(expected)) + } + } + + struct Int32Tests: DecodableValueSuite { + let value: Int32 + + @Test(arguments: SignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Int) throws { + try _decode(parameter, Int32(expected)) + } + } + + struct Int64Tests: DecodableValueSuite { + let value: Int64 + + @Test(arguments: SignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Int) throws { + try _decode(parameter, Int64(expected)) + } + } + } + + struct UnsignedIntegers { + /// valid parameters for all unsigned integers + static let arguments: [(SQLParameter, UInt)] = [ + (0, 0), + (1, 1), + (0.9, 0), + ("2", 2), + (.blob(Data("3".utf8)), 3), + ] + + struct UIntTests: DecodableValueSuite { + let value: UInt + + @Test(arguments: UnsignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: UInt) throws { + try _decode(parameter, UInt(expected)) + } + } + + struct UInt8Tests: DecodableValueSuite { + let value: UInt8 + + @Test(arguments: UnsignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: UInt) throws { + try _decode(parameter, UInt8(expected)) + } + + @Test(arguments: [ + SQLParameter.int64(-1), + SQLParameter.int64(Int64.max), + ]) + static func dataCorrupted(_ parameter: SQLParameter) throws { + try _dataCorrupted(parameter, "Parsed SQL int64 <\(parameter)> does not fit in UInt8.") + } + } + + struct UInt16Tests: DecodableValueSuite { + let value: UInt16 + + @Test(arguments: UnsignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: UInt) throws { + try _decode(parameter, UInt16(expected)) + } + } + + struct UInt32Tests: DecodableValueSuite { + let value: UInt32 + + @Test(arguments: UnsignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: UInt) throws { + try _decode(parameter, UInt32(expected)) + } + } + + struct UInt64Tests: DecodableValueSuite { + let value: UInt64 + + @Test(arguments: UnsignedIntegers.arguments) + static func decode(_ parameter: SQLParameter, _ expected: UInt) throws { + try _decode(parameter, UInt64(expected)) + } + } + } + + struct FloatingPointNumerics { + static let arguments: [(SQLParameter, Double)] = [ + (-1, -1.0), + (0, 0.0), + (1, 1.0), + (0.5, 0.5), + ("0.5", 0.5), + ("1.0", 1.0), + (.blob(Data("1".utf8)), 1.0), + ] + + struct DoubleTests: DecodableValueSuite { + let value: Double + + @Test(arguments: FloatingPointNumerics.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Double) throws { + try _decode(parameter, Double(expected)) + } + } + + struct FloatTests: DecodableValueSuite { + let value: Float + + @Test(arguments: FloatingPointNumerics.arguments) + static func decode(_ parameter: SQLParameter, _ expected: Double) throws { + try _decode(parameter, Float(expected)) + } + + @Test static func dataCorrupted() throws { + let parameter = SQLParameter.double(Double.greatestFiniteMagnitude) + try _dataCorrupted(parameter, "Parsed SQL double <\(parameter)> does not fit in Float.") + } + } + } + + struct BoolTests: DecodableValueSuite { + let value: Bool + + @Test(arguments: [ + (SQLParameter.int64(0), false), + (SQLParameter.int64(1), true), + (SQLParameter.double(0.9), false), + (SQLParameter.text("abc"), false), + (SQLParameter.text("true"), false), + (SQLParameter.blob(Data("zxc".utf8)), false), + (SQLParameter.blob(Data("1".utf8)), true), + ]) + static func decode(_ parameter: SQLParameter, _ expected: Bool) throws { + try _decode(parameter, expected) + } + } + + struct StringTests: DecodableValueSuite { + let value: String + + @Test(arguments: [ + (SQLParameter.int64(0), "0"), + (SQLParameter.int64(1), "1"), + (SQLParameter.double(0.9), "0.9"), + (SQLParameter.text("abc"), "abc"), + (SQLParameter.blob(Data("zxc".utf8)), "zxc"), + ]) + static func decode(_ parameter: SQLParameter, _ expected: String) throws { + try _decode(parameter, expected) + } + } + + struct DataTests: DecodableValueSuite { + let value: Data + + @Test(arguments: [ + (SQLParameter.int64(1), Data("1".utf8)), + (SQLParameter.double(1.1), Data("1.1".utf8)), + (SQLParameter.text("123"), Data("123".utf8)), + (SQLParameter.blob(Data("zxc".utf8)), Data("zxc".utf8)), + ]) + static func decode(_ parameter: SQLParameter, _ expected: Data) throws { + try _decode(parameter, expected) + } + } + + struct OptionalTests: DecodableValueSuite { + let value: Int? + + @Test static func decode() throws { + try _decode(1, 1) + } + } + + struct DecodingErrorTests: Decodable { + let item: Int // invalid + + @Test static func keyNotFound() throws { + let repo = try ItemRepository(datatype: "ANY") + let row = try #require(try repo.select(.int64(1)).row()) + #expect { + try row.decode(DecodingErrorTests.self) + } throws: { error in + guard case .keyNotFound(let key, let context) = error as? DecodingError else { + return false + } + return key.stringValue == "item" + && context.codingPath.map(\.stringValue) == ["item"] + && context.debugDescription == "Column index not found for key: \(key)" + && context.underlyingError == nil + } + } + + @Test static func typeMismatch() throws { + let errorMatcher = { (error: any Error) -> Bool in + guard case .typeMismatch(_, let context) = error as? DecodingError else { + return false + } + return context.codingPath.isEmpty && context.debugDescription == "" && context.underlyingError == nil + } + let repo = try ItemRepository(datatype: "ANY") + let row = try #require(try repo.select(.int64(1)).row()) + #expect(performing: { try row.decode(Int.self) }, throws: errorMatcher) + #expect(performing: { try row.decode([Int].self) }, throws: errorMatcher) + } + + @Test static func valueNotFound() throws { + let repo = try ItemRepository(datatype: "ANY") + let row = try #require(try repo.select(.null).row()) + #expect { + try row.decode(Single.self) + } throws: { error in + guard case .valueNotFound(let type, let context) = error as? DecodingError else { + return false + } + return type == Int8.self + && context.codingPath.map(\.stringValue) == ["value"] + && context.debugDescription == "Column value not found for key: \(context.codingPath[0])" + && context.underlyingError == nil + } + } + } +} + +// MARK: - Test Suite + +protocol DecodableValueSuite: Decodable { + associatedtype Value: Decodable, Equatable + + var value: Value { get } +} + +extension DecodableValueSuite { + static func _decode( + _ parameter: SQLParameter, + _ expected: Value, + sourceLocation: SourceLocation = #_sourceLocation + ) throws { + let repo = try ItemRepository(datatype: "ANY") + let select = try repo.select(parameter) + let row = try #require(try select.row(), sourceLocation: sourceLocation) + + let keyed = try row.decode(Self.self) + #expect(keyed.value == expected, sourceLocation: sourceLocation) + + let single = try row.decode(Single.self) + #expect(single.value == expected, sourceLocation: sourceLocation) + + #expect(try select.row() == nil, sourceLocation: sourceLocation) + } + + static func _dataCorrupted( + _ parameter: SQLParameter, + _ message: String, + sourceLocation: SourceLocation = #_sourceLocation + ) throws { + let repo = try ItemRepository(datatype: "ANY") + let row = try #require(try repo.select(parameter).row(), sourceLocation: sourceLocation) + #expect(sourceLocation: sourceLocation) { + try row.decode(Self.self) + } throws: { error in + guard case .dataCorrupted(let context) = error as? DecodingError else { + return false + } + return context.codingPath.map(\.stringValue) == ["value"] + && context.debugDescription == message + && context.underlyingError == nil + } + } +} + +struct Single: Decodable { + let value: T +} + +struct ItemRepository { + private let db: Database + + init(datatype: String) throws { + db = try Database.open(at: ":memory:", options: [.readwrite, .memory]) + try db.execute("CREATE TABLE items (value \(datatype));") + } + + func select(_ parameter: SQLParameter) throws -> PreparedStatement { + try db.prepare("INSERT INTO items (value) VALUES (?);").bind(index: 1, parameter: parameter).execute() + return try db.prepare("SELECT value FROM items;") + } +} diff --git a/Tests/SQLyraTests/SQLParameter+Testing.swift b/Tests/SQLyraTests/SQLParameter+Testing.swift new file mode 100644 index 0000000..eb5ab61 --- /dev/null +++ b/Tests/SQLyraTests/SQLParameter+Testing.swift @@ -0,0 +1,32 @@ +import SQLyra +import Testing + +extension SQLParameter: CustomTestStringConvertible { + public var testDescription: String { + switch self { + case .null: "NULL" + case .int64(let value): "INT(\(value))" + case .double(let value): "DOUBLE(\(value))" + case .text(let value): "TEXT(\(value))" + case .blob(let value): "BLOB(\(Array(value))" + } + } +} + +extension SQLParameter: CustomTestArgumentEncodable { + public func encodeTestArgument(to encoder: some Encoder) throws { + switch self { + case .null: + var container = encoder.singleValueContainer() + try container.encodeNil() + case .int64(let value): + try value.encode(to: encoder) + case .double(let value): + try value.encode(to: encoder) + case .text(let value): + try value.encode(to: encoder) + case .blob(let value): + try value.encode(to: encoder) + } + } +}