|
| 1 | +package dialer |
| 2 | + |
| 3 | +import ( |
| 4 | + "errors" |
| 5 | + "fmt" |
| 6 | + "net/url" |
| 7 | + "time" |
| 8 | + |
| 9 | + "github.com/jellydator/ttlcache/v3" |
| 10 | + xproxy "golang.org/x/net/proxy" |
| 11 | + "golang.org/x/sync/singleflight" |
| 12 | +) |
| 13 | + |
| 14 | +type dialerCacheKey struct { |
| 15 | + url string |
| 16 | + next xproxy.Dialer |
| 17 | +} |
| 18 | + |
| 19 | +type dialerCacheValue struct { |
| 20 | + dialer xproxy.Dialer |
| 21 | + err error |
| 22 | +} |
| 23 | + |
| 24 | +var ( |
| 25 | + dialerCache = ttlcache.New[dialerCacheKey, dialerCacheValue]( |
| 26 | + ttlcache.WithDisableTouchOnHit[dialerCacheKey, dialerCacheValue](), |
| 27 | + ) |
| 28 | + dialerCacheSingleFlight = new(singleflight.Group) |
| 29 | +) |
| 30 | + |
| 31 | +func GetCachedDialer(u *url.URL, next xproxy.Dialer) (xproxy.Dialer, error) { |
| 32 | + params, err := url.ParseQuery(u.RawQuery) |
| 33 | + if err != nil { |
| 34 | + return nil, err |
| 35 | + } |
| 36 | + if !params.Has("url") { |
| 37 | + return nil, errors.New("cached dialer: no \"url\" parameter specified") |
| 38 | + } |
| 39 | + parsedURL, err := url.Parse(params.Get("url")) |
| 40 | + if err != nil { |
| 41 | + return nil, fmt.Errorf("unable to parse proxy URL: %w", err) |
| 42 | + } |
| 43 | + if !params.Has("ttl") { |
| 44 | + return nil, errors.New("cached dialer: no \"ttl\" parameter specified") |
| 45 | + } |
| 46 | + ttl, err := time.ParseDuration(params.Get("ttl")) |
| 47 | + if err != nil { |
| 48 | + return nil, fmt.Errorf("cached dialer: unable to parse TTL duration %q: %w", params.Get("ttl"), err) |
| 49 | + } |
| 50 | + cacheRes := dialerCache.Get( |
| 51 | + dialerCacheKey{ |
| 52 | + url: params.Get("url"), |
| 53 | + next: next, |
| 54 | + }, |
| 55 | + ttlcache.WithLoader[dialerCacheKey, dialerCacheValue]( |
| 56 | + ttlcache.NewSuppressedLoader[dialerCacheKey, dialerCacheValue]( |
| 57 | + ttlcache.LoaderFunc[dialerCacheKey, dialerCacheValue]( |
| 58 | + func(c *ttlcache.Cache[dialerCacheKey, dialerCacheValue], key dialerCacheKey) *ttlcache.Item[dialerCacheKey, dialerCacheValue] { |
| 59 | + dialer, err := xproxy.FromURL(parsedURL, next) |
| 60 | + return c.Set( |
| 61 | + key, |
| 62 | + dialerCacheValue{ |
| 63 | + dialer: dialer, |
| 64 | + err: err, |
| 65 | + }, |
| 66 | + ttl, |
| 67 | + ) |
| 68 | + }, |
| 69 | + ), |
| 70 | + dialerCacheSingleFlight, |
| 71 | + ), |
| 72 | + ), |
| 73 | + ).Value() |
| 74 | + return cacheRes.dialer, cacheRes.err |
| 75 | +} |
0 commit comments