Skip to content

Commit d2362b0

Browse files
committed
Use coroutines.
1 parent 17f1681 commit d2362b0

File tree

3 files changed

+91
-38
lines changed

3 files changed

+91
-38
lines changed

ext/fileio/coro.go

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package fileio
2+
3+
import (
4+
"fmt"
5+
6+
"github.com/ncruces/go-sqlite3/internal/util"
7+
)
8+
9+
// Adapted from: https://research.swtch.com/coro
10+
11+
const errCoroCanceled = util.ErrorString("coroutine canceled")
12+
13+
func coroNew[In, Out any](f func(In, func(Out) In) Out) (resume func(In) (Out, bool), cancel func()) {
14+
type msg[T any] struct {
15+
panic any
16+
val T
17+
}
18+
19+
cin := make(chan msg[In])
20+
cout := make(chan msg[Out])
21+
running := true
22+
resume = func(in In) (out Out, ok bool) {
23+
if !running {
24+
return
25+
}
26+
cin <- msg[In]{val: in}
27+
m := <-cout
28+
if m.panic != nil {
29+
panic(m.panic)
30+
}
31+
return m.val, running
32+
}
33+
cancel = func() {
34+
if !running {
35+
return
36+
}
37+
e := fmt.Errorf("%w", errCoroCanceled)
38+
cin <- msg[In]{panic: e}
39+
m := <-cout
40+
if m.panic != nil && m.panic != e {
41+
panic(m.panic)
42+
}
43+
}
44+
yield := func(out Out) In {
45+
cout <- msg[Out]{val: out}
46+
m := <-cin
47+
if m.panic != nil {
48+
panic(m.panic)
49+
}
50+
return m.val
51+
}
52+
go func() {
53+
defer func() {
54+
if running {
55+
running = false
56+
cout <- msg[Out]{panic: recover()}
57+
}
58+
}()
59+
var out Out
60+
m := <-cin
61+
if m.panic == nil {
62+
out = f(m.val, yield)
63+
}
64+
running = false
65+
cout <- msg[Out]{val: out}
66+
}()
67+
return resume, cancel
68+
}

ext/fileio/fsdir.go

Lines changed: 22 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,12 @@ func (d fsdir) Open() (sqlite3.VTabCursor, error) {
5353

5454
type cursor struct {
5555
fsdir
56-
curr entry
57-
next chan entry
58-
done chan struct{}
59-
base string
60-
rowID int64
61-
eof bool
62-
open bool
56+
base string
57+
resume func(struct{}) (entry, bool)
58+
cancel func()
59+
curr entry
60+
eof bool
61+
rowID int64
6362
}
6463

6564
type entry struct {
@@ -69,11 +68,8 @@ type entry struct {
6968
}
7069

7170
func (c *cursor) Close() error {
72-
if c.open {
73-
close(c.done)
74-
s := <-c.next
75-
c.open = false
76-
return s.err
71+
if c.cancel != nil {
72+
c.cancel()
7773
}
7874
return nil
7975
}
@@ -96,17 +92,25 @@ func (c *cursor) Filter(idxNum int, idxStr string, arg ...sqlite3.Value) error {
9692
c.base = base
9793
}
9894

99-
c.rowID = 0
95+
c.resume, c.cancel = coroNew(func(_ struct{}, yield func(entry) struct{}) entry {
96+
walkDir := func(path string, d fs.DirEntry, err error) error {
97+
yield(entry{d, err, path})
98+
return nil
99+
}
100+
if c.fsys != nil {
101+
fs.WalkDir(c.fsys, root, walkDir)
102+
} else {
103+
filepath.WalkDir(root, walkDir)
104+
}
105+
return entry{}
106+
})
100107
c.eof = false
101-
c.open = true
102-
c.next = make(chan entry)
103-
c.done = make(chan struct{})
104-
go c.WalkDir(root)
108+
c.rowID = 0
105109
return c.Next()
106110
}
107111

108112
func (c *cursor) Next() error {
109-
curr, ok := <-c.next
113+
curr, ok := c.resume(struct{}{})
110114
c.curr = curr
111115
c.eof = !ok
112116
c.rowID++
@@ -166,22 +170,3 @@ func (c *cursor) Column(ctx *sqlite3.Context, n int) error {
166170
}
167171
return nil
168172
}
169-
170-
func (c *cursor) WalkDir(path string) {
171-
defer close(c.next)
172-
173-
if c.fsys != nil {
174-
fs.WalkDir(c.fsys, path, c.WalkDirFunc)
175-
} else {
176-
filepath.WalkDir(path, c.WalkDirFunc)
177-
}
178-
}
179-
180-
func (c *cursor) WalkDirFunc(path string, d fs.DirEntry, err error) error {
181-
select {
182-
case <-c.done:
183-
return fs.SkipAll
184-
case c.next <- entry{d, err, path}:
185-
return nil
186-
}
187-
}

ext/fileio/fsdir_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func Test_fsdir(t *testing.T) {
2828
}
2929
defer db.Close()
3030

31-
rows, err := db.Query(`SELECT * FROM fsdir('.', '.') LIMIT 4`)
31+
rows, err := db.Query(`SELECT * FROM fsdir('.', '.')`)
3232
if err != nil {
3333
t.Fatal(err)
3434
}

0 commit comments

Comments
 (0)