diff --git a/doh-client/client.go b/doh-client/client.go index d768f10..576eb53 100644 --- a/doh-client/client.go +++ b/doh-client/client.go @@ -192,7 +192,25 @@ func NewClient(conf *config.Config) (c *Client, err error) { } c.selector = s - + case config.Hash: + if c.conf.Other.Verbose { + log.Println(config.Random, "mode start") + } + s := selector.NewHashSelector() + for _, u := range c.conf.Upstream.UpstreamGoogle { + if err := s.Add(u.URL, selector.Google); err != nil { + return nil, err + } + } + + for _, u := range c.conf.Upstream.UpstreamIETF { + if err := s.Add(u.URL, selector.IETF); err != nil { + return nil, err + } + } + + c.selector = s + default: if c.conf.Other.Verbose { log.Println(config.Random, "mode start") diff --git a/doh-client/config/config.go b/doh-client/config/config.go index 01dbe31..4122ddd 100644 --- a/doh-client/config/config.go +++ b/doh-client/config/config.go @@ -33,6 +33,7 @@ const ( Random = "random" NginxWRR = "weighted_round_robin" LVSWRR = "lvs_weighted_round_robin" + Hash = "hostname_hash" ) type upstreamDetail struct { diff --git a/doh-client/selector/hashSelector.go b/doh-client/selector/hashSelector.go new file mode 100644 index 0000000..553f4e2 --- /dev/null +++ b/doh-client/selector/hashSelector.go @@ -0,0 +1,54 @@ +package selector + +import ( + "errors" + "math/rand" + "time" +) + +func init() { + rand.Seed(time.Now().UnixNano()) +} + +type HashSelector struct { + upstreams []*Upstream +} + +func NewHashSelector() *HashSelector { + return new(HashSelector) +} + +func (rs *HashSelector) Add(url string, upstreamType UpstreamType) (err error) { + switch upstreamType { + case Google: + rs.upstreams = append(rs.upstreams, &Upstream{ + Type: Google, + URL: url, + RequestType: "application/dns-json", + }) + + case IETF: + rs.upstreams = append(rs.upstreams, &Upstream{ + Type: IETF, + URL: url, + RequestType: "application/dns-message", + }) + + default: + return errors.New("unknown upstream type") + } + + return nil +} + +func (rs *HashSelector) Get() *Upstream { + // here, if we have the name to be resolved (a string) + // we could compute the modulo over the size of upstream servers + // something like url.hash()%len(rs.upstreams) + // how to refactor Get() to get the name + return rs.upstreams[url.hash()%len(rs.upstreams)] +} + +func (rs *HashSelector) StartEvaluate() {} + +func (rs *HashSelector) ReportUpstreamStatus(upstream *Upstream, upstreamStatus upstreamStatus) {}