Skip to content

Commit 3e08070

Browse files
committed
improve transaction support
1 parent f03c5e7 commit 3e08070

File tree

3 files changed

+93
-61
lines changed

3 files changed

+93
-61
lines changed

src/easy_sqlite3/macros.nim

Lines changed: 54 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -209,35 +209,57 @@ proc db_begin() {.importdb: "BEGIN".}
209209
proc db_commit() {.importdb: "COMMIT".}
210210
proc db_rollback() {.importdb: "ROLLBACK".}
211211

212-
type Transaction*[Origin: ptr Database | ref Database] = object
213-
origin: Origin
214-
done: bool
215-
216-
proc `=destroy`*[Origin: ptr Database | ref Database](tran: var Transaction[Origin]) =
217-
assert tran.origin != nil
218-
if not tran.done:
219-
tran.origin[].db_rollback()
220-
221-
proc `=copy`*[Origin: ptr Database | ref Database](tran: var Transaction[Origin], rhs: Transaction[Origin]) {.error.}
222-
223-
proc initTransaction*(db: var Database): Transaction[ptr Database] =
224-
db.db_begin()
225-
result.origin = addr db
226-
result.done = false
227-
228-
proc initTransaction*(db: ref Database): Transaction[ref Database] =
229-
db[].db_begin()
230-
result.origin = db
231-
result.done = false
232-
233-
proc commit*[Origin: ptr Database | ref Database](tran: var Transaction[Origin]) =
234-
assert tran.origin != nil
235-
assert !tran.done
236-
tran.done = true
237-
tran.origin[].db_commit()
238-
239-
proc rollback*[Origin: ptr Database | ref Database](tran: var Transaction[Origin]) =
240-
assert tran.origin != nil
241-
assert !tran.done
242-
tran.done = true
243-
tran.origin[].db_rollback()
212+
# type Transaction* = object
213+
# origin: ptr Database
214+
215+
# proc `=destroy`*(tran: var Transaction) =
216+
# if tran.origin != nil:
217+
# tran.origin[].db_rollback()
218+
219+
# proc `=copy`*(tran: var Transaction, rhs: Transaction) {.error: "You should not copy transaction".}
220+
221+
# proc initTransaction*(db: var Database): Transaction =
222+
# db.db_begin()
223+
# result.origin = addr db
224+
225+
# proc commit*(tran: var Transaction) =
226+
# if tran.origin != nil:
227+
# tran.origin[].db_commit()
228+
# wasMoved(tran)
229+
230+
# proc rollback*(tran: var Transaction) =
231+
# if tran.origin != nil:
232+
# tran.origin[].db_rollback()
233+
# wasMoved(tran)
234+
235+
template transaction*(db: var Database, body: untyped): untyped =
236+
db_begin db
237+
block outer:
238+
var cached_exception: ref Exception
239+
block inner:
240+
try:
241+
template commit() {.inject, used.} =
242+
try:
243+
db_commit db
244+
break outer
245+
except:
246+
cached_exception = getCurrentException()
247+
break inner
248+
template rollback() {.inject, used.} =
249+
try:
250+
db_rollback db
251+
break outer
252+
except:
253+
cached_exception = getCurrentException()
254+
break inner
255+
body
256+
try:
257+
commit()
258+
except:
259+
cached_exception = getCurrentException()
260+
break inner
261+
except:
262+
db_rollback db
263+
raise getCurrentException()
264+
if cached_exception != nil:
265+
raise cached_exception

tests/test_basic.nim

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import easy_sqlite3/[memfs, logfs]
66
proc select_1(arg: int): tuple[value: int] {.importdb: "SELECT $arg".}
77

88
proc create_table() {.importdb: """
9-
CREATE TABLE mydata(name TEXT PRIMARY KEY NOT NULL, value INT NOT NULL);
9+
CREATE TABLE mydata(name TEXT PRIMARY KEY NOT NULL, value INT NOT NULL) WITHOUT ROWID;
1010
""".}
1111

