Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 43 additions & 24 deletions accounts/keystore/account_cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ type accountCache struct {
keydir string
watcher *watcher
mu sync.Mutex
all []accounts.Account
byURL map[accounts.URL][]accounts.Account
byAddr map[common.Address][]accounts.Account
throttle *time.Timer
notify chan struct{}
Expand All @@ -76,6 +76,7 @@ type accountCache struct {
func newAccountCache(keydir string) (*accountCache, chan struct{}) {
ac := &accountCache{
keydir: keydir,
byURL: make(map[accounts.URL][]accounts.Account),
byAddr: make(map[common.Address][]accounts.Account),
notify: make(chan struct{}, 1),
fileC: fileCache{all: mapset.NewThreadUnsafeSet[string]()},
Expand All @@ -88,8 +89,11 @@ func (ac *accountCache) accounts() []accounts.Account {
ac.maybeReload()
ac.mu.Lock()
defer ac.mu.Unlock()
cpy := make([]accounts.Account, len(ac.all))
copy(cpy, ac.all)
cpy := make([]accounts.Account, 0, len(ac.byURL))
for _, accs := range ac.byURL {
cpy = append(cpy, accs...)
}
sort.SliceStable(cpy, func(i, j int) bool { return cpy[i].URL.Cmp(cpy[j].URL) < 0 })
return cpy
}

Expand All @@ -104,14 +108,11 @@ func (ac *accountCache) add(newAccount accounts.Account) {
ac.mu.Lock()
defer ac.mu.Unlock()

i := sort.Search(len(ac.all), func(i int) bool { return ac.all[i].URL.Cmp(newAccount.URL) >= 0 })
if i < len(ac.all) && ac.all[i] == newAccount {
if accs, ok := ac.byURL[newAccount.URL]; ok && slices.Contains(accs, newAccount) {
return
}
// newAccount is not in the cache.
ac.all = append(ac.all, accounts.Account{})
copy(ac.all[i+1:], ac.all[i:])
ac.all[i] = newAccount
ac.byURL[newAccount.URL] = append(ac.byURL[newAccount.URL], newAccount)
ac.byAddr[newAccount.Address] = append(ac.byAddr[newAccount.Address], newAccount)
}

Expand All @@ -120,7 +121,12 @@ func (ac *accountCache) delete(removed accounts.Account) {
ac.mu.Lock()
defer ac.mu.Unlock()

ac.all = removeAccount(ac.all, removed)
if bu := removeAccount(ac.byURL[removed.URL], removed); len(bu) == 0 {
delete(ac.byURL, removed.URL)
} else {
ac.byURL[removed.URL] = bu
}

if ba := removeAccount(ac.byAddr[removed.Address], removed); len(ba) == 0 {
delete(ac.byAddr, removed.Address)
} else {
Expand All @@ -132,11 +138,14 @@ func (ac *accountCache) delete(removed accounts.Account) {
func (ac *accountCache) deleteByFile(path string) {
ac.mu.Lock()
defer ac.mu.Unlock()
i := sort.Search(len(ac.all), func(i int) bool { return ac.all[i].URL.Path >= path })

if i < len(ac.all) && ac.all[i].URL.Path == path {
removed := ac.all[i]
ac.all = append(ac.all[:i], ac.all[i+1:]...)
url := accounts.URL{Scheme: KeyStoreScheme, Path: path}
if accs, ok := ac.byURL[url]; ok {
removed := accs[0]
if len(accs) == 1 {
delete(ac.byURL, url)
} else {
ac.byURL[url] = accs[1:]
}
if ba := removeAccount(ac.byAddr[removed.Address], removed); len(ba) == 0 {
delete(ac.byAddr, removed.Address)
} else {
Expand Down Expand Up @@ -166,24 +175,34 @@ func removeAccount(slice []accounts.Account, elem accounts.Account) []accounts.A
// The exact matching rules are explained by the documentation of accounts.Account.
// Callers must hold ac.mu.
func (ac *accountCache) find(a accounts.Account) (accounts.Account, error) {
// Limit search to address candidates if possible.
matches := ac.all
if (a.Address != common.Address{}) {
matches = ac.byAddr[a.Address]
}
if a.URL.Path != "" {
// If only the basename is specified, complete the path.
if !strings.ContainsRune(a.URL.Path, filepath.Separator) {
a.URL.Path = filepath.Join(ac.keydir, a.URL.Path)
}
for i := range matches {
if matches[i].URL == a.URL {
return matches[i], nil
}
// Limit search to address candidates if possible.
var matches []accounts.Account
if (a.Address != common.Address{}) {
matches = ac.byAddr[a.Address]
if a.URL.Path != "" {
for i := range matches {
if matches[i].URL == a.URL {
return matches[i], nil
}
}
}
if (a.Address == common.Address{}) {
} else {
if a.URL.Path != "" {
if accs, ok := ac.byURL[a.URL]; ok {
return accs[0], nil
}
return accounts.Account{}, ErrNoMatch
}
matches = make([]accounts.Account, 0, len(ac.byURL))
for _, accs := range ac.byURL {
matches = append(matches, accs...)
}
}
switch len(matches) {
case 1:
Expand All @@ -193,7 +212,7 @@ func (ac *accountCache) find(a accounts.Account) (accounts.Account, error) {
default:
err := &AmbiguousAddrError{Addr: a.Address, Matches: make([]accounts.Account, len(matches))}
copy(err.Matches, matches)
slices.SortFunc(err.Matches, byURL)
slices.SortStableFunc(err.Matches, byURL)
return accounts.Account{}, err
}
}
Expand Down
91 changes: 91 additions & 0 deletions accounts/keystore/account_cache_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package keystore

import (
"encoding/binary"
"errors"
"fmt"
"math/rand"
Expand Down Expand Up @@ -405,3 +406,93 @@ func forceCopyFile(dst, src string) error {
}
return os.WriteFile(dst, data, 0644)
}

func BenchmarkAdd(b *testing.B) {
for _, preload := range []int{10, 100, 1000, 1_000_000} {
b.Run(fmt.Sprintf("preload=%d", preload), func(b *testing.B) {
benchmarkAdd(b, preload)
})
}
}

func benchmarkAdd(b *testing.B, preload int) {
dir := filepath.Join("testdata", "dir")
cache, _ := newAccountCache(dir)
cache.watcher.running = true // prevent unexpected reloads

for i := range preload {
acc := accounts.Account{
URL: accounts.URL{Scheme: KeyStoreScheme, Path: fmt.Sprintf("dir/preload%08x", i)},
}
binary.NativeEndian.PutUint64(acc.Address[0:], uint64(i))

cache.add(acc)
}

b.ResetTimer()
b.ReportAllocs()
for i := range b.N {
acc := accounts.Account{
URL: accounts.URL{Scheme: KeyStoreScheme, Path: fmt.Sprintf("dir/bench%08x", i)},
}
binary.NativeEndian.PutUint64(acc.Address[12:], uint64(i))

cache.add(acc)
}
}

func BenchmarkFind(b *testing.B) {
for _, preload := range []int{10, 100, 1000, 1_000_000} {
b.Run(fmt.Sprintf("preload=%d", preload), func(b *testing.B) {
benchmarkFind(b, preload)
})
}
}

func benchmarkFind(b *testing.B, preload int) {
dir := filepath.Join("testdata", "dir")
cache, _ := newAccountCache(dir)
cache.watcher.running = true // prevent unexpected reloads

for i := range preload {
acc := accounts.Account{
URL: accounts.URL{Scheme: KeyStoreScheme, Path: fmt.Sprintf("dir/account%08x", i)},
}
binary.NativeEndian.PutUint64(acc.Address[0:], uint64(i))

cache.add(acc)
}

b.Run("by address", func(b *testing.B) {
acc := accounts.Account{}
binary.NativeEndian.PutUint64(acc.Address[0:], uint64(preload/2))

b.ResetTimer()
b.ReportAllocs()
for range b.N {
cache.find(acc)
}
})

b.Run("by path", func(b *testing.B) {
acc := accounts.Account{
URL: accounts.URL{Scheme: KeyStoreScheme, Path: fmt.Sprintf("dir/account%08x", preload/2)},
}

b.ResetTimer()
b.ReportAllocs()
for range b.N {
cache.find(acc)
}
})

b.Run("ambiguous", func(b *testing.B) {
acc := accounts.Account{}

b.ResetTimer()
b.ReportAllocs()
for range b.N {
cache.find(acc)
}
})
}