diff --git a/accounts/keystore/account_cache.go b/accounts/keystore/account_cache.go index d3a98850c7a..f2db185c3dc 100644 --- a/accounts/keystore/account_cache.go +++ b/accounts/keystore/account_cache.go @@ -66,7 +66,7 @@ type accountCache struct { keydir string watcher *watcher mu sync.Mutex - all []accounts.Account + byURL map[accounts.URL][]accounts.Account byAddr map[common.Address][]accounts.Account throttle *time.Timer notify chan struct{} @@ -76,6 +76,7 @@ type accountCache struct { func newAccountCache(keydir string) (*accountCache, chan struct{}) { ac := &accountCache{ keydir: keydir, + byURL: make(map[accounts.URL][]accounts.Account), byAddr: make(map[common.Address][]accounts.Account), notify: make(chan struct{}, 1), fileC: fileCache{all: mapset.NewThreadUnsafeSet[string]()}, @@ -88,8 +89,11 @@ func (ac *accountCache) accounts() []accounts.Account { ac.maybeReload() ac.mu.Lock() defer ac.mu.Unlock() - cpy := make([]accounts.Account, len(ac.all)) - copy(cpy, ac.all) + cpy := make([]accounts.Account, 0, len(ac.byURL)) + for _, accs := range ac.byURL { + cpy = append(cpy, accs...) + } + sort.SliceStable(cpy, func(i, j int) bool { return cpy[i].URL.Cmp(cpy[j].URL) < 0 }) return cpy } @@ -104,14 +108,11 @@ func (ac *accountCache) add(newAccount accounts.Account) { ac.mu.Lock() defer ac.mu.Unlock() - i := sort.Search(len(ac.all), func(i int) bool { return ac.all[i].URL.Cmp(newAccount.URL) >= 0 }) - if i < len(ac.all) && ac.all[i] == newAccount { + if accs, ok := ac.byURL[newAccount.URL]; ok && slices.Contains(accs, newAccount) { return } // newAccount is not in the cache. - ac.all = append(ac.all, accounts.Account{}) - copy(ac.all[i+1:], ac.all[i:]) - ac.all[i] = newAccount + ac.byURL[newAccount.URL] = append(ac.byURL[newAccount.URL], newAccount) ac.byAddr[newAccount.Address] = append(ac.byAddr[newAccount.Address], newAccount) } @@ -120,7 +121,12 @@ func (ac *accountCache) delete(removed accounts.Account) { ac.mu.Lock() defer ac.mu.Unlock() - ac.all = removeAccount(ac.all, removed) + if bu := removeAccount(ac.byURL[removed.URL], removed); len(bu) == 0 { + delete(ac.byURL, removed.URL) + } else { + ac.byURL[removed.URL] = bu + } + if ba := removeAccount(ac.byAddr[removed.Address], removed); len(ba) == 0 { delete(ac.byAddr, removed.Address) } else { @@ -132,11 +138,14 @@ func (ac *accountCache) delete(removed accounts.Account) { func (ac *accountCache) deleteByFile(path string) { ac.mu.Lock() defer ac.mu.Unlock() - i := sort.Search(len(ac.all), func(i int) bool { return ac.all[i].URL.Path >= path }) - - if i < len(ac.all) && ac.all[i].URL.Path == path { - removed := ac.all[i] - ac.all = append(ac.all[:i], ac.all[i+1:]...) + url := accounts.URL{Scheme: KeyStoreScheme, Path: path} + if accs, ok := ac.byURL[url]; ok { + removed := accs[0] + if len(accs) == 1 { + delete(ac.byURL, url) + } else { + ac.byURL[url] = accs[1:] + } if ba := removeAccount(ac.byAddr[removed.Address], removed); len(ba) == 0 { delete(ac.byAddr, removed.Address) } else { @@ -166,24 +175,34 @@ func removeAccount(slice []accounts.Account, elem accounts.Account) []accounts.A // The exact matching rules are explained by the documentation of accounts.Account. // Callers must hold ac.mu. func (ac *accountCache) find(a accounts.Account) (accounts.Account, error) { - // Limit search to address candidates if possible. - matches := ac.all - if (a.Address != common.Address{}) { - matches = ac.byAddr[a.Address] - } if a.URL.Path != "" { // If only the basename is specified, complete the path. if !strings.ContainsRune(a.URL.Path, filepath.Separator) { a.URL.Path = filepath.Join(ac.keydir, a.URL.Path) } - for i := range matches { - if matches[i].URL == a.URL { - return matches[i], nil + } + // Limit search to address candidates if possible. + var matches []accounts.Account + if (a.Address != common.Address{}) { + matches = ac.byAddr[a.Address] + if a.URL.Path != "" { + for i := range matches { + if matches[i].URL == a.URL { + return matches[i], nil + } } } - if (a.Address == common.Address{}) { + } else { + if a.URL.Path != "" { + if accs, ok := ac.byURL[a.URL]; ok { + return accs[0], nil + } return accounts.Account{}, ErrNoMatch } + matches = make([]accounts.Account, 0, len(ac.byURL)) + for _, accs := range ac.byURL { + matches = append(matches, accs...) + } } switch len(matches) { case 1: @@ -193,7 +212,7 @@ func (ac *accountCache) find(a accounts.Account) (accounts.Account, error) { default: err := &AmbiguousAddrError{Addr: a.Address, Matches: make([]accounts.Account, len(matches))} copy(err.Matches, matches) - slices.SortFunc(err.Matches, byURL) + slices.SortStableFunc(err.Matches, byURL) return accounts.Account{}, err } } diff --git a/accounts/keystore/account_cache_test.go b/accounts/keystore/account_cache_test.go index c9a8cdfcef3..5ba7211d18f 100644 --- a/accounts/keystore/account_cache_test.go +++ b/accounts/keystore/account_cache_test.go @@ -17,6 +17,7 @@ package keystore import ( + "encoding/binary" "errors" "fmt" "math/rand" @@ -405,3 +406,93 @@ func forceCopyFile(dst, src string) error { } return os.WriteFile(dst, data, 0644) } + +func BenchmarkAdd(b *testing.B) { + for _, preload := range []int{10, 100, 1000, 1_000_000} { + b.Run(fmt.Sprintf("preload=%d", preload), func(b *testing.B) { + benchmarkAdd(b, preload) + }) + } +} + +func benchmarkAdd(b *testing.B, preload int) { + dir := filepath.Join("testdata", "dir") + cache, _ := newAccountCache(dir) + cache.watcher.running = true // prevent unexpected reloads + + for i := range preload { + acc := accounts.Account{ + URL: accounts.URL{Scheme: KeyStoreScheme, Path: fmt.Sprintf("dir/preload%08x", i)}, + } + binary.NativeEndian.PutUint64(acc.Address[0:], uint64(i)) + + cache.add(acc) + } + + b.ResetTimer() + b.ReportAllocs() + for i := range b.N { + acc := accounts.Account{ + URL: accounts.URL{Scheme: KeyStoreScheme, Path: fmt.Sprintf("dir/bench%08x", i)}, + } + binary.NativeEndian.PutUint64(acc.Address[12:], uint64(i)) + + cache.add(acc) + } +} + +func BenchmarkFind(b *testing.B) { + for _, preload := range []int{10, 100, 1000, 1_000_000} { + b.Run(fmt.Sprintf("preload=%d", preload), func(b *testing.B) { + benchmarkFind(b, preload) + }) + } +} + +func benchmarkFind(b *testing.B, preload int) { + dir := filepath.Join("testdata", "dir") + cache, _ := newAccountCache(dir) + cache.watcher.running = true // prevent unexpected reloads + + for i := range preload { + acc := accounts.Account{ + URL: accounts.URL{Scheme: KeyStoreScheme, Path: fmt.Sprintf("dir/account%08x", i)}, + } + binary.NativeEndian.PutUint64(acc.Address[0:], uint64(i)) + + cache.add(acc) + } + + b.Run("by address", func(b *testing.B) { + acc := accounts.Account{} + binary.NativeEndian.PutUint64(acc.Address[0:], uint64(preload/2)) + + b.ResetTimer() + b.ReportAllocs() + for range b.N { + cache.find(acc) + } + }) + + b.Run("by path", func(b *testing.B) { + acc := accounts.Account{ + URL: accounts.URL{Scheme: KeyStoreScheme, Path: fmt.Sprintf("dir/account%08x", preload/2)}, + } + + b.ResetTimer() + b.ReportAllocs() + for range b.N { + cache.find(acc) + } + }) + + b.Run("ambiguous", func(b *testing.B) { + acc := accounts.Account{} + + b.ResetTimer() + b.ReportAllocs() + for range b.N { + cache.find(acc) + } + }) +}