1212
proc insert_data(name: string, value: int) {.importdb: """
@@ -17,6 +17,8 @@ iterator iterate_data(): tuple[name: string, value: int] {.importdb: """
1717
SELECT name, value FROM mydata;
1818
""".} = discard
1919

20+
proc count_data(): tuple[count: int] {.importdb: "SELECT count(*) FROM mydata".}
21+
2022
test "simple":
2123
var db = initDatabase(":memory:")
2224
check db.select_1(1) == (value: 1)
@@ -31,9 +33,11 @@ test "full":
3133
var db = initDatabase("test")
3234
db.exec "PRAGMA journal_mode=DELETE"
3335
db.create_table()
34-
for name, value in dataset:
35-
db.insert_data name, value
36+
db.transaction:
37+
for name, value in dataset:
38+
db.insert_data name, value
3639
db.exec "VACUUM"
3740
for name, value in db.iterate_data():
3841
check name in dataset
39-
check dataset[name] == value
42+
check dataset[name] == value
43+
check db.count_data() == (count: dataset.len)

tests/test_thread.nim

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,21 @@ import easy_sqlite3/memfs
55

66
const useMemFs = true
77

8-
when not useMemFs:
8+
when useMemFs:
9+
template retry(body: untyped) = body
10+
else:
911
enableSharedCache()
1012
var failedCount = 0
13+
template retry(body: untyped) =
14+
var failed = 0
15+
while true:
16+
try:
17+
body
18+
break
19+
except:
20+
failed.inc
21+
if failed > 0:
22+
failedCount.atomicInc(failed)
1123

1224
proc create_table() {.importdb: """
1325
CREATE TABLE store(key INTEGER PRIMARY KEY, value INT NOT NULL);
@@ -30,52 +42,46 @@ var gdb = connectDatabase()
3042
gdb.create_table()
3143
gdb.exec "VACUUM"
3244

33-
const COUNT = 100000
45+
const COUNT = 1000000
46+
const GROUP = 100
3447

3548
proc worker_fn() {.thread.} =
3649
echo "thread start"
3750
var tdb = connectDatabase()
3851
var r = initRand(42)
39-
for _ in 0..<COUNT:
40-
let val = r.rand(1048576)
41-
# increase the chance of collision
42-
if val < 1024:
43-
sleep(1)
44-
when useMemFs:
45-
tdb.insert_data(val)
46-
else:
47-
var failed = 0
48-
block retry:
49-
while true:
50-
try:
51-
tdb.insert_data(val)
52-
break retry
53-
except:
54-
failed.inc
55-
if failed > 0:
56-
failedCount.atomicInc(failed)
52+
for _ in 0..<(COUNT div GROUP):
53+
retry:
54+
tdb.transaction:
55+
for _ in 0..<GROUP:
56+
let val = r.rand(1048576)
57+
# increase the chance of collision
58+
if val < 1024:
59+
sleep(1)
60+
tdb.insert_data(val)
5761

5862
var worker: Thread[void]
5963
createThread(worker, worker_fn)
6064

6165
let init = cpuTime()
6266
var prev = init
6367
while true:
64-
let c = gdb.count_items().count
68+
var c: int
69+
retry:
70+
c = gdb.count_items().count
6571
let curr = cpuTime()
6672
let diff = curr - prev - 0.2
6773
if diff > 0:
6874
when useMemFs:
69-
echo fmt"{curr - init:>6.1f}s: {c:<6}"
75+
echo fmt"{curr - init:>6.1f}s: {c:>7}"
7076
else:
71-
echo fmt"{curr - init:>6.1f}s: {c:<6} failures: {failedCount}"
77+
echo fmt"{curr - init:>6.1f}s: {c:>7} failures: {failedCount}"
7278
prev = curr - diff
7379
if c == COUNT:
7480
break
7581

7682
when useMemFs:
77-
echo fmt"time: {cpuTime() - init:>9.5f}s"
83+
echo fmt"time: {cpuTime() - init:>9.4f}s"
7884
else:
79-
echo fmt"time: {cpuTime() - init:>9.5f}s failures: {failedCount}"
85+
echo fmt"time: {cpuTime() - init:>9.4f}s failures: {failedCount}"
8086

8187
worker.joinThread()

0 commit comments

Comments
 (0)