diff --git a/epsilon_greedy.go b/epsilon_greedy.go index f43fae7..10e60cf 100644 --- a/epsilon_greedy.go +++ b/epsilon_greedy.go @@ -1,34 +1,38 @@ package hostpool import ( + "errors" "log" "math/rand" + "sync" "time" ) type epsilonHostPoolResponse struct { - standardHostPoolResponse - started time.Time - ended time.Time + HostPoolResponse + started time.Time + ended time.Time + selector *epsilonGreedySelector } func (r *epsilonHostPoolResponse) Mark(err error) { - r.Do(func() { + if err == nil { r.ended = time.Now() - doMark(err, r) - }) - + r.selector.recordTiming(r) + } + r.HostPoolResponse.Mark(err) } -type epsilonGreedyHostPool struct { - standardHostPool // TODO - would be nifty if we could embed HostPool and Locker interfaces +type epsilonGreedySelector struct { + Selector + sync.Locker epsilon float32 // this is our exploration factor decayDuration time.Duration EpsilonValueCalculator // embed the epsilonValueCalculator timer } -// Construct an Epsilon Greedy HostPool +// Construct an Epsilon Greedy Selector // // Epsilon Greedy is an algorithm that allows HostPool not only to track failure state, // but also to learn about "better" options in terms of speed, and to pick from available hosts @@ -42,86 +46,84 @@ type epsilonGreedyHostPool struct { // To compute the weighting scores, we perform a weighted average of recent response times, over the course of // `decayDuration`. decayDuration may be set to 0 to use the default value of 5 minutes // We then use the supplied EpsilonValueCalculator to calculate a score from that weighted average response time. -func NewEpsilonGreedy(hosts []string, decayDuration time.Duration, calc EpsilonValueCalculator) HostPool { +func NewEpsilonGreedy(decayDuration time.Duration, calc EpsilonValueCalculator) Selector { if decayDuration <= 0 { decayDuration = defaultDecayDuration } - stdHP := New(hosts).(*standardHostPool) - p := &epsilonGreedyHostPool{ - standardHostPool: *stdHP, + ss := &standardSelector{} + s := &epsilonGreedySelector{ + Selector: ss, + Locker: ss, epsilon: float32(initialEpsilon), decayDuration: decayDuration, EpsilonValueCalculator: calc, timer: &realTimer{}, } + return s +} + +func (s *epsilonGreedySelector) Init(hosts []string) { + s.Selector.Init(hosts) // allocate structures - for _, h := range p.hostList { + for _, h := range s.Selector.(*standardSelector).hostList { h.epsilonCounts = make([]int64, epsilonBuckets) h.epsilonValues = make([]int64, epsilonBuckets) } - go p.epsilonGreedyDecay() - return p -} - -func (p *epsilonGreedyHostPool) SetEpsilon(newEpsilon float32) { - p.Lock() - defer p.Unlock() - p.epsilon = newEpsilon + go s.epsilonGreedyDecay() } -func (p *epsilonGreedyHostPool) epsilonGreedyDecay() { - durationPerBucket := p.decayDuration / epsilonBuckets +func (s *epsilonGreedySelector) epsilonGreedyDecay() { + durationPerBucket := s.decayDuration / epsilonBuckets ticker := time.Tick(durationPerBucket) for { <-ticker - p.performEpsilonGreedyDecay() + s.performEpsilonGreedyDecay() } } -func (p *epsilonGreedyHostPool) performEpsilonGreedyDecay() { - p.Lock() - for _, h := range p.hostList { +func (s *epsilonGreedySelector) performEpsilonGreedyDecay() { + s.Lock() + for _, h := range s.Selector.(*standardSelector).hostList { h.epsilonIndex += 1 h.epsilonIndex = h.epsilonIndex % epsilonBuckets h.epsilonCounts[h.epsilonIndex] = 0 h.epsilonValues[h.epsilonIndex] = 0 } - p.Unlock() + s.Unlock() } -func (p *epsilonGreedyHostPool) Get() HostPoolResponse { - p.Lock() - defer p.Unlock() - host := p.getEpsilonGreedy() - started := time.Now() - return &epsilonHostPoolResponse{ - standardHostPoolResponse: standardHostPoolResponse{host: host, pool: p}, - started: started, +func (s *epsilonGreedySelector) SelectNextHost() string { + s.Lock() + host, err := s.getEpsilonGreedy() + s.Unlock() + if err != nil { + host = s.Selector.SelectNextHost() } + return host } -func (p *epsilonGreedyHostPool) getEpsilonGreedy() string { +func (s *epsilonGreedySelector) getEpsilonGreedy() (string, error) { var hostToUse *hostEntry // this is our exploration phase - if rand.Float32() < p.epsilon { - p.epsilon = p.epsilon * epsilonDecay - if p.epsilon < minEpsilon { - p.epsilon = minEpsilon + if rand.Float32() < s.epsilon { + s.epsilon = s.epsilon * epsilonDecay + if s.epsilon < minEpsilon { + s.epsilon = minEpsilon } - return p.getRoundRobin() + return "", errors.New("Exploration") } // calculate values for each host in the 0..1 range (but not ormalized) var possibleHosts []*hostEntry now := time.Now() var sumValues float64 - for _, h := range p.hostList { + for _, h := range s.Selector.(*standardSelector).hostList { if h.canTryHost(now) { v := h.getWeightedAverageResponseTime() if v > 0 { - ev := p.CalcValueFromAvgResponseTime(v) + ev := s.CalcValueFromAvgResponseTime(v) h.epsilonValue = ev sumValues += ev possibleHosts = append(possibleHosts, h) @@ -151,36 +153,40 @@ func (p *epsilonGreedyHostPool) getEpsilonGreedy() string { if len(possibleHosts) != 0 { log.Println("Failed to randomly choose a host, Dan loses") } - return p.getRoundRobin() - } - - if hostToUse.dead { - hostToUse.willRetryHost(p.maxRetryInterval) + return "", errors.New("No host chosen") } - return hostToUse.host + return hostToUse.host, nil } -func (p *epsilonGreedyHostPool) markSuccess(hostR HostPoolResponse) { - // first do the base markSuccess - a little redundant with host lookup but cleaner than repeating logic - p.standardHostPool.markSuccess(hostR) - eHostR, ok := hostR.(*epsilonHostPoolResponse) - if !ok { - log.Printf("Incorrect type in eps markSuccess!") // TODO reflection to print out offending type - return - } - host := eHostR.host - duration := p.between(eHostR.started, eHostR.ended) +func (s *epsilonGreedySelector) recordTiming(eHostR *epsilonHostPoolResponse) { + host := eHostR.Host() + duration := s.between(eHostR.started, eHostR.ended) - p.Lock() - defer p.Unlock() - h, ok := p.hosts[host] + s.Lock() + defer s.Unlock() + h, ok := s.Selector.(*standardSelector).hosts[host] if !ok { - log.Fatalf("host %s not in HostPool %v", host, p.Hosts()) + log.Fatalf("host %s not in HostPool", host) } h.epsilonCounts[h.epsilonIndex]++ h.epsilonValues[h.epsilonIndex] += int64(duration.Seconds() * 1000) } +func (s *epsilonGreedySelector) MakeHostResponse(host string) HostPoolResponse { + resp := s.Selector.MakeHostResponse(host) + return s.toEpsilonHostPoolResponse(resp) +} + +// Convert regular response to one equipped for EG. Doesn't require lock, for now +func (s *epsilonGreedySelector) toEpsilonHostPoolResponse(resp HostPoolResponse) *epsilonHostPoolResponse { + started := time.Now() + return &epsilonHostPoolResponse{ + HostPoolResponse: resp, + started: started, + selector: s, + } +} + // --- timer: this just exists for testing type timer interface { diff --git a/example_test.go b/example_test.go index 88d0e55..54843be 100644 --- a/example_test.go +++ b/example_test.go @@ -1,13 +1,13 @@ package hostpool import ( - "github.com/bitly/go-hostpool" + "errors" ) func ExampleNewEpsilonGreedy() { - hp := hostpool.NewEpsilonGreedy([]string{"a", "b"}, 0, &hostpool.LinearEpsilonValueCalculator{}) + hp := NewWithSelector([]string{"a", "b"}, NewEpsilonGreedy(0, &LinearEpsilonValueCalculator{})) hostResponse := hp.Get() hostname := hostResponse.Host() - err := nil // (make a request with hostname) + err := errors.New("I am your http error from " + hostname) // (make a request with hostname) hostResponse.Mark(err) } diff --git a/hostpool.go b/hostpool.go index 25ca1fb..5e5638f 100644 --- a/hostpool.go +++ b/hostpool.go @@ -4,8 +4,6 @@ package hostpool import ( - "log" - "sync" "time" ) @@ -23,13 +21,11 @@ func Version() string { type HostPoolResponse interface { Host() string Mark(error) - hostPool() HostPool } type standardHostPoolResponse struct { host string - sync.Once - pool HostPool + ss *standardSelector } // --- HostPool structs and interfaces ---- @@ -39,21 +35,13 @@ type standardHostPoolResponse struct { // get the list of all Hosts, and use ResetAll to reset state. type HostPool interface { Get() HostPoolResponse - // keep the marks separate so we can override independently - markSuccess(HostPoolResponse) - markFailed(HostPoolResponse) - ResetAll() Hosts() []string } type standardHostPool struct { - sync.RWMutex - hosts map[string]*hostEntry - hostList []*hostEntry - initialRetryDelay time.Duration - maxRetryInterval time.Duration - nextHostIndex int + hosts []string + Selector } // ------ constants ------------------- @@ -66,127 +54,31 @@ const defaultDecayDuration = time.Duration(5) * time.Minute // Construct a basic HostPool using the hostnames provided func New(hosts []string) HostPool { - p := &standardHostPool{ - hosts: make(map[string]*hostEntry, len(hosts)), - hostList: make([]*hostEntry, len(hosts)), - initialRetryDelay: time.Duration(30) * time.Second, - maxRetryInterval: time.Duration(900) * time.Second, - } + return NewWithSelector(hosts, &standardSelector{}) +} - for i, h := range hosts { - e := &hostEntry{ - host: h, - retryDelay: p.initialRetryDelay, - } - p.hosts[h] = e - p.hostList[i] = e +func NewWithSelector(hosts []string, s Selector) HostPool { + s.Init(hosts) + return &standardHostPool{ + hosts, + s, } - - return p } func (r *standardHostPoolResponse) Host() string { return r.host } -func (r *standardHostPoolResponse) hostPool() HostPool { - return r.pool -} - func (r *standardHostPoolResponse) Mark(err error) { - r.Do(func() { - doMark(err, r) - }) -} - -func doMark(err error, r HostPoolResponse) { - if err == nil { - r.hostPool().markSuccess(r) - } else { - r.hostPool().markFailed(r) - } + r.ss.MarkHost(r.host, err) } // return an entry from the HostPool func (p *standardHostPool) Get() HostPoolResponse { - p.Lock() - defer p.Unlock() - host := p.getRoundRobin() - return &standardHostPoolResponse{host: host, pool: p} + host := p.SelectNextHost() + return p.MakeHostResponse(host) } -func (p *standardHostPool) getRoundRobin() string { - now := time.Now() - hostCount := len(p.hostList) - for i := range p.hostList { - // iterate via sequenece from where we last iterated - currentIndex := (i + p.nextHostIndex) % hostCount - - h := p.hostList[currentIndex] - if !h.dead { - p.nextHostIndex = currentIndex + 1 - return h.host - } - if h.nextRetry.Before(now) { - h.willRetryHost(p.maxRetryInterval) - p.nextHostIndex = currentIndex + 1 - return h.host - } - } - - // all hosts are down. re-add them - p.doResetAll() - p.nextHostIndex = 0 - return p.hostList[0].host -} - -func (p *standardHostPool) ResetAll() { - p.Lock() - defer p.Unlock() - p.doResetAll() -} - -// this actually performs the logic to reset, -// and should only be called when the lock has -// already been acquired -func (p *standardHostPool) doResetAll() { - for _, h := range p.hosts { - h.dead = false - } -} - -func (p *standardHostPool) markSuccess(hostR HostPoolResponse) { - host := hostR.Host() - p.Lock() - defer p.Unlock() - - h, ok := p.hosts[host] - if !ok { - log.Fatalf("host %s not in HostPool %v", host, p.Hosts()) - } - h.dead = false -} - -func (p *standardHostPool) markFailed(hostR HostPoolResponse) { - host := hostR.Host() - p.Lock() - defer p.Unlock() - h, ok := p.hosts[host] - if !ok { - log.Fatalf("host %s not in HostPool %v", host, p.Hosts()) - } - if !h.dead { - h.dead = true - h.retryCount = 0 - h.retryDelay = p.initialRetryDelay - h.nextRetry = time.Now().Add(h.retryDelay) - } - -} func (p *standardHostPool) Hosts() []string { - hosts := make([]string, len(p.hosts)) - for host, _ := range p.hosts { - hosts = append(hosts, host) - } - return hosts + return p.hosts } diff --git a/hostpool_test.go b/hostpool_test.go index 352c8ea..f51d04b 100644 --- a/hostpool_test.go +++ b/hostpool_test.go @@ -17,7 +17,7 @@ func TestHostPool(t *testing.T) { dummyErr := errors.New("Dummy Error") - p := New([]string{"a", "b", "c"}) + p := New([]string{"a", "b", "c"}).(*standardHostPool) assert.Equal(t, p.Get().Host(), "a") assert.Equal(t, p.Get().Host(), "b") assert.Equal(t, p.Get().Host(), "c") @@ -32,18 +32,16 @@ func TestHostPool(t *testing.T) { respC.Mark(nil) // get again, and verify that it's still c assert.Equal(t, p.Get().Host(), "c") - // now try to mark b as success; should fail because already marked - respB.Mark(nil) assert.Equal(t, p.Get().Host(), "c") // would be b if it were not dead // now restore a - respA = &standardHostPoolResponse{host: "a", pool: p} + respA = &standardHostPoolResponse{host: "a", ss: p.Selector.(*standardSelector)} respA.Mark(nil) assert.Equal(t, p.Get().Host(), "a") assert.Equal(t, p.Get().Host(), "c") // ensure that we get *something* back when all hosts fail for _, host := range []string{"a", "b", "c"} { - response := &standardHostPoolResponse{host: host, pool: p} + response := &standardHostPoolResponse{host: host, ss: p.Selector.(*standardSelector)} response.Mark(dummyErr) } resp := p.Get() @@ -59,13 +57,13 @@ func (t *mockTimer) between(start time.Time, end time.Time) time.Duration { } func TestEpsilonGreedy(t *testing.T) { - log.SetOutput(ioutil.Discard) - defer log.SetOutput(os.Stdout) + // log.SetOutput(ioutil.Discard) + // defer log.SetOutput(os.Stdout) rand.Seed(10) iterations := 12000 - p := NewEpsilonGreedy([]string{"a", "b"}, 0, &LinearEpsilonValueCalculator{}).(*epsilonGreedyHostPool) + p := NewWithSelector([]string{"a", "b"}, NewEpsilonGreedy(0, &LinearEpsilonValueCalculator{})).(*standardHostPool) timings := make(map[string]int64) timings["a"] = 200 @@ -79,13 +77,13 @@ func TestEpsilonGreedy(t *testing.T) { for i := 0; i < iterations; i += 1 { if i != 0 && i%100 == 0 { - p.performEpsilonGreedyDecay() + p.Selector.(*epsilonGreedySelector).performEpsilonGreedyDecay() } hostR := p.Get() host := hostR.Host() hitCounts[host]++ timing := timings[host] - p.timer = &mockTimer{t: int(timing)} + p.Selector.(*epsilonGreedySelector).timer = &mockTimer{t: int(timing)} hostR.Mark(nil) } @@ -103,13 +101,13 @@ func TestEpsilonGreedy(t *testing.T) { for i := 0; i < iterations; i += 1 { if i != 0 && i%100 == 0 { - p.performEpsilonGreedyDecay() + p.Selector.(*epsilonGreedySelector).performEpsilonGreedyDecay() } hostR := p.Get() host := hostR.Host() hitCounts[host]++ timing := timings[host] - p.timer = &mockTimer{t: int(timing)} + p.Selector.(*epsilonGreedySelector).timer = &mockTimer{t: int(timing)} hostR.Mark(nil) } @@ -131,15 +129,15 @@ func BenchmarkEpsilonGreedy(b *testing.B) { } // Make the hostpool with a few hosts - p := NewEpsilonGreedy([]string{"a", "b"}, 0, &LinearEpsilonValueCalculator{}).(*epsilonGreedyHostPool) + p := NewWithSelector([]string{"a", "b"}, NewEpsilonGreedy(0, &LinearEpsilonValueCalculator{})).(*standardHostPool) b.StartTimer() for i := 0; i < b.N; i++ { if i != 0 && i%100 == 0 { - p.performEpsilonGreedyDecay() + p.Selector.(*epsilonGreedySelector).performEpsilonGreedyDecay() } hostR := p.Get() - p.timer = &mockTimer{t: int(timings[i])} + p.Selector.(*epsilonGreedySelector).timer = &mockTimer{t: int(timings[i])} hostR.Mark(nil) } } diff --git a/selector.go b/selector.go new file mode 100644 index 0000000..ad22b3e --- /dev/null +++ b/selector.go @@ -0,0 +1,118 @@ +package hostpool + +import ( + "log" + "sync" + "time" +) + +type Selector interface { + Init([]string) + SelectNextHost() string + MakeHostResponse(string) HostPoolResponse + MarkHost(string, error) + ResetAll() +} + +type standardSelector struct { + sync.RWMutex + hosts map[string]*hostEntry + hostList []*hostEntry + initialRetryDelay time.Duration + maxRetryInterval time.Duration + nextHostIndex int +} + +func (s *standardSelector) Init(hosts []string) { + s.hosts = make(map[string]*hostEntry, len(hosts)) + s.hostList = make([]*hostEntry, len(hosts)) + s.initialRetryDelay = time.Duration(30) * time.Second + s.maxRetryInterval = time.Duration(900) * time.Second + + for i, h := range hosts { + e := &hostEntry{ + host: h, + retryDelay: s.initialRetryDelay, + } + s.hosts[h] = e + s.hostList[i] = e + } +} + +func (s *standardSelector) SelectNextHost() string { + s.Lock() + host := s.getRoundRobin() + s.Unlock() + return host +} + +func (s *standardSelector) getRoundRobin() string { + now := time.Now() + hostCount := len(s.hostList) + for i := range s.hostList { + // iterate via sequenece from where we last iterated + currentIndex := (i + s.nextHostIndex) % hostCount + + h := s.hostList[currentIndex] + if h.canTryHost(now) { + s.nextHostIndex = currentIndex + 1 + return h.host + } + } + + // all hosts are down. re-add them + s.doResetAll() + s.nextHostIndex = 0 + return s.hostList[0].host +} + +func (s *standardSelector) MakeHostResponse(host string) HostPoolResponse { + s.Lock() + defer s.Unlock() + h, ok := s.hosts[host] + if !ok { + log.Fatalf("host %s not in HostPool", host) + } + now := time.Now() + if h.dead && h.nextRetry.Before(now) { + h.willRetryHost(s.maxRetryInterval) + } + return &standardHostPoolResponse{host: host, ss: s} +} + +func (s *standardSelector) MarkHost(host string, err error) { + s.Lock() + defer s.Unlock() + + h, ok := s.hosts[host] + if !ok { + log.Fatalf("host %s not in HostPool", host) + } + if err == nil { + // success - mark host alive + h.dead = false + } else { + // failure - mark host dead + if !h.dead { + h.dead = true + h.retryCount = 0 + h.retryDelay = s.initialRetryDelay + h.nextRetry = time.Now().Add(h.retryDelay) + } + } +} + +func (s *standardSelector) ResetAll() { + s.Lock() + defer s.Unlock() + s.doResetAll() +} + +// this actually performs the logic to reset, +// and should only be called when the lock has +// already been acquired +func (s *standardSelector) doResetAll() { + for _, h := range s.hosts { + h.dead = false + } +}