Skip to content

Commit 496ebed

Browse files
authored
Merge pull request #174 from SenseUnit/new_cache
New cache
2 parents 8c4f008 + ad178b4 commit 496ebed

File tree

8 files changed

+206
-265
lines changed

8 files changed

+206
-265
lines changed

README.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,8 +492,6 @@ Usage of /home/user/go/bin/dumbproxy:
492492
email used for ACME registration
493493
-autocert-http string
494494
listen address for HTTP-01 challenges handler of ACME
495-
-autocert-local-cache-timeout duration
496-
timeout for cert cache queries (default 10s)
497495
-autocert-local-cache-ttl duration
498496
enables in-memory cache for certificates
499497
-autocert-whitelist value

certcache/local.go

Lines changed: 37 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -3,58 +3,61 @@ package certcache
33
import (
44
"context"
55
"io"
6-
"sync"
76
"time"
87

9-
"github.com/jellydator/ttlcache/v3"
8+
"github.com/Snawoot/secache"
109
"golang.org/x/crypto/acme/autocert"
10+
"golang.org/x/sync/singleflight"
1111
)
1212

1313
type certCacheKey = string
1414
type certCacheValue struct {
15+
ts time.Time
1516
res []byte
1617
err error
1718
}
1819

1920
type LocalCertCache struct {
20-
cache *ttlcache.Cache[certCacheKey, certCacheValue]
21-
next autocert.Cache
22-
stopOnce sync.Once
21+
cache secache.Cache[certCacheKey, *certCacheValue]
22+
sf singleflight.Group
23+
next autocert.Cache
2324
}
2425

