Skip to content

Commit 433a2dc

Browse files
author
DominicGBauer
committed
fix: tranasactions resulting in error when using different threads
1 parent db11a7d commit 433a2dc

File tree

3 files changed

+232
-35
lines changed

3 files changed

+232
-35
lines changed

Sources/PowerSyncSwift/Kotlin/KotlinPowerSyncDatabaseImpl.swift

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -117,52 +117,42 @@ final class KotlinPowerSyncDatabaseImpl: PowerSyncDatabaseProtocol {
117117
}
118118
}
119119

120-
func writeTransaction<R>(callback: @escaping (any PowerSyncTransactionProtocol) async throws -> R) async throws -> R {
121-
let wrappedCallback = SuspendTaskWrapper { [kmpDatabase] in
122-
// Create a wrapper that converts the KMP transaction to our Swift protocol
123-
if let kmpTransaction = kmpDatabase as? PowerSyncTransactionProtocol {
124-
return try await callback(kmpTransaction)
125-
} else {
126-
throw PowerSyncError.invalidTransaction
127-
}
128-
}
129-
130-
return try await kmpDatabase.writeTransaction(callback: wrappedCallback) as! R
120+
@MainActor
121+
public func writeTransaction<R>(callback: @escaping (any PowerSyncTransaction) async throws -> R) async throws -> R {
122+
return try await kmpDatabase.writeTransaction(callback: SuspendTaskWrapper { transaction in
123+
return try await callback(transaction)
124+
}) as! R
131125
}
132126

133-
func readTransaction<R>(callback: @escaping (any PowerSyncTransactionProtocol) async throws -> R) async throws -> R {
134-
let wrappedCallback = SuspendTaskWrapper { [kmpDatabase] in
135-
// Create a wrapper that converts the KMP transaction to our Swift protocol
136-
if let kmpTransaction = kmpDatabase as? PowerSyncTransactionProtocol {
137-
return try await callback(kmpTransaction)
138-
} else {
139-
throw PowerSyncError.invalidTransaction
140-
}
141-
}
142-
143-
return try await kmpDatabase.readTransaction(callback: wrappedCallback) as! R
127+
@MainActor
128+
public func readTransaction<R>(callback: @escaping (any PowerSyncTransaction) async throws -> R) async throws -> R {
129+
return try await kmpDatabase.writeTransaction(callback: SuspendTaskWrapper { transaction in
130+
return try await callback(transaction)
131+
}) as! R
144132
}
145133
}
146134

147135
enum PowerSyncError: Error {
148136
case invalidTransaction
149137
}
150138

139+
@MainActor
151140
class SuspendTaskWrapper: KotlinSuspendFunction1 {
152-
let handle: () async throws -> Any
141+
let handle: (any PowerSyncTransaction) async throws -> Any
153142

154-
init(_ handle: @escaping () async throws -> Any) {
143+
init(_ handle: @escaping (any PowerSyncTransaction) async throws -> Any) {
155144
self.handle = handle
156145
}
157146

158-
@MainActor
159-
func invoke(p1: Any?, completionHandler: @escaping (Any?, Error?) -> Void) {
160-
Task {
161-
do {
162-
let result = try await self.handle()
163-
completionHandler(result, nil)
164-
} catch {
165-
completionHandler(nil, error)
147+
nonisolated func __invoke(p1: Any?, completionHandler: @escaping (Any?, Error?) -> Void) {
148+
DispatchQueue.main.async {
149+
Task { @MainActor in
150+
do {
151+
let result = try await self.handle(p1 as! any PowerSyncTransaction)
152+
completionHandler(result, nil)
153+
} catch {
154+
completionHandler(nil, error)
155+
}
166156
}
167157
}
168158
}

Sources/PowerSyncSwift/QueriesProtocol.swift

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import Foundation
22
import Combine
3+
import PowerSync
34

45
public protocol Queries {
56
/// Execute a write query (INSERT, UPDATE, DELETE)
@@ -37,8 +38,42 @@ public protocol Queries {
3738
) -> AsyncStream<[RowType]>
3839

3940
/// Execute a write transaction with the given callback
40-
func writeTransaction<R>(callback: @escaping (PowerSyncTransactionProtocol) async throws -> R) async throws -> R
41-
41+
func writeTransaction<R>(callback: @escaping (any PowerSyncTransaction) async throws -> R) async throws -> R
42+
4243
/// Execute a read transaction with the given callback
43-
func readTransaction<R>(callback: @escaping (PowerSyncTransactionProtocol) async throws -> R) async throws -> R
44+
func readTransaction<R>(callback: @escaping (any PowerSyncTransaction) async throws -> R) async throws -> R
45+
}
46+
47+
extension Queries {
48+
public func execute(_ sql: String) async throws -> Int64 {
49+
return try await execute(sql: sql, parameters: [])
50+
}
51+
52+
public func get<RowType>(
53+
_ sql: String,
54+
mapper: @escaping (SqlCursor) -> RowType
55+
) async throws -> RowType {
56+
return try await get(sql: sql, parameters: [], mapper: mapper)
57+
}
58+
59+
public func getAll<RowType>(
60+
_ sql: String,
61+
mapper: @escaping (SqlCursor) -> RowType
62+
) async throws -> [RowType] {
63+
return try await getAll(sql: sql, parameters: [], mapper: mapper)
64+
}
65+
66+
public func getOptional<RowType>(
67+
_ sql: String,
68+
mapper: @escaping (SqlCursor) -> RowType
69+
) async throws -> RowType? {
70+
return try await getOptional(sql: sql, parameters: [], mapper: mapper)
71+
}
72+
73+
public func watch<RowType>(
74+
_ sql: String,
75+
mapper: @escaping (SqlCursor) -> RowType
76+
) -> AsyncStream<[RowType]> {
77+
return watch(sql: sql, parameters: [], mapper: mapper)
78+
}
4479
}
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import XCTest
2+
@testable import PowerSyncSwift
3+
4+
final class KotlinPowerSyncDatabaseImplTests: XCTestCase {
5+
private var database: KotlinPowerSyncDatabaseImpl!
6+
private var schema: Schema!
7+
8+
override func setUp() async throws {
9+
try await super.setUp()
10+
schema = Schema(tables: [
11+
Table(name: "users", columns: [
12+
.text("name"),
13+
.text("email")
14+
])
15+
])
16+
17+
database = KotlinPowerSyncDatabaseImpl(
18+
schema: schema,
19+
dbFilename: ":memory:"
20+
)
21+
}
22+
23+
override func tearDown() async throws {
24+
try await database.disconnectAndClear()
25+
database = nil
26+
try await super.tearDown()
27+
}
28+
29+
func testInsertAndGet() async throws {
30+
_ = try await database.execute(
31+
sql: "INSERT INTO users (id, name, email) VALUES (?, ?, ?)",
32+
parameters: ["1", "Test User", "[email protected]"]
33+
)
34+
35+
let user: (String, String, String) = try await database.get(
36+
sql: "SELECT id, name, email FROM users WHERE id = ?",
37+
parameters: ["1"]
38+
) { cursor in
39+
(
40+
cursor.getString(index: 0)!,
41+
cursor.getString(index: 1)!,
42+
cursor.getString(index: 2)!
43+
)
44+
}
45+
46+
XCTAssertEqual(user.0, "1")
47+
XCTAssertEqual(user.1, "Test User")
48+
XCTAssertEqual(user.2, "[email protected]")
49+
}
50+
51+
func testGetOptional() async throws {
52+
let nonExistent: String? = try await database.getOptional(
53+
sql: "SELECT name FROM users WHERE id = ?",
54+
parameters: ["999"]
55+
) { cursor in
56+
cursor.getString(index: 0)!
57+
}
58+
59+
XCTAssertNil(nonExistent)
60+
61+
_ = try await database.execute(
62+
sql: "INSERT INTO users (id, name, email) VALUES (?, ?, ?)",
63+
parameters: ["1", "Test User", "[email protected]"]
64+
)
65+
66+
let existing: String? = try await database.getOptional(
67+
sql: "SELECT name FROM users WHERE id = ?",
68+
parameters: ["1"]
69+
) { cursor in
70+
cursor.getString(index: 0)!
71+
}
72+
73+
XCTAssertEqual(existing, "Test User")
74+
}
75+
76+
func testGetAll() async throws {
77+
_ = try await database.execute(
78+
sql: "INSERT INTO users (id, name, email) VALUES (?, ?, ?), (?, ?, ?)",
79+
parameters: ["1", "User 1", "[email protected]", "2", "User 2", "[email protected]"]
80+
)
81+
82+
let users: [(String, String)] = try await database.getAll(
83+
sql: "SELECT id, name FROM users ORDER BY id",
84+
parameters: nil
85+
) { cursor in
86+
(cursor.getString(index: 0)!, cursor.getString(index: 1)!)
87+
}
88+
89+
XCTAssertEqual(users.count, 2)
90+
XCTAssertEqual(users[0].0, "1")
91+
XCTAssertEqual(users[0].1, "User 1")
92+
XCTAssertEqual(users[1].0, "2")
93+
XCTAssertEqual(users[1].1, "User 2")
94+
}
95+
96+
func testWatchTableChanges() async throws {
97+
let expectation = XCTestExpectation(description: "Watch changes")
98+
var results: [[String]] = []
99+
100+
let stream = database.watch(
101+
sql: "SELECT name FROM users ORDER BY id",
102+
parameters: nil
103+
) { cursor in
104+
cursor.getString(index: 0)!
105+
}
106+
107+
let watchTask = Task {
108+
for await names in stream {
109+
results.append(names)
110+
if results.count == 2 {
111+
expectation.fulfill()
112+
}
113+
}
114+
}
115+
116+
_ = try await database.execute(
117+
sql: "INSERT INTO users (id, name, email) VALUES (?, ?, ?)",
118+
parameters: ["1", "User 1", "[email protected]"]
119+
)
120+
121+
_ = try await database.execute(
122+
sql: "INSERT INTO users (id, name, email) VALUES (?, ?, ?)",
123+
parameters: ["2", "User 2", "[email protected]"]
124+
)
125+
126+
await fulfillment(of: [expectation], timeout: 5)
127+
watchTask.cancel()
128+
129+
XCTAssertEqual(results.count, 2)
130+
XCTAssertEqual(results[1], ["User 1", "User 2"])
131+
}
132+
133+
@MainActor
134+
func testWriteTransaction() async throws {
135+
try await database.writeTransaction { transaction in
136+
_ = try await transaction.execute(
137+
sql: "INSERT INTO users (id, name, email) VALUES (?, ?, ?)",
138+
parameters: ["1", "Test User", "[email protected]"]
139+
)
140+
}
141+
142+
143+
let result = try await database.get(
144+
sql: "SELECT COUNT(*) FROM users",
145+
parameters: []
146+
) { cursor in
147+
cursor.getLong(index: 0)
148+
}
149+
150+
XCTAssertEqual(result as! Int, 1)
151+
}
152+
153+
@MainActor
154+
func testReadTransaction() async throws {
155+
_ = try await database.execute(
156+
sql: "INSERT INTO users (id, name, email) VALUES (?, ?, ?)",
157+
parameters: ["1", "Test User", "[email protected]"]
158+
)
159+
160+
161+
try await database.readTransaction { transaction in
162+
let result = try await transaction.get(
163+
sql: "SELECT COUNT(*) FROM users",
164+
parameters: []
165+
) { cursor in
166+
cursor.getLong(index: 0)
167+
}
168+
169+
XCTAssertEqual(result as! Int, 1)
170+
}
171+
}
172+
}

0 commit comments

Comments
 (0)