|
| 1 | +/* |
| 2 | +Copyright 2024 The Dapr Authors |
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | +Unless required by applicable law or agreed to in writing, software |
| 8 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 9 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 10 | +See the License for the specific language governing permissions and |
| 11 | +limitations under the License. |
| 12 | +*/ |
| 13 | + |
| 14 | +package watchhosts |
| 15 | + |
| 16 | +import ( |
| 17 | + "context" |
| 18 | + "fmt" |
| 19 | + "math/rand" |
| 20 | + "slices" |
| 21 | + "sync" |
| 22 | + "sync/atomic" |
| 23 | + |
| 24 | + "google.golang.org/grpc/codes" |
| 25 | + "google.golang.org/grpc/status" |
| 26 | + |
| 27 | + schedulerv1pb "github.com/dapr/dapr/pkg/proto/scheduler/v1" |
| 28 | + "github.com/dapr/dapr/pkg/scheduler/client" |
| 29 | + "github.com/dapr/dapr/pkg/security" |
| 30 | + "github.com/dapr/kit/events/broadcaster" |
| 31 | + "github.com/dapr/kit/logger" |
| 32 | +) |
| 33 | + |
| 34 | +var log = logger.NewLogger("dapr.runtime.scheduler.watchhosts") |
| 35 | + |
| 36 | +type Options struct { |
| 37 | + Addresses []string |
| 38 | + Security security.Handler |
| 39 | +} |
| 40 | + |
| 41 | +type WatchHosts struct { |
| 42 | + allAddrs []string |
| 43 | + security security.Handler |
| 44 | + |
| 45 | + gotAddrs atomic.Pointer[ContextAddress] |
| 46 | + subs *broadcaster.Broadcaster[*ContextAddress] |
| 47 | + |
| 48 | + cancel context.CancelFunc |
| 49 | + readyCh chan struct{} |
| 50 | + lock sync.Mutex |
| 51 | +} |
| 52 | + |
| 53 | +type ContextAddress struct { |
| 54 | + context.Context |
| 55 | + Addresses []string |
| 56 | +} |
| 57 | + |
| 58 | +func New(opts Options) *WatchHosts { |
| 59 | + return &WatchHosts{ |
| 60 | + allAddrs: opts.Addresses, |
| 61 | + security: opts.Security, |
| 62 | + subs: broadcaster.New[*ContextAddress](), |
| 63 | + readyCh: make(chan struct{}), |
| 64 | + } |
| 65 | +} |
| 66 | + |
| 67 | +func (w *WatchHosts) Run(ctx context.Context) error { |
| 68 | + defer w.subs.Close() |
| 69 | + |
| 70 | + stream, closeCon, err := w.connSchedulerHosts(ctx) |
| 71 | + if err != nil { |
| 72 | + return fmt.Errorf("failed to connect to scheduler host: %s", err) |
| 73 | + } |
| 74 | + |
| 75 | + if stream != nil { |
| 76 | + defer func() { |
| 77 | + stream.CloseSend() |
| 78 | + closeCon() |
| 79 | + }() |
| 80 | + } |
| 81 | + |
| 82 | + err = w.handleStream(ctx, stream) |
| 83 | + |
| 84 | + if ctx.Err() != nil { |
| 85 | + return ctx.Err() |
| 86 | + } |
| 87 | + |
| 88 | + if err != nil { |
| 89 | + return fmt.Errorf("failed to handle Scheduler WatchHosts stream: %s", err) |
| 90 | + } |
| 91 | + |
| 92 | + return nil |
| 93 | +} |
| 94 | + |
| 95 | +func (w *WatchHosts) Addresses(ctx context.Context) <-chan *ContextAddress { |
| 96 | + ch := make(chan *ContextAddress, 1) |
| 97 | + select { |
| 98 | + case <-ctx.Done(): |
| 99 | + ch <- &ContextAddress{ctx, nil} |
| 100 | + return ch |
| 101 | + case <-w.readyCh: |
| 102 | + } |
| 103 | + |
| 104 | + w.lock.Lock() |
| 105 | + defer w.lock.Unlock() |
| 106 | + w.subs.Subscribe(ctx, ch) |
| 107 | + ch <- w.gotAddrs.Load() |
| 108 | + |
| 109 | + return ch |
| 110 | +} |
| 111 | + |
| 112 | +func (w *WatchHosts) handleStream(ctx context.Context, stream schedulerv1pb.Scheduler_WatchHostsClient) error { |
| 113 | + // If no stream was made the server doesn't support watching hosts |
| 114 | + // (pre-1.15), so we use static. Remove in 1.16. |
| 115 | + if stream == nil { |
| 116 | + w.lock.Lock() |
| 117 | + got := &ContextAddress{ctx, w.allAddrs} |
| 118 | + w.gotAddrs.Store(got) |
| 119 | + w.subs.Broadcast(got) |
| 120 | + w.lock.Unlock() |
| 121 | + close(w.readyCh) |
| 122 | + <-ctx.Done() |
| 123 | + return nil |
| 124 | + } |
| 125 | + |
| 126 | + for { |
| 127 | + gotAddrs, err := w.watchNextAddresses(ctx, stream) |
| 128 | + if err != nil { |
| 129 | + return err |
| 130 | + } |
| 131 | + |
| 132 | + w.lock.Lock() |
| 133 | + if w.cancel != nil { |
| 134 | + w.cancel() |
| 135 | + } |
| 136 | + |
| 137 | + actx, cancel := context.WithCancel(ctx) |
| 138 | + w.cancel = cancel |
| 139 | + got := &ContextAddress{actx, gotAddrs} |
| 140 | + w.gotAddrs.Store(got) |
| 141 | + w.subs.Broadcast(got) |
| 142 | + w.lock.Unlock() |
| 143 | + |
| 144 | + select { |
| 145 | + case <-w.readyCh: |
| 146 | + default: |
| 147 | + close(w.readyCh) |
| 148 | + } |
| 149 | + } |
| 150 | +} |
| 151 | + |
| 152 | +func (w *WatchHosts) watchNextAddresses(ctx context.Context, stream schedulerv1pb.Scheduler_WatchHostsClient) ([]string, error) { |
| 153 | + resp, err := stream.Recv() |
| 154 | + if err != nil { |
| 155 | + if status.Code(err) == codes.Unimplemented { |
| 156 | + // Ignore unimplemented error code as we are talking to an old server. |
| 157 | + // TODO: @joshvanl: remove special case in v1.16. |
| 158 | + return slices.Clone(w.allAddrs), nil |
| 159 | + } |
| 160 | + return nil, err |
| 161 | + } |
| 162 | + |
| 163 | + gotAddrs := make([]string, 0, len(resp.GetHosts())) |
| 164 | + for _, host := range resp.GetHosts() { |
| 165 | + gotAddrs = append(gotAddrs, host.GetAddress()) |
| 166 | + } |
| 167 | + |
| 168 | + log.Infof("Received updated scheduler hosts addresses: %v", gotAddrs) |
| 169 | + |
| 170 | + return gotAddrs, nil |
| 171 | +} |
| 172 | + |
| 173 | +func (w *WatchHosts) connSchedulerHosts(ctx context.Context) (schedulerv1pb.Scheduler_WatchHostsClient, context.CancelFunc, error) { |
| 174 | + //nolint:gosec |
| 175 | + i := rand.Intn(len(w.allAddrs)) |
| 176 | + log.Debugf("Attempting to connect to scheduler to WatchHosts: %s", w.allAddrs[i]) |
| 177 | + |
| 178 | + // This is connecting to a DNS A rec which will return healthy scheduler |
| 179 | + // hosts. |
| 180 | + cl, closeCon, err := client.New(ctx, w.allAddrs[i], w.security) |
| 181 | + if err != nil { |
| 182 | + return nil, nil, fmt.Errorf("scheduler client not initialized for address %s: %s", w.allAddrs[i], err) |
| 183 | + } |
| 184 | + |
| 185 | + stream, err := cl.WatchHosts(ctx, new(schedulerv1pb.WatchHostsRequest)) |
| 186 | + if err != nil { |
| 187 | + if status.Code(err) == codes.Unimplemented { |
| 188 | + // Ignore unimplemented error code as we are talking to an old server. |
| 189 | + // TODO: @joshvanl: remove special case in v1.16. |
| 190 | + return nil, nil, nil |
| 191 | + } |
| 192 | + |
| 193 | + return nil, nil, fmt.Errorf("failed to watch scheduler hosts: %s", err) |
| 194 | + } |
| 195 | + |
| 196 | + return stream, closeCon, nil |
| 197 | +} |
0 commit comments