25-
func NewLocalCertCache(next autocert.Cache, ttl, timeout time.Duration) *LocalCertCache {
26-
cache := ttlcache.New[certCacheKey, certCacheValue](
27-
ttlcache.WithTTL[certCacheKey, certCacheValue](ttl),
28-
ttlcache.WithLoader(
29-
ttlcache.NewSuppressedLoader(
30-
ttlcache.LoaderFunc[certCacheKey, certCacheValue](
31-
func(c *ttlcache.Cache[certCacheKey, certCacheValue], key certCacheKey) *ttlcache.Item[certCacheKey, certCacheValue] {
32-
ctx, cl := context.WithTimeout(context.Background(), timeout)
33-
defer cl()
34-
res, err := next.Get(ctx, key)
35-
if err != nil {
36-
return c.Set(key, certCacheValue{res, err}, -100)
37-
}
38-
return c.Set(key, certCacheValue{res, err}, 0)
39-
},
40-
),
41-
nil),
42-
),
43-
)
44-
go cache.Start()
26+
func NewLocalCertCache(next autocert.Cache, ttl time.Duration) *LocalCertCache {
4527
return &LocalCertCache{
46-
cache: cache,
47-
next: next,
28+
cache: *(secache.New[certCacheKey, *certCacheValue](3, func(key certCacheKey, item *certCacheValue) bool {
29+
return time.Now().Before(item.ts.Add(ttl))
30+
})),
31+
next: next,
4832
}
4933
}
5034

51-
func (cc *LocalCertCache) Get(_ context.Context, key string) ([]byte, error) {
52-
resItem := cc.cache.Get(key).Value()
35+
func (cc *LocalCertCache) Get(ctx context.Context, key string) ([]byte, error) {
36+
resItem, ok := cc.cache.GetValidOrDelete(key)
37+
if !ok {
38+
v, _, _ := cc.sf.Do(key, func() (any, error) {
39+
res, err := cc.next.Get(ctx, key)
40+
item := &certCacheValue{
41+
ts: time.Now(),
42+
res: res,
43+
err: err,
44+
}
45+
if ctx.Err() == nil {
46+
cc.cache.Set(key, item)
47+
}
48+
return item, err
49+
})
50+
resItem = v.(*certCacheValue)
51+
}
5352
return resItem.res, resItem.err
5453
}
5554

5655
func (cc *LocalCertCache) Put(ctx context.Context, key string, data []byte) error {
57-
cc.cache.Set(key, certCacheValue{data, nil}, 0)
56+
cc.cache.Set(key, &certCacheValue{
57+
ts: time.Now(),
58+
res: data,
59+
err: nil,
60+
})
5861
return cc.next.Put(ctx, key, data)
5962
}
6063

@@ -64,14 +67,10 @@ func (cc *LocalCertCache) Delete(ctx context.Context, key string) error {
6467
}
6568

6669
func (cc *LocalCertCache) Close() error {
67-
var err error
68-
cc.stopOnce.Do(func() {
69-
cc.cache.Stop()
70-
if cacheCloser, ok := cc.next.(io.Closer); ok {
71-
err = cacheCloser.Close()
72-
}
73-
})
74-
return err
70+
if cacheCloser, ok := cc.next.(io.Closer); ok {
71+
return cacheCloser.Close()
72+
}
73+
return nil
7574
}
7675

7776
var _ autocert.Cache = new(LocalCertCache)

dialer/cache.go

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,8 @@ import (
66
"net/url"
77
"time"
88

9-
"github.com/jellydator/ttlcache/v3"
9+
"github.com/Snawoot/secache"
1010
xproxy "golang.org/x/net/proxy"
11-
"golang.org/x/sync/singleflight"
1211
)
1312

1413
type dialerCacheKey struct {
@@ -17,20 +16,14 @@ type dialerCacheKey struct {
1716
}
1817

1918
type dialerCacheValue struct {
20-
dialer xproxy.Dialer
21-
err error
19+
expires time.Time
20+
dialer xproxy.Dialer
21+
err error
2222
}
2323

24-
var (
25-
dialerCache = ttlcache.New[dialerCacheKey, dialerCacheValue](
26-
ttlcache.WithDisableTouchOnHit[dialerCacheKey, dialerCacheValue](),
27-
)
28-
dialerCacheSingleFlight = new(singleflight.Group)
29-
)
30-
31-
func init() {
32-
go dialerCache.Start()
33-
}
24+
var dialerCache = secache.New[dialerCacheKey, *dialerCacheValue](3, func(key dialerCacheKey, val *dialerCacheValue) bool {
25+
return time.Now().Before(val.expires)
26+
})
3427

3528
func GetCachedDialer(u *url.URL, next xproxy.Dialer) (xproxy.Dialer, error) {
3629
params, err := url.ParseQuery(u.RawQuery)
@@ -51,29 +44,19 @@ func GetCachedDialer(u *url.URL, next xproxy.Dialer) (xproxy.Dialer, error) {
5144
if err != nil {
5245
return nil, fmt.Errorf("cached dialer: unable to parse TTL duration %q: %w", params.Get("ttl"), err)
5346
}
54-
cacheRes := dialerCache.Get(
47+
item := dialerCache.GetOrCreate(
5548
dialerCacheKey{
5649
url: params.Get("url"),
5750
next: next,
5851
},
59-
ttlcache.WithLoader[dialerCacheKey, dialerCacheValue](
60-
ttlcache.NewSuppressedLoader[dialerCacheKey, dialerCacheValue](
61-
ttlcache.LoaderFunc[dialerCacheKey, dialerCacheValue](
62-
func(c *ttlcache.Cache[dialerCacheKey, dialerCacheValue], key dialerCacheKey) *ttlcache.Item[dialerCacheKey, dialerCacheValue] {
63-
dialer, err := xproxy.FromURL(parsedURL, next)
64-
return c.Set(
65-
key,
66-
dialerCacheValue{
67-
dialer: dialer,
68-
err: err,
69-
},
70-
ttl,
71-
)
72-
},
73-
),
74-
dialerCacheSingleFlight,
75-
),
76-
),
77-
).Value()
78-
return cacheRes.dialer, cacheRes.err
52+
func() *dialerCacheValue {
53+
dialer, err := xproxy.FromURL(parsedURL, next)
54+
return &dialerCacheValue{
55+
expires: time.Now().Add(ttl),
56+
dialer: dialer,
57+
err: err,
58+
}
59+
},
60+
)
61+
return item.dialer, item.err
7962
}

dialer/rescache.go

Lines changed: 52 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,13 @@ import (
66
"net"
77
"net/netip"
88
"strings"
9-
"sync"
109
"time"
1110

12-
"github.com/SenseUnit/dumbproxy/dialer/dto"
11+
"github.com/Snawoot/secache"
1312
"github.com/hashicorp/go-multierror"
14-
"github.com/jellydator/ttlcache/v3"
13+
"golang.org/x/sync/singleflight"
14+
15+
"github.com/SenseUnit/dumbproxy/dialer/dto"
1516
)
1617

1718
type resolverCacheKey struct {
@@ -20,46 +21,36 @@ type resolverCacheKey struct {
2021
}
2122

2223
type resolverCacheValue struct {
23-
addrs []netip.Addr
24-
err error
24+
expires time.Time
25+
addrs []netip.Addr
26+
err error
2527
}
2628

2729
type NameResolveCachingDialer struct {
28-
cache *ttlcache.Cache[resolverCacheKey, resolverCacheValue]
29-
next Dialer
30-
startOnce sync.Once
31-
stopOnce sync.Once
30+
resolver Resolver
31+
cache secache.Cache[resolverCacheKey, *resolverCacheValue]
32+
sf singleflight.Group
33+
posTTL time.Duration
34+
negTTL time.Duration
35+
timeout time.Duration
36+
next Dialer
3237
}
3338

3439
func NewNameResolveCachingDialer(next Dialer, resolver Resolver, posTTL, negTTL, timeout time.Duration) *NameResolveCachingDialer {
35-
cache := ttlcache.New[resolverCacheKey, resolverCacheValue](
36-
ttlcache.WithDisableTouchOnHit[resolverCacheKey, resolverCacheValue](),
37-
ttlcache.WithLoader(
38-
ttlcache.NewSuppressedLoader(
39-
ttlcache.LoaderFunc[resolverCacheKey, resolverCacheValue](
40-
func(c *ttlcache.Cache[resolverCacheKey, resolverCacheValue], key resolverCacheKey) *ttlcache.Item[resolverCacheKey, resolverCacheValue] {
41-
ctx, cl := context.WithTimeout(context.Background(), timeout)
42-
defer cl()
43-
res, err := resolver.LookupNetIP(ctx, key.network, key.host)
44-
for i := range res {
45-
res[i] = res[i].Unmap()
46-
}
47-
setTTL := negTTL
48-
if err == nil {
49-
setTTL = posTTL
50-
}
51-
return c.Set(key, resolverCacheValue{
52-
addrs: res,
53-
err: err,
54-
}, setTTL)
55-
},
56-
),
57-
nil),
58-
),
59-
)
40+
// func(c *ttlcache.Cache[resolverCacheKey, resolverCacheValue], key resolverCacheKey) *ttlcache.Item[resolverCacheKey, resolverCacheValue] {
41+
// },
6042
return &NameResolveCachingDialer{
61-
cache: cache,
62-
next: next,
43+
resolver: resolver,
44+
cache: *(secache.New[resolverCacheKey, *resolverCacheValue](
45+
3,
46+
func(key resolverCacheKey, item *resolverCacheValue) bool {
47+
return time.Now().Before(item.expires)
48+
},
49+
)),
50+
posTTL: posTTL,
51+
negTTL: negTTL,
52+
timeout: timeout,
53+
next: next,
6354
}
6455
}
6556

@@ -91,16 +82,35 @@ func (nrcd *NameResolveCachingDialer) DialContext(ctx context.Context, network,
9182
}
9283

9384
host = strings.ToLower(host)
94-
95-
resItem := nrcd.cache.Get(resolverCacheKey{
85+
key := resolverCacheKey{
9686
network: resolveNetwork,
9787
host: host,
98-
})
99-
if resItem == nil {
100-
return nil, fmt.Errorf("cache lookup failed for pair <%q, %q>", resolveNetwork, host)
10188
}
10289

103-
res := resItem.Value()
90+
res, ok := nrcd.cache.GetValidOrDelete(key)
91+
if !ok {
92+
v, _, _ := nrcd.sf.Do(key.network+":"+key.host, func() (any, error) {
93+
ctx, cl := context.WithTimeout(context.Background(), nrcd.timeout)
94+
defer cl()
95+
res, err := nrcd.resolver.LookupNetIP(ctx, key.network, key.host)
96+
for i := range res {
97+
res[i] = res[i].Unmap()
98+
}
99+
setTTL := nrcd.negTTL
100+
if err == nil {
101+
setTTL = nrcd.posTTL
102+
}
103+
item := &resolverCacheValue{
104+
expires: time.Now().Add(setTTL),
105+
addrs: res,
106+
err: err,
107+
}
108+
nrcd.cache.Set(key, item)
109+
return item, nil
110+
})
111+
res = v.(*resolverCacheValue)
112+
}
113+
104114
if res.err != nil {
105115
return nil, res.err
106116
}
@@ -129,15 +139,5 @@ func (nrcd *NameResolveCachingDialer) WantsHostname(ctx context.Context, net, ad
129139
return WantsHostname(ctx, net, address, nrcd.next)
130140
}
131141

132-
func (nrcd *NameResolveCachingDialer) Start() {
133-
nrcd.startOnce.Do(func() {
134-
go nrcd.cache.Start()
135-
})
136-
}
137-
138-
func (nrcd *NameResolveCachingDialer) Stop() {
139-
nrcd.stopOnce.Do(nrcd.cache.Stop)
140-
}
141-
142142
var _ Dialer = new(NameResolveCachingDialer)
143143
var _ HostnameWanter = new(NameResolveCachingDialer)

0 commit comments

Comments
 (0)