diff --git a/go.mod b/go.mod index 18e51e36a..a1c5fd6af 100644 --- a/go.mod +++ b/go.mod @@ -38,7 +38,7 @@ require ( go.mongodb.org/mongo-driver v1.14.0 go.virtual-secrets.dev/apimachinery v0.0.1 gomodules.xyz/pointer v0.1.0 - google.golang.org/grpc v1.76.0 + google.golang.org/grpc v1.79.3 k8s.io/api v0.34.3 k8s.io/apimachinery v0.34.3 k8s.io/klog/v2 v2.130.1 @@ -236,7 +236,7 @@ require ( golang.org/x/crypto v0.46.0 // indirect golang.org/x/exp v0.0.0-20250718183923-645b1fa84792 // indirect golang.org/x/net v0.48.0 // indirect - golang.org/x/oauth2 v0.33.0 // indirect + golang.org/x/oauth2 v0.34.0 // indirect golang.org/x/sync v0.19.0 // indirect golang.org/x/sys v0.40.0 // indirect golang.org/x/term v0.38.0 // indirect @@ -251,8 +251,8 @@ require ( gomodules.xyz/sync v0.1.0 // indirect gomodules.xyz/wait v0.2.0 // indirect gomodules.xyz/x v0.0.17 // indirect - google.golang.org/genproto/googleapis/api v0.0.0-20250818200422-3122310a409c // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda // indirect + google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 // indirect + google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 // indirect google.golang.org/protobuf v1.36.10 // indirect gopkg.in/evanphx/json-patch.v4 v4.13.0 // indirect gopkg.in/inf.v0 v0.9.1 // indirect diff --git a/go.sum b/go.sum index 0ef3b719d..178548a91 100644 --- a/go.sum +++ b/go.sum @@ -903,8 +903,8 @@ golang.org/x/net v0.17.0/go.mod h1:NxSsAGuq816PNPmqtQdLE42eU2Fs7NoRIZrHJAlaCOE= golang.org/x/net v0.48.0 h1:zyQRTTrjc33Lhh0fBgT/H3oZq9WuvRR5gPC70xpDiQU= golang.org/x/net v0.48.0/go.mod h1:+ndRgGjkh8FGtu1w1FGbEC31if4VrNVMuKTgcAAnQRY= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= -golang.org/x/oauth2 v0.33.0 h1:4Q+qn+E5z8gPRJfmRy7C2gGG3T4jIprK6aSYgTXGRpo= -golang.org/x/oauth2 v0.33.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= +golang.org/x/oauth2 v0.34.0 h1:hqK/t4AKgbqWkdkcAeI8XLmbK+4m4G5YeQRrmiotGlw= +golang.org/x/oauth2 v0.34.0/go.mod h1:lzm5WQJQwKZ3nwavOZ3IS5Aulzxi68dUSgRHujetwEA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20181108010431-42b317875d0f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20190227155943-e225da77a7e6/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -1038,10 +1038,10 @@ google.golang.org/genproto v0.0.0-20190819201941-24fa4b261c55/go.mod h1:DMBHOl98 google.golang.org/genproto v0.0.0-20200423170343-7949de9c1215/go.mod h1:55QSHmfGQM9UVYDPBsyGGes0y52j32PQ3BqQfXhyH3c= google.golang.org/genproto v0.0.0-20200526211855-cb27e3aa2013/go.mod h1:NbSheEEYHJ7i3ixzK3sjbqSGDJWnxyFXZblF3eUsNvo= google.golang.org/genproto v0.0.0-20210624195500-8bfb893ecb84/go.mod h1:SzzZ/N+nwJDaO1kznhnlzqS8ocJICar6hYhVyhi++24= -google.golang.org/genproto/googleapis/api v0.0.0-20250818200422-3122310a409c h1:AtEkQdl5b6zsybXcbz00j1LwNodDuH6hVifIaNqk7NQ= -google.golang.org/genproto/googleapis/api v0.0.0-20250818200422-3122310a409c/go.mod h1:ea2MjsO70ssTfCjiwHgI0ZFqcw45Ksuk2ckf9G468GA= -google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda h1:i/Q+bfisr7gq6feoJnS/DlpdwEL4ihp41fvRiM3Ork0= -google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= +google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 h1:fCvbg86sFXwdrl5LgVcTEvNC+2txB5mgROGmRL5mrls= +google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:+rXWjjaukWZun3mLfjmVnQi18E1AsFbDN9QdJ5YXLto= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 h1:gRkg/vSppuSQoDjxyiGfN4Upv/h/DQmIR10ZU8dh4Ww= +google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217/go.mod h1:7i2o+ce6H/6BluujYR+kqX3GKH+dChPTQU19wjRPiGk= google.golang.org/grpc v1.12.0/go.mod h1:yo6s7OP7yaDglbqo1J04qKzAhqBH6lvTonzMVmEdcZw= google.golang.org/grpc v1.19.0/go.mod h1:mqu4LbDTu4XGKhr4mRzUsmM4RtVoemTSY81AxZiDr8c= google.golang.org/grpc v1.23.0/go.mod h1:Y5yQAOtifL1yxbo5wqy6BxZv8vAUGQwXBOALyacEbxg= @@ -1049,8 +1049,8 @@ google.golang.org/grpc v1.25.1/go.mod h1:c3i+UQWmh7LiEpx4sFZnkU36qjEYZ0imhYfXVyQ google.golang.org/grpc v1.27.0/go.mod h1:qbnxyOmOxrQa7FizSgH+ReBfzJrCY1pSN7KXBS8abTk= google.golang.org/grpc v1.29.1/go.mod h1:itym6AZVZYACWQqET3MqgPpjcuV5QH3BxFS3IjizoKk= google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= -google.golang.org/grpc v1.76.0 h1:UnVkv1+uMLYXoIz6o7chp59WfQUYA2ex/BXQ9rHZu7A= -google.golang.org/grpc v1.76.0/go.mod h1:Ju12QI8M6iQJtbcsV+awF5a4hfJMLi4X0JLo94ULZ6c= +google.golang.org/grpc v1.79.3 h1:sybAEdRIEtvcD68Gx7dmnwjZKlyfuc61Dyo9pGXXkKE= +google.golang.org/grpc v1.79.3/go.mod h1:KmT0Kjez+0dde/v2j9vzwoAScgEPx/Bw1CYChhHLrHQ= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/vendor/google.golang.org/grpc/balancer/balancer.go b/vendor/google.golang.org/grpc/balancer/balancer.go index b1264017d..d08b7ad63 100644 --- a/vendor/google.golang.org/grpc/balancer/balancer.go +++ b/vendor/google.golang.org/grpc/balancer/balancer.go @@ -75,8 +75,6 @@ func unregisterForTesting(name string) { func init() { internal.BalancerUnregister = unregisterForTesting - internal.ConnectedAddress = connectedAddress - internal.SetConnectedAddress = setConnectedAddress } // Get returns the resolver builder registered with the given name. diff --git a/vendor/google.golang.org/grpc/balancer/pickfirst/internal/internal.go b/vendor/google.golang.org/grpc/balancer/pickfirst/internal/internal.go index 7d66cb491..cc902a4de 100644 --- a/vendor/google.golang.org/grpc/balancer/pickfirst/internal/internal.go +++ b/vendor/google.golang.org/grpc/balancer/pickfirst/internal/internal.go @@ -26,6 +26,8 @@ import ( var ( // RandShuffle pseudo-randomizes the order of addresses. RandShuffle = rand.Shuffle + // RandFloat64 returns, as a float64, a pseudo-random number in [0.0,1.0). + RandFloat64 = rand.Float64 // TimeAfterFunc allows mocking the timer for testing connection delay // related functionality. TimeAfterFunc = func(d time.Duration, f func()) func() { diff --git a/vendor/google.golang.org/grpc/balancer/pickfirst/pickfirst.go b/vendor/google.golang.org/grpc/balancer/pickfirst/pickfirst.go index b15c10e46..dccd9f0bf 100644 --- a/vendor/google.golang.org/grpc/balancer/pickfirst/pickfirst.go +++ b/vendor/google.golang.org/grpc/balancer/pickfirst/pickfirst.go @@ -16,55 +16,129 @@ * */ -// Package pickfirst contains the pick_first load balancing policy. +// Package pickfirst contains the pick_first load balancing policy which +// is the universal leaf policy. package pickfirst import ( + "cmp" "encoding/json" "errors" "fmt" - rand "math/rand/v2" + "math" + "net" + "net/netip" + "slices" + "sync" + "time" "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/pickfirst/internal" "google.golang.org/grpc/connectivity" + expstats "google.golang.org/grpc/experimental/stats" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/balancer/weight" "google.golang.org/grpc/internal/envconfig" internalgrpclog "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/pretty" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" - - _ "google.golang.org/grpc/balancer/pickfirst/pickfirstleaf" // For automatically registering the new pickfirst if required. ) func init() { - if envconfig.NewPickFirstEnabled { - return - } balancer.Register(pickfirstBuilder{}) } -var logger = grpclog.Component("pick-first-lb") +// Name is the name of the pick_first balancer. +const Name = "pick_first" + +// enableHealthListenerKeyType is a unique key type used in resolver +// attributes to indicate whether the health listener usage is enabled. +type enableHealthListenerKeyType struct{} + +var ( + logger = grpclog.Component("pick-first-leaf-lb") + disconnectionsMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{ + Name: "grpc.lb.pick_first.disconnections", + Description: "EXPERIMENTAL. Number of times the selected subchannel becomes disconnected.", + Unit: "{disconnection}", + Labels: []string{"grpc.target"}, + Default: false, + }) + connectionAttemptsSucceededMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{ + Name: "grpc.lb.pick_first.connection_attempts_succeeded", + Description: "EXPERIMENTAL. Number of successful connection attempts.", + Unit: "{attempt}", + Labels: []string{"grpc.target"}, + Default: false, + }) + connectionAttemptsFailedMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{ + Name: "grpc.lb.pick_first.connection_attempts_failed", + Description: "EXPERIMENTAL. Number of failed connection attempts.", + Unit: "{attempt}", + Labels: []string{"grpc.target"}, + Default: false, + }) +) const ( - // Name is the name of the pick_first balancer. - Name = "pick_first" - logPrefix = "[pick-first-lb %p] " + // TODO: change to pick-first when this becomes the default pick_first policy. + logPrefix = "[pick-first-leaf-lb %p] " + // connectionDelayInterval is the time to wait for during the happy eyeballs + // pass before starting the next connection attempt. + connectionDelayInterval = 250 * time.Millisecond +) + +type ipAddrFamily int + +const ( + // ipAddrFamilyUnknown represents strings that can't be parsed as an IP + // address. + ipAddrFamilyUnknown ipAddrFamily = iota + ipAddrFamilyV4 + ipAddrFamilyV6 ) type pickfirstBuilder struct{} -func (pickfirstBuilder) Build(cc balancer.ClientConn, _ balancer.BuildOptions) balancer.Balancer { - b := &pickfirstBalancer{cc: cc} +func (pickfirstBuilder) Build(cc balancer.ClientConn, bo balancer.BuildOptions) balancer.Balancer { + b := &pickfirstBalancer{ + cc: cc, + target: bo.Target.String(), + metricsRecorder: cc.MetricsRecorder(), + + subConns: resolver.NewAddressMapV2[*scData](), + state: connectivity.Connecting, + cancelConnectionTimer: func() {}, + } b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(logPrefix, b)) return b } -func (pickfirstBuilder) Name() string { +func (b pickfirstBuilder) Name() string { return Name } +func (pickfirstBuilder) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { + var cfg pfConfig + if err := json.Unmarshal(js, &cfg); err != nil { + return nil, fmt.Errorf("pickfirst: unable to unmarshal LB policy config: %s, error: %v", string(js), err) + } + return cfg, nil +} + +// EnableHealthListener updates the state to configure pickfirst for using a +// generic health listener. +// +// # Experimental +// +// Notice: This API is EXPERIMENTAL and may be changed or removed in a later +// release. +func EnableHealthListener(state resolver.State) resolver.State { + state.Attributes = state.Attributes.WithValue(enableHealthListenerKeyType{}, true) + return state +} + type pfConfig struct { serviceconfig.LoadBalancingConfig `json:"-"` @@ -74,90 +148,163 @@ type pfConfig struct { ShuffleAddressList bool `json:"shuffleAddressList"` } -func (pickfirstBuilder) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { - var cfg pfConfig - if err := json.Unmarshal(js, &cfg); err != nil { - return nil, fmt.Errorf("pickfirst: unable to unmarshal LB policy config: %s, error: %v", string(js), err) +// scData keeps track of the current state of the subConn. +// It is not safe for concurrent access. +type scData struct { + // The following fields are initialized at build time and read-only after + // that. + subConn balancer.SubConn + addr resolver.Address + + rawConnectivityState connectivity.State + // The effective connectivity state based on raw connectivity, health state + // and after following sticky TransientFailure behaviour defined in A62. + effectiveState connectivity.State + lastErr error + connectionFailedInFirstPass bool +} + +func (b *pickfirstBalancer) newSCData(addr resolver.Address) (*scData, error) { + sd := &scData{ + rawConnectivityState: connectivity.Idle, + effectiveState: connectivity.Idle, + addr: addr, } - return cfg, nil + sc, err := b.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{ + StateListener: func(state balancer.SubConnState) { + b.updateSubConnState(sd, state) + }, + }) + if err != nil { + return nil, err + } + sd.subConn = sc + return sd, nil } type pickfirstBalancer struct { - logger *internalgrpclog.PrefixLogger - state connectivity.State - cc balancer.ClientConn - subConn balancer.SubConn + // The following fields are initialized at build time and read-only after + // that and therefore do not need to be guarded by a mutex. + logger *internalgrpclog.PrefixLogger + cc balancer.ClientConn + target string + metricsRecorder expstats.MetricsRecorder // guaranteed to be non nil + + // The mutex is used to ensure synchronization of updates triggered + // from the idle picker and the already serialized resolver, + // SubConn state updates. + mu sync.Mutex + // State reported to the channel based on SubConn states and resolver + // updates. + state connectivity.State + // scData for active subonns mapped by address. + subConns *resolver.AddressMapV2[*scData] + addressList addressList + firstPass bool + numTF int + cancelConnectionTimer func() + healthCheckingEnabled bool } +// ResolverError is called by the ClientConn when the name resolver produces +// an error or when pickfirst determined the resolver update to be invalid. func (b *pickfirstBalancer) ResolverError(err error) { + b.mu.Lock() + defer b.mu.Unlock() + b.resolverErrorLocked(err) +} + +func (b *pickfirstBalancer) resolverErrorLocked(err error) { if b.logger.V(2) { b.logger.Infof("Received error from the name resolver: %v", err) } - if b.subConn == nil { - b.state = connectivity.TransientFailure - } - if b.state != connectivity.TransientFailure { - // The picker will not change since the balancer does not currently - // report an error. + // The picker will not change since the balancer does not currently + // report an error. If the balancer hasn't received a single good resolver + // update yet, transition to TRANSIENT_FAILURE. + if b.state != connectivity.TransientFailure && b.addressList.size() > 0 { + if b.logger.V(2) { + b.logger.Infof("Ignoring resolver error because balancer is using a previous good update.") + } return } - b.cc.UpdateState(balancer.State{ + + b.updateBalancerState(balancer.State{ ConnectivityState: connectivity.TransientFailure, Picker: &picker{err: fmt.Errorf("name resolver error: %v", err)}, }) } -// Shuffler is an interface for shuffling an address list. -type Shuffler interface { - ShuffleAddressListForTesting(n int, swap func(i, j int)) -} - -// ShuffleAddressListForTesting pseudo-randomizes the order of addresses. n -// is the number of elements. swap swaps the elements with indexes i and j. -func ShuffleAddressListForTesting(n int, swap func(i, j int)) { rand.Shuffle(n, swap) } - func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState) error { + b.mu.Lock() + defer b.mu.Unlock() + b.cancelConnectionTimer() if len(state.ResolverState.Addresses) == 0 && len(state.ResolverState.Endpoints) == 0 { - // The resolver reported an empty address list. Treat it like an error by - // calling b.ResolverError. - if b.subConn != nil { - // Shut down the old subConn. All addresses were removed, so it is - // no longer valid. - b.subConn.Shutdown() - b.subConn = nil - } - b.ResolverError(errors.New("produced zero addresses")) + // Cleanup state pertaining to the previous resolver state. + // Treat an empty address list like an error by calling b.ResolverError. + b.closeSubConnsLocked() + b.addressList.updateAddrs(nil) + b.resolverErrorLocked(errors.New("produced zero addresses")) return balancer.ErrBadResolverState } - // We don't have to guard this block with the env var because ParseConfig - // already does so. + b.healthCheckingEnabled = state.ResolverState.Attributes.Value(enableHealthListenerKeyType{}) != nil cfg, ok := state.BalancerConfig.(pfConfig) if state.BalancerConfig != nil && !ok { - return fmt.Errorf("pickfirst: received illegal BalancerConfig (type %T): %v", state.BalancerConfig, state.BalancerConfig) + return fmt.Errorf("pickfirst: received illegal BalancerConfig (type %T): %v: %w", state.BalancerConfig, state.BalancerConfig, balancer.ErrBadResolverState) } if b.logger.V(2) { b.logger.Infof("Received new config %s, resolver state %s", pretty.ToJSON(cfg), pretty.ToJSON(state.ResolverState)) } - var addrs []resolver.Address + var newAddrs []resolver.Address if endpoints := state.ResolverState.Endpoints; len(endpoints) != 0 { - // Perform the optional shuffling described in gRFC A62. The shuffling will - // change the order of endpoints but not touch the order of the addresses - // within each endpoint. - A61 + // Perform the optional shuffling described in gRFC A62. The shuffling + // will change the order of endpoints but not touch the order of the + // addresses within each endpoint. - A61 if cfg.ShuffleAddressList { - endpoints = append([]resolver.Endpoint{}, endpoints...) - internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] }) + if envconfig.PickFirstWeightedShuffling { + type weightedEndpoint struct { + endpoint resolver.Endpoint + weight float64 + } + + // For each endpoint, compute a key as described in A113 and + // https://utopia.duth.gr/~pefraimi/research/data/2007EncOfAlg.pdf: + var weightedEndpoints []weightedEndpoint + for _, endpoint := range endpoints { + u := internal.RandFloat64() // Random number in [0.0, 1.0) + weight := weightAttribute(endpoint) + weightedEndpoints = append(weightedEndpoints, weightedEndpoint{ + endpoint: endpoint, + weight: math.Pow(u, 1.0/float64(weight)), + }) + } + // Sort endpoints by key in descending order and reconstruct the + // endpoints slice. + slices.SortFunc(weightedEndpoints, func(a, b weightedEndpoint) int { + return cmp.Compare(b.weight, a.weight) + }) + + // Here, and in the "else" block below, we clone the endpoints + // slice to avoid mutating the resolver state. Doing the latter + // would lead to data races if the caller is accessing the same + // slice concurrently. + sortedEndpoints := make([]resolver.Endpoint, len(endpoints)) + for i, we := range weightedEndpoints { + sortedEndpoints[i] = we.endpoint + } + endpoints = sortedEndpoints + } else { + endpoints = slices.Clone(endpoints) + internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] }) + } } - // "Flatten the list by concatenating the ordered list of addresses for each - // of the endpoints, in order." - A61 + // "Flatten the list by concatenating the ordered list of addresses for + // each of the endpoints, in order." - A61 for _, endpoint := range endpoints { - // "In the flattened list, interleave addresses from the two address - // families, as per RFC-8304 section 4." - A61 - // TODO: support the above language. - addrs = append(addrs, endpoint.Addresses...) + newAddrs = append(newAddrs, endpoint.Addresses...) } } else { // Endpoints not set, process addresses until we migrate resolver @@ -166,42 +313,53 @@ func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState // target do not forward the corresponding correct endpoints down/split // endpoints properly. Once all balancers correctly forward endpoints // down, can delete this else conditional. - addrs = state.ResolverState.Addresses + newAddrs = state.ResolverState.Addresses if cfg.ShuffleAddressList { - addrs = append([]resolver.Address{}, addrs...) - internal.RandShuffle(len(addrs), func(i, j int) { addrs[i], addrs[j] = addrs[j], addrs[i] }) + newAddrs = append([]resolver.Address{}, newAddrs...) + internal.RandShuffle(len(newAddrs), func(i, j int) { newAddrs[i], newAddrs[j] = newAddrs[j], newAddrs[i] }) } } - if b.subConn != nil { - b.cc.UpdateAddresses(b.subConn, addrs) + // If an address appears in multiple endpoints or in the same endpoint + // multiple times, we keep it only once. We will create only one SubConn + // for the address because an AddressMap is used to store SubConns. + // Not de-duplicating would result in attempting to connect to the same + // SubConn multiple times in the same pass. We don't want this. + newAddrs = deDupAddresses(newAddrs) + newAddrs = interleaveAddresses(newAddrs) + + prevAddr := b.addressList.currentAddress() + prevSCData, found := b.subConns.Get(prevAddr) + prevAddrsCount := b.addressList.size() + isPrevRawConnectivityStateReady := found && prevSCData.rawConnectivityState == connectivity.Ready + b.addressList.updateAddrs(newAddrs) + + // If the previous ready SubConn exists in new address list, + // keep this connection and don't create new SubConns. + if isPrevRawConnectivityStateReady && b.addressList.seekTo(prevAddr) { return nil } - var subConn balancer.SubConn - subConn, err := b.cc.NewSubConn(addrs, balancer.NewSubConnOptions{ - StateListener: func(state balancer.SubConnState) { - b.updateSubConnState(subConn, state) - }, - }) - if err != nil { - if b.logger.V(2) { - b.logger.Infof("Failed to create new SubConn: %v", err) - } - b.state = connectivity.TransientFailure - b.cc.UpdateState(balancer.State{ - ConnectivityState: connectivity.TransientFailure, - Picker: &picker{err: fmt.Errorf("error creating connection: %v", err)}, + b.reconcileSubConnsLocked(newAddrs) + // If it's the first resolver update or the balancer was already READY + // (but the new address list does not contain the ready SubConn) or + // CONNECTING, enter CONNECTING. + // We may be in TRANSIENT_FAILURE due to a previous empty address list, + // we should still enter CONNECTING because the sticky TF behaviour + // mentioned in A62 applies only when the TRANSIENT_FAILURE is reported + // due to connectivity failures. + if isPrevRawConnectivityStateReady || b.state == connectivity.Connecting || prevAddrsCount == 0 { + // Start connection attempt at first address. + b.forceUpdateConcludedStateLocked(balancer.State{ + ConnectivityState: connectivity.Connecting, + Picker: &picker{err: balancer.ErrNoSubConnAvailable}, }) - return balancer.ErrBadResolverState + b.startFirstPassLocked() + } else if b.state == connectivity.TransientFailure { + // If we're in TRANSIENT_FAILURE, we stay in TRANSIENT_FAILURE until + // we're READY. See A62. + b.startFirstPassLocked() } - b.subConn = subConn - b.state = connectivity.Idle - b.cc.UpdateState(balancer.State{ - ConnectivityState: connectivity.Connecting, - Picker: &picker{err: balancer.ErrNoSubConnAvailable}, - }) - b.subConn.Connect() return nil } @@ -211,63 +369,484 @@ func (b *pickfirstBalancer) UpdateSubConnState(subConn balancer.SubConn, state b b.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", subConn, state) } -func (b *pickfirstBalancer) updateSubConnState(subConn balancer.SubConn, state balancer.SubConnState) { - if b.logger.V(2) { - b.logger.Infof("Received SubConn state update: %p, %+v", subConn, state) +func (b *pickfirstBalancer) Close() { + b.mu.Lock() + defer b.mu.Unlock() + b.closeSubConnsLocked() + b.cancelConnectionTimer() + b.state = connectivity.Shutdown +} + +// ExitIdle moves the balancer out of idle state. It can be called concurrently +// by the idlePicker and clientConn so access to variables should be +// synchronized. +func (b *pickfirstBalancer) ExitIdle() { + b.mu.Lock() + defer b.mu.Unlock() + if b.state == connectivity.Idle { + // Move the balancer into CONNECTING state immediately. This is done to + // avoid staying in IDLE if a resolver update arrives before the first + // SubConn reports CONNECTING. + b.updateBalancerState(balancer.State{ + ConnectivityState: connectivity.Connecting, + Picker: &picker{err: balancer.ErrNoSubConnAvailable}, + }) + b.startFirstPassLocked() + } +} + +func (b *pickfirstBalancer) startFirstPassLocked() { + b.firstPass = true + b.numTF = 0 + // Reset the connection attempt record for existing SubConns. + for _, sd := range b.subConns.Values() { + sd.connectionFailedInFirstPass = false } - if b.subConn != subConn { + b.requestConnectionLocked() +} + +func (b *pickfirstBalancer) closeSubConnsLocked() { + for _, sd := range b.subConns.Values() { + sd.subConn.Shutdown() + } + b.subConns = resolver.NewAddressMapV2[*scData]() +} + +// deDupAddresses ensures that each address appears only once in the slice. +func deDupAddresses(addrs []resolver.Address) []resolver.Address { + seenAddrs := resolver.NewAddressMapV2[bool]() + retAddrs := []resolver.Address{} + + for _, addr := range addrs { + if _, ok := seenAddrs.Get(addr); ok { + continue + } + seenAddrs.Set(addr, true) + retAddrs = append(retAddrs, addr) + } + return retAddrs +} + +// interleaveAddresses interleaves addresses of both families (IPv4 and IPv6) +// as per RFC-8305 section 4. +// Whichever address family is first in the list is followed by an address of +// the other address family; that is, if the first address in the list is IPv6, +// then the first IPv4 address should be moved up in the list to be second in +// the list. It doesn't support configuring "First Address Family Count", i.e. +// there will always be a single member of the first address family at the +// beginning of the interleaved list. +// Addresses that are neither IPv4 nor IPv6 are treated as part of a third +// "unknown" family for interleaving. +// See: https://datatracker.ietf.org/doc/html/rfc8305#autoid-6 +func interleaveAddresses(addrs []resolver.Address) []resolver.Address { + familyAddrsMap := map[ipAddrFamily][]resolver.Address{} + interleavingOrder := []ipAddrFamily{} + for _, addr := range addrs { + family := addressFamily(addr.Addr) + if _, found := familyAddrsMap[family]; !found { + interleavingOrder = append(interleavingOrder, family) + } + familyAddrsMap[family] = append(familyAddrsMap[family], addr) + } + + interleavedAddrs := make([]resolver.Address, 0, len(addrs)) + + for curFamilyIdx := 0; len(interleavedAddrs) < len(addrs); curFamilyIdx = (curFamilyIdx + 1) % len(interleavingOrder) { + // Some IP types may have fewer addresses than others, so we look for + // the next type that has a remaining member to add to the interleaved + // list. + family := interleavingOrder[curFamilyIdx] + remainingMembers := familyAddrsMap[family] + if len(remainingMembers) > 0 { + interleavedAddrs = append(interleavedAddrs, remainingMembers[0]) + familyAddrsMap[family] = remainingMembers[1:] + } + } + + return interleavedAddrs +} + +// addressFamily returns the ipAddrFamily after parsing the address string. +// If the address isn't of the format "ip-address:port", it returns +// ipAddrFamilyUnknown. The address may be valid even if it's not an IP when +// using a resolver like passthrough where the address may be a hostname in +// some format that the dialer can resolve. +func addressFamily(address string) ipAddrFamily { + // Parse the IP after removing the port. + host, _, err := net.SplitHostPort(address) + if err != nil { + return ipAddrFamilyUnknown + } + ip, err := netip.ParseAddr(host) + if err != nil { + return ipAddrFamilyUnknown + } + switch { + case ip.Is4() || ip.Is4In6(): + return ipAddrFamilyV4 + case ip.Is6(): + return ipAddrFamilyV6 + default: + return ipAddrFamilyUnknown + } +} + +// reconcileSubConnsLocked updates the active subchannels based on a new address +// list from the resolver. It does this by: +// - closing subchannels: any existing subchannels associated with addresses +// that are no longer in the updated list are shut down. +// - removing subchannels: entries for these closed subchannels are removed +// from the subchannel map. +// +// This ensures that the subchannel map accurately reflects the current set of +// addresses received from the name resolver. +func (b *pickfirstBalancer) reconcileSubConnsLocked(newAddrs []resolver.Address) { + newAddrsMap := resolver.NewAddressMapV2[bool]() + for _, addr := range newAddrs { + newAddrsMap.Set(addr, true) + } + + for _, oldAddr := range b.subConns.Keys() { + if _, ok := newAddrsMap.Get(oldAddr); ok { + continue + } + val, _ := b.subConns.Get(oldAddr) + val.subConn.Shutdown() + b.subConns.Delete(oldAddr) + } +} + +// shutdownRemainingLocked shuts down remaining subConns. Called when a subConn +// becomes ready, which means that all other subConn must be shutdown. +func (b *pickfirstBalancer) shutdownRemainingLocked(selected *scData) { + b.cancelConnectionTimer() + for _, sd := range b.subConns.Values() { + if sd.subConn != selected.subConn { + sd.subConn.Shutdown() + } + } + b.subConns = resolver.NewAddressMapV2[*scData]() + b.subConns.Set(selected.addr, selected) +} + +// requestConnectionLocked starts connecting on the subchannel corresponding to +// the current address. If no subchannel exists, one is created. If the current +// subchannel is in TransientFailure, a connection to the next address is +// attempted until a subchannel is found. +func (b *pickfirstBalancer) requestConnectionLocked() { + if !b.addressList.isValid() { + return + } + var lastErr error + for valid := true; valid; valid = b.addressList.increment() { + curAddr := b.addressList.currentAddress() + sd, ok := b.subConns.Get(curAddr) + if !ok { + var err error + // We want to assign the new scData to sd from the outer scope, + // hence we can't use := below. + sd, err = b.newSCData(curAddr) + if err != nil { + // This should never happen, unless the clientConn is being shut + // down. + if b.logger.V(2) { + b.logger.Infof("Failed to create a subConn for address %v: %v", curAddr.String(), err) + } + // Do nothing, the LB policy will be closed soon. + return + } + b.subConns.Set(curAddr, sd) + } + + switch sd.rawConnectivityState { + case connectivity.Idle: + sd.subConn.Connect() + b.scheduleNextConnectionLocked() + return + case connectivity.TransientFailure: + // The SubConn is being re-used and failed during a previous pass + // over the addressList. It has not completed backoff yet. + // Mark it as having failed and try the next address. + sd.connectionFailedInFirstPass = true + lastErr = sd.lastErr + continue + case connectivity.Connecting: + // Wait for the connection attempt to complete or the timer to fire + // before attempting the next address. + b.scheduleNextConnectionLocked() + return + default: + b.logger.Errorf("SubConn with unexpected state %v present in SubConns map.", sd.rawConnectivityState) + return + + } + } + + // All the remaining addresses in the list are in TRANSIENT_FAILURE, end the + // first pass if possible. + b.endFirstPassIfPossibleLocked(lastErr) +} + +func (b *pickfirstBalancer) scheduleNextConnectionLocked() { + b.cancelConnectionTimer() + if !b.addressList.hasNext() { + return + } + curAddr := b.addressList.currentAddress() + cancelled := false // Access to this is protected by the balancer's mutex. + closeFn := internal.TimeAfterFunc(connectionDelayInterval, func() { + b.mu.Lock() + defer b.mu.Unlock() + // If the scheduled task is cancelled while acquiring the mutex, return. + if cancelled { + return + } if b.logger.V(2) { - b.logger.Infof("Ignored state change because subConn is not recognized") + b.logger.Infof("Happy Eyeballs timer expired while waiting for connection to %q.", curAddr.Addr) + } + if b.addressList.increment() { + b.requestConnectionLocked() } + }) + // Access to the cancellation callback held by the balancer is guarded by + // the balancer's mutex, so it's safe to set the boolean from the callback. + b.cancelConnectionTimer = sync.OnceFunc(func() { + cancelled = true + closeFn() + }) +} + +func (b *pickfirstBalancer) updateSubConnState(sd *scData, newState balancer.SubConnState) { + b.mu.Lock() + defer b.mu.Unlock() + oldState := sd.rawConnectivityState + sd.rawConnectivityState = newState.ConnectivityState + // Previously relevant SubConns can still callback with state updates. + // To prevent pickers from returning these obsolete SubConns, this logic + // is included to check if the current list of active SubConns includes this + // SubConn. + if !b.isActiveSCData(sd) { return } - if state.ConnectivityState == connectivity.Shutdown { - b.subConn = nil + if newState.ConnectivityState == connectivity.Shutdown { + sd.effectiveState = connectivity.Shutdown return } - switch state.ConnectivityState { - case connectivity.Ready: - b.cc.UpdateState(balancer.State{ - ConnectivityState: state.ConnectivityState, - Picker: &picker{result: balancer.PickResult{SubConn: subConn}}, - }) - case connectivity.Connecting: - if b.state == connectivity.TransientFailure { - // We stay in TransientFailure until we are Ready. See A62. + // Record a connection attempt when exiting CONNECTING. + if newState.ConnectivityState == connectivity.TransientFailure { + sd.connectionFailedInFirstPass = true + connectionAttemptsFailedMetric.Record(b.metricsRecorder, 1, b.target) + } + + if newState.ConnectivityState == connectivity.Ready { + connectionAttemptsSucceededMetric.Record(b.metricsRecorder, 1, b.target) + b.shutdownRemainingLocked(sd) + if !b.addressList.seekTo(sd.addr) { + // This should not fail as we should have only one SubConn after + // entering READY. The SubConn should be present in the addressList. + b.logger.Errorf("Address %q not found address list in %v", sd.addr, b.addressList.addresses) + return + } + if !b.healthCheckingEnabled { + if b.logger.V(2) { + b.logger.Infof("SubConn %p reported connectivity state READY and the health listener is disabled. Transitioning SubConn to READY.", sd.subConn) + } + + sd.effectiveState = connectivity.Ready + b.updateBalancerState(balancer.State{ + ConnectivityState: connectivity.Ready, + Picker: &picker{result: balancer.PickResult{SubConn: sd.subConn}}, + }) return } - b.cc.UpdateState(balancer.State{ - ConnectivityState: state.ConnectivityState, + if b.logger.V(2) { + b.logger.Infof("SubConn %p reported connectivity state READY. Registering health listener.", sd.subConn) + } + // Send a CONNECTING update to take the SubConn out of sticky-TF if + // required. + sd.effectiveState = connectivity.Connecting + b.updateBalancerState(balancer.State{ + ConnectivityState: connectivity.Connecting, Picker: &picker{err: balancer.ErrNoSubConnAvailable}, }) + sd.subConn.RegisterHealthListener(func(scs balancer.SubConnState) { + b.updateSubConnHealthState(sd, scs) + }) + return + } + + // If the LB policy is READY, and it receives a subchannel state change, + // it means that the READY subchannel has failed. + // A SubConn can also transition from CONNECTING directly to IDLE when + // a transport is successfully created, but the connection fails + // before the SubConn can send the notification for READY. We treat + // this as a successful connection and transition to IDLE. + // TODO: https://github.com/grpc/grpc-go/issues/7862 - Remove the second + // part of the if condition below once the issue is fixed. + if oldState == connectivity.Ready || (oldState == connectivity.Connecting && newState.ConnectivityState == connectivity.Idle) { + // Once a transport fails, the balancer enters IDLE and starts from + // the first address when the picker is used. + b.shutdownRemainingLocked(sd) + sd.effectiveState = newState.ConnectivityState + // READY SubConn interspliced in between CONNECTING and IDLE, need to + // account for that. + if oldState == connectivity.Connecting { + // A known issue (https://github.com/grpc/grpc-go/issues/7862) + // causes a race that prevents the READY state change notification. + // This works around it. + connectionAttemptsSucceededMetric.Record(b.metricsRecorder, 1, b.target) + } + disconnectionsMetric.Record(b.metricsRecorder, 1, b.target) + b.addressList.reset() + b.updateBalancerState(balancer.State{ + ConnectivityState: connectivity.Idle, + Picker: &idlePicker{exitIdle: sync.OnceFunc(b.ExitIdle)}, + }) + return + } + + if b.firstPass { + switch newState.ConnectivityState { + case connectivity.Connecting: + // The effective state can be in either IDLE, CONNECTING or + // TRANSIENT_FAILURE. If it's TRANSIENT_FAILURE, stay in + // TRANSIENT_FAILURE until it's READY. See A62. + if sd.effectiveState != connectivity.TransientFailure { + sd.effectiveState = connectivity.Connecting + b.updateBalancerState(balancer.State{ + ConnectivityState: connectivity.Connecting, + Picker: &picker{err: balancer.ErrNoSubConnAvailable}, + }) + } + case connectivity.TransientFailure: + sd.lastErr = newState.ConnectionError + sd.effectiveState = connectivity.TransientFailure + // Since we're re-using common SubConns while handling resolver + // updates, we could receive an out of turn TRANSIENT_FAILURE from + // a pass over the previous address list. Happy Eyeballs will also + // cause out of order updates to arrive. + + if curAddr := b.addressList.currentAddress(); equalAddressIgnoringBalAttributes(&curAddr, &sd.addr) { + b.cancelConnectionTimer() + if b.addressList.increment() { + b.requestConnectionLocked() + return + } + } + + // End the first pass if we've seen a TRANSIENT_FAILURE from all + // SubConns once. + b.endFirstPassIfPossibleLocked(newState.ConnectionError) + } + return + } + + // We have finished the first pass, keep re-connecting failing SubConns. + switch newState.ConnectivityState { + case connectivity.TransientFailure: + b.numTF = (b.numTF + 1) % b.subConns.Len() + sd.lastErr = newState.ConnectionError + if b.numTF%b.subConns.Len() == 0 { + b.updateBalancerState(balancer.State{ + ConnectivityState: connectivity.TransientFailure, + Picker: &picker{err: newState.ConnectionError}, + }) + } + // We don't need to request re-resolution since the SubConn already + // does that before reporting TRANSIENT_FAILURE. + // TODO: #7534 - Move re-resolution requests from SubConn into + // pick_first. case connectivity.Idle: - if b.state == connectivity.TransientFailure { - // We stay in TransientFailure until we are Ready. Also kick the - // subConn out of Idle into Connecting. See A62. - b.subConn.Connect() + sd.subConn.Connect() + } +} + +// endFirstPassIfPossibleLocked ends the first happy-eyeballs pass if all the +// addresses are tried and their SubConns have reported a failure. +func (b *pickfirstBalancer) endFirstPassIfPossibleLocked(lastErr error) { + // An optimization to avoid iterating over the entire SubConn map. + if b.addressList.isValid() { + return + } + // Connect() has been called on all the SubConns. The first pass can be + // ended if all the SubConns have reported a failure. + for _, sd := range b.subConns.Values() { + if !sd.connectionFailedInFirstPass { return } - b.cc.UpdateState(balancer.State{ - ConnectivityState: state.ConnectivityState, - Picker: &idlePicker{subConn: subConn}, + } + b.firstPass = false + b.updateBalancerState(balancer.State{ + ConnectivityState: connectivity.TransientFailure, + Picker: &picker{err: lastErr}, + }) + // Start re-connecting all the SubConns that are already in IDLE. + for _, sd := range b.subConns.Values() { + if sd.rawConnectivityState == connectivity.Idle { + sd.subConn.Connect() + } + } +} + +func (b *pickfirstBalancer) isActiveSCData(sd *scData) bool { + activeSD, found := b.subConns.Get(sd.addr) + return found && activeSD == sd +} + +func (b *pickfirstBalancer) updateSubConnHealthState(sd *scData, state balancer.SubConnState) { + b.mu.Lock() + defer b.mu.Unlock() + // Previously relevant SubConns can still callback with state updates. + // To prevent pickers from returning these obsolete SubConns, this logic + // is included to check if the current list of active SubConns includes + // this SubConn. + if !b.isActiveSCData(sd) { + return + } + sd.effectiveState = state.ConnectivityState + switch state.ConnectivityState { + case connectivity.Ready: + b.updateBalancerState(balancer.State{ + ConnectivityState: connectivity.Ready, + Picker: &picker{result: balancer.PickResult{SubConn: sd.subConn}}, }) case connectivity.TransientFailure: - b.cc.UpdateState(balancer.State{ - ConnectivityState: state.ConnectivityState, - Picker: &picker{err: state.ConnectionError}, + b.updateBalancerState(balancer.State{ + ConnectivityState: connectivity.TransientFailure, + Picker: &picker{err: fmt.Errorf("pickfirst: health check failure: %v", state.ConnectionError)}, + }) + case connectivity.Connecting: + b.updateBalancerState(balancer.State{ + ConnectivityState: connectivity.Connecting, + Picker: &picker{err: balancer.ErrNoSubConnAvailable}, }) + default: + b.logger.Errorf("Got unexpected health update for SubConn %p: %v", state) } - b.state = state.ConnectivityState } -func (b *pickfirstBalancer) Close() { +// updateBalancerState stores the state reported to the channel and calls +// ClientConn.UpdateState(). As an optimization, it avoids sending duplicate +// updates to the channel. +func (b *pickfirstBalancer) updateBalancerState(newState balancer.State) { + // In case of TransientFailures allow the picker to be updated to update + // the connectivity error, in all other cases don't send duplicate state + // updates. + if newState.ConnectivityState == b.state && b.state != connectivity.TransientFailure { + return + } + b.forceUpdateConcludedStateLocked(newState) } -func (b *pickfirstBalancer) ExitIdle() { - if b.subConn != nil && b.state == connectivity.Idle { - b.subConn.Connect() - } +// forceUpdateConcludedStateLocked stores the state reported to the channel and +// calls ClientConn.UpdateState(). +// A separate function is defined to force update the ClientConn state since the +// channel doesn't correctly assume that LB policies start in CONNECTING and +// relies on LB policy to send an initial CONNECTING update. +func (b *pickfirstBalancer) forceUpdateConcludedStateLocked(newState balancer.State) { + b.state = newState.ConnectivityState + b.cc.UpdateState(newState) } type picker struct { @@ -282,10 +861,101 @@ func (p *picker) Pick(balancer.PickInfo) (balancer.PickResult, error) { // idlePicker is used when the SubConn is IDLE and kicks the SubConn into // CONNECTING when Pick is called. type idlePicker struct { - subConn balancer.SubConn + exitIdle func() } func (i *idlePicker) Pick(balancer.PickInfo) (balancer.PickResult, error) { - i.subConn.Connect() + i.exitIdle() return balancer.PickResult{}, balancer.ErrNoSubConnAvailable } + +// addressList manages sequentially iterating over addresses present in a list +// of endpoints. It provides a 1 dimensional view of the addresses present in +// the endpoints. +// This type is not safe for concurrent access. +type addressList struct { + addresses []resolver.Address + idx int +} + +func (al *addressList) isValid() bool { + return al.idx < len(al.addresses) +} + +func (al *addressList) size() int { + return len(al.addresses) +} + +// increment moves to the next index in the address list. +// This method returns false if it went off the list, true otherwise. +func (al *addressList) increment() bool { + if !al.isValid() { + return false + } + al.idx++ + return al.idx < len(al.addresses) +} + +// currentAddress returns the current address pointed to in the addressList. +// If the list is in an invalid state, it returns an empty address instead. +func (al *addressList) currentAddress() resolver.Address { + if !al.isValid() { + return resolver.Address{} + } + return al.addresses[al.idx] +} + +func (al *addressList) reset() { + al.idx = 0 +} + +func (al *addressList) updateAddrs(addrs []resolver.Address) { + al.addresses = addrs + al.reset() +} + +// seekTo returns false if the needle was not found and the current index was +// left unchanged. +func (al *addressList) seekTo(needle resolver.Address) bool { + for ai, addr := range al.addresses { + if !equalAddressIgnoringBalAttributes(&addr, &needle) { + continue + } + al.idx = ai + return true + } + return false +} + +// hasNext returns whether incrementing the addressList will result in moving +// past the end of the list. If the list has already moved past the end, it +// returns false. +func (al *addressList) hasNext() bool { + if !al.isValid() { + return false + } + return al.idx+1 < len(al.addresses) +} + +// equalAddressIgnoringBalAttributes returns true is a and b are considered +// equal. This is different from the Equal method on the resolver.Address type +// which considers all fields to determine equality. Here, we only consider +// fields that are meaningful to the SubConn. +func equalAddressIgnoringBalAttributes(a, b *resolver.Address) bool { + return a.Addr == b.Addr && a.ServerName == b.ServerName && + a.Attributes.Equal(b.Attributes) +} + +// weightAttribute is a convenience function which returns the value of the +// weight endpoint Attribute. +// +// When used in the xDS context, the weight attribute is guaranteed to be +// non-zero. But, when used in a non-xDS context, the weight attribute could be +// unset. A Default of 1 is used in the latter case. +func weightAttribute(e resolver.Endpoint) uint32 { + w := weight.FromEndpoint(e).Weight + if w == 0 { + return 1 + } + return w +} diff --git a/vendor/google.golang.org/grpc/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go b/vendor/google.golang.org/grpc/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go deleted file mode 100644 index 9ffdd28a0..000000000 --- a/vendor/google.golang.org/grpc/balancer/pickfirst/pickfirstleaf/pickfirstleaf.go +++ /dev/null @@ -1,913 +0,0 @@ -/* - * - * Copyright 2024 gRPC authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -// Package pickfirstleaf contains the pick_first load balancing policy which -// will be the universal leaf policy after dualstack changes are implemented. -// -// # Experimental -// -// Notice: This package is EXPERIMENTAL and may be changed or removed in a -// later release. -package pickfirstleaf - -import ( - "encoding/json" - "errors" - "fmt" - "net" - "net/netip" - "sync" - "time" - - "google.golang.org/grpc/balancer" - "google.golang.org/grpc/balancer/pickfirst/internal" - "google.golang.org/grpc/connectivity" - expstats "google.golang.org/grpc/experimental/stats" - "google.golang.org/grpc/grpclog" - "google.golang.org/grpc/internal/envconfig" - internalgrpclog "google.golang.org/grpc/internal/grpclog" - "google.golang.org/grpc/internal/pretty" - "google.golang.org/grpc/resolver" - "google.golang.org/grpc/serviceconfig" -) - -func init() { - if envconfig.NewPickFirstEnabled { - // Register as the default pick_first balancer. - Name = "pick_first" - } - balancer.Register(pickfirstBuilder{}) -} - -// enableHealthListenerKeyType is a unique key type used in resolver -// attributes to indicate whether the health listener usage is enabled. -type enableHealthListenerKeyType struct{} - -var ( - logger = grpclog.Component("pick-first-leaf-lb") - // Name is the name of the pick_first_leaf balancer. - // It is changed to "pick_first" in init() if this balancer is to be - // registered as the default pickfirst. - Name = "pick_first_leaf" - disconnectionsMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{ - Name: "grpc.lb.pick_first.disconnections", - Description: "EXPERIMENTAL. Number of times the selected subchannel becomes disconnected.", - Unit: "{disconnection}", - Labels: []string{"grpc.target"}, - Default: false, - }) - connectionAttemptsSucceededMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{ - Name: "grpc.lb.pick_first.connection_attempts_succeeded", - Description: "EXPERIMENTAL. Number of successful connection attempts.", - Unit: "{attempt}", - Labels: []string{"grpc.target"}, - Default: false, - }) - connectionAttemptsFailedMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{ - Name: "grpc.lb.pick_first.connection_attempts_failed", - Description: "EXPERIMENTAL. Number of failed connection attempts.", - Unit: "{attempt}", - Labels: []string{"grpc.target"}, - Default: false, - }) -) - -const ( - // TODO: change to pick-first when this becomes the default pick_first policy. - logPrefix = "[pick-first-leaf-lb %p] " - // connectionDelayInterval is the time to wait for during the happy eyeballs - // pass before starting the next connection attempt. - connectionDelayInterval = 250 * time.Millisecond -) - -type ipAddrFamily int - -const ( - // ipAddrFamilyUnknown represents strings that can't be parsed as an IP - // address. - ipAddrFamilyUnknown ipAddrFamily = iota - ipAddrFamilyV4 - ipAddrFamilyV6 -) - -type pickfirstBuilder struct{} - -func (pickfirstBuilder) Build(cc balancer.ClientConn, bo balancer.BuildOptions) balancer.Balancer { - b := &pickfirstBalancer{ - cc: cc, - target: bo.Target.String(), - metricsRecorder: cc.MetricsRecorder(), - - subConns: resolver.NewAddressMapV2[*scData](), - state: connectivity.Connecting, - cancelConnectionTimer: func() {}, - } - b.logger = internalgrpclog.NewPrefixLogger(logger, fmt.Sprintf(logPrefix, b)) - return b -} - -func (b pickfirstBuilder) Name() string { - return Name -} - -func (pickfirstBuilder) ParseConfig(js json.RawMessage) (serviceconfig.LoadBalancingConfig, error) { - var cfg pfConfig - if err := json.Unmarshal(js, &cfg); err != nil { - return nil, fmt.Errorf("pickfirst: unable to unmarshal LB policy config: %s, error: %v", string(js), err) - } - return cfg, nil -} - -// EnableHealthListener updates the state to configure pickfirst for using a -// generic health listener. -func EnableHealthListener(state resolver.State) resolver.State { - state.Attributes = state.Attributes.WithValue(enableHealthListenerKeyType{}, true) - return state -} - -type pfConfig struct { - serviceconfig.LoadBalancingConfig `json:"-"` - - // If set to true, instructs the LB policy to shuffle the order of the list - // of endpoints received from the name resolver before attempting to - // connect to them. - ShuffleAddressList bool `json:"shuffleAddressList"` -} - -// scData keeps track of the current state of the subConn. -// It is not safe for concurrent access. -type scData struct { - // The following fields are initialized at build time and read-only after - // that. - subConn balancer.SubConn - addr resolver.Address - - rawConnectivityState connectivity.State - // The effective connectivity state based on raw connectivity, health state - // and after following sticky TransientFailure behaviour defined in A62. - effectiveState connectivity.State - lastErr error - connectionFailedInFirstPass bool -} - -func (b *pickfirstBalancer) newSCData(addr resolver.Address) (*scData, error) { - sd := &scData{ - rawConnectivityState: connectivity.Idle, - effectiveState: connectivity.Idle, - addr: addr, - } - sc, err := b.cc.NewSubConn([]resolver.Address{addr}, balancer.NewSubConnOptions{ - StateListener: func(state balancer.SubConnState) { - b.updateSubConnState(sd, state) - }, - }) - if err != nil { - return nil, err - } - sd.subConn = sc - return sd, nil -} - -type pickfirstBalancer struct { - // The following fields are initialized at build time and read-only after - // that and therefore do not need to be guarded by a mutex. - logger *internalgrpclog.PrefixLogger - cc balancer.ClientConn - target string - metricsRecorder expstats.MetricsRecorder // guaranteed to be non nil - - // The mutex is used to ensure synchronization of updates triggered - // from the idle picker and the already serialized resolver, - // SubConn state updates. - mu sync.Mutex - // State reported to the channel based on SubConn states and resolver - // updates. - state connectivity.State - // scData for active subonns mapped by address. - subConns *resolver.AddressMapV2[*scData] - addressList addressList - firstPass bool - numTF int - cancelConnectionTimer func() - healthCheckingEnabled bool -} - -// ResolverError is called by the ClientConn when the name resolver produces -// an error or when pickfirst determined the resolver update to be invalid. -func (b *pickfirstBalancer) ResolverError(err error) { - b.mu.Lock() - defer b.mu.Unlock() - b.resolverErrorLocked(err) -} - -func (b *pickfirstBalancer) resolverErrorLocked(err error) { - if b.logger.V(2) { - b.logger.Infof("Received error from the name resolver: %v", err) - } - - // The picker will not change since the balancer does not currently - // report an error. If the balancer hasn't received a single good resolver - // update yet, transition to TRANSIENT_FAILURE. - if b.state != connectivity.TransientFailure && b.addressList.size() > 0 { - if b.logger.V(2) { - b.logger.Infof("Ignoring resolver error because balancer is using a previous good update.") - } - return - } - - b.updateBalancerState(balancer.State{ - ConnectivityState: connectivity.TransientFailure, - Picker: &picker{err: fmt.Errorf("name resolver error: %v", err)}, - }) -} - -func (b *pickfirstBalancer) UpdateClientConnState(state balancer.ClientConnState) error { - b.mu.Lock() - defer b.mu.Unlock() - b.cancelConnectionTimer() - if len(state.ResolverState.Addresses) == 0 && len(state.ResolverState.Endpoints) == 0 { - // Cleanup state pertaining to the previous resolver state. - // Treat an empty address list like an error by calling b.ResolverError. - b.closeSubConnsLocked() - b.addressList.updateAddrs(nil) - b.resolverErrorLocked(errors.New("produced zero addresses")) - return balancer.ErrBadResolverState - } - b.healthCheckingEnabled = state.ResolverState.Attributes.Value(enableHealthListenerKeyType{}) != nil - cfg, ok := state.BalancerConfig.(pfConfig) - if state.BalancerConfig != nil && !ok { - return fmt.Errorf("pickfirst: received illegal BalancerConfig (type %T): %v: %w", state.BalancerConfig, state.BalancerConfig, balancer.ErrBadResolverState) - } - - if b.logger.V(2) { - b.logger.Infof("Received new config %s, resolver state %s", pretty.ToJSON(cfg), pretty.ToJSON(state.ResolverState)) - } - - var newAddrs []resolver.Address - if endpoints := state.ResolverState.Endpoints; len(endpoints) != 0 { - // Perform the optional shuffling described in gRFC A62. The shuffling - // will change the order of endpoints but not touch the order of the - // addresses within each endpoint. - A61 - if cfg.ShuffleAddressList { - endpoints = append([]resolver.Endpoint{}, endpoints...) - internal.RandShuffle(len(endpoints), func(i, j int) { endpoints[i], endpoints[j] = endpoints[j], endpoints[i] }) - } - - // "Flatten the list by concatenating the ordered list of addresses for - // each of the endpoints, in order." - A61 - for _, endpoint := range endpoints { - newAddrs = append(newAddrs, endpoint.Addresses...) - } - } else { - // Endpoints not set, process addresses until we migrate resolver - // emissions fully to Endpoints. The top channel does wrap emitted - // addresses with endpoints, however some balancers such as weighted - // target do not forward the corresponding correct endpoints down/split - // endpoints properly. Once all balancers correctly forward endpoints - // down, can delete this else conditional. - newAddrs = state.ResolverState.Addresses - if cfg.ShuffleAddressList { - newAddrs = append([]resolver.Address{}, newAddrs...) - internal.RandShuffle(len(newAddrs), func(i, j int) { newAddrs[i], newAddrs[j] = newAddrs[j], newAddrs[i] }) - } - } - - // If an address appears in multiple endpoints or in the same endpoint - // multiple times, we keep it only once. We will create only one SubConn - // for the address because an AddressMap is used to store SubConns. - // Not de-duplicating would result in attempting to connect to the same - // SubConn multiple times in the same pass. We don't want this. - newAddrs = deDupAddresses(newAddrs) - newAddrs = interleaveAddresses(newAddrs) - - prevAddr := b.addressList.currentAddress() - prevSCData, found := b.subConns.Get(prevAddr) - prevAddrsCount := b.addressList.size() - isPrevRawConnectivityStateReady := found && prevSCData.rawConnectivityState == connectivity.Ready - b.addressList.updateAddrs(newAddrs) - - // If the previous ready SubConn exists in new address list, - // keep this connection and don't create new SubConns. - if isPrevRawConnectivityStateReady && b.addressList.seekTo(prevAddr) { - return nil - } - - b.reconcileSubConnsLocked(newAddrs) - // If it's the first resolver update or the balancer was already READY - // (but the new address list does not contain the ready SubConn) or - // CONNECTING, enter CONNECTING. - // We may be in TRANSIENT_FAILURE due to a previous empty address list, - // we should still enter CONNECTING because the sticky TF behaviour - // mentioned in A62 applies only when the TRANSIENT_FAILURE is reported - // due to connectivity failures. - if isPrevRawConnectivityStateReady || b.state == connectivity.Connecting || prevAddrsCount == 0 { - // Start connection attempt at first address. - b.forceUpdateConcludedStateLocked(balancer.State{ - ConnectivityState: connectivity.Connecting, - Picker: &picker{err: balancer.ErrNoSubConnAvailable}, - }) - b.startFirstPassLocked() - } else if b.state == connectivity.TransientFailure { - // If we're in TRANSIENT_FAILURE, we stay in TRANSIENT_FAILURE until - // we're READY. See A62. - b.startFirstPassLocked() - } - return nil -} - -// UpdateSubConnState is unused as a StateListener is always registered when -// creating SubConns. -func (b *pickfirstBalancer) UpdateSubConnState(subConn balancer.SubConn, state balancer.SubConnState) { - b.logger.Errorf("UpdateSubConnState(%v, %+v) called unexpectedly", subConn, state) -} - -func (b *pickfirstBalancer) Close() { - b.mu.Lock() - defer b.mu.Unlock() - b.closeSubConnsLocked() - b.cancelConnectionTimer() - b.state = connectivity.Shutdown -} - -// ExitIdle moves the balancer out of idle state. It can be called concurrently -// by the idlePicker and clientConn so access to variables should be -// synchronized. -func (b *pickfirstBalancer) ExitIdle() { - b.mu.Lock() - defer b.mu.Unlock() - if b.state == connectivity.Idle { - // Move the balancer into CONNECTING state immediately. This is done to - // avoid staying in IDLE if a resolver update arrives before the first - // SubConn reports CONNECTING. - b.updateBalancerState(balancer.State{ - ConnectivityState: connectivity.Connecting, - Picker: &picker{err: balancer.ErrNoSubConnAvailable}, - }) - b.startFirstPassLocked() - } -} - -func (b *pickfirstBalancer) startFirstPassLocked() { - b.firstPass = true - b.numTF = 0 - // Reset the connection attempt record for existing SubConns. - for _, sd := range b.subConns.Values() { - sd.connectionFailedInFirstPass = false - } - b.requestConnectionLocked() -} - -func (b *pickfirstBalancer) closeSubConnsLocked() { - for _, sd := range b.subConns.Values() { - sd.subConn.Shutdown() - } - b.subConns = resolver.NewAddressMapV2[*scData]() -} - -// deDupAddresses ensures that each address appears only once in the slice. -func deDupAddresses(addrs []resolver.Address) []resolver.Address { - seenAddrs := resolver.NewAddressMapV2[*scData]() - retAddrs := []resolver.Address{} - - for _, addr := range addrs { - if _, ok := seenAddrs.Get(addr); ok { - continue - } - retAddrs = append(retAddrs, addr) - } - return retAddrs -} - -// interleaveAddresses interleaves addresses of both families (IPv4 and IPv6) -// as per RFC-8305 section 4. -// Whichever address family is first in the list is followed by an address of -// the other address family; that is, if the first address in the list is IPv6, -// then the first IPv4 address should be moved up in the list to be second in -// the list. It doesn't support configuring "First Address Family Count", i.e. -// there will always be a single member of the first address family at the -// beginning of the interleaved list. -// Addresses that are neither IPv4 nor IPv6 are treated as part of a third -// "unknown" family for interleaving. -// See: https://datatracker.ietf.org/doc/html/rfc8305#autoid-6 -func interleaveAddresses(addrs []resolver.Address) []resolver.Address { - familyAddrsMap := map[ipAddrFamily][]resolver.Address{} - interleavingOrder := []ipAddrFamily{} - for _, addr := range addrs { - family := addressFamily(addr.Addr) - if _, found := familyAddrsMap[family]; !found { - interleavingOrder = append(interleavingOrder, family) - } - familyAddrsMap[family] = append(familyAddrsMap[family], addr) - } - - interleavedAddrs := make([]resolver.Address, 0, len(addrs)) - - for curFamilyIdx := 0; len(interleavedAddrs) < len(addrs); curFamilyIdx = (curFamilyIdx + 1) % len(interleavingOrder) { - // Some IP types may have fewer addresses than others, so we look for - // the next type that has a remaining member to add to the interleaved - // list. - family := interleavingOrder[curFamilyIdx] - remainingMembers := familyAddrsMap[family] - if len(remainingMembers) > 0 { - interleavedAddrs = append(interleavedAddrs, remainingMembers[0]) - familyAddrsMap[family] = remainingMembers[1:] - } - } - - return interleavedAddrs -} - -// addressFamily returns the ipAddrFamily after parsing the address string. -// If the address isn't of the format "ip-address:port", it returns -// ipAddrFamilyUnknown. The address may be valid even if it's not an IP when -// using a resolver like passthrough where the address may be a hostname in -// some format that the dialer can resolve. -func addressFamily(address string) ipAddrFamily { - // Parse the IP after removing the port. - host, _, err := net.SplitHostPort(address) - if err != nil { - return ipAddrFamilyUnknown - } - ip, err := netip.ParseAddr(host) - if err != nil { - return ipAddrFamilyUnknown - } - switch { - case ip.Is4() || ip.Is4In6(): - return ipAddrFamilyV4 - case ip.Is6(): - return ipAddrFamilyV6 - default: - return ipAddrFamilyUnknown - } -} - -// reconcileSubConnsLocked updates the active subchannels based on a new address -// list from the resolver. It does this by: -// - closing subchannels: any existing subchannels associated with addresses -// that are no longer in the updated list are shut down. -// - removing subchannels: entries for these closed subchannels are removed -// from the subchannel map. -// -// This ensures that the subchannel map accurately reflects the current set of -// addresses received from the name resolver. -func (b *pickfirstBalancer) reconcileSubConnsLocked(newAddrs []resolver.Address) { - newAddrsMap := resolver.NewAddressMapV2[bool]() - for _, addr := range newAddrs { - newAddrsMap.Set(addr, true) - } - - for _, oldAddr := range b.subConns.Keys() { - if _, ok := newAddrsMap.Get(oldAddr); ok { - continue - } - val, _ := b.subConns.Get(oldAddr) - val.subConn.Shutdown() - b.subConns.Delete(oldAddr) - } -} - -// shutdownRemainingLocked shuts down remaining subConns. Called when a subConn -// becomes ready, which means that all other subConn must be shutdown. -func (b *pickfirstBalancer) shutdownRemainingLocked(selected *scData) { - b.cancelConnectionTimer() - for _, sd := range b.subConns.Values() { - if sd.subConn != selected.subConn { - sd.subConn.Shutdown() - } - } - b.subConns = resolver.NewAddressMapV2[*scData]() - b.subConns.Set(selected.addr, selected) -} - -// requestConnectionLocked starts connecting on the subchannel corresponding to -// the current address. If no subchannel exists, one is created. If the current -// subchannel is in TransientFailure, a connection to the next address is -// attempted until a subchannel is found. -func (b *pickfirstBalancer) requestConnectionLocked() { - if !b.addressList.isValid() { - return - } - var lastErr error - for valid := true; valid; valid = b.addressList.increment() { - curAddr := b.addressList.currentAddress() - sd, ok := b.subConns.Get(curAddr) - if !ok { - var err error - // We want to assign the new scData to sd from the outer scope, - // hence we can't use := below. - sd, err = b.newSCData(curAddr) - if err != nil { - // This should never happen, unless the clientConn is being shut - // down. - if b.logger.V(2) { - b.logger.Infof("Failed to create a subConn for address %v: %v", curAddr.String(), err) - } - // Do nothing, the LB policy will be closed soon. - return - } - b.subConns.Set(curAddr, sd) - } - - switch sd.rawConnectivityState { - case connectivity.Idle: - sd.subConn.Connect() - b.scheduleNextConnectionLocked() - return - case connectivity.TransientFailure: - // The SubConn is being re-used and failed during a previous pass - // over the addressList. It has not completed backoff yet. - // Mark it as having failed and try the next address. - sd.connectionFailedInFirstPass = true - lastErr = sd.lastErr - continue - case connectivity.Connecting: - // Wait for the connection attempt to complete or the timer to fire - // before attempting the next address. - b.scheduleNextConnectionLocked() - return - default: - b.logger.Errorf("SubConn with unexpected state %v present in SubConns map.", sd.rawConnectivityState) - return - - } - } - - // All the remaining addresses in the list are in TRANSIENT_FAILURE, end the - // first pass if possible. - b.endFirstPassIfPossibleLocked(lastErr) -} - -func (b *pickfirstBalancer) scheduleNextConnectionLocked() { - b.cancelConnectionTimer() - if !b.addressList.hasNext() { - return - } - curAddr := b.addressList.currentAddress() - cancelled := false // Access to this is protected by the balancer's mutex. - closeFn := internal.TimeAfterFunc(connectionDelayInterval, func() { - b.mu.Lock() - defer b.mu.Unlock() - // If the scheduled task is cancelled while acquiring the mutex, return. - if cancelled { - return - } - if b.logger.V(2) { - b.logger.Infof("Happy Eyeballs timer expired while waiting for connection to %q.", curAddr.Addr) - } - if b.addressList.increment() { - b.requestConnectionLocked() - } - }) - // Access to the cancellation callback held by the balancer is guarded by - // the balancer's mutex, so it's safe to set the boolean from the callback. - b.cancelConnectionTimer = sync.OnceFunc(func() { - cancelled = true - closeFn() - }) -} - -func (b *pickfirstBalancer) updateSubConnState(sd *scData, newState balancer.SubConnState) { - b.mu.Lock() - defer b.mu.Unlock() - oldState := sd.rawConnectivityState - sd.rawConnectivityState = newState.ConnectivityState - // Previously relevant SubConns can still callback with state updates. - // To prevent pickers from returning these obsolete SubConns, this logic - // is included to check if the current list of active SubConns includes this - // SubConn. - if !b.isActiveSCData(sd) { - return - } - if newState.ConnectivityState == connectivity.Shutdown { - sd.effectiveState = connectivity.Shutdown - return - } - - // Record a connection attempt when exiting CONNECTING. - if newState.ConnectivityState == connectivity.TransientFailure { - sd.connectionFailedInFirstPass = true - connectionAttemptsFailedMetric.Record(b.metricsRecorder, 1, b.target) - } - - if newState.ConnectivityState == connectivity.Ready { - connectionAttemptsSucceededMetric.Record(b.metricsRecorder, 1, b.target) - b.shutdownRemainingLocked(sd) - if !b.addressList.seekTo(sd.addr) { - // This should not fail as we should have only one SubConn after - // entering READY. The SubConn should be present in the addressList. - b.logger.Errorf("Address %q not found address list in %v", sd.addr, b.addressList.addresses) - return - } - if !b.healthCheckingEnabled { - if b.logger.V(2) { - b.logger.Infof("SubConn %p reported connectivity state READY and the health listener is disabled. Transitioning SubConn to READY.", sd.subConn) - } - - sd.effectiveState = connectivity.Ready - b.updateBalancerState(balancer.State{ - ConnectivityState: connectivity.Ready, - Picker: &picker{result: balancer.PickResult{SubConn: sd.subConn}}, - }) - return - } - if b.logger.V(2) { - b.logger.Infof("SubConn %p reported connectivity state READY. Registering health listener.", sd.subConn) - } - // Send a CONNECTING update to take the SubConn out of sticky-TF if - // required. - sd.effectiveState = connectivity.Connecting - b.updateBalancerState(balancer.State{ - ConnectivityState: connectivity.Connecting, - Picker: &picker{err: balancer.ErrNoSubConnAvailable}, - }) - sd.subConn.RegisterHealthListener(func(scs balancer.SubConnState) { - b.updateSubConnHealthState(sd, scs) - }) - return - } - - // If the LB policy is READY, and it receives a subchannel state change, - // it means that the READY subchannel has failed. - // A SubConn can also transition from CONNECTING directly to IDLE when - // a transport is successfully created, but the connection fails - // before the SubConn can send the notification for READY. We treat - // this as a successful connection and transition to IDLE. - // TODO: https://github.com/grpc/grpc-go/issues/7862 - Remove the second - // part of the if condition below once the issue is fixed. - if oldState == connectivity.Ready || (oldState == connectivity.Connecting && newState.ConnectivityState == connectivity.Idle) { - // Once a transport fails, the balancer enters IDLE and starts from - // the first address when the picker is used. - b.shutdownRemainingLocked(sd) - sd.effectiveState = newState.ConnectivityState - // READY SubConn interspliced in between CONNECTING and IDLE, need to - // account for that. - if oldState == connectivity.Connecting { - // A known issue (https://github.com/grpc/grpc-go/issues/7862) - // causes a race that prevents the READY state change notification. - // This works around it. - connectionAttemptsSucceededMetric.Record(b.metricsRecorder, 1, b.target) - } - disconnectionsMetric.Record(b.metricsRecorder, 1, b.target) - b.addressList.reset() - b.updateBalancerState(balancer.State{ - ConnectivityState: connectivity.Idle, - Picker: &idlePicker{exitIdle: sync.OnceFunc(b.ExitIdle)}, - }) - return - } - - if b.firstPass { - switch newState.ConnectivityState { - case connectivity.Connecting: - // The effective state can be in either IDLE, CONNECTING or - // TRANSIENT_FAILURE. If it's TRANSIENT_FAILURE, stay in - // TRANSIENT_FAILURE until it's READY. See A62. - if sd.effectiveState != connectivity.TransientFailure { - sd.effectiveState = connectivity.Connecting - b.updateBalancerState(balancer.State{ - ConnectivityState: connectivity.Connecting, - Picker: &picker{err: balancer.ErrNoSubConnAvailable}, - }) - } - case connectivity.TransientFailure: - sd.lastErr = newState.ConnectionError - sd.effectiveState = connectivity.TransientFailure - // Since we're re-using common SubConns while handling resolver - // updates, we could receive an out of turn TRANSIENT_FAILURE from - // a pass over the previous address list. Happy Eyeballs will also - // cause out of order updates to arrive. - - if curAddr := b.addressList.currentAddress(); equalAddressIgnoringBalAttributes(&curAddr, &sd.addr) { - b.cancelConnectionTimer() - if b.addressList.increment() { - b.requestConnectionLocked() - return - } - } - - // End the first pass if we've seen a TRANSIENT_FAILURE from all - // SubConns once. - b.endFirstPassIfPossibleLocked(newState.ConnectionError) - } - return - } - - // We have finished the first pass, keep re-connecting failing SubConns. - switch newState.ConnectivityState { - case connectivity.TransientFailure: - b.numTF = (b.numTF + 1) % b.subConns.Len() - sd.lastErr = newState.ConnectionError - if b.numTF%b.subConns.Len() == 0 { - b.updateBalancerState(balancer.State{ - ConnectivityState: connectivity.TransientFailure, - Picker: &picker{err: newState.ConnectionError}, - }) - } - // We don't need to request re-resolution since the SubConn already - // does that before reporting TRANSIENT_FAILURE. - // TODO: #7534 - Move re-resolution requests from SubConn into - // pick_first. - case connectivity.Idle: - sd.subConn.Connect() - } -} - -// endFirstPassIfPossibleLocked ends the first happy-eyeballs pass if all the -// addresses are tried and their SubConns have reported a failure. -func (b *pickfirstBalancer) endFirstPassIfPossibleLocked(lastErr error) { - // An optimization to avoid iterating over the entire SubConn map. - if b.addressList.isValid() { - return - } - // Connect() has been called on all the SubConns. The first pass can be - // ended if all the SubConns have reported a failure. - for _, sd := range b.subConns.Values() { - if !sd.connectionFailedInFirstPass { - return - } - } - b.firstPass = false - b.updateBalancerState(balancer.State{ - ConnectivityState: connectivity.TransientFailure, - Picker: &picker{err: lastErr}, - }) - // Start re-connecting all the SubConns that are already in IDLE. - for _, sd := range b.subConns.Values() { - if sd.rawConnectivityState == connectivity.Idle { - sd.subConn.Connect() - } - } -} - -func (b *pickfirstBalancer) isActiveSCData(sd *scData) bool { - activeSD, found := b.subConns.Get(sd.addr) - return found && activeSD == sd -} - -func (b *pickfirstBalancer) updateSubConnHealthState(sd *scData, state balancer.SubConnState) { - b.mu.Lock() - defer b.mu.Unlock() - // Previously relevant SubConns can still callback with state updates. - // To prevent pickers from returning these obsolete SubConns, this logic - // is included to check if the current list of active SubConns includes - // this SubConn. - if !b.isActiveSCData(sd) { - return - } - sd.effectiveState = state.ConnectivityState - switch state.ConnectivityState { - case connectivity.Ready: - b.updateBalancerState(balancer.State{ - ConnectivityState: connectivity.Ready, - Picker: &picker{result: balancer.PickResult{SubConn: sd.subConn}}, - }) - case connectivity.TransientFailure: - b.updateBalancerState(balancer.State{ - ConnectivityState: connectivity.TransientFailure, - Picker: &picker{err: fmt.Errorf("pickfirst: health check failure: %v", state.ConnectionError)}, - }) - case connectivity.Connecting: - b.updateBalancerState(balancer.State{ - ConnectivityState: connectivity.Connecting, - Picker: &picker{err: balancer.ErrNoSubConnAvailable}, - }) - default: - b.logger.Errorf("Got unexpected health update for SubConn %p: %v", state) - } -} - -// updateBalancerState stores the state reported to the channel and calls -// ClientConn.UpdateState(). As an optimization, it avoids sending duplicate -// updates to the channel. -func (b *pickfirstBalancer) updateBalancerState(newState balancer.State) { - // In case of TransientFailures allow the picker to be updated to update - // the connectivity error, in all other cases don't send duplicate state - // updates. - if newState.ConnectivityState == b.state && b.state != connectivity.TransientFailure { - return - } - b.forceUpdateConcludedStateLocked(newState) -} - -// forceUpdateConcludedStateLocked stores the state reported to the channel and -// calls ClientConn.UpdateState(). -// A separate function is defined to force update the ClientConn state since the -// channel doesn't correctly assume that LB policies start in CONNECTING and -// relies on LB policy to send an initial CONNECTING update. -func (b *pickfirstBalancer) forceUpdateConcludedStateLocked(newState balancer.State) { - b.state = newState.ConnectivityState - b.cc.UpdateState(newState) -} - -type picker struct { - result balancer.PickResult - err error -} - -func (p *picker) Pick(balancer.PickInfo) (balancer.PickResult, error) { - return p.result, p.err -} - -// idlePicker is used when the SubConn is IDLE and kicks the SubConn into -// CONNECTING when Pick is called. -type idlePicker struct { - exitIdle func() -} - -func (i *idlePicker) Pick(balancer.PickInfo) (balancer.PickResult, error) { - i.exitIdle() - return balancer.PickResult{}, balancer.ErrNoSubConnAvailable -} - -// addressList manages sequentially iterating over addresses present in a list -// of endpoints. It provides a 1 dimensional view of the addresses present in -// the endpoints. -// This type is not safe for concurrent access. -type addressList struct { - addresses []resolver.Address - idx int -} - -func (al *addressList) isValid() bool { - return al.idx < len(al.addresses) -} - -func (al *addressList) size() int { - return len(al.addresses) -} - -// increment moves to the next index in the address list. -// This method returns false if it went off the list, true otherwise. -func (al *addressList) increment() bool { - if !al.isValid() { - return false - } - al.idx++ - return al.idx < len(al.addresses) -} - -// currentAddress returns the current address pointed to in the addressList. -// If the list is in an invalid state, it returns an empty address instead. -func (al *addressList) currentAddress() resolver.Address { - if !al.isValid() { - return resolver.Address{} - } - return al.addresses[al.idx] -} - -func (al *addressList) reset() { - al.idx = 0 -} - -func (al *addressList) updateAddrs(addrs []resolver.Address) { - al.addresses = addrs - al.reset() -} - -// seekTo returns false if the needle was not found and the current index was -// left unchanged. -func (al *addressList) seekTo(needle resolver.Address) bool { - for ai, addr := range al.addresses { - if !equalAddressIgnoringBalAttributes(&addr, &needle) { - continue - } - al.idx = ai - return true - } - return false -} - -// hasNext returns whether incrementing the addressList will result in moving -// past the end of the list. If the list has already moved past the end, it -// returns false. -func (al *addressList) hasNext() bool { - if !al.isValid() { - return false - } - return al.idx+1 < len(al.addresses) -} - -// equalAddressIgnoringBalAttributes returns true is a and b are considered -// equal. This is different from the Equal method on the resolver.Address type -// which considers all fields to determine equality. Here, we only consider -// fields that are meaningful to the SubConn. -func equalAddressIgnoringBalAttributes(a, b *resolver.Address) bool { - return a.Addr == b.Addr && a.ServerName == b.ServerName && - a.Attributes.Equal(b.Attributes) -} diff --git a/vendor/google.golang.org/grpc/balancer/roundrobin/roundrobin.go b/vendor/google.golang.org/grpc/balancer/roundrobin/roundrobin.go index 22045bf39..22e6e3267 100644 --- a/vendor/google.golang.org/grpc/balancer/roundrobin/roundrobin.go +++ b/vendor/google.golang.org/grpc/balancer/roundrobin/roundrobin.go @@ -26,7 +26,7 @@ import ( "google.golang.org/grpc/balancer" "google.golang.org/grpc/balancer/endpointsharding" - "google.golang.org/grpc/balancer/pickfirst/pickfirstleaf" + "google.golang.org/grpc/balancer/pickfirst" "google.golang.org/grpc/grpclog" internalgrpclog "google.golang.org/grpc/internal/grpclog" ) @@ -47,7 +47,7 @@ func (bb builder) Name() string { } func (bb builder) Build(cc balancer.ClientConn, opts balancer.BuildOptions) balancer.Balancer { - childBuilder := balancer.Get(pickfirstleaf.Name).Build + childBuilder := balancer.Get(pickfirst.Name).Build bal := &rrBalancer{ cc: cc, Balancer: endpointsharding.NewBalancer(cc, opts, childBuilder, endpointsharding.Options{}), @@ -67,6 +67,6 @@ func (b *rrBalancer) UpdateClientConnState(ccs balancer.ClientConnState) error { return b.Balancer.UpdateClientConnState(balancer.ClientConnState{ // Enable the health listener in pickfirst children for client side health // checks and outlier detection, if configured. - ResolverState: pickfirstleaf.EnableHealthListener(ccs.ResolverState), + ResolverState: pickfirst.EnableHealthListener(ccs.ResolverState), }) } diff --git a/vendor/google.golang.org/grpc/balancer/subconn.go b/vendor/google.golang.org/grpc/balancer/subconn.go index 9ee44d4af..c1ca7c92e 100644 --- a/vendor/google.golang.org/grpc/balancer/subconn.go +++ b/vendor/google.golang.org/grpc/balancer/subconn.go @@ -111,20 +111,6 @@ type SubConnState struct { // ConnectionError is set if the ConnectivityState is TransientFailure, // describing the reason the SubConn failed. Otherwise, it is nil. ConnectionError error - // connectedAddr contains the connected address when ConnectivityState is - // Ready. Otherwise, it is indeterminate. - connectedAddress resolver.Address -} - -// connectedAddress returns the connected address for a SubConnState. The -// address is only valid if the state is READY. -func connectedAddress(scs SubConnState) resolver.Address { - return scs.connectedAddress -} - -// setConnectedAddress sets the connected address for a SubConnState. -func setConnectedAddress(scs *SubConnState, addr resolver.Address) { - scs.connectedAddress = addr } // A Producer is a type shared among potentially many consumers. It is diff --git a/vendor/google.golang.org/grpc/balancer_wrapper.go b/vendor/google.golang.org/grpc/balancer_wrapper.go index 948a21ef6..a1e56a389 100644 --- a/vendor/google.golang.org/grpc/balancer_wrapper.go +++ b/vendor/google.golang.org/grpc/balancer_wrapper.go @@ -36,7 +36,6 @@ import ( ) var ( - setConnectedAddress = internal.SetConnectedAddress.(func(*balancer.SubConnState, resolver.Address)) // noOpRegisterHealthListenerFn is used when client side health checking is // disabled. It sends a single READY update on the registered listener. noOpRegisterHealthListenerFn = func(_ context.Context, listener func(balancer.SubConnState)) func() { @@ -305,7 +304,7 @@ func newHealthData(s connectivity.State) *healthData { // updateState is invoked by grpc to push a subConn state update to the // underlying balancer. -func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolver.Address, err error) { +func (acbw *acBalancerWrapper) updateState(s connectivity.State, err error) { acbw.ccb.serializer.TrySchedule(func(ctx context.Context) { if ctx.Err() != nil || acbw.ccb.balancer == nil { return @@ -317,9 +316,6 @@ func (acbw *acBalancerWrapper) updateState(s connectivity.State, curAddr resolve // opts.StateListener is set, so this cannot ever be nil. // TODO: delete this comment when UpdateSubConnState is removed. scs := balancer.SubConnState{ConnectivityState: s, ConnectionError: err} - if s == connectivity.Ready { - setConnectedAddress(&scs, curAddr) - } // Invalidate the health listener by updating the healthData. acbw.healthMu.Lock() // A race may occur if a health listener is registered soon after the @@ -450,13 +446,14 @@ func (acbw *acBalancerWrapper) healthListenerRegFn() func(context.Context, func( if acbw.ccb.cc.dopts.disableHealthCheck { return noOpRegisterHealthListenerFn } + cfg := acbw.ac.cc.healthCheckConfig() + if cfg == nil { + return noOpRegisterHealthListenerFn + } regHealthLisFn := internal.RegisterClientHealthCheckListener if regHealthLisFn == nil { // The health package is not imported. - return noOpRegisterHealthListenerFn - } - cfg := acbw.ac.cc.healthCheckConfig() - if cfg == nil { + channelz.Error(logger, acbw.ac.channelz, "Health check is requested but health package is not imported.") return noOpRegisterHealthListenerFn } return func(ctx context.Context, listener func(balancer.SubConnState)) func() { diff --git a/vendor/google.golang.org/grpc/binarylog/grpc_binarylog_v1/binarylog.pb.go b/vendor/google.golang.org/grpc/binarylog/grpc_binarylog_v1/binarylog.pb.go index b1364a032..42c61cf9f 100644 --- a/vendor/google.golang.org/grpc/binarylog/grpc_binarylog_v1/binarylog.pb.go +++ b/vendor/google.golang.org/grpc/binarylog/grpc_binarylog_v1/binarylog.pb.go @@ -18,7 +18,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.6 +// protoc-gen-go v1.36.10 // protoc v5.27.1 // source: grpc/binlog/v1/binarylog.proto diff --git a/vendor/google.golang.org/grpc/clientconn.go b/vendor/google.golang.org/grpc/clientconn.go index a3c315f2d..5dec2dacc 100644 --- a/vendor/google.golang.org/grpc/clientconn.go +++ b/vendor/google.golang.org/grpc/clientconn.go @@ -35,16 +35,19 @@ import ( "google.golang.org/grpc/balancer/pickfirst" "google.golang.org/grpc/codes" "google.golang.org/grpc/connectivity" + "google.golang.org/grpc/credentials" + expstats "google.golang.org/grpc/experimental/stats" "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/channelz" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/idle" iresolver "google.golang.org/grpc/internal/resolver" - "google.golang.org/grpc/internal/stats" + istats "google.golang.org/grpc/internal/stats" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/keepalive" "google.golang.org/grpc/resolver" "google.golang.org/grpc/serviceconfig" + "google.golang.org/grpc/stats" "google.golang.org/grpc/status" _ "google.golang.org/grpc/balancer/roundrobin" // To register roundrobin. @@ -97,6 +100,41 @@ var ( errTransportCredentialsMissing = errors.New("grpc: the credentials require transport level security (use grpc.WithTransportCredentials() to set)") ) +var ( + disconnectionsMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{ + Name: "grpc.subchannel.disconnections", + Description: "EXPERIMENTAL. Number of times the selected subchannel becomes disconnected.", + Unit: "{disconnection}", + Labels: []string{"grpc.target"}, + OptionalLabels: []string{"grpc.lb.backend_service", "grpc.lb.locality", "grpc.disconnect_error"}, + Default: false, + }) + connectionAttemptsSucceededMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{ + Name: "grpc.subchannel.connection_attempts_succeeded", + Description: "EXPERIMENTAL. Number of successful connection attempts.", + Unit: "{attempt}", + Labels: []string{"grpc.target"}, + OptionalLabels: []string{"grpc.lb.backend_service", "grpc.lb.locality"}, + Default: false, + }) + connectionAttemptsFailedMetric = expstats.RegisterInt64Count(expstats.MetricDescriptor{ + Name: "grpc.subchannel.connection_attempts_failed", + Description: "EXPERIMENTAL. Number of failed connection attempts.", + Unit: "{attempt}", + Labels: []string{"grpc.target"}, + OptionalLabels: []string{"grpc.lb.backend_service", "grpc.lb.locality"}, + Default: false, + }) + openConnectionsMetric = expstats.RegisterInt64UpDownCount(expstats.MetricDescriptor{ + Name: "grpc.subchannel.open_connections", + Description: "EXPERIMENTAL. Number of open connections.", + Unit: "{attempt}", + Labels: []string{"grpc.target"}, + OptionalLabels: []string{"grpc.lb.backend_service", "grpc.security_level", "grpc.lb.locality"}, + Default: false, + }) +) + const ( defaultClientMaxReceiveMessageSize = 1024 * 1024 * 4 defaultClientMaxSendMessageSize = math.MaxInt32 @@ -210,7 +248,8 @@ func NewClient(target string, opts ...DialOption) (conn *ClientConn, err error) cc.csMgr = newConnectivityStateManager(cc.ctx, cc.channelz) cc.pickerWrapper = newPickerWrapper() - cc.metricsRecorderList = stats.NewMetricsRecorderList(cc.dopts.copts.StatsHandlers) + cc.metricsRecorderList = istats.NewMetricsRecorderList(cc.dopts.copts.StatsHandlers) + cc.statsHandler = istats.NewCombinedHandler(cc.dopts.copts.StatsHandlers...) cc.initIdleStateLocked() // Safe to call without the lock, since nothing else has a reference to cc. cc.idlenessMgr = idle.NewManager((*idler)(cc), cc.dopts.idleTimeout) @@ -260,9 +299,10 @@ func DialContext(ctx context.Context, target string, opts ...DialOption) (conn * }() // This creates the name resolver, load balancer, etc. - if err := cc.idlenessMgr.ExitIdleMode(); err != nil { - return nil, err + if err := cc.exitIdleMode(); err != nil { + return nil, fmt.Errorf("failed to exit idle mode: %w", err) } + cc.idlenessMgr.UnsafeSetNotIdle() // Return now for non-blocking dials. if !cc.dopts.block { @@ -330,7 +370,7 @@ func (cc *ClientConn) addTraceEvent(msg string) { Severity: channelz.CtInfo, } } - channelz.AddTraceEvent(logger, cc.channelz, 0, ted) + channelz.AddTraceEvent(logger, cc.channelz, 1, ted) } type idler ClientConn @@ -339,14 +379,17 @@ func (i *idler) EnterIdleMode() { (*ClientConn)(i).enterIdleMode() } -func (i *idler) ExitIdleMode() error { - return (*ClientConn)(i).exitIdleMode() +func (i *idler) ExitIdleMode() { + // Ignore the error returned from this method, because from the perspective + // of the caller (idleness manager), the channel would have always moved out + // of IDLE by the time this method returns. + (*ClientConn)(i).exitIdleMode() } // exitIdleMode moves the channel out of idle mode by recreating the name // resolver and load balancer. This should never be called directly; use // cc.idlenessMgr.ExitIdleMode instead. -func (cc *ClientConn) exitIdleMode() (err error) { +func (cc *ClientConn) exitIdleMode() error { cc.mu.Lock() if cc.conns == nil { cc.mu.Unlock() @@ -354,11 +397,23 @@ func (cc *ClientConn) exitIdleMode() (err error) { } cc.mu.Unlock() + // Set state to CONNECTING before building the name resolver + // so the channel does not remain in IDLE. + cc.csMgr.updateState(connectivity.Connecting) + // This needs to be called without cc.mu because this builds a new resolver // which might update state or report error inline, which would then need to // acquire cc.mu. if err := cc.resolverWrapper.start(); err != nil { - return err + // If resolver creation fails, treat it like an error reported by the + // resolver before any valid updates. Set channel's state to + // TransientFailure, and set an erroring picker with the resolver build + // error, which will returned as part of any subsequent RPCs. + logger.Warningf("Failed to start resolver: %v", err) + cc.csMgr.updateState(connectivity.TransientFailure) + cc.mu.Lock() + cc.updateResolverStateAndUnlock(resolver.State{}, err) + return fmt.Errorf("failed to start resolver: %w", err) } cc.addTraceEvent("exiting idle mode") @@ -621,7 +676,8 @@ type ClientConn struct { channelz *channelz.Channel // Channelz object. resolverBuilder resolver.Builder // See initParsedTargetAndResolverBuilder(). idlenessMgr *idle.Manager - metricsRecorderList *stats.MetricsRecorderList + metricsRecorderList *istats.MetricsRecorderList + statsHandler stats.Handler // The following provide their own synchronization, and therefore don't // require cc.mu to be held to access them. @@ -678,10 +734,8 @@ func (cc *ClientConn) GetState() connectivity.State { // Notice: This API is EXPERIMENTAL and may be changed or removed in a later // release. func (cc *ClientConn) Connect() { - if err := cc.idlenessMgr.ExitIdleMode(); err != nil { - cc.addTraceEvent(err.Error()) - return - } + cc.idlenessMgr.ExitIdleMode() + // If the ClientConn was not in idle mode, we need to call ExitIdle on the // LB policy so that connections can be created. cc.mu.Lock() @@ -732,8 +786,8 @@ func init() { internal.EnterIdleModeForTesting = func(cc *ClientConn) { cc.idlenessMgr.EnterIdleModeForTesting() } - internal.ExitIdleModeForTesting = func(cc *ClientConn) error { - return cc.idlenessMgr.ExitIdleMode() + internal.ExitIdleModeForTesting = func(cc *ClientConn) { + cc.idlenessMgr.ExitIdleMode() } } @@ -858,6 +912,7 @@ func (cc *ClientConn) newAddrConnLocked(addrs []resolver.Address, opts balancer. channelz: channelz.RegisterSubChannel(cc.channelz, ""), resetBackoff: make(chan struct{}), } + ac.updateTelemetryLabelsLocked() ac.ctx, ac.cancel = context.WithCancel(cc.ctx) // Start with our address set to the first address; this may be updated if // we connect to different addresses. @@ -922,25 +977,24 @@ func (cc *ClientConn) incrCallsFailed() { // connect starts creating a transport. // It does nothing if the ac is not IDLE. // TODO(bar) Move this to the addrConn section. -func (ac *addrConn) connect() error { +func (ac *addrConn) connect() { ac.mu.Lock() if ac.state == connectivity.Shutdown { if logger.V(2) { logger.Infof("connect called on shutdown addrConn; ignoring.") } ac.mu.Unlock() - return errConnClosing + return } if ac.state != connectivity.Idle { if logger.V(2) { logger.Infof("connect called on addrConn in non-idle state (%v); ignoring.", ac.state) } ac.mu.Unlock() - return nil + return } ac.resetTransportAndUnlock() - return nil } // equalAddressIgnoringBalAttributes returns true is a and b are considered equal. @@ -974,7 +1028,7 @@ func (ac *addrConn) updateAddrs(addrs []resolver.Address) { } ac.addrs = addrs - + ac.updateTelemetryLabelsLocked() if ac.state == connectivity.Shutdown || ac.state == connectivity.TransientFailure || ac.state == connectivity.Idle { @@ -1213,6 +1267,9 @@ type addrConn struct { resetBackoff chan struct{} channelz *channelz.SubChannel + + localityLabel string + backendServiceLabel string } // Note: this requires a lock on ac.mu. @@ -1220,6 +1277,18 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error) if ac.state == s { return } + + // If we are transitioning out of Ready, it means there is a disconnection. + // A SubConn can also transition from CONNECTING directly to IDLE when + // a transport is successfully created, but the connection fails + // before the SubConn can send the notification for READY. We treat + // this as a successful connection and transition to IDLE. + // TODO: https://github.com/grpc/grpc-go/issues/7862 - Remove the second + // part of the if condition below once the issue is fixed. + if ac.state == connectivity.Ready || (ac.state == connectivity.Connecting && s == connectivity.Idle) { + disconnectionsMetric.Record(ac.cc.metricsRecorderList, 1, ac.cc.target, ac.backendServiceLabel, ac.localityLabel, "unknown") + openConnectionsMetric.Record(ac.cc.metricsRecorderList, -1, ac.cc.target, ac.backendServiceLabel, ac.securityLevelLocked(), ac.localityLabel) + } ac.state = s ac.channelz.ChannelMetrics.State.Store(&s) if lastErr == nil { @@ -1227,7 +1296,7 @@ func (ac *addrConn) updateConnectivityState(s connectivity.State, lastErr error) } else { channelz.Infof(logger, ac.channelz, "Subchannel Connectivity change to %v, last error: %s", s, lastErr) } - ac.acbw.updateState(s, ac.curAddr, lastErr) + ac.acbw.updateState(s, lastErr) } // adjustParams updates parameters used to create transports upon @@ -1277,6 +1346,15 @@ func (ac *addrConn) resetTransportAndUnlock() { ac.mu.Unlock() if err := ac.tryAllAddrs(acCtx, addrs, connectDeadline); err != nil { + if !errors.Is(err, context.Canceled) { + connectionAttemptsFailedMetric.Record(ac.cc.metricsRecorderList, 1, ac.cc.target, ac.backendServiceLabel, ac.localityLabel) + } else { + if logger.V(2) { + // This records cancelled connection attempts which can be later + // replaced by a metric. + logger.Infof("Context cancellation detected; not recording this as a failed connection attempt.") + } + } // TODO: #7534 - Move re-resolution requests into the pick_first LB policy // to ensure one resolution request per pass instead of per subconn failure. ac.cc.resolveNow(resolver.ResolveNowOptions{}) @@ -1316,10 +1394,50 @@ func (ac *addrConn) resetTransportAndUnlock() { } // Success; reset backoff. ac.mu.Lock() + connectionAttemptsSucceededMetric.Record(ac.cc.metricsRecorderList, 1, ac.cc.target, ac.backendServiceLabel, ac.localityLabel) + openConnectionsMetric.Record(ac.cc.metricsRecorderList, 1, ac.cc.target, ac.backendServiceLabel, ac.securityLevelLocked(), ac.localityLabel) ac.backoffIdx = 0 ac.mu.Unlock() } +// updateTelemetryLabelsLocked calculates and caches the telemetry labels based on the +// first address in addrConn. +func (ac *addrConn) updateTelemetryLabelsLocked() { + labelsFunc, ok := internal.AddressToTelemetryLabels.(func(resolver.Address) map[string]string) + if !ok || len(ac.addrs) == 0 { + // Reset defaults + ac.localityLabel = "" + ac.backendServiceLabel = "" + return + } + labels := labelsFunc(ac.addrs[0]) + ac.localityLabel = labels["grpc.lb.locality"] + ac.backendServiceLabel = labels["grpc.lb.backend_service"] +} + +type securityLevelKey struct{} + +func (ac *addrConn) securityLevelLocked() string { + var secLevel string + // During disconnection, ac.transport is nil. Fall back to the security level + // stored in the current address during connection. + if ac.transport == nil { + secLevel, _ = ac.curAddr.Attributes.Value(securityLevelKey{}).(string) + return secLevel + } + authInfo := ac.transport.Peer().AuthInfo + if ci, ok := authInfo.(interface { + GetCommonAuthInfo() credentials.CommonAuthInfo + }); ok { + secLevel = ci.GetCommonAuthInfo().SecurityLevel.String() + // Store the security level in the current address' attributes so + // that it remains available for disconnection metrics after the + // transport is closed. + ac.curAddr.Attributes = ac.curAddr.Attributes.WithValue(securityLevelKey{}, secLevel) + } + return secLevel +} + // tryAllAddrs tries to create a connection to the addresses, and stop when at // the first successful one. It returns an error if no address was successfully // connected, or updates ac appropriately with the new transport. @@ -1409,25 +1527,26 @@ func (ac *addrConn) createTransport(ctx context.Context, addr resolver.Address, } ac.mu.Lock() - defer ac.mu.Unlock() if ctx.Err() != nil { // This can happen if the subConn was removed while in `Connecting` // state. tearDown() would have set the state to `Shutdown`, but // would not have closed the transport since ac.transport would not // have been set at that point. - // - // We run this in a goroutine because newTr.Close() calls onClose() + + // We unlock ac.mu because newTr.Close() calls onClose() // inline, which requires locking ac.mu. - // + ac.mu.Unlock() + // The error we pass to Close() is immaterial since there are no open // streams at this point, so no trailers with error details will be sent // out. We just need to pass a non-nil error. // // This can also happen when updateAddrs is called during a connection // attempt. - go newTr.Close(transport.ErrConnClosing) + newTr.Close(transport.ErrConnClosing) return nil } + defer ac.mu.Unlock() if hctx.Err() != nil { // onClose was already called for this connection, but the connection // was successfully established first. Consider it a success and set diff --git a/vendor/google.golang.org/grpc/credentials/credentials.go b/vendor/google.golang.org/grpc/credentials/credentials.go index c8e337cdd..06f6c6c70 100644 --- a/vendor/google.golang.org/grpc/credentials/credentials.go +++ b/vendor/google.golang.org/grpc/credentials/credentials.go @@ -44,8 +44,7 @@ type PerRPCCredentials interface { // A54). uri is the URI of the entry point for the request. When supported // by the underlying implementation, ctx can be used for timeout and // cancellation. Additionally, RequestInfo data will be available via ctx - // to this call. TODO(zhaoq): Define the set of the qualified keys instead - // of leaving it as an arbitrary string. + // to this call. GetRequestMetadata(ctx context.Context, uri ...string) (map[string]string, error) // RequireTransportSecurity indicates whether the credentials requires // transport security. diff --git a/vendor/google.golang.org/grpc/credentials/tls.go b/vendor/google.golang.org/grpc/credentials/tls.go index 8277be7d6..0bcd16dbb 100644 --- a/vendor/google.golang.org/grpc/credentials/tls.go +++ b/vendor/google.golang.org/grpc/credentials/tls.go @@ -56,9 +56,13 @@ func (t TLSInfo) AuthType() string { // non-nil error if the validation fails. func (t TLSInfo) ValidateAuthority(authority string) error { var errs []error + host, _, err := net.SplitHostPort(authority) + if err != nil { + host = authority + } for _, cert := range t.State.PeerCertificates { var err error - if err = cert.VerifyHostname(authority); err == nil { + if err = cert.VerifyHostname(host); err == nil { return nil } errs = append(errs, err) diff --git a/vendor/google.golang.org/grpc/encoding/encoding.go b/vendor/google.golang.org/grpc/encoding/encoding.go index 11d0ae142..296f38c3a 100644 --- a/vendor/google.golang.org/grpc/encoding/encoding.go +++ b/vendor/google.golang.org/grpc/encoding/encoding.go @@ -27,8 +27,10 @@ package encoding import ( "io" + "slices" "strings" + "google.golang.org/grpc/encoding/internal" "google.golang.org/grpc/internal/grpcutil" ) @@ -36,12 +38,26 @@ import ( // It is intended for grpc internal use only. const Identity = "identity" +func init() { + internal.RegisterCompressorForTesting = func(c Compressor) func() { + name := c.Name() + curCompressor, found := registeredCompressor[name] + RegisterCompressor(c) + return func() { + if found { + registeredCompressor[name] = curCompressor + return + } + delete(registeredCompressor, name) + grpcutil.RegisteredCompressorNames = slices.DeleteFunc(grpcutil.RegisteredCompressorNames, func(s string) bool { + return s == name + }) + } + } +} + // Compressor is used for compressing and decompressing when sending or // receiving messages. -// -// If a Compressor implements `DecompressedSize(compressedBytes []byte) int`, -// gRPC will invoke it to determine the size of the buffer allocated for the -// result of decompression. A return value of -1 indicates unknown size. type Compressor interface { // Compress writes the data written to wc to w after compressing it. If an // error occurs while initializing the compressor, that error is returned diff --git a/vendor/google.golang.org/grpc/encoding/gzip/gzip.go b/vendor/google.golang.org/grpc/encoding/gzip/gzip.go index 6306e8bb0..153e4dbfb 100644 --- a/vendor/google.golang.org/grpc/encoding/gzip/gzip.go +++ b/vendor/google.golang.org/grpc/encoding/gzip/gzip.go @@ -27,7 +27,6 @@ package gzip import ( "compress/gzip" - "encoding/binary" "fmt" "io" "sync" @@ -111,17 +110,6 @@ func (z *reader) Read(p []byte) (n int, err error) { return n, err } -// RFC1952 specifies that the last four bytes "contains the size of -// the original (uncompressed) input data modulo 2^32." -// gRPC has a max message size of 2GB so we don't need to worry about wraparound. -func (c *compressor) DecompressedSize(buf []byte) int { - last := len(buf) - if last < 4 { - return -1 - } - return int(binary.LittleEndian.Uint32(buf[last-4 : last])) -} - func (c *compressor) Name() string { return Name } diff --git a/vendor/google.golang.org/grpc/encoding/internal/internal.go b/vendor/google.golang.org/grpc/encoding/internal/internal.go new file mode 100644 index 000000000..ee9acb437 --- /dev/null +++ b/vendor/google.golang.org/grpc/encoding/internal/internal.go @@ -0,0 +1,28 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package internal contains code internal to the encoding package. +package internal + +// RegisterCompressorForTesting registers a compressor in the global compressor +// registry. It returns a cleanup function that should be called at the end +// of the test to unregister the compressor. +// +// This prevents compressors registered in one test from appearing in the +// encoding headers of subsequent tests. +var RegisterCompressorForTesting any // func RegisterCompressor(c Compressor) func() diff --git a/vendor/google.golang.org/grpc/experimental/stats/metricregistry.go b/vendor/google.golang.org/grpc/experimental/stats/metricregistry.go index ad75313a1..472813f58 100644 --- a/vendor/google.golang.org/grpc/experimental/stats/metricregistry.go +++ b/vendor/google.golang.org/grpc/experimental/stats/metricregistry.go @@ -75,6 +75,8 @@ const ( MetricTypeIntHisto MetricTypeFloatHisto MetricTypeIntGauge + MetricTypeIntUpDownCount + MetricTypeIntAsyncGauge ) // Int64CountHandle is a typed handle for a int count metric. This handle @@ -93,6 +95,23 @@ func (h *Int64CountHandle) Record(recorder MetricsRecorder, incr int64, labels . recorder.RecordInt64Count(h, incr, labels...) } +// Int64UpDownCountHandle is a typed handle for an int up-down counter metric. +// This handle is passed at the recording point in order to know which metric +// to record on. +type Int64UpDownCountHandle MetricDescriptor + +// Descriptor returns the int64 up-down counter handle typecast to a pointer to a +// MetricDescriptor. +func (h *Int64UpDownCountHandle) Descriptor() *MetricDescriptor { + return (*MetricDescriptor)(h) +} + +// Record records the int64 up-down counter value on the metrics recorder provided. +// The value 'v' can be positive to increment or negative to decrement. +func (h *Int64UpDownCountHandle) Record(recorder MetricsRecorder, v int64, labels ...string) { + recorder.RecordInt64UpDownCount(h, v, labels...) +} + // Float64CountHandle is a typed handle for a float count metric. This handle is // passed at the recording point in order to know which metric to record on. type Float64CountHandle MetricDescriptor @@ -154,6 +173,30 @@ func (h *Int64GaugeHandle) Record(recorder MetricsRecorder, incr int64, labels . recorder.RecordInt64Gauge(h, incr, labels...) } +// AsyncMetric is a marker interface for asynchronous metric types. +type AsyncMetric interface { + isAsync() + Descriptor() *MetricDescriptor +} + +// Int64AsyncGaugeHandle is a typed handle for an int gauge metric. This handle is +// passed at the recording point in order to know which metric to record on. +type Int64AsyncGaugeHandle MetricDescriptor + +// isAsync implements the AsyncMetric interface. +func (h *Int64AsyncGaugeHandle) isAsync() {} + +// Descriptor returns the int64 gauge handle typecast to a pointer to a +// MetricDescriptor. +func (h *Int64AsyncGaugeHandle) Descriptor() *MetricDescriptor { + return (*MetricDescriptor)(h) +} + +// Record records the int64 gauge value on the metrics recorder provided. +func (h *Int64AsyncGaugeHandle) Record(recorder AsyncMetricsRecorder, value int64, labels ...string) { + recorder.RecordInt64AsyncGauge(h, value, labels...) +} + // registeredMetrics are the registered metric descriptor names. var registeredMetrics = make(map[string]bool) @@ -249,6 +292,35 @@ func RegisterInt64Gauge(descriptor MetricDescriptor) *Int64GaugeHandle { return (*Int64GaugeHandle)(descPtr) } +// RegisterInt64UpDownCount registers the metric description onto the global registry. +// It returns a typed handle to use for recording data. +// +// NOTE: this function must only be called during initialization time (i.e. in +// an init() function), and is not thread-safe. If multiple metrics are +// registered with the same name, this function will panic. +func RegisterInt64UpDownCount(descriptor MetricDescriptor) *Int64UpDownCountHandle { + registerMetric(descriptor.Name, descriptor.Default) + // Set the specific metric type for the up-down counter + descriptor.Type = MetricTypeIntUpDownCount + descPtr := &descriptor + metricsRegistry[descriptor.Name] = descPtr + return (*Int64UpDownCountHandle)(descPtr) +} + +// RegisterInt64AsyncGauge registers the metric description onto the global registry. +// It returns a typed handle to use for recording data. +// +// NOTE: this function must only be called during initialization time (i.e. in +// an init() function), and is not thread-safe. If multiple metrics are +// registered with the same name, this function will panic. +func RegisterInt64AsyncGauge(descriptor MetricDescriptor) *Int64AsyncGaugeHandle { + registerMetric(descriptor.Name, descriptor.Default) + descriptor.Type = MetricTypeIntAsyncGauge + descPtr := &descriptor + metricsRegistry[descriptor.Name] = descPtr + return (*Int64AsyncGaugeHandle)(descPtr) +} + // snapshotMetricsRegistryForTesting snapshots the global data of the metrics // registry. Returns a cleanup function that sets the metrics registry to its // original state. diff --git a/vendor/google.golang.org/grpc/experimental/stats/metrics.go b/vendor/google.golang.org/grpc/experimental/stats/metrics.go index ee1423605..88742724a 100644 --- a/vendor/google.golang.org/grpc/experimental/stats/metrics.go +++ b/vendor/google.golang.org/grpc/experimental/stats/metrics.go @@ -19,9 +19,13 @@ // Package stats contains experimental metrics/stats API's. package stats -import "google.golang.org/grpc/stats" +import ( + "google.golang.org/grpc/internal" + "google.golang.org/grpc/stats" +) // MetricsRecorder records on metrics derived from metric registry. +// Implementors must embed UnimplementedMetricsRecorder. type MetricsRecorder interface { // RecordInt64Count records the measurement alongside labels on the int // count associated with the provided handle. @@ -38,6 +42,49 @@ type MetricsRecorder interface { // RecordInt64Gauge records the measurement alongside labels on the int // gauge associated with the provided handle. RecordInt64Gauge(handle *Int64GaugeHandle, incr int64, labels ...string) + // RecordInt64UpDownCounter records the measurement alongside labels on the int + // count associated with the provided handle. + RecordInt64UpDownCount(handle *Int64UpDownCountHandle, incr int64, labels ...string) + // RegisterAsyncReporter registers a reporter to produce metric values for + // only the listed descriptors. The returned function must be called when + // the metrics are no longer needed, which will remove the reporter. The + // returned method needs to be idempotent and concurrent safe. + RegisterAsyncReporter(reporter AsyncMetricReporter, descriptors ...AsyncMetric) func() + + // EnforceMetricsRecorderEmbedding is included to force implementers to embed + // another implementation of this interface, allowing gRPC to add methods + // without breaking users. + internal.EnforceMetricsRecorderEmbedding +} + +// AsyncMetricReporter is an interface for types that record metrics asynchronously +// for the set of descriptors they are registered with. The AsyncMetricsRecorder +// parameter is used to record values for these metrics. +// +// Implementations must make unique recordings across all registered +// AsyncMetricReporters. Meaning, they should not report values for a metric with +// the same attributes as another AsyncMetricReporter will report. +// +// Implementations must be concurrent-safe. +type AsyncMetricReporter interface { + // Report records metric values using the provided recorder. + Report(AsyncMetricsRecorder) error +} + +// AsyncMetricReporterFunc is an adapter to allow the use of ordinary functions as +// AsyncMetricReporters. +type AsyncMetricReporterFunc func(AsyncMetricsRecorder) error + +// Report calls f(r). +func (f AsyncMetricReporterFunc) Report(r AsyncMetricsRecorder) error { + return f(r) +} + +// AsyncMetricsRecorder records on asynchronous metrics derived from metric registry. +type AsyncMetricsRecorder interface { + // RecordInt64AsyncGauge records the measurement alongside labels on the int + // count associated with the provided handle asynchronously + RecordInt64AsyncGauge(handle *Int64AsyncGaugeHandle, incr int64, labels ...string) } // Metrics is an experimental legacy alias of the now-stable stats.MetricSet. @@ -52,3 +99,33 @@ type Metric = string func NewMetrics(metrics ...Metric) *Metrics { return stats.NewMetricSet(metrics...) } + +// UnimplementedMetricsRecorder must be embedded to have forward compatible implementations. +type UnimplementedMetricsRecorder struct { + internal.EnforceMetricsRecorderEmbedding +} + +// RecordInt64Count provides a no-op implementation. +func (UnimplementedMetricsRecorder) RecordInt64Count(*Int64CountHandle, int64, ...string) {} + +// RecordFloat64Count provides a no-op implementation. +func (UnimplementedMetricsRecorder) RecordFloat64Count(*Float64CountHandle, float64, ...string) {} + +// RecordInt64Histo provides a no-op implementation. +func (UnimplementedMetricsRecorder) RecordInt64Histo(*Int64HistoHandle, int64, ...string) {} + +// RecordFloat64Histo provides a no-op implementation. +func (UnimplementedMetricsRecorder) RecordFloat64Histo(*Float64HistoHandle, float64, ...string) {} + +// RecordInt64Gauge provides a no-op implementation. +func (UnimplementedMetricsRecorder) RecordInt64Gauge(*Int64GaugeHandle, int64, ...string) {} + +// RecordInt64UpDownCount provides a no-op implementation. +func (UnimplementedMetricsRecorder) RecordInt64UpDownCount(*Int64UpDownCountHandle, int64, ...string) { +} + +// RegisterAsyncReporter provides a no-op implementation. +func (UnimplementedMetricsRecorder) RegisterAsyncReporter(AsyncMetricReporter, ...AsyncMetric) func() { + // No-op: Return an empty function to ensure caller doesn't panic on nil function call + return func() {} +} diff --git a/vendor/google.golang.org/grpc/health/grpc_health_v1/health.pb.go b/vendor/google.golang.org/grpc/health/grpc_health_v1/health.pb.go index 22d263fb9..8f7d9f6bb 100644 --- a/vendor/google.golang.org/grpc/health/grpc_health_v1/health.pb.go +++ b/vendor/google.golang.org/grpc/health/grpc_health_v1/health.pb.go @@ -17,7 +17,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.6 +// protoc-gen-go v1.36.10 // protoc v5.27.1 // source: grpc/health/v1/health.proto diff --git a/vendor/google.golang.org/grpc/health/grpc_health_v1/health_grpc.pb.go b/vendor/google.golang.org/grpc/health/grpc_health_v1/health_grpc.pb.go index f2c01f296..e99cd5c83 100644 --- a/vendor/google.golang.org/grpc/health/grpc_health_v1/health_grpc.pb.go +++ b/vendor/google.golang.org/grpc/health/grpc_health_v1/health_grpc.pb.go @@ -17,7 +17,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.5.1 +// - protoc-gen-go-grpc v1.6.0 // - protoc v5.27.1 // source: grpc/health/v1/health.proto diff --git a/vendor/google.golang.org/grpc/interceptor.go b/vendor/google.golang.org/grpc/interceptor.go index 877d78fc3..099e3d093 100644 --- a/vendor/google.golang.org/grpc/interceptor.go +++ b/vendor/google.golang.org/grpc/interceptor.go @@ -97,8 +97,12 @@ type StreamServerInfo struct { IsServerStream bool } -// StreamServerInterceptor provides a hook to intercept the execution of a streaming RPC on the server. -// info contains all the information of this RPC the interceptor can operate on. And handler is the -// service method implementation. It is the responsibility of the interceptor to invoke handler to -// complete the RPC. +// StreamServerInterceptor provides a hook to intercept the execution of a +// streaming RPC on the server. +// +// srv is the service implementation on which the RPC was invoked, and needs to +// be passed to handler, and not used otherwise. ss is the server side of the +// stream. info contains all the information of this RPC the interceptor can +// operate on. And handler is the service method implementation. It is the +// responsibility of the interceptor to invoke handler to complete the RPC. type StreamServerInterceptor func(srv any, ss ServerStream, info *StreamServerInfo, handler StreamHandler) error diff --git a/vendor/google.golang.org/grpc/internal/balancer/gracefulswitch/gracefulswitch.go b/vendor/google.golang.org/grpc/internal/balancer/gracefulswitch/gracefulswitch.go index ba25b8988..f38de74a4 100644 --- a/vendor/google.golang.org/grpc/internal/balancer/gracefulswitch/gracefulswitch.go +++ b/vendor/google.golang.org/grpc/internal/balancer/gracefulswitch/gracefulswitch.go @@ -67,6 +67,10 @@ type Balancer struct { // balancerCurrent before the UpdateSubConnState is called on the // balancerCurrent. currentMu sync.Mutex + + // activeGoroutines tracks all the goroutines that this balancer has started + // and that should be waited on when the balancer closes. + activeGoroutines sync.WaitGroup } // swap swaps out the current lb with the pending lb and updates the ClientConn. @@ -76,7 +80,9 @@ func (gsb *Balancer) swap() { cur := gsb.balancerCurrent gsb.balancerCurrent = gsb.balancerPending gsb.balancerPending = nil + gsb.activeGoroutines.Add(1) go func() { + defer gsb.activeGoroutines.Done() gsb.currentMu.Lock() defer gsb.currentMu.Unlock() cur.Close() @@ -274,6 +280,7 @@ func (gsb *Balancer) Close() { currentBalancerToClose.Close() pendingBalancerToClose.Close() + gsb.activeGoroutines.Wait() } // balancerWrapper wraps a balancer.Balancer, and overrides some Balancer @@ -324,7 +331,12 @@ func (bw *balancerWrapper) UpdateState(state balancer.State) { defer bw.gsb.mu.Unlock() bw.lastState = state + // If Close() acquires the mutex before UpdateState(), the balancer + // will already have been removed from the current or pending state when + // reaching this point. if !bw.gsb.balancerCurrentOrPending(bw) { + // Returning here ensures that (*Balancer).swap() is not invoked after + // (*Balancer).Close() and therefore prevents "use after close". return } diff --git a/vendor/google.golang.org/grpc/internal/balancer/weight/weight.go b/vendor/google.golang.org/grpc/internal/balancer/weight/weight.go new file mode 100644 index 000000000..11beb07d1 --- /dev/null +++ b/vendor/google.golang.org/grpc/internal/balancer/weight/weight.go @@ -0,0 +1,66 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +// Package weight contains utilities to manage endpoint weights. Weights are +// used by LB policies such as ringhash to distribute load across multiple +// endpoints. +package weight + +import ( + "fmt" + + "google.golang.org/grpc/resolver" +) + +// attributeKey is the type used as the key to store EndpointInfo in the +// Attributes field of resolver.Endpoint. +type attributeKey struct{} + +// EndpointInfo will be stored in the Attributes field of Endpoints in order to +// use the ringhash balancer. +type EndpointInfo struct { + Weight uint32 +} + +// Equal allows the values to be compared by Attributes.Equal. +func (a EndpointInfo) Equal(o any) bool { + oa, ok := o.(EndpointInfo) + return ok && oa.Weight == a.Weight +} + +// Set returns a copy of endpoint in which the Attributes field is updated with +// EndpointInfo. +func Set(endpoint resolver.Endpoint, epInfo EndpointInfo) resolver.Endpoint { + endpoint.Attributes = endpoint.Attributes.WithValue(attributeKey{}, epInfo) + return endpoint +} + +// String returns a human-readable representation of EndpointInfo. +// This method is intended for logging, testing, and debugging purposes only. +// Do not rely on the output format, as it is not guaranteed to remain stable. +func (a EndpointInfo) String() string { + return fmt.Sprintf("Weight: %d", a.Weight) +} + +// FromEndpoint returns the EndpointInfo stored in the Attributes field of an +// endpoint. It returns an empty EndpointInfo if attribute is not found. +func FromEndpoint(endpoint resolver.Endpoint) EndpointInfo { + v := endpoint.Attributes.Value(attributeKey{}) + ei, _ := v.(EndpointInfo) + return ei +} diff --git a/vendor/google.golang.org/grpc/internal/envconfig/envconfig.go b/vendor/google.golang.org/grpc/internal/envconfig/envconfig.go index 7e060f5ed..7ad6fb44c 100644 --- a/vendor/google.golang.org/grpc/internal/envconfig/envconfig.go +++ b/vendor/google.golang.org/grpc/internal/envconfig/envconfig.go @@ -52,12 +52,6 @@ var ( // or "false". EnforceALPNEnabled = boolFromEnv("GRPC_ENFORCE_ALPN_ENABLED", true) - // NewPickFirstEnabled is set if the new pickfirst leaf policy is to be used - // instead of the exiting pickfirst implementation. This can be disabled by - // setting the environment variable "GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST" - // to "false". - NewPickFirstEnabled = boolFromEnv("GRPC_EXPERIMENTAL_ENABLE_NEW_PICK_FIRST", true) - // XDSEndpointHashKeyBackwardCompat controls the parsing of the endpoint hash // key from EDS LbEndpoint metadata. Endpoint hash keys can be disabled by // setting "GRPC_XDS_ENDPOINT_HASH_KEY_BACKWARD_COMPAT" to "true". When the @@ -75,6 +69,41 @@ var ( // ALTSHandshakerKeepaliveParams is set if we should add the // KeepaliveParams when dial the ALTS handshaker service. ALTSHandshakerKeepaliveParams = boolFromEnv("GRPC_EXPERIMENTAL_ALTS_HANDSHAKER_KEEPALIVE_PARAMS", false) + + // EnableDefaultPortForProxyTarget controls whether the resolver adds a default port 443 + // to a target address that lacks one. This flag only has an effect when all of + // the following conditions are met: + // - A connect proxy is being used. + // - Target resolution is disabled. + // - The DNS resolver is being used. + EnableDefaultPortForProxyTarget = boolFromEnv("GRPC_EXPERIMENTAL_ENABLE_DEFAULT_PORT_FOR_PROXY_TARGET", true) + + // XDSAuthorityRewrite indicates whether xDS authority rewriting is enabled. + // This feature is defined in gRFC A81 and is enabled by setting the + // environment variable GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE to "true". + XDSAuthorityRewrite = boolFromEnv("GRPC_EXPERIMENTAL_XDS_AUTHORITY_REWRITE", false) + + // PickFirstWeightedShuffling indicates whether weighted endpoint shuffling + // is enabled in the pick_first LB policy, as defined in gRFC A113. This + // feature can be disabled by setting the environment variable + // GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING to "false". + PickFirstWeightedShuffling = boolFromEnv("GRPC_EXPERIMENTAL_PF_WEIGHTED_SHUFFLING", true) + + // DisableStrictPathChecking indicates whether strict path checking is + // disabled. This feature can be disabled by setting the environment + // variable GRPC_GO_EXPERIMENTAL_DISABLE_STRICT_PATH_CHECKING to "true". + // + // When strict path checking is enabled, gRPC will reject requests with + // paths that do not conform to the gRPC over HTTP/2 specification found at + // https://github.com/grpc/grpc/blob/master/doc/PROTOCOL-HTTP2.md. + // + // When disabled, gRPC will allow paths that do not contain a leading slash. + // Enabling strict path checking is recommended for security reasons, as it + // prevents potential path traversal vulnerabilities. + // + // A future release will remove this environment variable, enabling strict + // path checking behavior unconditionally. + DisableStrictPathChecking = boolFromEnv("GRPC_GO_EXPERIMENTAL_DISABLE_STRICT_PATH_CHECKING", false) ) func boolFromEnv(envVar string, def bool) bool { diff --git a/vendor/google.golang.org/grpc/internal/envconfig/xds.go b/vendor/google.golang.org/grpc/internal/envconfig/xds.go index b1f883bca..7685d08b5 100644 --- a/vendor/google.golang.org/grpc/internal/envconfig/xds.go +++ b/vendor/google.golang.org/grpc/internal/envconfig/xds.go @@ -74,4 +74,9 @@ var ( // For more details, see: // https://github.com/grpc/proposal/blob/master/A86-xds-http-connect.md XDSHTTPConnectEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_HTTP_CONNECT", false) + + // XDSBootstrapCallCredsEnabled controls if call credentials can be used in + // xDS bootstrap configuration via the `call_creds` field. For more details, + // see: https://github.com/grpc/proposal/blob/master/A97-xds-jwt-call-creds.md + XDSBootstrapCallCredsEnabled = boolFromEnv("GRPC_EXPERIMENTAL_XDS_BOOTSTRAP_CALL_CREDS", false) ) diff --git a/vendor/google.golang.org/grpc/internal/experimental.go b/vendor/google.golang.org/grpc/internal/experimental.go index 7617be215..8a999917d 100644 --- a/vendor/google.golang.org/grpc/internal/experimental.go +++ b/vendor/google.golang.org/grpc/internal/experimental.go @@ -25,4 +25,11 @@ var ( // BufferPool is implemented by the grpc package and returns a server // option to configure a shared buffer pool for a grpc.Server. BufferPool any // func (grpc.SharedBufferPool) grpc.ServerOption + + // SetDefaultBufferPool updates the default buffer pool. + SetDefaultBufferPool any // func(mem.BufferPool) + + // AcceptCompressors is implemented by the grpc package and returns + // a call option that restricts the grpc-accept-encoding header for a call. + AcceptCompressors any // func(...string) grpc.CallOption ) diff --git a/vendor/google.golang.org/grpc/internal/idle/idle.go b/vendor/google.golang.org/grpc/internal/idle/idle.go index 2c13ee9da..d3cd24f80 100644 --- a/vendor/google.golang.org/grpc/internal/idle/idle.go +++ b/vendor/google.golang.org/grpc/internal/idle/idle.go @@ -21,7 +21,6 @@ package idle import ( - "fmt" "math" "sync" "sync/atomic" @@ -33,15 +32,15 @@ var timeAfterFunc = func(d time.Duration, f func()) *time.Timer { return time.AfterFunc(d, f) } -// Enforcer is the functionality provided by grpc.ClientConn to enter -// and exit from idle mode. -type Enforcer interface { - ExitIdleMode() error +// ClientConn is the functionality provided by grpc.ClientConn to enter and exit +// from idle mode. +type ClientConn interface { + ExitIdleMode() EnterIdleMode() } -// Manager implements idleness detection and calls the configured Enforcer to -// enter/exit idle mode when appropriate. Must be created by NewManager. +// Manager implements idleness detection and calls the ClientConn to enter/exit +// idle mode when appropriate. Must be created by NewManager. type Manager struct { // State accessed atomically. lastCallEndTime int64 // Unix timestamp in nanos; time when the most recent RPC completed. @@ -51,8 +50,8 @@ type Manager struct { // Can be accessed without atomics or mutex since these are set at creation // time and read-only after that. - enforcer Enforcer // Functionality provided by grpc.ClientConn. - timeout time.Duration + cc ClientConn // Functionality provided by grpc.ClientConn. + timeout time.Duration // idleMu is used to guarantee mutual exclusion in two scenarios: // - Opposing intentions: @@ -72,9 +71,9 @@ type Manager struct { // NewManager creates a new idleness manager implementation for the // given idle timeout. It begins in idle mode. -func NewManager(enforcer Enforcer, timeout time.Duration) *Manager { +func NewManager(cc ClientConn, timeout time.Duration) *Manager { return &Manager{ - enforcer: enforcer, + cc: cc, timeout: timeout, actuallyIdle: true, activeCallsCount: -math.MaxInt32, @@ -127,7 +126,7 @@ func (m *Manager) handleIdleTimeout() { // Now that we've checked that there has been no activity, attempt to enter // idle mode, which is very likely to succeed. - if m.tryEnterIdleMode() { + if m.tryEnterIdleMode(true) { // Successfully entered idle mode. No timer needed until we exit idle. return } @@ -142,10 +141,13 @@ func (m *Manager) handleIdleTimeout() { // that, it performs a last minute check to ensure that no new RPC has come in, // making the channel active. // +// checkActivity controls if a check for RPC activity, since the last time the +// idle_timeout fired, is made. + // Return value indicates whether or not the channel moved to idle mode. // // Holds idleMu which ensures mutual exclusion with exitIdleMode. -func (m *Manager) tryEnterIdleMode() bool { +func (m *Manager) tryEnterIdleMode(checkActivity bool) bool { // Setting the activeCallsCount to -math.MaxInt32 indicates to OnCallBegin() // that the channel is either in idle mode or is trying to get there. if !atomic.CompareAndSwapInt32(&m.activeCallsCount, 0, -math.MaxInt32) { @@ -166,7 +168,7 @@ func (m *Manager) tryEnterIdleMode() bool { atomic.AddInt32(&m.activeCallsCount, math.MaxInt32) return false } - if atomic.LoadInt32(&m.activeSinceLastTimerCheck) == 1 { + if checkActivity && atomic.LoadInt32(&m.activeSinceLastTimerCheck) == 1 { // A very short RPC could have come in (and also finished) after we // checked for calls count and activity in handleIdleTimeout(), but // before the CAS operation. So, we need to check for activity again. @@ -177,44 +179,37 @@ func (m *Manager) tryEnterIdleMode() bool { // No new RPCs have come in since we set the active calls count value to // -math.MaxInt32. And since we have the lock, it is safe to enter idle mode // unconditionally now. - m.enforcer.EnterIdleMode() + m.cc.EnterIdleMode() m.actuallyIdle = true return true } // EnterIdleModeForTesting instructs the channel to enter idle mode. func (m *Manager) EnterIdleModeForTesting() { - m.tryEnterIdleMode() + m.tryEnterIdleMode(false) } // OnCallBegin is invoked at the start of every RPC. -func (m *Manager) OnCallBegin() error { +func (m *Manager) OnCallBegin() { if m.isClosed() { - return nil + return } if atomic.AddInt32(&m.activeCallsCount, 1) > 0 { // Channel is not idle now. Set the activity bit and allow the call. atomic.StoreInt32(&m.activeSinceLastTimerCheck, 1) - return nil + return } // Channel is either in idle mode or is in the process of moving to idle // mode. Attempt to exit idle mode to allow this RPC. - if err := m.ExitIdleMode(); err != nil { - // Undo the increment to calls count, and return an error causing the - // RPC to fail. - atomic.AddInt32(&m.activeCallsCount, -1) - return err - } - + m.ExitIdleMode() atomic.StoreInt32(&m.activeSinceLastTimerCheck, 1) - return nil } -// ExitIdleMode instructs m to call the enforcer's ExitIdleMode and update m's +// ExitIdleMode instructs m to call the ClientConn's ExitIdleMode and update its // internal state. -func (m *Manager) ExitIdleMode() error { +func (m *Manager) ExitIdleMode() { // Holds idleMu which ensures mutual exclusion with tryEnterIdleMode. m.idleMu.Lock() defer m.idleMu.Unlock() @@ -231,12 +226,10 @@ func (m *Manager) ExitIdleMode() error { // m.ExitIdleMode. // // In any case, there is nothing to do here. - return nil + return } - if err := m.enforcer.ExitIdleMode(); err != nil { - return fmt.Errorf("failed to exit idle mode: %w", err) - } + m.cc.ExitIdleMode() // Undo the idle entry process. This also respects any new RPC attempts. atomic.AddInt32(&m.activeCallsCount, math.MaxInt32) @@ -244,7 +237,23 @@ func (m *Manager) ExitIdleMode() error { // Start a new timer to fire after the configured idle timeout. m.resetIdleTimerLocked(m.timeout) - return nil +} + +// UnsafeSetNotIdle instructs the Manager to update its internal state to +// reflect the reality that the channel is no longer in IDLE mode. +// +// N.B. This method is intended only for internal use by the gRPC client +// when it exits IDLE mode **manually** from `Dial`. The callsite must ensure: +// - The channel was **actually in IDLE mode** immediately prior to the call. +// - There is **no concurrent activity** that could cause the channel to exit +// IDLE mode *naturally* at the same time. +func (m *Manager) UnsafeSetNotIdle() { + m.idleMu.Lock() + defer m.idleMu.Unlock() + + atomic.AddInt32(&m.activeCallsCount, math.MaxInt32) + m.actuallyIdle = false + m.resetIdleTimerLocked(m.timeout) } // OnCallEnd is invoked at the end of every RPC. diff --git a/vendor/google.golang.org/grpc/internal/internal.go b/vendor/google.golang.org/grpc/internal/internal.go index 2699223a2..4b3d563f8 100644 --- a/vendor/google.golang.org/grpc/internal/internal.go +++ b/vendor/google.golang.org/grpc/internal/internal.go @@ -211,22 +211,11 @@ var ( // default resolver scheme. UserSetDefaultScheme = false - // ConnectedAddress returns the connected address for a SubConnState. The - // address is only valid if the state is READY. - ConnectedAddress any // func (scs SubConnState) resolver.Address - - // SetConnectedAddress sets the connected address for a SubConnState. - SetConnectedAddress any // func(scs *SubConnState, addr resolver.Address) - // SnapshotMetricRegistryForTesting snapshots the global data of the metric // registry. Returns a cleanup function that sets the metric registry to its // original state. Only called in testing functions. SnapshotMetricRegistryForTesting func() func() - // SetDefaultBufferPoolForTesting updates the default buffer pool, for - // testing purposes. - SetDefaultBufferPoolForTesting any // func(mem.BufferPool) - // SetBufferPoolingThresholdForTesting updates the buffer pooling threshold, for // testing purposes. SetBufferPoolingThresholdForTesting any // func(int) @@ -244,6 +233,18 @@ var ( // When set, the function will be called before the stream enters // the blocking state. NewStreamWaitingForResolver = func() {} + + // AddressToTelemetryLabels is an xDS-provided function to extract telemetry + // labels from a resolver.Address. Callers must assert its type before calling. + AddressToTelemetryLabels any // func(addr resolver.Address) map[string]string + + // AsyncReporterCleanupDelegate is initialized to a pass-through function by + // default (production behavior), allowing tests to swap it with an + // implementation which tracks registration of async reporter and its + // corresponding cleanup. + AsyncReporterCleanupDelegate = func(cleanup func()) func() { + return cleanup + } ) // HealthChecker defines the signature of the client-side LB channel health @@ -291,3 +292,9 @@ type EnforceClientConnEmbedding interface { type Timer interface { Stop() bool } + +// EnforceMetricsRecorderEmbedding is used to enforce proper MetricsRecorder +// implementation embedding. +type EnforceMetricsRecorderEmbedding interface { + enforceMetricsRecorderEmbedding() +} diff --git a/vendor/google.golang.org/grpc/internal/resolver/delegatingresolver/delegatingresolver.go b/vendor/google.golang.org/grpc/internal/resolver/delegatingresolver/delegatingresolver.go index 20b8fb098..5bfa67b72 100644 --- a/vendor/google.golang.org/grpc/internal/resolver/delegatingresolver/delegatingresolver.go +++ b/vendor/google.golang.org/grpc/internal/resolver/delegatingresolver/delegatingresolver.go @@ -22,11 +22,13 @@ package delegatingresolver import ( "fmt" + "net" "net/http" "net/url" "sync" "google.golang.org/grpc/grpclog" + "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/proxyattributes" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/internal/transport/networktype" @@ -40,6 +42,8 @@ var ( HTTPSProxyFromEnvironment = http.ProxyFromEnvironment ) +const defaultPort = "443" + // delegatingResolver manages both target URI and proxy address resolution by // delegating these tasks to separate child resolvers. Essentially, it acts as // an intermediary between the gRPC ClientConn and the child resolvers. @@ -107,10 +111,18 @@ func New(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOpti targetResolver: nopResolver{}, } + addr := target.Endpoint() var err error - r.proxyURL, err = proxyURLForTarget(target.Endpoint()) + if target.URL.Scheme == "dns" && !targetResolutionEnabled && envconfig.EnableDefaultPortForProxyTarget { + addr, err = parseTarget(addr) + if err != nil { + return nil, fmt.Errorf("delegating_resolver: invalid target address %q: %v", target.Endpoint(), err) + } + } + + r.proxyURL, err = proxyURLForTarget(addr) if err != nil { - return nil, fmt.Errorf("delegating_resolver: failed to determine proxy URL for target %s: %v", target, err) + return nil, fmt.Errorf("delegating_resolver: failed to determine proxy URL for target %q: %v", target, err) } // proxy is not configured or proxy address excluded using `NO_PROXY` env @@ -132,8 +144,8 @@ func New(target resolver.Target, cc resolver.ClientConn, opts resolver.BuildOpti // bypass the target resolver and store the unresolved target address. if target.URL.Scheme == "dns" && !targetResolutionEnabled { r.targetResolverState = &resolver.State{ - Addresses: []resolver.Address{{Addr: target.Endpoint()}}, - Endpoints: []resolver.Endpoint{{Addresses: []resolver.Address{{Addr: target.Endpoint()}}}}, + Addresses: []resolver.Address{{Addr: addr}}, + Endpoints: []resolver.Endpoint{{Addresses: []resolver.Address{{Addr: addr}}}}, } r.updateTargetResolverState(*r.targetResolverState) return r, nil @@ -202,6 +214,44 @@ func needsProxyResolver(state *resolver.State) bool { return false } +// parseTarget takes a target string and ensures it is a valid "host:port" target. +// +// It does the following: +// 1. If the target already has a port (e.g., "host:port", "[ipv6]:port"), +// it is returned as is. +// 2. If the host part is empty (e.g., ":80"), it defaults to "localhost", +// returning "localhost:80". +// 3. If the target is missing a port (e.g., "host", "ipv6"), the defaultPort +// is added. +// +// An error is returned for empty targets or targets with a trailing colon +// but no port (e.g., "host:"). +func parseTarget(target string) (string, error) { + if target == "" { + return "", fmt.Errorf("missing address") + } + + host, port, err := net.SplitHostPort(target) + if err != nil { + // If SplitHostPort fails, it's likely because the port is missing. + // We append the default port and return the result. + return net.JoinHostPort(target, defaultPort), nil + } + + // If SplitHostPort succeeds, we check for edge cases. + if port == "" { + // A success with an empty port means the target had a trailing colon, + // e.g., "host:", which is an error. + return "", fmt.Errorf("missing port after port-separator colon") + } + if host == "" { + // A success with an empty host means the target was like ":80". + // We default the host to "localhost". + host = "localhost" + } + return net.JoinHostPort(host, port), nil +} + func skipProxy(address resolver.Address) bool { // Avoid proxy when network is not tcp. networkType, ok := networktype.Get(address) diff --git a/vendor/google.golang.org/grpc/internal/resolver/dns/dns_resolver.go b/vendor/google.golang.org/grpc/internal/resolver/dns/dns_resolver.go index ada5251cf..70b89e4d7 100644 --- a/vendor/google.golang.org/grpc/internal/resolver/dns/dns_resolver.go +++ b/vendor/google.golang.org/grpc/internal/resolver/dns/dns_resolver.go @@ -125,7 +125,10 @@ func (b *dnsBuilder) Build(target resolver.Target, cc resolver.ClientConn, opts // IP address. if ipAddr, err := formatIP(host); err == nil { addr := []resolver.Address{{Addr: ipAddr + ":" + port}} - cc.UpdateState(resolver.State{Addresses: addr}) + cc.UpdateState(resolver.State{ + Addresses: addr, + Endpoints: []resolver.Endpoint{{Addresses: addr}}, + }) return deadResolver{}, nil } @@ -342,7 +345,15 @@ func (d *dnsResolver) lookup() (*resolver.State, error) { return nil, hostErr } - state := resolver.State{Addresses: addrs} + eps := make([]resolver.Endpoint, 0, len(addrs)) + for _, addr := range addrs { + eps = append(eps, resolver.Endpoint{Addresses: []resolver.Address{addr}}) + } + + state := resolver.State{ + Addresses: addrs, + Endpoints: eps, + } if len(srv) > 0 { state = grpclbstate.Set(state, &grpclbstate.State{BalancerAddresses: srv}) } diff --git a/vendor/google.golang.org/grpc/internal/stats/metrics_recorder_list.go b/vendor/google.golang.org/grpc/internal/stats/metrics_recorder_list.go index 79044657b..1c8c2ab30 100644 --- a/vendor/google.golang.org/grpc/internal/stats/metrics_recorder_list.go +++ b/vendor/google.golang.org/grpc/internal/stats/metrics_recorder_list.go @@ -20,6 +20,7 @@ import ( "fmt" estats "google.golang.org/grpc/experimental/stats" + "google.golang.org/grpc/internal" "google.golang.org/grpc/stats" ) @@ -28,6 +29,7 @@ import ( // It eats any record calls where the label values provided do not match the // number of label keys. type MetricsRecorderList struct { + internal.EnforceMetricsRecorderEmbedding // metricsRecorders are the metrics recorders this list will forward to. metricsRecorders []estats.MetricsRecorder } @@ -64,6 +66,16 @@ func (l *MetricsRecorderList) RecordInt64Count(handle *estats.Int64CountHandle, } } +// RecordInt64UpDownCount records the measurement alongside labels on the int +// count associated with the provided handle. +func (l *MetricsRecorderList) RecordInt64UpDownCount(handle *estats.Int64UpDownCountHandle, incr int64, labels ...string) { + verifyLabels(handle.Descriptor(), labels...) + + for _, metricRecorder := range l.metricsRecorders { + metricRecorder.RecordInt64UpDownCount(handle, incr, labels...) + } +} + // RecordFloat64Count records the measurement alongside labels on the float // count associated with the provided handle. func (l *MetricsRecorderList) RecordFloat64Count(handle *estats.Float64CountHandle, incr float64, labels ...string) { @@ -103,3 +115,61 @@ func (l *MetricsRecorderList) RecordInt64Gauge(handle *estats.Int64GaugeHandle, metricRecorder.RecordInt64Gauge(handle, incr, labels...) } } + +// RegisterAsyncReporter forwards the registration to all underlying metrics +// recorders. +// +// It returns a cleanup function that, when called, invokes the cleanup function +// returned by each underlying recorder, ensuring the reporter is unregistered +// from all of them. +func (l *MetricsRecorderList) RegisterAsyncReporter(reporter estats.AsyncMetricReporter, metrics ...estats.AsyncMetric) func() { + descriptorsMap := make(map[*estats.MetricDescriptor]bool, len(metrics)) + for _, m := range metrics { + descriptorsMap[m.Descriptor()] = true + } + unregisterFns := make([]func(), 0, len(l.metricsRecorders)) + for _, mr := range l.metricsRecorders { + // Wrap the AsyncMetricsRecorder to intercept calls to RecordInt64Gauge + // and validate the labels. + wrappedCallback := func(recorder estats.AsyncMetricsRecorder) error { + wrappedRecorder := &asyncRecorderWrapper{ + delegate: recorder, + descriptors: descriptorsMap, + } + return reporter.Report(wrappedRecorder) + } + unregisterFns = append(unregisterFns, mr.RegisterAsyncReporter(estats.AsyncMetricReporterFunc(wrappedCallback), metrics...)) + } + + // Wrap the cleanup function using the internal delegate. + // In production, this returns realCleanup as-is. + // In tests, the leak checker can swap this to track the registration lifetime. + return internal.AsyncReporterCleanupDelegate(defaultCleanUp(unregisterFns)) +} + +func defaultCleanUp(unregisterFns []func()) func() { + return func() { + for _, unregister := range unregisterFns { + unregister() + } + } +} + +type asyncRecorderWrapper struct { + delegate estats.AsyncMetricsRecorder + descriptors map[*estats.MetricDescriptor]bool +} + +// RecordIntAsync64Gauge records the measurement alongside labels on the int +// gauge associated with the provided handle. +func (w *asyncRecorderWrapper) RecordInt64AsyncGauge(handle *estats.Int64AsyncGaugeHandle, value int64, labels ...string) { + // Ensure only metrics for descriptors passed during callback registration + // are emitted. + d := handle.Descriptor() + if _, ok := w.descriptors[d]; !ok { + return + } + // Validate labels and delegate. + verifyLabels(d, labels...) + w.delegate.RecordInt64AsyncGauge(handle, value, labels...) +} diff --git a/vendor/google.golang.org/grpc/internal/stats/stats.go b/vendor/google.golang.org/grpc/internal/stats/stats.go new file mode 100644 index 000000000..49019b80d --- /dev/null +++ b/vendor/google.golang.org/grpc/internal/stats/stats.go @@ -0,0 +1,70 @@ +/* + * + * Copyright 2025 gRPC authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +package stats + +import ( + "context" + + "google.golang.org/grpc/stats" +) + +type combinedHandler struct { + handlers []stats.Handler +} + +// NewCombinedHandler combines multiple stats.Handlers into a single handler. +// +// It returns nil if no handlers are provided. If only one handler is +// provided, it is returned directly without wrapping. +func NewCombinedHandler(handlers ...stats.Handler) stats.Handler { + switch len(handlers) { + case 0: + return nil + case 1: + return handlers[0] + default: + return &combinedHandler{handlers: handlers} + } +} + +func (ch *combinedHandler) TagRPC(ctx context.Context, info *stats.RPCTagInfo) context.Context { + for _, h := range ch.handlers { + ctx = h.TagRPC(ctx, info) + } + return ctx +} + +func (ch *combinedHandler) HandleRPC(ctx context.Context, stats stats.RPCStats) { + for _, h := range ch.handlers { + h.HandleRPC(ctx, stats) + } +} + +func (ch *combinedHandler) TagConn(ctx context.Context, info *stats.ConnTagInfo) context.Context { + for _, h := range ch.handlers { + ctx = h.TagConn(ctx, info) + } + return ctx +} + +func (ch *combinedHandler) HandleConn(ctx context.Context, stats stats.ConnStats) { + for _, h := range ch.handlers { + h.HandleConn(ctx, stats) + } +} diff --git a/vendor/google.golang.org/grpc/internal/transport/client_stream.go b/vendor/google.golang.org/grpc/internal/transport/client_stream.go index ccc0e017e..cd8152ef1 100644 --- a/vendor/google.golang.org/grpc/internal/transport/client_stream.go +++ b/vendor/google.golang.org/grpc/internal/transport/client_stream.go @@ -24,30 +24,34 @@ import ( "golang.org/x/net/http2" "google.golang.org/grpc/mem" "google.golang.org/grpc/metadata" + "google.golang.org/grpc/stats" "google.golang.org/grpc/status" ) // ClientStream implements streaming functionality for a gRPC client. type ClientStream struct { - *Stream // Embed for common stream functionality. + Stream // Embed for common stream functionality. ct *http2Client done chan struct{} // closed at the end of stream to unblock writers. doneFunc func() // invoked at the end of stream. - headerChan chan struct{} // closed to indicate the end of header metadata. - headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times. + headerChan chan struct{} // closed to indicate the end of header metadata. + header metadata.MD // the received header metadata + + status *status.Status // the status error received from the server + + // Non-pointer fields are at the end to optimize GC allocations. + // headerValid indicates whether a valid header was received. Only // meaningful after headerChan is closed (always call waitOnHeader() before // reading its value). - headerValid bool - header metadata.MD // the received header metadata - noHeaders bool // set if the client never received headers (set only after the stream is done). - - bytesReceived atomic.Bool // indicates whether any bytes have been received on this stream - unprocessed atomic.Bool // set if the server sends a refused stream or GOAWAY including this stream - - status *status.Status // the status error received from the server + headerValid bool + noHeaders bool // set if the client never received headers (set only after the stream is done). + headerChanClosed uint32 // set when headerChan is closed. Used to avoid closing headerChan multiple times. + bytesReceived atomic.Bool // indicates whether any bytes have been received on this stream + unprocessed atomic.Bool // set if the server sends a refused stream or GOAWAY including this stream + statsHandler stats.Handler // nil for internal streams (e.g., health check, ORCA) where telemetry is not supported. } // Read reads an n byte message from the input stream. @@ -142,3 +146,11 @@ func (s *ClientStream) TrailersOnly() bool { func (s *ClientStream) Status() *status.Status { return s.status } + +func (s *ClientStream) requestRead(n int) { + s.ct.adjustWindow(s, uint32(n)) +} + +func (s *ClientStream) updateWindow(n int) { + s.ct.updateWindow(s, uint32(n)) +} diff --git a/vendor/google.golang.org/grpc/internal/transport/controlbuf.go b/vendor/google.golang.org/grpc/internal/transport/controlbuf.go index a2831e5d0..7efa52478 100644 --- a/vendor/google.golang.org/grpc/internal/transport/controlbuf.go +++ b/vendor/google.golang.org/grpc/internal/transport/controlbuf.go @@ -24,16 +24,13 @@ import ( "fmt" "net" "runtime" - "strconv" "sync" "sync/atomic" "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" "google.golang.org/grpc/internal/grpclog" - "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/mem" - "google.golang.org/grpc/status" ) var updateHeaderTblSize = func(e *hpack.Encoder, v uint32) { @@ -147,11 +144,9 @@ type cleanupStream struct { func (c *cleanupStream) isTransportResponseFrame() bool { return c.rst } // Results in a RST_STREAM type earlyAbortStream struct { - httpStatus uint32 - streamID uint32 - contentSubtype string - status *status.Status - rst bool + streamID uint32 + rst bool + hf []hpack.HeaderField // Pre-built header fields } func (*earlyAbortStream) isTransportResponseFrame() bool { return false } @@ -496,6 +491,16 @@ const ( serverSide ) +// maxWriteBufSize is the maximum length (number of elements) the cached +// writeBuf can grow to. The length depends on the number of buffers +// contained within the BufferSlice produced by the codec, which is +// generally small. +// +// If a writeBuf larger than this limit is required, it will be allocated +// and freed after use, rather than being cached. This avoids holding +// on to large amounts of memory. +const maxWriteBufSize = 64 + // Loopy receives frames from the control buffer. // Each frame is handled individually; most of the work done by loopy goes // into handling data frames. Loopy maintains a queue of active streams, and each @@ -530,6 +535,8 @@ type loopyWriter struct { // Side-specific handlers ssGoAwayHandler func(*goAway) (bool, error) + + writeBuf [][]byte // cached slice to avoid heap allocations for calls to mem.Reader.Peek. } func newLoopyWriter(s side, fr *framer, cbuf *controlBuffer, bdpEst *bdpEstimator, conn net.Conn, logger *grpclog.PrefixLogger, goAwayHandler func(*goAway) (bool, error), bufferPool mem.BufferPool) *loopyWriter { @@ -665,11 +672,10 @@ func (l *loopyWriter) incomingSettingsHandler(s *incomingSettings) error { func (l *loopyWriter) registerStreamHandler(h *registerStream) { str := &outStream{ - id: h.streamID, - state: empty, - itl: &itemList{}, - wq: h.wq, - reader: mem.BufferSlice{}.Reader(), + id: h.streamID, + state: empty, + itl: &itemList{}, + wq: h.wq, } l.estdStreams[h.streamID] = str } @@ -701,11 +707,10 @@ func (l *loopyWriter) headerHandler(h *headerFrame) error { } // Case 2: Client wants to originate stream. str := &outStream{ - id: h.streamID, - state: empty, - itl: &itemList{}, - wq: h.wq, - reader: mem.BufferSlice{}.Reader(), + id: h.streamID, + state: empty, + itl: &itemList{}, + wq: h.wq, } return l.originateStream(str, h) } @@ -833,18 +838,7 @@ func (l *loopyWriter) earlyAbortStreamHandler(eas *earlyAbortStream) error { if l.side == clientSide { return errors.New("earlyAbortStream not handled on client") } - // In case the caller forgets to set the http status, default to 200. - if eas.httpStatus == 0 { - eas.httpStatus = 200 - } - headerFields := []hpack.HeaderField{ - {Name: ":status", Value: strconv.Itoa(int(eas.httpStatus))}, - {Name: "content-type", Value: grpcutil.ContentType(eas.contentSubtype)}, - {Name: "grpc-status", Value: strconv.Itoa(int(eas.status.Code()))}, - {Name: "grpc-message", Value: encodeGrpcMessage(eas.status.Message())}, - } - - if err := l.writeHeader(eas.streamID, true, headerFields, nil); err != nil { + if err := l.writeHeader(eas.streamID, true, eas.hf, nil); err != nil { return err } if eas.rst { @@ -948,11 +942,11 @@ func (l *loopyWriter) processData() (bool, error) { if str == nil { return true, nil } - reader := str.reader + reader := &str.reader dataItem := str.itl.peek().(*dataFrame) // Peek at the first data item this stream. if !dataItem.processing { dataItem.processing = true - str.reader.Reset(dataItem.data) + reader.Reset(dataItem.data) dataItem.data.Free() } // A data item is represented by a dataFrame, since it later translates into @@ -964,11 +958,11 @@ func (l *loopyWriter) processData() (bool, error) { if len(dataItem.h) == 0 && reader.Remaining() == 0 { // Empty data frame // Client sends out empty data frame with endStream = true - if err := l.framer.fr.WriteData(dataItem.streamID, dataItem.endStream, nil); err != nil { + if err := l.framer.writeData(dataItem.streamID, dataItem.endStream, nil); err != nil { return false, err } str.itl.dequeue() // remove the empty data item from stream - _ = reader.Close() + reader.Close() if str.itl.isEmpty() { str.state = empty } else if trailer, ok := str.itl.peek().(*headerFrame); ok { // the next item is trailers. @@ -1001,25 +995,20 @@ func (l *loopyWriter) processData() (bool, error) { remainingBytes := len(dataItem.h) + reader.Remaining() - hSize - dSize size := hSize + dSize - var buf *[]byte - - if hSize != 0 && dSize == 0 { - buf = &dataItem.h - } else { - // Note: this is only necessary because the http2.Framer does not support - // partially writing a frame, so the sequence must be materialized into a buffer. - // TODO: Revisit once https://github.com/golang/go/issues/66655 is addressed. - pool := l.bufferPool - if pool == nil { - // Note that this is only supposed to be nil in tests. Otherwise, stream is - // always initialized with a BufferPool. - pool = mem.DefaultBufferPool() + l.writeBuf = l.writeBuf[:0] + if hSize > 0 { + l.writeBuf = append(l.writeBuf, dataItem.h[:hSize]) + } + if dSize > 0 { + var err error + l.writeBuf, err = reader.Peek(dSize, l.writeBuf) + if err != nil { + // This must never happen since the reader must have at least dSize + // bytes. + // Log an error to fail tests. + l.logger.Errorf("unexpected error while reading Data frame payload: %v", err) + return false, err } - buf = pool.Get(size) - defer pool.Put(buf) - - copy((*buf)[:hSize], dataItem.h) - _, _ = reader.Read((*buf)[hSize:]) } // Now that outgoing flow controls are checked we can replenish str's write quota @@ -1032,7 +1021,14 @@ func (l *loopyWriter) processData() (bool, error) { if dataItem.onEachWrite != nil { dataItem.onEachWrite() } - if err := l.framer.fr.WriteData(dataItem.streamID, endStream, (*buf)[:size]); err != nil { + err := l.framer.writeData(dataItem.streamID, endStream, l.writeBuf) + reader.Discard(dSize) + if cap(l.writeBuf) > maxWriteBufSize { + l.writeBuf = nil + } else { + clear(l.writeBuf) + } + if err != nil { return false, err } str.bytesOutStanding += size @@ -1040,7 +1036,7 @@ func (l *loopyWriter) processData() (bool, error) { dataItem.h = dataItem.h[hSize:] if remainingBytes == 0 { // All the data from that message was written out. - _ = reader.Close() + reader.Close() str.itl.dequeue() } if str.itl.isEmpty() { diff --git a/vendor/google.golang.org/grpc/internal/transport/flowcontrol.go b/vendor/google.golang.org/grpc/internal/transport/flowcontrol.go index dfc0f224e..7cfbc9637 100644 --- a/vendor/google.golang.org/grpc/internal/transport/flowcontrol.go +++ b/vendor/google.golang.org/grpc/internal/transport/flowcontrol.go @@ -28,7 +28,7 @@ import ( // writeQuota is a soft limit on the amount of data a stream can // schedule before some of it is written out. type writeQuota struct { - quota int32 + _ noCopy // get waits on read from when quota goes less than or equal to zero. // replenish writes on it when quota goes positive again. ch chan struct{} @@ -38,16 +38,17 @@ type writeQuota struct { // It is implemented as a field so that it can be updated // by tests. replenish func(n int) + quota int32 } -func newWriteQuota(sz int32, done <-chan struct{}) *writeQuota { - w := &writeQuota{ - quota: sz, - ch: make(chan struct{}, 1), - done: done, - } +// init allows a writeQuota to be initialized in-place, which is useful for +// resetting a buffer or for avoiding a heap allocation when the buffer is +// embedded in another struct. +func (w *writeQuota) init(sz int32, done <-chan struct{}) { + w.quota = sz + w.ch = make(chan struct{}, 1) + w.done = done w.replenish = w.realReplenish - return w } func (w *writeQuota) get(sz int32) error { @@ -67,9 +68,9 @@ func (w *writeQuota) get(sz int32) error { func (w *writeQuota) realReplenish(n int) { sz := int32(n) - a := atomic.AddInt32(&w.quota, sz) - b := a - sz - if b <= 0 && a > 0 { + newQuota := atomic.AddInt32(&w.quota, sz) + previousQuota := newQuota - sz + if previousQuota <= 0 && newQuota > 0 { select { case w.ch <- struct{}{}: default: diff --git a/vendor/google.golang.org/grpc/internal/transport/handler_server.go b/vendor/google.golang.org/grpc/internal/transport/handler_server.go index d954a64c3..7ab3422b8 100644 --- a/vendor/google.golang.org/grpc/internal/transport/handler_server.go +++ b/vendor/google.golang.org/grpc/internal/transport/handler_server.go @@ -50,7 +50,7 @@ import ( // NewServerHandlerTransport returns a ServerTransport handling gRPC from // inside an http.Handler, or writes an HTTP error to w and returns an error. // It requires that the http Server supports HTTP/2. -func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats []stats.Handler, bufferPool mem.BufferPool) (ServerTransport, error) { +func NewServerHandlerTransport(w http.ResponseWriter, r *http.Request, stats stats.Handler, bufferPool mem.BufferPool) (ServerTransport, error) { if r.Method != http.MethodPost { w.Header().Set("Allow", http.MethodPost) msg := fmt.Sprintf("invalid gRPC request method %q", r.Method) @@ -170,7 +170,7 @@ type serverHandlerTransport struct { // TODO make sure this is consistent across handler_server and http2_server contentSubtype string - stats []stats.Handler + stats stats.Handler logger *grpclog.PrefixLogger bufferPool mem.BufferPool @@ -274,15 +274,13 @@ func (ht *serverHandlerTransport) writeStatus(s *ServerStream, st *status.Status } }) - if err == nil { // transport has not been closed + if err == nil && ht.stats != nil { // transport has not been closed // Note: The trailer fields are compressed with hpack after this call returns. // No WireLength field is set here. s.hdrMu.Lock() - for _, sh := range ht.stats { - sh.HandleRPC(s.Context(), &stats.OutTrailer{ - Trailer: s.trailer.Copy(), - }) - } + ht.stats.HandleRPC(s.Context(), &stats.OutTrailer{ + Trailer: s.trailer.Copy(), + }) s.hdrMu.Unlock() } ht.Close(errors.New("finished writing status")) @@ -374,19 +372,23 @@ func (ht *serverHandlerTransport) writeHeader(s *ServerStream, md metadata.MD) e ht.rw.(http.Flusher).Flush() }) - if err == nil { - for _, sh := range ht.stats { - // Note: The header fields are compressed with hpack after this call returns. - // No WireLength field is set here. - sh.HandleRPC(s.Context(), &stats.OutHeader{ - Header: md.Copy(), - Compression: s.sendCompress, - }) - } + if err == nil && ht.stats != nil { + // Note: The header fields are compressed with hpack after this call returns. + // No WireLength field is set here. + ht.stats.HandleRPC(s.Context(), &stats.OutHeader{ + Header: md.Copy(), + Compression: s.sendCompress, + }) } return err } +func (ht *serverHandlerTransport) adjustWindow(*ServerStream, uint32) { +} + +func (ht *serverHandlerTransport) updateWindow(*ServerStream, uint32) { +} + func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream func(*ServerStream)) { // With this transport type there will be exactly 1 stream: this HTTP request. var cancel context.CancelFunc @@ -411,11 +413,9 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream ctx = metadata.NewIncomingContext(ctx, ht.headerMD) req := ht.req s := &ServerStream{ - Stream: &Stream{ + Stream: Stream{ id: 0, // irrelevant ctx: ctx, - requestRead: func(int) {}, - buf: newRecvBuffer(), method: req.URL.Path, recvCompress: req.Header.Get("grpc-encoding"), contentSubtype: ht.contentSubtype, @@ -424,9 +424,11 @@ func (ht *serverHandlerTransport) HandleStreams(ctx context.Context, startStream st: ht, headerWireLength: 0, // won't have access to header wire length until golang/go#18997. } - s.trReader = &transportReader{ - reader: &recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: s.buf}, - windowHandler: func(int) {}, + s.Stream.buf.init() + s.readRequester = s + s.trReader = transportReader{ + reader: recvBufferReader{ctx: s.ctx, ctxDone: s.ctx.Done(), recv: &s.buf}, + windowHandler: s, } // readerDone is closed when the Body.Read-ing goroutine exits. diff --git a/vendor/google.golang.org/grpc/internal/transport/http2_client.go b/vendor/google.golang.org/grpc/internal/transport/http2_client.go index 7cb238794..37b1acc34 100644 --- a/vendor/google.golang.org/grpc/internal/transport/http2_client.go +++ b/vendor/google.golang.org/grpc/internal/transport/http2_client.go @@ -44,6 +44,7 @@ import ( "google.golang.org/grpc/internal/grpcutil" imetadata "google.golang.org/grpc/internal/metadata" "google.golang.org/grpc/internal/proxyattributes" + istats "google.golang.org/grpc/internal/stats" istatus "google.golang.org/grpc/internal/status" isyscall "google.golang.org/grpc/internal/syscall" "google.golang.org/grpc/internal/transport/networktype" @@ -105,7 +106,7 @@ type http2Client struct { kp keepalive.ClientParameters keepaliveEnabled bool - statsHandlers []stats.Handler + statsHandler stats.Handler initialWindowSize int32 @@ -335,14 +336,14 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts writerDone: make(chan struct{}), goAway: make(chan struct{}), keepaliveDone: make(chan struct{}), - framer: newFramer(conn, writeBufSize, readBufSize, opts.SharedWriteBuffer, maxHeaderListSize), + framer: newFramer(conn, writeBufSize, readBufSize, opts.SharedWriteBuffer, maxHeaderListSize, opts.BufferPool), fc: &trInFlow{limit: uint32(icwz)}, scheme: scheme, activeStreams: make(map[uint32]*ClientStream), isSecure: isSecure, perRPCCreds: perRPCCreds, kp: kp, - statsHandlers: opts.StatsHandlers, + statsHandler: istats.NewCombinedHandler(opts.StatsHandlers...), initialWindowSize: initialWindowSize, nextID: 1, maxConcurrentStreams: defaultMaxStreamsClient, @@ -369,7 +370,7 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts }) t.logger = prefixLoggerForClientTransport(t) // Add peer information to the http2client context. - t.ctx = peer.NewContext(t.ctx, t.getPeer()) + t.ctx = peer.NewContext(t.ctx, t.Peer()) if md, ok := addr.Metadata.(*metadata.MD); ok { t.md = *md @@ -386,15 +387,14 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts updateFlowControl: t.updateFlowControl, } } - for _, sh := range t.statsHandlers { - t.ctx = sh.TagConn(t.ctx, &stats.ConnTagInfo{ + if t.statsHandler != nil { + t.ctx = t.statsHandler.TagConn(t.ctx, &stats.ConnTagInfo{ RemoteAddr: t.remoteAddr, LocalAddr: t.localAddr, }) - connBegin := &stats.ConnBegin{ + t.statsHandler.HandleConn(t.ctx, &stats.ConnBegin{ Client: true, - } - sh.HandleConn(t.ctx, connBegin) + }) } if t.keepaliveEnabled { t.kpDormancyCond = sync.NewCond(&t.mu) @@ -478,45 +478,40 @@ func NewHTTP2Client(connectCtx, ctx context.Context, addr resolver.Address, opts return t, nil } -func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr) *ClientStream { +func (t *http2Client) newStream(ctx context.Context, callHdr *CallHdr, handler stats.Handler) *ClientStream { // TODO(zhaoq): Handle uint32 overflow of Stream.id. s := &ClientStream{ - Stream: &Stream{ + Stream: Stream{ method: callHdr.Method, sendCompress: callHdr.SendCompress, - buf: newRecvBuffer(), contentSubtype: callHdr.ContentSubtype, }, - ct: t, - done: make(chan struct{}), - headerChan: make(chan struct{}), - doneFunc: callHdr.DoneFunc, - } - s.wq = newWriteQuota(defaultWriteQuota, s.done) - s.requestRead = func(n int) { - t.adjustWindow(s, uint32(n)) - } + ct: t, + done: make(chan struct{}), + headerChan: make(chan struct{}), + doneFunc: callHdr.DoneFunc, + statsHandler: handler, + } + s.Stream.buf.init() + s.Stream.wq.init(defaultWriteQuota, s.done) + s.readRequester = s // The client side stream context should have exactly the same life cycle with the user provided context. // That means, s.ctx should be read-only. And s.ctx is done iff ctx is done. // So we use the original context here instead of creating a copy. s.ctx = ctx - s.trReader = &transportReader{ - reader: &recvBufferReader{ - ctx: s.ctx, - ctxDone: s.ctx.Done(), - recv: s.buf, - closeStream: func(err error) { - s.Close(err) - }, - }, - windowHandler: func(n int) { - t.updateWindow(s, uint32(n)) + s.trReader = transportReader{ + reader: recvBufferReader{ + ctx: s.ctx, + ctxDone: s.ctx.Done(), + recv: &s.buf, + clientStream: s, }, + windowHandler: s, } return s } -func (t *http2Client) getPeer() *peer.Peer { +func (t *http2Client) Peer() *peer.Peer { return &peer.Peer{ Addr: t.remoteAddr, AuthInfo: t.authInfo, // Can be nil @@ -557,6 +552,9 @@ func (t *http2Client) createHeaderFields(ctx context.Context, callHdr *CallHdr) hfLen := 7 // :method, :scheme, :path, :authority, content-type, user-agent, te hfLen += len(authData) + len(callAuthData) registeredCompressors := t.registeredCompressors + if callHdr.AcceptedCompressors != nil { + registeredCompressors = *callHdr.AcceptedCompressors + } if callHdr.PreviousAttempts > 0 { hfLen++ } @@ -747,8 +745,8 @@ func (e NewStreamError) Error() string { // NewStream creates a stream and registers it into the transport as "active" // streams. All non-nil errors returned will be *NewStreamError. -func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientStream, error) { - ctx = peer.NewContext(ctx, t.getPeer()) +func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr, handler stats.Handler) (*ClientStream, error) { + ctx = peer.NewContext(ctx, t.Peer()) // ServerName field of the resolver returned address takes precedence over // Host field of CallHdr to determine the :authority header. This is because, @@ -784,7 +782,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientS if err != nil { return nil, &NewStreamError{Err: err, AllowTransparentRetry: false} } - s := t.newStream(ctx, callHdr) + s := t.newStream(ctx, callHdr, handler) cleanup := func(err error) { if s.swapState(streamDone) == streamDone { // If it was already done, return. @@ -823,7 +821,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientS return nil }, onOrphaned: cleanup, - wq: s.wq, + wq: &s.wq, } firstTry := true var ch chan struct{} @@ -854,7 +852,7 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientS transportDrainRequired = t.nextID > MaxStreamID s.id = hdr.streamID - s.fc = &inFlow{limit: uint32(t.initialWindowSize)} + s.fc = inFlow{limit: uint32(t.initialWindowSize)} t.activeStreams[s.id] = s t.mu.Unlock() @@ -905,27 +903,23 @@ func (t *http2Client) NewStream(ctx context.Context, callHdr *CallHdr) (*ClientS return nil, &NewStreamError{Err: ErrConnClosing, AllowTransparentRetry: true} } } - if len(t.statsHandlers) != 0 { + if s.statsHandler != nil { header, ok := metadata.FromOutgoingContext(ctx) if ok { header.Set("user-agent", t.userAgent) } else { header = metadata.Pairs("user-agent", t.userAgent) } - for _, sh := range t.statsHandlers { - // Note: The header fields are compressed with hpack after this call returns. - // No WireLength field is set here. - // Note: Creating a new stats object to prevent pollution. - outHeader := &stats.OutHeader{ - Client: true, - FullMethod: callHdr.Method, - RemoteAddr: t.remoteAddr, - LocalAddr: t.localAddr, - Compression: callHdr.SendCompress, - Header: header, - } - sh.HandleRPC(s.ctx, outHeader) - } + // Note: The header fields are compressed with hpack after this call returns. + // No WireLength field is set here. + s.statsHandler.HandleRPC(s.ctx, &stats.OutHeader{ + Client: true, + FullMethod: callHdr.Method, + RemoteAddr: t.remoteAddr, + LocalAddr: t.localAddr, + Compression: callHdr.SendCompress, + Header: header, + }) } if transportDrainRequired { if t.logger.V(logLevel) { @@ -1002,6 +996,9 @@ func (t *http2Client) closeStream(s *ClientStream, err error, rst bool, rstCode // accessed anymore. func (t *http2Client) Close(err error) { t.conn.SetWriteDeadline(time.Now().Add(time.Second * 10)) + // For background on the deadline value chosen here, see + // https://github.com/grpc/grpc-go/issues/8425#issuecomment-3057938248 . + t.conn.SetReadDeadline(time.Now().Add(time.Second)) t.mu.Lock() // Make sure we only close once. if t.state == closing { @@ -1063,11 +1060,10 @@ func (t *http2Client) Close(err error) { for _, s := range streams { t.closeStream(s, err, false, http2.ErrCodeNo, st, nil, false) } - for _, sh := range t.statsHandlers { - connEnd := &stats.ConnEnd{ + if t.statsHandler != nil { + t.statsHandler.HandleConn(t.ctx, &stats.ConnEnd{ Client: true, - } - sh.HandleConn(t.ctx, connEnd) + }) } } @@ -1178,7 +1174,7 @@ func (t *http2Client) updateFlowControl(n uint32) { }) } -func (t *http2Client) handleData(f *http2.DataFrame) { +func (t *http2Client) handleData(f *parsedDataFrame) { size := f.Header().Length var sendBDPPing bool if t.bdpEst != nil { @@ -1222,22 +1218,15 @@ func (t *http2Client) handleData(f *http2.DataFrame) { t.closeStream(s, io.EOF, true, http2.ErrCodeFlowControl, status.New(codes.Internal, err.Error()), nil, false) return } + dataLen := f.data.Len() if f.Header().Flags.Has(http2.FlagDataPadded) { - if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 { + if w := s.fc.onRead(size - uint32(dataLen)); w > 0 { t.controlBuf.put(&outgoingWindowUpdate{s.id, w}) } } - // TODO(bradfitz, zhaoq): A copy is required here because there is no - // guarantee f.Data() is consumed before the arrival of next frame. - // Can this copy be eliminated? - if len(f.Data()) > 0 { - pool := t.bufferPool - if pool == nil { - // Note that this is only supposed to be nil in tests. Otherwise, stream is - // always initialized with a BufferPool. - pool = mem.DefaultBufferPool() - } - s.write(recvMsg{buffer: mem.Copy(f.Data(), pool)}) + if dataLen > 0 { + f.data.Ref() + s.write(recvMsg{buffer: f.data}) } } // The server has closed the stream without sending trailers. Record that @@ -1477,17 +1466,14 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { contentTypeErr = "malformed header: missing HTTP content-type" grpcMessage string recvCompress string - httpStatusCode *int httpStatusErr string - rawStatusCode = codes.Unknown + // the code from the grpc-status header, if present + grpcStatusCode = codes.Unknown // headerError is set if an error is encountered while parsing the headers headerError string + httpStatus string ) - if initialHeader { - httpStatusErr = "malformed header: missing HTTP status" - } - for _, hf := range frame.Fields { switch hf.Name { case "content-type": @@ -1503,73 +1489,75 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { case "grpc-status": code, err := strconv.ParseInt(hf.Value, 10, 32) if err != nil { - se := status.New(codes.Internal, fmt.Sprintf("transport: malformed grpc-status: %v", err)) + se := status.New(codes.Unknown, fmt.Sprintf("transport: malformed grpc-status: %v", err)) t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) return } - rawStatusCode = codes.Code(uint32(code)) + grpcStatusCode = codes.Code(uint32(code)) case "grpc-message": grpcMessage = decodeGrpcMessage(hf.Value) case ":status": - c, err := strconv.ParseInt(hf.Value, 10, 32) + httpStatus = hf.Value + default: + if isReservedHeader(hf.Name) && !isWhitelistedHeader(hf.Name) { + break + } + v, err := decodeMetadataHeader(hf.Name, hf.Value) + if err != nil { + headerError = fmt.Sprintf("transport: malformed %s: %v", hf.Name, err) + logger.Warningf("Failed to decode metadata header (%q, %q): %v", hf.Name, hf.Value, err) + break + } + mdata[hf.Name] = append(mdata[hf.Name], v) + } + } + + // If a non-gRPC response is received, then evaluate the HTTP status to + // process the response and close the stream. + // In case http status doesn't provide any error information (status : 200), + // then evalute response code to be Unknown. + if !isGRPC { + var grpcErrorCode = codes.Internal + if httpStatus == "" { + httpStatusErr = "malformed header: missing HTTP status" + } else { + // Parse the status codes (e.g. "200", 404"). + statusCode, err := strconv.Atoi(httpStatus) if err != nil { - se := status.New(codes.Internal, fmt.Sprintf("transport: malformed http-status: %v", err)) + se := status.New(grpcErrorCode, fmt.Sprintf("transport: malformed http-status: %v", err)) t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) return } - statusCode := int(c) if statusCode >= 100 && statusCode < 200 { if endStream { se := status.New(codes.Internal, fmt.Sprintf( "protocol error: informational header with status code %d must not have END_STREAM set", statusCode)) t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) } + // In case of informational headers, return. return } - httpStatusCode = &statusCode - if statusCode == 200 { - httpStatusErr = "" - break - } - httpStatusErr = fmt.Sprintf( "unexpected HTTP status code received from server: %d (%s)", statusCode, http.StatusText(statusCode), ) - default: - if isReservedHeader(hf.Name) && !isWhitelistedHeader(hf.Name) { - break - } - v, err := decodeMetadataHeader(hf.Name, hf.Value) - if err != nil { - headerError = fmt.Sprintf("transport: malformed %s: %v", hf.Name, err) - logger.Warningf("Failed to decode metadata header (%q, %q): %v", hf.Name, hf.Value, err) - break - } - mdata[hf.Name] = append(mdata[hf.Name], v) - } - } - - if !isGRPC || httpStatusErr != "" { - var code = codes.Internal // when header does not include HTTP status, return INTERNAL - - if httpStatusCode != nil { var ok bool - code, ok = HTTPStatusConvTab[*httpStatusCode] + grpcErrorCode, ok = HTTPStatusConvTab[statusCode] if !ok { - code = codes.Unknown + grpcErrorCode = codes.Unknown } } var errs []string if httpStatusErr != "" { errs = append(errs, httpStatusErr) } + if contentTypeErr != "" { errs = append(errs, contentTypeErr) } - // Verify the HTTP response is a 200. - se := status.New(code, strings.Join(errs, "; ")) + + se := status.New(grpcErrorCode, strings.Join(errs, "; ")) t.closeStream(s, se.Err(), true, http2.ErrCodeProtocol, se, nil, endStream) return } @@ -1600,22 +1588,20 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { } } - for _, sh := range t.statsHandlers { + if s.statsHandler != nil { if !endStream { - inHeader := &stats.InHeader{ + s.statsHandler.HandleRPC(s.ctx, &stats.InHeader{ Client: true, WireLength: int(frame.Header().Length), Header: metadata.MD(mdata).Copy(), Compression: s.recvCompress, - } - sh.HandleRPC(s.ctx, inHeader) + }) } else { - inTrailer := &stats.InTrailer{ + s.statsHandler.HandleRPC(s.ctx, &stats.InTrailer{ Client: true, WireLength: int(frame.Header().Length), Trailer: metadata.MD(mdata).Copy(), - } - sh.HandleRPC(s.ctx, inTrailer) + }) } } @@ -1623,7 +1609,7 @@ func (t *http2Client) operateHeaders(frame *http2.MetaHeadersFrame) { return } - status := istatus.NewWithProto(rawStatusCode, grpcMessage, mdata[grpcStatusDetailsBinHeader]) + status := istatus.NewWithProto(grpcStatusCode, grpcMessage, mdata[grpcStatusDetailsBinHeader]) // If client received END_STREAM from server while stream was still active, // send RST_STREAM. @@ -1670,7 +1656,7 @@ func (t *http2Client) reader(errCh chan<- error) { // loop to keep reading incoming messages on this transport. for { t.controlBuf.throttle() - frame, err := t.framer.fr.ReadFrame() + frame, err := t.framer.readFrame() if t.keepaliveEnabled { atomic.StoreInt64(&t.lastRead, time.Now().UnixNano()) } @@ -1685,7 +1671,7 @@ func (t *http2Client) reader(errCh chan<- error) { if s != nil { // use error detail to provide better err message code := http2ErrConvTab[se.Code] - errorDetail := t.framer.fr.ErrorDetail() + errorDetail := t.framer.errorDetail() var msg string if errorDetail != nil { msg = errorDetail.Error() @@ -1703,8 +1689,9 @@ func (t *http2Client) reader(errCh chan<- error) { switch frame := frame.(type) { case *http2.MetaHeadersFrame: t.operateHeaders(frame) - case *http2.DataFrame: + case *parsedDataFrame: t.handleData(frame) + frame.data.Free() case *http2.RSTStreamFrame: t.handleRSTStream(frame) case *http2.SettingsFrame: @@ -1824,8 +1811,6 @@ func (t *http2Client) socketMetrics() *channelz.EphemeralSocketMetrics { } } -func (t *http2Client) RemoteAddr() net.Addr { return t.remoteAddr } - func (t *http2Client) incrMsgSent() { if channelz.IsOn() { t.channelz.SocketMetrics.MessagesSent.Add(1) diff --git a/vendor/google.golang.org/grpc/internal/transport/http2_server.go b/vendor/google.golang.org/grpc/internal/transport/http2_server.go index 83cee314c..a1a14e14f 100644 --- a/vendor/google.golang.org/grpc/internal/transport/http2_server.go +++ b/vendor/google.golang.org/grpc/internal/transport/http2_server.go @@ -35,6 +35,8 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" + "google.golang.org/protobuf/proto" + "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/grpclog" "google.golang.org/grpc/internal/grpcutil" @@ -42,7 +44,6 @@ import ( istatus "google.golang.org/grpc/internal/status" "google.golang.org/grpc/internal/syscall" "google.golang.org/grpc/mem" - "google.golang.org/protobuf/proto" "google.golang.org/grpc/codes" "google.golang.org/grpc/credentials" @@ -86,7 +87,7 @@ type http2Server struct { // updates, reset streams, and various settings) to the controller. controlBuf *controlBuffer fc *trInFlow - stats []stats.Handler + stats stats.Handler // Keepalive and max-age parameters for the server. kp keepalive.ServerParameters // Keepalive enforcement policy. @@ -168,7 +169,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, if config.MaxHeaderListSize != nil { maxHeaderListSize = *config.MaxHeaderListSize } - framer := newFramer(conn, writeBufSize, readBufSize, config.SharedWriteBuffer, maxHeaderListSize) + framer := newFramer(conn, writeBufSize, readBufSize, config.SharedWriteBuffer, maxHeaderListSize, config.BufferPool) // Send initial settings as connection preface to client. isettings := []http2.Setting{{ ID: http2.SettingMaxFrameSize, @@ -260,7 +261,7 @@ func NewServerTransport(conn net.Conn, config *ServerConfig) (_ ServerTransport, fc: &trInFlow{limit: uint32(icwz)}, state: reachable, activeStreams: make(map[uint32]*ServerStream), - stats: config.StatsHandlers, + stats: config.StatsHandler, kp: kp, idle: time.Now(), kep: kep, @@ -390,16 +391,15 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade } t.maxStreamID = streamID - buf := newRecvBuffer() s := &ServerStream{ - Stream: &Stream{ - id: streamID, - buf: buf, - fc: &inFlow{limit: uint32(t.initialWindowSize)}, + Stream: Stream{ + id: streamID, + fc: inFlow{limit: uint32(t.initialWindowSize)}, }, st: t, headerWireLength: int(frame.Header().Length), } + s.Stream.buf.init() var ( // if false, content-type was missing or invalid isGRPC = false @@ -479,13 +479,7 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade if t.logger.V(logLevel) { t.logger.Infof("Aborting the stream early: %v", errMsg) } - t.controlBuf.put(&earlyAbortStream{ - httpStatus: http.StatusBadRequest, - streamID: streamID, - contentSubtype: s.contentSubtype, - status: status.New(codes.Internal, errMsg), - rst: !frame.StreamEnded(), - }) + t.writeEarlyAbort(streamID, s.contentSubtype, status.New(codes.Internal, errMsg), http.StatusBadRequest, !frame.StreamEnded()) return nil } @@ -499,23 +493,11 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade return nil } if !isGRPC { - t.controlBuf.put(&earlyAbortStream{ - httpStatus: http.StatusUnsupportedMediaType, - streamID: streamID, - contentSubtype: s.contentSubtype, - status: status.Newf(codes.InvalidArgument, "invalid gRPC request content-type %q", contentType), - rst: !frame.StreamEnded(), - }) + t.writeEarlyAbort(streamID, s.contentSubtype, status.Newf(codes.InvalidArgument, "invalid gRPC request content-type %q", contentType), http.StatusUnsupportedMediaType, !frame.StreamEnded()) return nil } if headerError != nil { - t.controlBuf.put(&earlyAbortStream{ - httpStatus: http.StatusBadRequest, - streamID: streamID, - contentSubtype: s.contentSubtype, - status: headerError, - rst: !frame.StreamEnded(), - }) + t.writeEarlyAbort(streamID, s.contentSubtype, headerError, http.StatusBadRequest, !frame.StreamEnded()) return nil } @@ -569,13 +551,7 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade if t.logger.V(logLevel) { t.logger.Infof("Aborting the stream early: %v", errMsg) } - t.controlBuf.put(&earlyAbortStream{ - httpStatus: http.StatusMethodNotAllowed, - streamID: streamID, - contentSubtype: s.contentSubtype, - status: status.New(codes.Internal, errMsg), - rst: !frame.StreamEnded(), - }) + t.writeEarlyAbort(streamID, s.contentSubtype, status.New(codes.Internal, errMsg), http.StatusMethodNotAllowed, !frame.StreamEnded()) s.cancel() return nil } @@ -590,27 +566,16 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade if !ok { stat = status.New(codes.PermissionDenied, err.Error()) } - t.controlBuf.put(&earlyAbortStream{ - httpStatus: http.StatusOK, - streamID: s.id, - contentSubtype: s.contentSubtype, - status: stat, - rst: !frame.StreamEnded(), - }) + t.writeEarlyAbort(s.id, s.contentSubtype, stat, http.StatusOK, !frame.StreamEnded()) return nil } } if s.ctx.Err() != nil { t.mu.Unlock() + st := status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()) // Early abort in case the timeout was zero or so low it already fired. - t.controlBuf.put(&earlyAbortStream{ - httpStatus: http.StatusOK, - streamID: s.id, - contentSubtype: s.contentSubtype, - status: status.New(codes.DeadlineExceeded, context.DeadlineExceeded.Error()), - rst: !frame.StreamEnded(), - }) + t.writeEarlyAbort(s.id, s.contentSubtype, st, http.StatusOK, !frame.StreamEnded()) return nil } @@ -640,25 +605,21 @@ func (t *http2Server) operateHeaders(ctx context.Context, frame *http2.MetaHeade t.channelz.SocketMetrics.StreamsStarted.Add(1) t.channelz.SocketMetrics.LastRemoteStreamCreatedTimestamp.Store(time.Now().UnixNano()) } - s.requestRead = func(n int) { - t.adjustWindow(s, uint32(n)) - } + s.readRequester = s s.ctxDone = s.ctx.Done() - s.wq = newWriteQuota(defaultWriteQuota, s.ctxDone) - s.trReader = &transportReader{ - reader: &recvBufferReader{ + s.Stream.wq.init(defaultWriteQuota, s.ctxDone) + s.trReader = transportReader{ + reader: recvBufferReader{ ctx: s.ctx, ctxDone: s.ctxDone, - recv: s.buf, - }, - windowHandler: func(n int) { - t.updateWindow(s, uint32(n)) + recv: &s.buf, }, + windowHandler: s, } // Register the stream with loopy. t.controlBuf.put(®isterStream{ streamID: s.id, - wq: s.wq, + wq: &s.wq, }) handle(s) return nil @@ -674,7 +635,7 @@ func (t *http2Server) HandleStreams(ctx context.Context, handle func(*ServerStre }() for { t.controlBuf.throttle() - frame, err := t.framer.fr.ReadFrame() + frame, err := t.framer.readFrame() atomic.StoreInt64(&t.lastRead, time.Now().UnixNano()) if err != nil { if se, ok := err.(http2.StreamError); ok { @@ -711,8 +672,9 @@ func (t *http2Server) HandleStreams(ctx context.Context, handle func(*ServerStre }) continue } - case *http2.DataFrame: + case *parsedDataFrame: t.handleData(frame) + frame.data.Free() case *http2.RSTStreamFrame: t.handleRSTStream(frame) case *http2.SettingsFrame: @@ -792,7 +754,7 @@ func (t *http2Server) updateFlowControl(n uint32) { } -func (t *http2Server) handleData(f *http2.DataFrame) { +func (t *http2Server) handleData(f *parsedDataFrame) { size := f.Header().Length var sendBDPPing bool if t.bdpEst != nil { @@ -837,22 +799,15 @@ func (t *http2Server) handleData(f *http2.DataFrame) { t.closeStream(s, true, http2.ErrCodeFlowControl, false) return } + dataLen := f.data.Len() if f.Header().Flags.Has(http2.FlagDataPadded) { - if w := s.fc.onRead(size - uint32(len(f.Data()))); w > 0 { + if w := s.fc.onRead(size - uint32(dataLen)); w > 0 { t.controlBuf.put(&outgoingWindowUpdate{s.id, w}) } } - // TODO(bradfitz, zhaoq): A copy is required here because there is no - // guarantee f.Data() is consumed before the arrival of next frame. - // Can this copy be eliminated? - if len(f.Data()) > 0 { - pool := t.bufferPool - if pool == nil { - // Note that this is only supposed to be nil in tests. Otherwise, stream is - // always initialized with a BufferPool. - pool = mem.DefaultBufferPool() - } - s.write(recvMsg{buffer: mem.Copy(f.Data(), pool)}) + if dataLen > 0 { + f.data.Ref() + s.write(recvMsg{buffer: f.data}) } } if f.StreamEnded() { @@ -979,13 +934,12 @@ func appendHeaderFieldsFromMD(headerFields []hpack.HeaderField, md metadata.MD) return headerFields } -func (t *http2Server) checkForHeaderListSize(it any) bool { +func (t *http2Server) checkForHeaderListSize(hf []hpack.HeaderField) bool { if t.maxSendHeaderListSize == nil { return true } - hdrFrame := it.(*headerFrame) var sz int64 - for _, f := range hdrFrame.hf { + for _, f := range hf { if sz += int64(f.Size()); sz > int64(*t.maxSendHeaderListSize) { if t.logger.V(logLevel) { t.logger.Infof("Header list size to send violates the maximum size (%d bytes) set by client", *t.maxSendHeaderListSize) @@ -996,6 +950,42 @@ func (t *http2Server) checkForHeaderListSize(it any) bool { return true } +// writeEarlyAbort sends an early abort response with the given HTTP status and +// gRPC status. If the header list size exceeds the peer's limit, it sends a +// RST_STREAM instead. +func (t *http2Server) writeEarlyAbort(streamID uint32, contentSubtype string, stat *status.Status, httpStatus uint32, rst bool) { + hf := []hpack.HeaderField{ + {Name: ":status", Value: strconv.Itoa(int(httpStatus))}, + {Name: "content-type", Value: grpcutil.ContentType(contentSubtype)}, + {Name: "grpc-status", Value: strconv.Itoa(int(stat.Code()))}, + {Name: "grpc-message", Value: encodeGrpcMessage(stat.Message())}, + } + if p := istatus.RawStatusProto(stat); len(p.GetDetails()) > 0 { + stBytes, err := proto.Marshal(p) + if err != nil { + t.logger.Errorf("Failed to marshal rpc status: %s, error: %v", pretty.ToJSON(p), err) + } + if err == nil { + hf = append(hf, hpack.HeaderField{Name: grpcStatusDetailsBinHeader, Value: encodeBinHeader(stBytes)}) + } + } + success, _ := t.controlBuf.executeAndPut(func() bool { + return t.checkForHeaderListSize(hf) + }, &earlyAbortStream{ + streamID: streamID, + rst: rst, + hf: hf, + }) + if !success { + t.controlBuf.put(&cleanupStream{ + streamID: streamID, + rst: true, + rstCode: http2.ErrCodeInternal, + onWrite: func() {}, + }) + } +} + func (t *http2Server) streamContextErr(s *ServerStream) error { select { case <-t.done: @@ -1051,7 +1041,7 @@ func (t *http2Server) writeHeaderLocked(s *ServerStream) error { endStream: false, onWrite: t.setResetPingStrikes, } - success, err := t.controlBuf.executeAndPut(func() bool { return t.checkForHeaderListSize(hf) }, hf) + success, err := t.controlBuf.executeAndPut(func() bool { return t.checkForHeaderListSize(hf.hf) }, hf) if !success { if err != nil { return err @@ -1059,14 +1049,13 @@ func (t *http2Server) writeHeaderLocked(s *ServerStream) error { t.closeStream(s, true, http2.ErrCodeInternal, false) return ErrHeaderListSizeLimitViolation } - for _, sh := range t.stats { + if t.stats != nil { // Note: Headers are compressed with hpack after this call returns. // No WireLength field is set here. - outHeader := &stats.OutHeader{ + t.stats.HandleRPC(s.Context(), &stats.OutHeader{ Header: s.header.Copy(), Compression: s.sendCompress, - } - sh.HandleRPC(s.Context(), outHeader) + }) } return nil } @@ -1122,7 +1111,7 @@ func (t *http2Server) writeStatus(s *ServerStream, st *status.Status) error { } success, err := t.controlBuf.executeAndPut(func() bool { - return t.checkForHeaderListSize(trailingHeader) + return t.checkForHeaderListSize(trailingHeader.hf) }, nil) if !success { if err != nil { @@ -1134,10 +1123,10 @@ func (t *http2Server) writeStatus(s *ServerStream, st *status.Status) error { // Send a RST_STREAM after the trailers if the client has not already half-closed. rst := s.getState() == streamActive t.finishStream(s, rst, http2.ErrCodeNo, trailingHeader, true) - for _, sh := range t.stats { + if t.stats != nil { // Note: The trailer fields are compressed with hpack after this call returns. // No WireLength field is set here. - sh.HandleRPC(s.Context(), &stats.OutTrailer{ + t.stats.HandleRPC(s.Context(), &stats.OutTrailer{ Trailer: s.trailer.Copy(), }) } @@ -1305,7 +1294,8 @@ func (t *http2Server) Close(err error) { // deleteStream deletes the stream s from transport's active streams. func (t *http2Server) deleteStream(s *ServerStream, eosReceived bool) { t.mu.Lock() - if _, ok := t.activeStreams[s.id]; ok { + _, isActive := t.activeStreams[s.id] + if isActive { delete(t.activeStreams, s.id) if len(t.activeStreams) == 0 { t.idle = time.Now() @@ -1313,7 +1303,7 @@ func (t *http2Server) deleteStream(s *ServerStream, eosReceived bool) { } t.mu.Unlock() - if channelz.IsOn() { + if isActive && channelz.IsOn() { if eosReceived { t.channelz.SocketMetrics.StreamsSucceeded.Add(1) } else { diff --git a/vendor/google.golang.org/grpc/internal/transport/http_util.go b/vendor/google.golang.org/grpc/internal/transport/http_util.go index e3663f87f..5bbb641ad 100644 --- a/vendor/google.golang.org/grpc/internal/transport/http_util.go +++ b/vendor/google.golang.org/grpc/internal/transport/http_util.go @@ -25,7 +25,6 @@ import ( "fmt" "io" "math" - "net" "net/http" "net/url" "strconv" @@ -37,6 +36,7 @@ import ( "golang.org/x/net/http2" "golang.org/x/net/http2/hpack" "google.golang.org/grpc/codes" + "google.golang.org/grpc/mem" ) const ( @@ -300,11 +300,11 @@ type bufWriter struct { buf []byte offset int batchSize int - conn net.Conn + conn io.Writer err error } -func newBufWriter(conn net.Conn, batchSize int, pool *sync.Pool) *bufWriter { +func newBufWriter(conn io.Writer, batchSize int, pool *sync.Pool) *bufWriter { w := &bufWriter{ batchSize: batchSize, conn: conn, @@ -388,15 +388,29 @@ func toIOError(err error) error { return ioError{error: err} } +type parsedDataFrame struct { + http2.FrameHeader + data mem.Buffer +} + +func (df *parsedDataFrame) StreamEnded() bool { + return df.FrameHeader.Flags.Has(http2.FlagDataEndStream) +} + type framer struct { - writer *bufWriter - fr *http2.Framer + writer *bufWriter + fr *http2.Framer + headerBuf []byte // cached slice for framer headers to reduce heap allocs. + reader io.Reader + dataFrame parsedDataFrame // Cached data frame to avoid heap allocations. + pool mem.BufferPool + errDetail error } var writeBufferPoolMap = make(map[int]*sync.Pool) var writeBufferMutex sync.Mutex -func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, sharedWriteBuffer bool, maxHeaderListSize uint32) *framer { +func newFramer(conn io.ReadWriter, writeBufferSize, readBufferSize int, sharedWriteBuffer bool, maxHeaderListSize uint32, memPool mem.BufferPool) *framer { if writeBufferSize < 0 { writeBufferSize = 0 } @@ -412,6 +426,8 @@ func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, sharedWriteBu f := &framer{ writer: w, fr: http2.NewFramer(w, r), + reader: r, + pool: memPool, } f.fr.SetMaxReadFrameSize(http2MaxFrameLen) // Opt-in to Frame reuse API on framer to reduce garbage. @@ -422,6 +438,146 @@ func newFramer(conn net.Conn, writeBufferSize, readBufferSize int, sharedWriteBu return f } +// writeData writes a DATA frame. +// +// It is the caller's responsibility not to violate the maximum frame size. +func (f *framer) writeData(streamID uint32, endStream bool, data [][]byte) error { + var flags http2.Flags + if endStream { + flags = http2.FlagDataEndStream + } + length := uint32(0) + for _, d := range data { + length += uint32(len(d)) + } + // TODO: Replace the header write with the framer API being added in + // https://github.com/golang/go/issues/66655. + f.headerBuf = append(f.headerBuf[:0], + byte(length>>16), + byte(length>>8), + byte(length), + byte(http2.FrameData), + byte(flags), + byte(streamID>>24), + byte(streamID>>16), + byte(streamID>>8), + byte(streamID)) + if _, err := f.writer.Write(f.headerBuf); err != nil { + return err + } + for _, d := range data { + if _, err := f.writer.Write(d); err != nil { + return err + } + } + return nil +} + +// readFrame reads a single frame. The returned Frame is only valid +// until the next call to readFrame. +func (f *framer) readFrame() (any, error) { + f.errDetail = nil + fh, err := f.fr.ReadFrameHeader() + if err != nil { + f.errDetail = f.fr.ErrorDetail() + return nil, err + } + // Read the data frame directly from the underlying io.Reader to avoid + // copies. + if fh.Type == http2.FrameData { + err = f.readDataFrame(fh) + return &f.dataFrame, err + } + fr, err := f.fr.ReadFrameForHeader(fh) + if err != nil { + f.errDetail = f.fr.ErrorDetail() + return nil, err + } + return fr, err +} + +// errorDetail returns a more detailed error of the last error +// returned by framer.readFrame. For instance, if readFrame +// returns a StreamError with code PROTOCOL_ERROR, errorDetail +// will say exactly what was invalid. errorDetail is not guaranteed +// to return a non-nil value. +// errorDetail is reset after the next call to readFrame. +func (f *framer) errorDetail() error { + return f.errDetail +} + +func (f *framer) readDataFrame(fh http2.FrameHeader) (err error) { + if fh.StreamID == 0 { + // DATA frames MUST be associated with a stream. If a + // DATA frame is received whose stream identifier + // field is 0x0, the recipient MUST respond with a + // connection error (Section 5.4.1) of type + // PROTOCOL_ERROR. + f.errDetail = errors.New("DATA frame with stream ID 0") + return http2.ConnectionError(http2.ErrCodeProtocol) + } + // Converting a *[]byte to a mem.SliceBuffer incurs a heap allocation. This + // conversion is performed by mem.NewBuffer. To avoid the extra allocation + // a []byte is allocated directly if required and cast to a mem.SliceBuffer. + var buf []byte + // poolHandle is the pointer returned by the buffer pool (if it's used.). + var poolHandle *[]byte + useBufferPool := !mem.IsBelowBufferPoolingThreshold(int(fh.Length)) + if useBufferPool { + poolHandle = f.pool.Get(int(fh.Length)) + buf = *poolHandle + defer func() { + if err != nil { + f.pool.Put(poolHandle) + } + }() + } else { + buf = make([]byte, int(fh.Length)) + } + if fh.Flags.Has(http2.FlagDataPadded) { + if fh.Length == 0 { + return io.ErrUnexpectedEOF + } + // This initial 1-byte read can be inefficient for unbuffered readers, + // but it allows the rest of the payload to be read directly to the + // start of the destination slice. This makes it easy to return the + // original slice back to the buffer pool. + if _, err := io.ReadFull(f.reader, buf[:1]); err != nil { + return err + } + padSize := buf[0] + buf = buf[:len(buf)-1] + if int(padSize) > len(buf) { + // If the length of the padding is greater than the + // length of the frame payload, the recipient MUST + // treat this as a connection error. + // Filed: https://github.com/http2/http2-spec/issues/610 + f.errDetail = errors.New("pad size larger than data payload") + return http2.ConnectionError(http2.ErrCodeProtocol) + } + if _, err := io.ReadFull(f.reader, buf); err != nil { + return err + } + buf = buf[:len(buf)-int(padSize)] + } else if _, err := io.ReadFull(f.reader, buf); err != nil { + return err + } + + f.dataFrame.FrameHeader = fh + if useBufferPool { + // Update the handle to point to the (potentially re-sliced) buf. + *poolHandle = buf + f.dataFrame.data = mem.NewBuffer(poolHandle, f.pool) + } else { + f.dataFrame.data = mem.SliceBuffer(buf) + } + return nil +} + +func (df *parsedDataFrame) Header() http2.FrameHeader { + return df.FrameHeader +} + func getWriteBufferPool(size int) *sync.Pool { writeBufferMutex.Lock() defer writeBufferMutex.Unlock() diff --git a/vendor/google.golang.org/grpc/internal/transport/server_stream.go b/vendor/google.golang.org/grpc/internal/transport/server_stream.go index cf8da0b52..ed6a13b75 100644 --- a/vendor/google.golang.org/grpc/internal/transport/server_stream.go +++ b/vendor/google.golang.org/grpc/internal/transport/server_stream.go @@ -32,7 +32,7 @@ import ( // ServerStream implements streaming functionality for a gRPC server. type ServerStream struct { - *Stream // Embed for common stream functionality. + Stream // Embed for common stream functionality. st internalServerTransport ctxDone <-chan struct{} // closed at the end of stream. Cache of ctx.Done() (for performance) @@ -43,12 +43,13 @@ type ServerStream struct { // Holds compressor names passed in grpc-accept-encoding metadata from the // client. clientAdvertisedCompressors string - headerWireLength int // hdrMu protects outgoing header and trailer metadata. hdrMu sync.Mutex header metadata.MD // the outgoing header metadata. Updated by WriteHeader. headerSent atomic.Bool // atomically set when the headers are sent out. + + headerWireLength int } // Read reads an n byte message from the input stream. @@ -178,3 +179,11 @@ func (s *ServerStream) SetTrailer(md metadata.MD) error { s.hdrMu.Unlock() return nil } + +func (s *ServerStream) requestRead(n int) { + s.st.adjustWindow(s, uint32(n)) +} + +func (s *ServerStream) updateWindow(n int) { + s.st.updateWindow(s, uint32(n)) +} diff --git a/vendor/google.golang.org/grpc/internal/transport/transport.go b/vendor/google.golang.org/grpc/internal/transport/transport.go index 7dd53e80a..b86094da9 100644 --- a/vendor/google.golang.org/grpc/internal/transport/transport.go +++ b/vendor/google.golang.org/grpc/internal/transport/transport.go @@ -68,11 +68,11 @@ type recvBuffer struct { err error } -func newRecvBuffer() *recvBuffer { - b := &recvBuffer{ - c: make(chan recvMsg, 1), - } - return b +// init allows a recvBuffer to be initialized in-place, which is useful +// for resetting a buffer or for avoiding a heap allocation when the buffer +// is embedded in another struct. +func (b *recvBuffer) init() { + b.c = make(chan recvMsg, 1) } func (b *recvBuffer) put(r recvMsg) { @@ -123,12 +123,13 @@ func (b *recvBuffer) get() <-chan recvMsg { // recvBufferReader implements io.Reader interface to read the data from // recvBuffer. type recvBufferReader struct { - closeStream func(error) // Closes the client transport stream with the given error and nil trailer metadata. - ctx context.Context - ctxDone <-chan struct{} // cache of ctx.Done() (for performance). - recv *recvBuffer - last mem.Buffer // Stores the remaining data in the previous calls. - err error + _ noCopy + clientStream *ClientStream // The client transport stream is closed with a status representing ctx.Err() and nil trailer metadata. + ctx context.Context + ctxDone <-chan struct{} // cache of ctx.Done() (for performance). + recv *recvBuffer + last mem.Buffer // Stores the remaining data in the previous calls. + err error } func (r *recvBufferReader) ReadMessageHeader(header []byte) (n int, err error) { @@ -139,7 +140,7 @@ func (r *recvBufferReader) ReadMessageHeader(header []byte) (n int, err error) { n, r.last = mem.ReadUnsafe(header, r.last) return n, nil } - if r.closeStream != nil { + if r.clientStream != nil { n, r.err = r.readMessageHeaderClient(header) } else { n, r.err = r.readMessageHeader(header) @@ -164,7 +165,7 @@ func (r *recvBufferReader) Read(n int) (buf mem.Buffer, err error) { } return buf, nil } - if r.closeStream != nil { + if r.clientStream != nil { buf, r.err = r.readClient(n) } else { buf, r.err = r.read(n) @@ -209,7 +210,7 @@ func (r *recvBufferReader) readMessageHeaderClient(header []byte) (n int, err er // TODO: delaying ctx error seems like a unnecessary side effect. What // we really want is to mark the stream as done, and return ctx error // faster. - r.closeStream(ContextErr(r.ctx.Err())) + r.clientStream.Close(ContextErr(r.ctx.Err())) m := <-r.recv.get() return r.readMessageHeaderAdditional(m, header) case m := <-r.recv.get(): @@ -236,7 +237,7 @@ func (r *recvBufferReader) readClient(n int) (buf mem.Buffer, err error) { // TODO: delaying ctx error seems like a unnecessary side effect. What // we really want is to mark the stream as done, and return ctx error // faster. - r.closeStream(ContextErr(r.ctx.Err())) + r.clientStream.Close(ContextErr(r.ctx.Err())) m := <-r.recv.get() return r.readAdditional(m, n) case m := <-r.recv.get(): @@ -285,27 +286,32 @@ const ( // Stream represents an RPC in the transport layer. type Stream struct { - id uint32 ctx context.Context // the associated context of the stream method string // the associated RPC method of the stream recvCompress string sendCompress string - buf *recvBuffer - trReader *transportReader - fc *inFlow - wq *writeQuota - - // Callback to state application's intentions to read data. This - // is used to adjust flow control, if needed. - requestRead func(int) - state streamState + readRequester readRequester // contentSubtype is the content-subtype for requests. // this must be lowercase or the behavior is undefined. contentSubtype string trailer metadata.MD // the key-value map of trailer metadata. + + // Non-pointer fields are at the end to optimize GC performance. + state streamState + id uint32 + buf recvBuffer + trReader transportReader + fc inFlow + wq writeQuota +} + +// readRequester is used to state application's intentions to read data. This +// is used to adjust flow control, if needed. +type readRequester interface { + requestRead(int) } func (s *Stream) swapState(st streamState) streamState { @@ -355,7 +361,7 @@ func (s *Stream) ReadMessageHeader(header []byte) (err error) { if er := s.trReader.er; er != nil { return er } - s.requestRead(len(header)) + s.readRequester.requestRead(len(header)) for len(header) != 0 { n, err := s.trReader.ReadMessageHeader(header) header = header[n:] @@ -372,13 +378,29 @@ func (s *Stream) ReadMessageHeader(header []byte) (err error) { return nil } +// ceil returns the ceil after dividing the numerator and denominator while +// avoiding integer overflows. +func ceil(numerator, denominator int) int { + if numerator == 0 { + return 0 + } + return (numerator-1)/denominator + 1 +} + // Read reads n bytes from the wire for this stream. func (s *Stream) read(n int) (data mem.BufferSlice, err error) { // Don't request a read if there was an error earlier if er := s.trReader.er; er != nil { return nil, er } - s.requestRead(n) + // gRPC Go accepts data frames with a maximum length of 16KB. Larger + // messages must be split into multiple frames. We pre-allocate the + // buffer to avoid resizing during the read loop, but cap the initial + // capacity to 128 frames (2MB) to prevent over-allocation or panics + // when reading extremely large streams. + allocCap := min(ceil(n, http2MaxFrameLen), 128) + data = make(mem.BufferSlice, 0, allocCap) + s.readRequester.requestRead(n) for n != 0 { buf, err := s.trReader.Read(n) var bufLen int @@ -401,16 +423,34 @@ func (s *Stream) read(n int) (data mem.BufferSlice, err error) { return data, nil } +// noCopy may be embedded into structs which must not be copied +// after the first use. +// +// See https://golang.org/issues/8005#issuecomment-190753527 +// for details. +type noCopy struct { +} + +func (*noCopy) Lock() {} +func (*noCopy) Unlock() {} + // transportReader reads all the data available for this Stream from the transport and // passes them into the decoder, which converts them into a gRPC message stream. // The error is io.EOF when the stream is done or another non-nil error if // the stream broke. type transportReader struct { - reader *recvBufferReader + _ noCopy // The handler to control the window update procedure for both this // particular stream and the associated transport. - windowHandler func(int) + windowHandler windowHandler er error + reader recvBufferReader +} + +// The handler to control the window update procedure for both this +// particular stream and the associated transport. +type windowHandler interface { + updateWindow(int) } func (t *transportReader) ReadMessageHeader(header []byte) (int, error) { @@ -419,7 +459,7 @@ func (t *transportReader) ReadMessageHeader(header []byte) (int, error) { t.er = err return 0, err } - t.windowHandler(n) + t.windowHandler.updateWindow(n) return n, nil } @@ -429,7 +469,7 @@ func (t *transportReader) Read(n int) (mem.Buffer, error) { t.er = err return buf, err } - t.windowHandler(buf.Len()) + t.windowHandler.updateWindow(buf.Len()) return buf, nil } @@ -454,7 +494,7 @@ type ServerConfig struct { ConnectionTimeout time.Duration Credentials credentials.TransportCredentials InTapHandle tap.ServerInHandle - StatsHandlers []stats.Handler + StatsHandler stats.Handler KeepaliveParams keepalive.ServerParameters KeepalivePolicy keepalive.EnforcementPolicy InitialWindowSize int32 @@ -529,6 +569,12 @@ type CallHdr struct { // outbound message. SendCompress string + // AcceptedCompressors overrides the grpc-accept-encoding header for this + // call. When nil, the transport advertises the default set of registered + // compressors. A non-nil pointer overrides that value (including the empty + // string to advertise none). + AcceptedCompressors *string + // Creds specifies credentials.PerRPCCredentials for a call. Creds credentials.PerRPCCredentials @@ -544,9 +590,14 @@ type CallHdr struct { DoneFunc func() // called when the stream is finished - // Authority is used to explicitly override the `:authority` header. If set, - // this value takes precedence over the Host field and will be used as the - // value for the `:authority` header. + // Authority is used to explicitly override the `:authority` header. + // + // This value comes from one of two sources: + // 1. The `CallAuthority` call option, if specified by the user. + // 2. An override provided by the LB picker (e.g. xDS authority rewriting). + // + // The `CallAuthority` call option always takes precedence over the LB + // picker override. Authority string } @@ -566,7 +617,7 @@ type ClientTransport interface { GracefulClose() // NewStream creates a Stream for an RPC. - NewStream(ctx context.Context, callHdr *CallHdr) (*ClientStream, error) + NewStream(ctx context.Context, callHdr *CallHdr, handler stats.Handler) (*ClientStream, error) // Error returns a channel that is closed when some I/O error // happens. Typically the caller should have a goroutine to monitor @@ -584,8 +635,9 @@ type ClientTransport interface { // with a human readable string with debug info. GetGoAwayReason() (GoAwayReason, string) - // RemoteAddr returns the remote network address. - RemoteAddr() net.Addr + // Peer returns information about the peer associated with the Transport. + // The returned information includes authentication and network address details. + Peer() *peer.Peer } // ServerTransport is the common interface for all gRPC server-side transport @@ -615,6 +667,8 @@ type internalServerTransport interface { write(s *ServerStream, hdr []byte, data mem.BufferSlice, opts *WriteOptions) error writeStatus(s *ServerStream, st *status.Status) error incrMsgRecv() + adjustWindow(s *ServerStream, n uint32) + updateWindow(s *ServerStream, n uint32) } // connectionErrorf creates an ConnectionError with the specified error description. diff --git a/vendor/google.golang.org/grpc/mem/buffer_pool.go b/vendor/google.golang.org/grpc/mem/buffer_pool.go index c37c58c02..2ea763a49 100644 --- a/vendor/google.golang.org/grpc/mem/buffer_pool.go +++ b/vendor/google.golang.org/grpc/mem/buffer_pool.go @@ -32,12 +32,17 @@ type BufferPool interface { Get(length int) *[]byte // Put returns a buffer to the pool. + // + // The provided pointer must hold a prefix of the buffer obtained via + // BufferPool.Get to ensure the buffer's entire capacity can be re-used. Put(*[]byte) } +const goPageSize = 4 << 10 // 4KiB. N.B. this must be a power of 2. + var defaultBufferPoolSizes = []int{ 256, - 4 << 10, // 4KB (go page size) + goPageSize, 16 << 10, // 16KB (max HTTP/2 frame size used by gRPC) 32 << 10, // 32KB (default buffer size for io.Copy) 1 << 20, // 1MB @@ -48,7 +53,7 @@ var defaultBufferPool BufferPool func init() { defaultBufferPool = NewTieredBufferPool(defaultBufferPoolSizes...) - internal.SetDefaultBufferPoolForTesting = func(pool BufferPool) { + internal.SetDefaultBufferPool = func(pool BufferPool) { defaultBufferPool = pool } @@ -118,7 +123,11 @@ type sizedBufferPool struct { } func (p *sizedBufferPool) Get(size int) *[]byte { - buf := p.pool.Get().(*[]byte) + buf, ok := p.pool.Get().(*[]byte) + if !ok { + buf := make([]byte, size, p.defaultSize) + return &buf + } b := *buf clear(b[:cap(b)]) *buf = b[:size] @@ -137,12 +146,6 @@ func (p *sizedBufferPool) Put(buf *[]byte) { func newSizedBufferPool(size int) *sizedBufferPool { return &sizedBufferPool{ - pool: sync.Pool{ - New: func() any { - buf := make([]byte, size) - return &buf - }, - }, defaultSize: size, } } @@ -160,6 +163,7 @@ type simpleBufferPool struct { func (p *simpleBufferPool) Get(size int) *[]byte { bs, ok := p.pool.Get().(*[]byte) if ok && cap(*bs) >= size { + clear((*bs)[:cap(*bs)]) *bs = (*bs)[:size] return bs } @@ -170,7 +174,14 @@ func (p *simpleBufferPool) Get(size int) *[]byte { p.pool.Put(bs) } - b := make([]byte, size) + // If we're going to allocate, round up to the nearest page. This way if + // requests frequently arrive with small variation we don't allocate + // repeatedly if we get unlucky and they increase over time. By default we + // only allocate here if size > 1MiB. Because goPageSize is a power of 2, we + // can round up efficiently. + allocSize := (size + goPageSize - 1) & ^(goPageSize - 1) + + b := make([]byte, size, allocSize) return &b } diff --git a/vendor/google.golang.org/grpc/mem/buffer_slice.go b/vendor/google.golang.org/grpc/mem/buffer_slice.go index af510d20c..084fb19c6 100644 --- a/vendor/google.golang.org/grpc/mem/buffer_slice.go +++ b/vendor/google.golang.org/grpc/mem/buffer_slice.go @@ -19,6 +19,7 @@ package mem import ( + "fmt" "io" ) @@ -117,43 +118,36 @@ func (s BufferSlice) MaterializeToBuffer(pool BufferPool) Buffer { // Reader returns a new Reader for the input slice after taking references to // each underlying buffer. -func (s BufferSlice) Reader() Reader { +func (s BufferSlice) Reader() *Reader { s.Ref() - return &sliceReader{ + return &Reader{ data: s, len: s.Len(), } } // Reader exposes a BufferSlice's data as an io.Reader, allowing it to interface -// with other parts systems. It also provides an additional convenience method -// Remaining(), which returns the number of unread bytes remaining in the slice. +// with other systems. +// // Buffers will be freed as they are read. -type Reader interface { - io.Reader - io.ByteReader - // Close frees the underlying BufferSlice and never returns an error. Subsequent - // calls to Read will return (0, io.EOF). - Close() error - // Remaining returns the number of unread bytes remaining in the slice. - Remaining() int - // Reset frees the currently held buffer slice and starts reading from the - // provided slice. This allows reusing the reader object. - Reset(s BufferSlice) -} - -type sliceReader struct { +// +// A Reader can be constructed from a BufferSlice; alternatively the zero value +// of a Reader may be used after calling Reset on it. +type Reader struct { data BufferSlice len int // The index into data[0].ReadOnlyData(). bufferIdx int } -func (r *sliceReader) Remaining() int { +// Remaining returns the number of unread bytes remaining in the slice. +func (r *Reader) Remaining() int { return r.len } -func (r *sliceReader) Reset(s BufferSlice) { +// Reset frees the currently held buffer slice and starts reading from the +// provided slice. This allows reusing the reader object. +func (r *Reader) Reset(s BufferSlice) { r.data.Free() s.Ref() r.data = s @@ -161,14 +155,16 @@ func (r *sliceReader) Reset(s BufferSlice) { r.bufferIdx = 0 } -func (r *sliceReader) Close() error { +// Close frees the underlying BufferSlice and never returns an error. Subsequent +// calls to Read will return (0, io.EOF). +func (r *Reader) Close() error { r.data.Free() r.data = nil r.len = 0 return nil } -func (r *sliceReader) freeFirstBufferIfEmpty() bool { +func (r *Reader) freeFirstBufferIfEmpty() bool { if len(r.data) == 0 || r.bufferIdx != len(r.data[0].ReadOnlyData()) { return false } @@ -179,7 +175,7 @@ func (r *sliceReader) freeFirstBufferIfEmpty() bool { return true } -func (r *sliceReader) Read(buf []byte) (n int, _ error) { +func (r *Reader) Read(buf []byte) (n int, _ error) { if r.len == 0 { return 0, io.EOF } @@ -202,7 +198,8 @@ func (r *sliceReader) Read(buf []byte) (n int, _ error) { return n, nil } -func (r *sliceReader) ReadByte() (byte, error) { +// ReadByte reads a single byte. +func (r *Reader) ReadByte() (byte, error) { if r.len == 0 { return 0, io.EOF } @@ -290,3 +287,59 @@ nextBuffer: } } } + +// Discard skips the next n bytes, returning the number of bytes discarded. +// +// It frees buffers as they are fully consumed. +// +// If Discard skips fewer than n bytes, it also returns an error. +func (r *Reader) Discard(n int) (discarded int, err error) { + total := n + for n > 0 && r.len > 0 { + curData := r.data[0].ReadOnlyData() + curSize := min(n, len(curData)-r.bufferIdx) + n -= curSize + r.len -= curSize + r.bufferIdx += curSize + if r.bufferIdx >= len(curData) { + r.data[0].Free() + r.data = r.data[1:] + r.bufferIdx = 0 + } + } + discarded = total - n + if n > 0 { + return discarded, fmt.Errorf("insufficient bytes in reader") + } + return discarded, nil +} + +// Peek returns the next n bytes without advancing the reader. +// +// Peek appends results to the provided res slice and returns the updated slice. +// This pattern allows re-using the storage of res if it has sufficient +// capacity. +// +// The returned subslices are views into the underlying buffers and are only +// valid until the reader is advanced past the corresponding buffer. +// +// If Peek returns fewer than n bytes, it also returns an error. +func (r *Reader) Peek(n int, res [][]byte) ([][]byte, error) { + for i := 0; n > 0 && i < len(r.data); i++ { + curData := r.data[i].ReadOnlyData() + start := 0 + if i == 0 { + start = r.bufferIdx + } + curSize := min(n, len(curData)-start) + if curSize == 0 { + continue + } + res = append(res, curData[start:start+curSize]) + n -= curSize + } + if n > 0 { + return nil, fmt.Errorf("insufficient bytes in reader") + } + return res, nil +} diff --git a/vendor/google.golang.org/grpc/mem/buffers.go b/vendor/google.golang.org/grpc/mem/buffers.go index ecbf0b9a7..db1620e6a 100644 --- a/vendor/google.golang.org/grpc/mem/buffers.go +++ b/vendor/google.golang.org/grpc/mem/buffers.go @@ -62,7 +62,6 @@ var ( bufferPoolingThreshold = 1 << 10 bufferObjectPool = sync.Pool{New: func() any { return new(buffer) }} - refObjectPool = sync.Pool{New: func() any { return new(atomic.Int32) }} ) // IsBelowBufferPoolingThreshold returns true if the given size is less than or @@ -73,9 +72,19 @@ func IsBelowBufferPoolingThreshold(size int) bool { } type buffer struct { + refs atomic.Int32 + data []byte + + // rootBuf is the buffer responsible for returning origData to the pool + // once the reference count drops to 0. + // + // When a buffer is split, the new buffer inherits the rootBuf of the + // original and increments the root's reference count. For the + // initial buffer (the root), this field points to itself. + rootBuf *buffer + + // The following fields are only set for root buffers. origData *[]byte - data []byte - refs *atomic.Int32 pool BufferPool } @@ -103,8 +112,8 @@ func NewBuffer(data *[]byte, pool BufferPool) Buffer { b.origData = data b.data = *data b.pool = pool - b.refs = refObjectPool.Get().(*atomic.Int32) - b.refs.Add(1) + b.rootBuf = b + b.refs.Store(1) return b } @@ -127,42 +136,44 @@ func Copy(data []byte, pool BufferPool) Buffer { } func (b *buffer) ReadOnlyData() []byte { - if b.refs == nil { + if b.rootBuf == nil { panic("Cannot read freed buffer") } return b.data } func (b *buffer) Ref() { - if b.refs == nil { + if b.refs.Add(1) <= 1 { panic("Cannot ref freed buffer") } - b.refs.Add(1) } func (b *buffer) Free() { - if b.refs == nil { + refs := b.refs.Add(-1) + if refs < 0 { panic("Cannot free freed buffer") } - - refs := b.refs.Add(-1) - switch { - case refs > 0: + if refs > 0 { return - case refs == 0: + } + + b.data = nil + if b.rootBuf == b { + // This buffer is the owner of the data slice and its ref count reached + // 0, free the slice. if b.pool != nil { b.pool.Put(b.origData) + b.pool = nil } - - refObjectPool.Put(b.refs) b.origData = nil - b.data = nil - b.refs = nil - b.pool = nil - bufferObjectPool.Put(b) - default: - panic("Cannot free freed buffer") + } else { + // This buffer doesn't own the data slice, decrement a ref on the root + // buffer. + b.rootBuf.Free() } + + b.rootBuf = nil + bufferObjectPool.Put(b) } func (b *buffer) Len() int { @@ -170,16 +181,14 @@ func (b *buffer) Len() int { } func (b *buffer) split(n int) (Buffer, Buffer) { - if b.refs == nil { + if b.rootBuf == nil || b.rootBuf.refs.Add(1) <= 1 { panic("Cannot split freed buffer") } - b.refs.Add(1) split := newBuffer() - split.origData = b.origData split.data = b.data[n:] - split.refs = b.refs - split.pool = b.pool + split.rootBuf = b.rootBuf + split.refs.Store(1) b.data = b.data[:n] @@ -187,7 +196,7 @@ func (b *buffer) split(n int) (Buffer, Buffer) { } func (b *buffer) read(buf []byte) (int, Buffer) { - if b.refs == nil { + if b.rootBuf == nil { panic("Cannot read freed buffer") } diff --git a/vendor/google.golang.org/grpc/preloader.go b/vendor/google.golang.org/grpc/preloader.go index ee0ff969a..1e783febf 100644 --- a/vendor/google.golang.org/grpc/preloader.go +++ b/vendor/google.golang.org/grpc/preloader.go @@ -47,9 +47,6 @@ func (p *PreparedMsg) Encode(s Stream, msg any) error { } // check if the context has the relevant information to prepareMsg - if rpcInfo.preloaderInfo == nil { - return status.Errorf(codes.Internal, "grpc: rpcInfo.preloaderInfo is nil") - } if rpcInfo.preloaderInfo.codec == nil { return status.Errorf(codes.Internal, "grpc: rpcInfo.preloaderInfo.codec is nil") } diff --git a/vendor/google.golang.org/grpc/reflection/grpc_reflection_v1/reflection.pb.go b/vendor/google.golang.org/grpc/reflection/grpc_reflection_v1/reflection.pb.go index 92f529221..92fdc3afa 100644 --- a/vendor/google.golang.org/grpc/reflection/grpc_reflection_v1/reflection.pb.go +++ b/vendor/google.golang.org/grpc/reflection/grpc_reflection_v1/reflection.pb.go @@ -21,7 +21,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.6 +// protoc-gen-go v1.36.10 // protoc v5.27.1 // source: grpc/reflection/v1/reflection.proto diff --git a/vendor/google.golang.org/grpc/reflection/grpc_reflection_v1/reflection_grpc.pb.go b/vendor/google.golang.org/grpc/reflection/grpc_reflection_v1/reflection_grpc.pb.go index f4a361c64..93a243631 100644 --- a/vendor/google.golang.org/grpc/reflection/grpc_reflection_v1/reflection_grpc.pb.go +++ b/vendor/google.golang.org/grpc/reflection/grpc_reflection_v1/reflection_grpc.pb.go @@ -21,7 +21,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.5.1 +// - protoc-gen-go-grpc v1.6.0 // - protoc v5.27.1 // source: grpc/reflection/v1/reflection.proto diff --git a/vendor/google.golang.org/grpc/reflection/grpc_reflection_v1alpha/reflection.pb.go b/vendor/google.golang.org/grpc/reflection/grpc_reflection_v1alpha/reflection.pb.go index 5253e862f..c803cf3ba 100644 --- a/vendor/google.golang.org/grpc/reflection/grpc_reflection_v1alpha/reflection.pb.go +++ b/vendor/google.golang.org/grpc/reflection/grpc_reflection_v1alpha/reflection.pb.go @@ -18,7 +18,7 @@ // Code generated by protoc-gen-go. DO NOT EDIT. // versions: -// protoc-gen-go v1.36.6 +// protoc-gen-go v1.36.10 // protoc v5.27.1 // grpc/reflection/v1alpha/reflection.proto is a deprecated file. diff --git a/vendor/google.golang.org/grpc/reflection/grpc_reflection_v1alpha/reflection_grpc.pb.go b/vendor/google.golang.org/grpc/reflection/grpc_reflection_v1alpha/reflection_grpc.pb.go index 0a43b521c..cee004ab5 100644 --- a/vendor/google.golang.org/grpc/reflection/grpc_reflection_v1alpha/reflection_grpc.pb.go +++ b/vendor/google.golang.org/grpc/reflection/grpc_reflection_v1alpha/reflection_grpc.pb.go @@ -18,7 +18,7 @@ // Code generated by protoc-gen-go-grpc. DO NOT EDIT. // versions: -// - protoc-gen-go-grpc v1.5.1 +// - protoc-gen-go-grpc v1.6.0 // - protoc v5.27.1 // grpc/reflection/v1alpha/reflection.proto is a deprecated file. diff --git a/vendor/google.golang.org/grpc/resolver/resolver.go b/vendor/google.golang.org/grpc/resolver/resolver.go index 8e6af9514..598ed21a2 100644 --- a/vendor/google.golang.org/grpc/resolver/resolver.go +++ b/vendor/google.golang.org/grpc/resolver/resolver.go @@ -182,6 +182,7 @@ type BuildOptions struct { // An Endpoint is one network endpoint, or server, which may have multiple // addresses with which it can be accessed. +// TODO(i/8773) : make resolver.Endpoint and resolver.Address immutable type Endpoint struct { // Addresses contains a list of addresses used to access this endpoint. Addresses []Address diff --git a/vendor/google.golang.org/grpc/resolver_wrapper.go b/vendor/google.golang.org/grpc/resolver_wrapper.go index 80e16a327..6e6137643 100644 --- a/vendor/google.golang.org/grpc/resolver_wrapper.go +++ b/vendor/google.golang.org/grpc/resolver_wrapper.go @@ -69,6 +69,7 @@ func (ccr *ccResolverWrapper) start() error { errCh := make(chan error) ccr.serializer.TrySchedule(func(ctx context.Context) { if ctx.Err() != nil { + errCh <- ctx.Err() return } opts := resolver.BuildOptions{ diff --git a/vendor/google.golang.org/grpc/rpc_util.go b/vendor/google.golang.org/grpc/rpc_util.go index 47ea09f5c..8160f9430 100644 --- a/vendor/google.golang.org/grpc/rpc_util.go +++ b/vendor/google.golang.org/grpc/rpc_util.go @@ -33,6 +33,8 @@ import ( "google.golang.org/grpc/credentials" "google.golang.org/grpc/encoding" "google.golang.org/grpc/encoding/proto" + "google.golang.org/grpc/internal" + "google.golang.org/grpc/internal/grpcutil" "google.golang.org/grpc/internal/transport" "google.golang.org/grpc/mem" "google.golang.org/grpc/metadata" @@ -41,6 +43,10 @@ import ( "google.golang.org/grpc/status" ) +func init() { + internal.AcceptCompressors = acceptCompressors +} + // Compressor defines the interface gRPC uses to compress a message. // // Deprecated: use package encoding. @@ -151,16 +157,32 @@ func (d *gzipDecompressor) Type() string { // callInfo contains all related configuration and information about an RPC. type callInfo struct { - compressorName string - failFast bool - maxReceiveMessageSize *int - maxSendMessageSize *int - creds credentials.PerRPCCredentials - contentSubtype string - codec baseCodec - maxRetryRPCBufferSize int - onFinish []func(err error) - authority string + compressorName string + failFast bool + maxReceiveMessageSize *int + maxSendMessageSize *int + creds credentials.PerRPCCredentials + contentSubtype string + codec baseCodec + maxRetryRPCBufferSize int + onFinish []func(err error) + authority string + acceptedResponseCompressors []string +} + +func acceptedCompressorAllows(allowed []string, name string) bool { + if allowed == nil { + return true + } + if name == "" || name == encoding.Identity { + return true + } + for _, a := range allowed { + if a == name { + return true + } + } + return false } func defaultCallInfo() *callInfo { @@ -170,6 +192,29 @@ func defaultCallInfo() *callInfo { } } +func newAcceptedCompressionConfig(names []string) ([]string, error) { + if len(names) == 0 { + return nil, nil + } + var allowed []string + seen := make(map[string]struct{}, len(names)) + for _, name := range names { + name = strings.TrimSpace(name) + if name == "" || name == encoding.Identity { + continue + } + if !grpcutil.IsCompressorNameRegistered(name) { + return nil, status.Errorf(codes.InvalidArgument, "grpc: compressor %q is not registered", name) + } + if _, dup := seen[name]; dup { + continue + } + seen[name] = struct{}{} + allowed = append(allowed, name) + } + return allowed, nil +} + // CallOption configures a Call before it starts or extracts information from // a Call after it completes. type CallOption interface { @@ -471,6 +516,31 @@ func (o CompressorCallOption) before(c *callInfo) error { } func (o CompressorCallOption) after(*callInfo, *csAttempt) {} +// acceptCompressors returns a CallOption that limits the compression algorithms +// advertised in the grpc-accept-encoding header for response messages. +// Compression algorithms not in the provided list will not be advertised, and +// responses compressed with non-listed algorithms will be rejected. +func acceptCompressors(names ...string) CallOption { + cp := append([]string(nil), names...) + return acceptCompressorsCallOption{names: cp} +} + +// acceptCompressorsCallOption is a CallOption that limits response compression. +type acceptCompressorsCallOption struct { + names []string +} + +func (o acceptCompressorsCallOption) before(c *callInfo) error { + allowed, err := newAcceptedCompressionConfig(o.names) + if err != nil { + return err + } + c.acceptedResponseCompressors = allowed + return nil +} + +func (acceptCompressorsCallOption) after(*callInfo, *csAttempt) {} + // CallContentSubtype returns a CallOption that will set the content-subtype // for a call. For example, if content-subtype is "json", the Content-Type over // the wire will be "application/grpc+json". The content-subtype is converted @@ -657,8 +727,20 @@ type streamReader interface { Read(n int) (mem.BufferSlice, error) } +// noCopy may be embedded into structs which must not be copied +// after the first use. +// +// See https://golang.org/issues/8005#issuecomment-190753527 +// for details. +type noCopy struct { +} + +func (*noCopy) Lock() {} +func (*noCopy) Unlock() {} + // parser reads complete gRPC messages from the underlying reader. type parser struct { + _ noCopy // r is the underlying reader. // See the comment on recvMsg for the permissible // error types. @@ -845,8 +927,7 @@ func (p *payloadInfo) free() { // the buffer is no longer needed. // TODO: Refactor this function to reduce the number of arguments. // See: https://google.github.io/styleguide/go/best-practices.html#function-argument-lists -func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool, -) (out mem.BufferSlice, err error) { +func recvAndDecompress(p *parser, s recvCompressor, dc Decompressor, maxReceiveMessageSize int, payInfo *payloadInfo, compressor encoding.Compressor, isServer bool) (out mem.BufferSlice, err error) { pf, compressed, err := p.recvMsg(maxReceiveMessageSize) if err != nil { return nil, err @@ -949,7 +1030,7 @@ func recv(p *parser, c baseCodec, s recvCompressor, dc Decompressor, m any, maxR // Information about RPC type rpcInfo struct { failfast bool - preloaderInfo *compressorInfo + preloaderInfo compressorInfo } // Information about Preloader @@ -968,7 +1049,7 @@ type rpcInfoContextKey struct{} func newContextWithRPCInfo(ctx context.Context, failfast bool, codec baseCodec, cp Compressor, comp encoding.Compressor) context.Context { return context.WithValue(ctx, rpcInfoContextKey{}, &rpcInfo{ failfast: failfast, - preloaderInfo: &compressorInfo{ + preloaderInfo: compressorInfo{ codec: codec, cp: cp, comp: comp, diff --git a/vendor/google.golang.org/grpc/server.go b/vendor/google.golang.org/grpc/server.go index 1da2a542a..8efb29a7b 100644 --- a/vendor/google.golang.org/grpc/server.go +++ b/vendor/google.golang.org/grpc/server.go @@ -42,6 +42,7 @@ import ( "google.golang.org/grpc/internal" "google.golang.org/grpc/internal/binarylog" "google.golang.org/grpc/internal/channelz" + "google.golang.org/grpc/internal/envconfig" "google.golang.org/grpc/internal/grpcsync" "google.golang.org/grpc/internal/grpcutil" istats "google.golang.org/grpc/internal/stats" @@ -124,7 +125,8 @@ type serviceInfo struct { // Server is a gRPC server to serve RPC requests. type Server struct { - opts serverOptions + opts serverOptions + statsHandler stats.Handler mu sync.Mutex // guards following lis map[net.Listener]bool @@ -148,6 +150,8 @@ type Server struct { serverWorkerChannel chan func() serverWorkerChannelClose func() + + strictPathCheckingLogEmitted atomic.Bool } type serverOptions struct { @@ -692,13 +696,14 @@ func NewServer(opt ...ServerOption) *Server { o.apply(&opts) } s := &Server{ - lis: make(map[net.Listener]bool), - opts: opts, - conns: make(map[string]map[transport.ServerTransport]bool), - services: make(map[string]*serviceInfo), - quit: grpcsync.NewEvent(), - done: grpcsync.NewEvent(), - channelz: channelz.RegisterServer(""), + lis: make(map[net.Listener]bool), + opts: opts, + statsHandler: istats.NewCombinedHandler(opts.statsHandlers...), + conns: make(map[string]map[transport.ServerTransport]bool), + services: make(map[string]*serviceInfo), + quit: grpcsync.NewEvent(), + done: grpcsync.NewEvent(), + channelz: channelz.RegisterServer(""), } chainUnaryServerInterceptors(s) chainStreamServerInterceptors(s) @@ -921,9 +926,7 @@ func (s *Server) Serve(lis net.Listener) error { tempDelay = 5 * time.Millisecond } else { tempDelay *= 2 - } - if max := 1 * time.Second; tempDelay > max { - tempDelay = max + tempDelay = min(tempDelay, 1*time.Second) } s.mu.Lock() s.printf("Accept error: %v; retrying in %v", err, tempDelay) @@ -999,7 +1002,7 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport { ConnectionTimeout: s.opts.connectionTimeout, Credentials: s.opts.creds, InTapHandle: s.opts.inTapHandle, - StatsHandlers: s.opts.statsHandlers, + StatsHandler: s.statsHandler, KeepaliveParams: s.opts.keepaliveParams, KeepalivePolicy: s.opts.keepalivePolicy, InitialWindowSize: s.opts.initialWindowSize, @@ -1036,18 +1039,18 @@ func (s *Server) newHTTP2Transport(c net.Conn) transport.ServerTransport { func (s *Server) serveStreams(ctx context.Context, st transport.ServerTransport, rawConn net.Conn) { ctx = transport.SetConnection(ctx, rawConn) ctx = peer.NewContext(ctx, st.Peer()) - for _, sh := range s.opts.statsHandlers { - ctx = sh.TagConn(ctx, &stats.ConnTagInfo{ + if s.statsHandler != nil { + ctx = s.statsHandler.TagConn(ctx, &stats.ConnTagInfo{ RemoteAddr: st.Peer().Addr, LocalAddr: st.Peer().LocalAddr, }) - sh.HandleConn(ctx, &stats.ConnBegin{}) + s.statsHandler.HandleConn(ctx, &stats.ConnBegin{}) } defer func() { st.Close(errors.New("finished serving streams for the server transport")) - for _, sh := range s.opts.statsHandlers { - sh.HandleConn(ctx, &stats.ConnEnd{}) + if s.statsHandler != nil { + s.statsHandler.HandleConn(ctx, &stats.ConnEnd{}) } }() @@ -1104,7 +1107,7 @@ var _ http.Handler = (*Server)(nil) // Notice: This API is EXPERIMENTAL and may be changed or removed in a // later release. func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { - st, err := transport.NewServerHandlerTransport(w, r, s.opts.statsHandlers, s.opts.bufferPool) + st, err := transport.NewServerHandlerTransport(w, r, s.statsHandler, s.opts.bufferPool) if err != nil { // Errors returned from transport.NewServerHandlerTransport have // already been written to w. @@ -1198,12 +1201,8 @@ func (s *Server) sendResponse(ctx context.Context, stream *transport.ServerStrea return status.Errorf(codes.ResourceExhausted, "grpc: trying to send message larger than max (%d vs. %d)", payloadLen, s.opts.maxSendMessageSize) } err = stream.Write(hdr, payload, opts) - if err == nil { - if len(s.opts.statsHandlers) != 0 { - for _, sh := range s.opts.statsHandlers { - sh.HandleRPC(ctx, outPayload(false, msg, dataLen, payloadLen, time.Now())) - } - } + if err == nil && s.statsHandler != nil { + s.statsHandler.HandleRPC(ctx, outPayload(false, msg, dataLen, payloadLen, time.Now())) } return err } @@ -1245,16 +1244,15 @@ func getChainUnaryHandler(interceptors []UnaryServerInterceptor, curr int, info } func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerStream, info *serviceInfo, md *MethodDesc, trInfo *traceInfo) (err error) { - shs := s.opts.statsHandlers - if len(shs) != 0 || trInfo != nil || channelz.IsOn() { + sh := s.statsHandler + if sh != nil || trInfo != nil || channelz.IsOn() { if channelz.IsOn() { s.incrCallsStarted() } var statsBegin *stats.Begin - for _, sh := range shs { - beginTime := time.Now() + if sh != nil { statsBegin = &stats.Begin{ - BeginTime: beginTime, + BeginTime: time.Now(), IsClientStream: false, IsServerStream: false, } @@ -1282,7 +1280,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt trInfo.tr.Finish() } - for _, sh := range shs { + if sh != nil { end := &stats.End{ BeginTime: statsBegin.BeginTime, EndTime: time.Now(), @@ -1379,7 +1377,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt } var payInfo *payloadInfo - if len(shs) != 0 || len(binlogs) != 0 { + if sh != nil || len(binlogs) != 0 { payInfo = &payloadInfo{} defer payInfo.free() } @@ -1405,7 +1403,7 @@ func (s *Server) processUnaryRPC(ctx context.Context, stream *transport.ServerSt return status.Errorf(codes.Internal, "grpc: error unmarshalling request: %v", err) } - for _, sh := range shs { + if sh != nil { sh.HandleRPC(ctx, &stats.InPayload{ RecvTime: time.Now(), Payload: v, @@ -1579,33 +1577,30 @@ func (s *Server) processStreamingRPC(ctx context.Context, stream *transport.Serv if channelz.IsOn() { s.incrCallsStarted() } - shs := s.opts.statsHandlers + sh := s.statsHandler var statsBegin *stats.Begin - if len(shs) != 0 { - beginTime := time.Now() + if sh != nil { statsBegin = &stats.Begin{ - BeginTime: beginTime, + BeginTime: time.Now(), IsClientStream: sd.ClientStreams, IsServerStream: sd.ServerStreams, } - for _, sh := range shs { - sh.HandleRPC(ctx, statsBegin) - } + sh.HandleRPC(ctx, statsBegin) } ctx = NewContextWithServerTransportStream(ctx, stream) ss := &serverStream{ ctx: ctx, s: stream, - p: &parser{r: stream, bufferPool: s.opts.bufferPool}, + p: parser{r: stream, bufferPool: s.opts.bufferPool}, codec: s.getCodec(stream.ContentSubtype()), desc: sd, maxReceiveMessageSize: s.opts.maxReceiveMessageSize, maxSendMessageSize: s.opts.maxSendMessageSize, trInfo: trInfo, - statsHandler: shs, + statsHandler: sh, } - if len(shs) != 0 || trInfo != nil || channelz.IsOn() { + if sh != nil || trInfo != nil || channelz.IsOn() { // See comment in processUnaryRPC on defers. defer func() { if trInfo != nil { @@ -1619,7 +1614,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, stream *transport.Serv ss.mu.Unlock() } - if len(shs) != 0 { + if sh != nil { end := &stats.End{ BeginTime: statsBegin.BeginTime, EndTime: time.Now(), @@ -1627,9 +1622,7 @@ func (s *Server) processStreamingRPC(ctx context.Context, stream *transport.Serv if err != nil && err != io.EOF { end.Error = toRPCErr(err) } - for _, sh := range shs { - sh.HandleRPC(ctx, end) - } + sh.HandleRPC(ctx, end) } if channelz.IsOn() { @@ -1772,6 +1765,24 @@ func (s *Server) processStreamingRPC(ctx context.Context, stream *transport.Serv return ss.s.WriteStatus(statusOK) } +func (s *Server) handleMalformedMethodName(stream *transport.ServerStream, ti *traceInfo) { + if ti != nil { + ti.tr.LazyLog(&fmtStringer{"Malformed method name %q", []any{stream.Method()}}, true) + ti.tr.SetError() + } + errDesc := fmt.Sprintf("malformed method name: %q", stream.Method()) + if err := stream.WriteStatus(status.New(codes.Unimplemented, errDesc)); err != nil { + if ti != nil { + ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true) + ti.tr.SetError() + } + channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream failed to write status: %v", err) + } + if ti != nil { + ti.tr.Finish() + } +} + func (s *Server) handleStream(t transport.ServerTransport, stream *transport.ServerStream) { ctx := stream.Context() ctx = contextWithServer(ctx, s) @@ -1792,45 +1803,47 @@ func (s *Server) handleStream(t transport.ServerTransport, stream *transport.Ser } sm := stream.Method() - if sm != "" && sm[0] == '/' { + if sm == "" { + s.handleMalformedMethodName(stream, ti) + return + } + if sm[0] != '/' { + // TODO(easwars): Add a link to the CVE in the below log messages once + // published. + if envconfig.DisableStrictPathChecking { + if old := s.strictPathCheckingLogEmitted.Swap(true); !old { + channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream received malformed method name %q. Allowing it because the environment variable GRPC_GO_EXPERIMENTAL_DISABLE_STRICT_PATH_CHECKING is set to true, but this option will be removed in a future release.", sm) + } + } else { + if old := s.strictPathCheckingLogEmitted.Swap(true); !old { + channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream rejected malformed method name %q. To temporarily allow such requests, set the environment variable GRPC_GO_EXPERIMENTAL_DISABLE_STRICT_PATH_CHECKING to true. Note that this is not recommended as it may allow requests to bypass security policies.", sm) + } + s.handleMalformedMethodName(stream, ti) + return + } + } else { sm = sm[1:] } pos := strings.LastIndex(sm, "/") if pos == -1 { - if ti != nil { - ti.tr.LazyLog(&fmtStringer{"Malformed method name %q", []any{sm}}, true) - ti.tr.SetError() - } - errDesc := fmt.Sprintf("malformed method name: %q", stream.Method()) - if err := stream.WriteStatus(status.New(codes.Unimplemented, errDesc)); err != nil { - if ti != nil { - ti.tr.LazyLog(&fmtStringer{"%v", []any{err}}, true) - ti.tr.SetError() - } - channelz.Warningf(logger, s.channelz, "grpc: Server.handleStream failed to write status: %v", err) - } - if ti != nil { - ti.tr.Finish() - } + s.handleMalformedMethodName(stream, ti) return } service := sm[:pos] method := sm[pos+1:] // FromIncomingContext is expensive: skip if there are no statsHandlers - if len(s.opts.statsHandlers) > 0 { + if s.statsHandler != nil { md, _ := metadata.FromIncomingContext(ctx) - for _, sh := range s.opts.statsHandlers { - ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: stream.Method()}) - sh.HandleRPC(ctx, &stats.InHeader{ - FullMethod: stream.Method(), - RemoteAddr: t.Peer().Addr, - LocalAddr: t.Peer().LocalAddr, - Compression: stream.RecvCompress(), - WireLength: stream.HeaderWireLength(), - Header: md, - }) - } + ctx = s.statsHandler.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: stream.Method()}) + s.statsHandler.HandleRPC(ctx, &stats.InHeader{ + FullMethod: stream.Method(), + RemoteAddr: t.Peer().Addr, + LocalAddr: t.Peer().LocalAddr, + Compression: stream.RecvCompress(), + WireLength: stream.HeaderWireLength(), + Header: md, + }) } // To have calls in stream callouts work. Will delete once all stats handler // calls come from the gRPC layer. diff --git a/vendor/google.golang.org/grpc/stream.go b/vendor/google.golang.org/grpc/stream.go index 0a0af8961..eedb5f9b9 100644 --- a/vendor/google.golang.org/grpc/stream.go +++ b/vendor/google.golang.org/grpc/stream.go @@ -25,6 +25,7 @@ import ( "math" rand "math/rand/v2" "strconv" + "strings" "sync" "time" @@ -51,7 +52,8 @@ import ( var metadataFromOutgoingContextRaw = internal.FromOutgoingContextRaw.(func(context.Context) (metadata.MD, [][]string, bool)) // StreamHandler defines the handler called by gRPC server to complete the -// execution of a streaming RPC. +// execution of a streaming RPC. srv is the service implementation on which the +// RPC was invoked. // // If a StreamHandler returns an error, it should either be produced by the // status package, or be one of the context errors. Otherwise, gRPC will use @@ -177,13 +179,43 @@ func NewClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth return cc.NewStream(ctx, desc, method, opts...) } +var emptyMethodConfig = serviceconfig.MethodConfig{} + +// endOfClientStream performs cleanup actions required for both successful and +// failed streams. This includes incrementing channelz stats and invoking all +// registered OnFinish call options. +func endOfClientStream(cc *ClientConn, err error, opts ...CallOption) { + if channelz.IsOn() { + if err != nil { + cc.incrCallsFailed() + } else { + cc.incrCallsSucceeded() + } + } + + for _, o := range opts { + if o, ok := o.(OnFinishCallOption); ok { + o.OnFinish(err) + } + } +} + func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, opts ...CallOption) (_ ClientStream, err error) { + if channelz.IsOn() { + cc.incrCallsStarted() + } + defer func() { + if err != nil { + // Ensure cleanup when stream creation fails. + endOfClientStream(cc, err, opts...) + } + }() + // Start tracking the RPC for idleness purposes. This is where a stream is // created for both streaming and unary RPCs, and hence is a good place to // track active RPC count. - if err := cc.idlenessMgr.OnCallBegin(); err != nil { - return nil, err - } + cc.idlenessMgr.OnCallBegin() + // Add a calloption, to decrement the active call count, that gets executed // when the RPC completes. opts = append([]CallOption{OnFinish(func(error) { cc.idlenessMgr.OnCallEnd() })}, opts...) @@ -202,14 +234,6 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth } } } - if channelz.IsOn() { - cc.incrCallsStarted() - defer func() { - if err != nil { - cc.incrCallsFailed() - } - }() - } // Provide an opportunity for the first RPC to see the first service config // provided by the resolver. nameResolutionDelayed, err := cc.waitForResolvedAddrs(ctx) @@ -217,7 +241,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth return nil, err } - var mc serviceconfig.MethodConfig + mc := &emptyMethodConfig var onCommit func() newStream := func(ctx context.Context, done func()) (iresolver.ClientStream, error) { return newClientStreamWithParams(ctx, desc, cc, method, mc, onCommit, done, nameResolutionDelayed, opts...) @@ -240,7 +264,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth if rpcConfig.Context != nil { ctx = rpcConfig.Context } - mc = rpcConfig.MethodConfig + mc = &rpcConfig.MethodConfig onCommit = rpcConfig.OnCommitted if rpcConfig.Interceptor != nil { rpcInfo.Context = nil @@ -258,7 +282,7 @@ func newClientStream(ctx context.Context, desc *StreamDesc, cc *ClientConn, meth return newStream(ctx, func() {}) } -func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, mc serviceconfig.MethodConfig, onCommit, doneFunc func(), nameResolutionDelayed bool, opts ...CallOption) (_ iresolver.ClientStream, err error) { +func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *ClientConn, method string, mc *serviceconfig.MethodConfig, onCommit, doneFunc func(), nameResolutionDelayed bool, opts ...CallOption) (_ iresolver.ClientStream, err error) { callInfo := defaultCallInfo() if mc.WaitForReady != nil { callInfo.failFast = !*mc.WaitForReady @@ -299,6 +323,10 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client DoneFunc: doneFunc, Authority: callInfo.authority, } + if allowed := callInfo.acceptedResponseCompressors; len(allowed) > 0 { + headerValue := strings.Join(allowed, ",") + callHdr.AcceptedCompressors = &headerValue + } // Set our outgoing compression according to the UseCompressor CallOption, if // set. In that case, also find the compressor from the encoding package. @@ -325,7 +353,7 @@ func newClientStreamWithParams(ctx context.Context, desc *StreamDesc, cc *Client cs := &clientStream{ callHdr: callHdr, ctx: ctx, - methodConfig: &mc, + methodConfig: mc, opts: opts, callInfo: callInfo, cc: cc, @@ -418,19 +446,21 @@ func (cs *clientStream) newAttemptLocked(isTransparent bool) (*csAttempt, error) ctx := newContextWithRPCInfo(cs.ctx, cs.callInfo.failFast, cs.callInfo.codec, cs.compressorV0, cs.compressorV1) method := cs.callHdr.Method var beginTime time.Time - shs := cs.cc.dopts.copts.StatsHandlers - for _, sh := range shs { - ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{FullMethodName: method, FailFast: cs.callInfo.failFast, NameResolutionDelay: cs.nameResolutionDelay}) + sh := cs.cc.statsHandler + if sh != nil { beginTime = time.Now() - begin := &stats.Begin{ + ctx = sh.TagRPC(ctx, &stats.RPCTagInfo{ + FullMethodName: method, FailFast: cs.callInfo.failFast, + NameResolutionDelay: cs.nameResolutionDelay, + }) + sh.HandleRPC(ctx, &stats.Begin{ Client: true, BeginTime: beginTime, FailFast: cs.callInfo.failFast, IsClientStream: cs.desc.ClientStreams, IsServerStream: cs.desc.ServerStreams, IsTransparentRetryAttempt: isTransparent, - } - sh.HandleRPC(ctx, begin) + }) } var trInfo *traceInfo @@ -461,7 +491,7 @@ func (cs *clientStream) newAttemptLocked(isTransparent bool) (*csAttempt, error) beginTime: beginTime, cs: cs, decompressorV0: cs.cc.dopts.dc, - statsHandlers: shs, + statsHandler: sh, trInfo: trInfo, }, nil } @@ -480,12 +510,10 @@ func (a *csAttempt) getTransport() error { return err } if a.trInfo != nil { - a.trInfo.firstLine.SetRemoteAddr(a.transport.RemoteAddr()) + a.trInfo.firstLine.SetRemoteAddr(a.transport.Peer().Addr) } - if pick.blocked { - for _, sh := range a.statsHandlers { - sh.HandleRPC(a.ctx, &stats.DelayedPickComplete{}) - } + if pick.blocked && a.statsHandler != nil { + a.statsHandler.HandleRPC(a.ctx, &stats.DelayedPickComplete{}) } return nil } @@ -510,9 +538,17 @@ func (a *csAttempt) newStream() error { md, _ := metadata.FromOutgoingContext(a.ctx) md = metadata.Join(md, a.pickResult.Metadata) a.ctx = metadata.NewOutgoingContext(a.ctx, md) - } - s, err := a.transport.NewStream(a.ctx, cs.callHdr) + // If the `CallAuthority` CallOption is not set, check if the LB picker + // has provided an authority override in the PickResult metadata and + // apply it, as specified in gRFC A81. + if cs.callInfo.authority == "" { + if authMD := a.pickResult.Metadata.Get(":authority"); len(authMD) > 0 { + cs.callHdr.Authority = authMD[0] + } + } + } + s, err := a.transport.NewStream(a.ctx, cs.callHdr, a.statsHandler) if err != nil { nse, ok := err.(*transport.NewStreamError) if !ok { @@ -529,7 +565,7 @@ func (a *csAttempt) newStream() error { } a.transportStream = s a.ctx = s.Context() - a.parser = &parser{r: s, bufferPool: a.cs.cc.dopts.copts.BufferPool} + a.parser = parser{r: s, bufferPool: a.cs.cc.dopts.copts.BufferPool} return nil } @@ -601,7 +637,7 @@ type csAttempt struct { cs *clientStream transport transport.ClientTransport transportStream *transport.ClientStream - parser *parser + parser parser pickResult balancer.PickResult finished bool @@ -615,8 +651,8 @@ type csAttempt struct { // and cleared when the finish method is called. trInfo *traceInfo - statsHandlers []stats.Handler - beginTime time.Time + statsHandler stats.Handler + beginTime time.Time // set for newStream errors that may be transparently retried allowTransparentRetry bool @@ -1040,9 +1076,6 @@ func (cs *clientStream) finish(err error) { return } cs.finished = true - for _, onFinish := range cs.callInfo.onFinish { - onFinish(err) - } cs.commitAttemptLocked() if cs.attempt != nil { cs.attempt.finish(err) @@ -1082,13 +1115,7 @@ func (cs *clientStream) finish(err error) { if err == nil { cs.retryThrottler.successfulRPC() } - if channelz.IsOn() { - if err != nil { - cs.cc.incrCallsFailed() - } else { - cs.cc.incrCallsSucceeded() - } - } + endOfClientStream(cs.cc, err, cs.opts...) cs.cancel() } @@ -1110,17 +1137,15 @@ func (a *csAttempt) sendMsg(m any, hdr []byte, payld mem.BufferSlice, dataLength } return io.EOF } - if len(a.statsHandlers) != 0 { - for _, sh := range a.statsHandlers { - sh.HandleRPC(a.ctx, outPayload(true, m, dataLength, payloadLength, time.Now())) - } + if a.statsHandler != nil { + a.statsHandler.HandleRPC(a.ctx, outPayload(true, m, dataLength, payloadLength, time.Now())) } return nil } func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) { cs := a.cs - if len(a.statsHandlers) != 0 && payInfo == nil { + if a.statsHandler != nil && payInfo == nil { payInfo = &payloadInfo{} defer payInfo.free() } @@ -1134,6 +1159,10 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) { a.decompressorV0 = nil a.decompressorV1 = encoding.GetCompressor(ct) } + // Validate that the compression method is acceptable for this call. + if !acceptedCompressorAllows(cs.callInfo.acceptedResponseCompressors, ct) { + return status.Errorf(codes.Internal, "grpc: peer compressed the response with %q which is not allowed by AcceptCompressors", ct) + } } else { // No compression is used; disable our decompressor. a.decompressorV0 = nil @@ -1141,7 +1170,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) { // Only initialize this state once per stream. a.decompressorSet = true } - if err := recv(a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decompressorV1, false); err != nil { + if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, payInfo, a.decompressorV1, false); err != nil { if err == io.EOF { if statusErr := a.transportStream.Status().Err(); statusErr != nil { return statusErr @@ -1163,8 +1192,8 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) { } a.mu.Unlock() } - for _, sh := range a.statsHandlers { - sh.HandleRPC(a.ctx, &stats.InPayload{ + if a.statsHandler != nil { + a.statsHandler.HandleRPC(a.ctx, &stats.InPayload{ Client: true, RecvTime: time.Now(), Payload: m, @@ -1179,7 +1208,7 @@ func (a *csAttempt) recvMsg(m any, payInfo *payloadInfo) (err error) { } // Special handling for non-server-stream rpcs. // This recv expects EOF or errors, so we don't collect inPayload. - if err := recv(a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decompressorV1, false); err == io.EOF { + if err := recv(&a.parser, cs.codec, a.transportStream, a.decompressorV0, m, *cs.callInfo.maxReceiveMessageSize, nil, a.decompressorV1, false); err == io.EOF { return a.transportStream.Status().Err() // non-server streaming Recv returns nil on success } else if err != nil { return toRPCErr(err) @@ -1217,15 +1246,14 @@ func (a *csAttempt) finish(err error) { ServerLoad: balancerload.Parse(tr), }) } - for _, sh := range a.statsHandlers { - end := &stats.End{ + if a.statsHandler != nil { + a.statsHandler.HandleRPC(a.ctx, &stats.End{ Client: true, BeginTime: a.beginTime, EndTime: time.Now(), Trailer: tr, Error: err, - } - sh.HandleRPC(a.ctx, end) + }) } if a.trInfo != nil && a.trInfo.tr != nil { if err == nil { @@ -1322,16 +1350,18 @@ func newNonRetryClientStream(ctx context.Context, desc *StreamDesc, method strin codec: c.codec, sendCompressorV0: cp, sendCompressorV1: comp, + decompressorV0: ac.cc.dopts.dc, transport: t, } - s, err := as.transport.NewStream(as.ctx, as.callHdr) + // nil stats handler: internal streams like health and ORCA do not support telemetry. + s, err := as.transport.NewStream(as.ctx, as.callHdr, nil) if err != nil { err = toRPCErr(err) return nil, err } as.transportStream = s - as.parser = &parser{r: s, bufferPool: ac.dopts.copts.BufferPool} + as.parser = parser{r: s, bufferPool: ac.dopts.copts.BufferPool} ac.incrCallsStarted() if desc != unaryStreamDesc { // Listen on stream context to cleanup when the stream context is @@ -1374,7 +1404,7 @@ type addrConnStream struct { decompressorSet bool decompressorV0 Decompressor decompressorV1 encoding.Compressor - parser *parser + parser parser // mu guards finished and is held for the entire finish method. mu sync.Mutex @@ -1480,6 +1510,10 @@ func (as *addrConnStream) RecvMsg(m any) (err error) { as.decompressorV0 = nil as.decompressorV1 = encoding.GetCompressor(ct) } + // Validate that the compression method is acceptable for this call. + if !acceptedCompressorAllows(as.callInfo.acceptedResponseCompressors, ct) { + return status.Errorf(codes.Internal, "grpc: peer compressed the response with %q which is not allowed by AcceptCompressors", ct) + } } else { // No compression is used; disable our decompressor. as.decompressorV0 = nil @@ -1487,7 +1521,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) { // Only initialize this state once per stream. as.decompressorSet = true } - if err := recv(as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err != nil { + if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err != nil { if err == io.EOF { if statusErr := as.transportStream.Status().Err(); statusErr != nil { return statusErr @@ -1509,7 +1543,7 @@ func (as *addrConnStream) RecvMsg(m any) (err error) { // Special handling for non-server-stream rpcs. // This recv expects EOF or errors, so we don't collect inPayload. - if err := recv(as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err == io.EOF { + if err := recv(&as.parser, as.codec, as.transportStream, as.decompressorV0, m, *as.callInfo.maxReceiveMessageSize, nil, as.decompressorV1, false); err == io.EOF { return as.transportStream.Status().Err() // non-server streaming Recv returns nil on success } else if err != nil { return toRPCErr(err) @@ -1597,7 +1631,7 @@ type ServerStream interface { type serverStream struct { ctx context.Context s *transport.ServerStream - p *parser + p parser codec baseCodec desc *StreamDesc @@ -1614,7 +1648,7 @@ type serverStream struct { maxSendMessageSize int trInfo *traceInfo - statsHandler []stats.Handler + statsHandler stats.Handler binlogs []binarylog.MethodLogger // serverHeaderBinlogged indicates whether server header has been logged. It @@ -1750,10 +1784,8 @@ func (ss *serverStream) SendMsg(m any) (err error) { binlog.Log(ss.ctx, sm) } } - if len(ss.statsHandler) != 0 { - for _, sh := range ss.statsHandler { - sh.HandleRPC(ss.s.Context(), outPayload(false, m, dataLen, payloadLen, time.Now())) - } + if ss.statsHandler != nil { + ss.statsHandler.HandleRPC(ss.s.Context(), outPayload(false, m, dataLen, payloadLen, time.Now())) } return nil } @@ -1784,11 +1816,11 @@ func (ss *serverStream) RecvMsg(m any) (err error) { } }() var payInfo *payloadInfo - if len(ss.statsHandler) != 0 || len(ss.binlogs) != 0 { + if ss.statsHandler != nil || len(ss.binlogs) != 0 { payInfo = &payloadInfo{} defer payInfo.free() } - if err := recv(ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, payInfo, ss.decompressorV1, true); err != nil { + if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, payInfo, ss.decompressorV1, true); err != nil { if err == io.EOF { if len(ss.binlogs) != 0 { chc := &binarylog.ClientHalfClose{} @@ -1808,16 +1840,14 @@ func (ss *serverStream) RecvMsg(m any) (err error) { return toRPCErr(err) } ss.recvFirstMsg = true - if len(ss.statsHandler) != 0 { - for _, sh := range ss.statsHandler { - sh.HandleRPC(ss.s.Context(), &stats.InPayload{ - RecvTime: time.Now(), - Payload: m, - Length: payInfo.uncompressedBytes.Len(), - WireLength: payInfo.compressedLength + headerLen, - CompressedLength: payInfo.compressedLength, - }) - } + if ss.statsHandler != nil { + ss.statsHandler.HandleRPC(ss.s.Context(), &stats.InPayload{ + RecvTime: time.Now(), + Payload: m, + Length: payInfo.uncompressedBytes.Len(), + WireLength: payInfo.compressedLength + headerLen, + CompressedLength: payInfo.compressedLength, + }) } if len(ss.binlogs) != 0 { cm := &binarylog.ClientMessage{ @@ -1834,7 +1864,7 @@ func (ss *serverStream) RecvMsg(m any) (err error) { } // Special handling for non-client-stream rpcs. // This recv expects EOF or errors, so we don't collect inPayload. - if err := recv(ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, nil, ss.decompressorV1, true); err == io.EOF { + if err := recv(&ss.p, ss.codec, ss.s, ss.decompressorV0, m, ss.maxReceiveMessageSize, nil, ss.decompressorV1, true); err == io.EOF { return nil } else if err != nil { return err diff --git a/vendor/google.golang.org/grpc/version.go b/vendor/google.golang.org/grpc/version.go index 76f2e0d06..76c2eed77 100644 --- a/vendor/google.golang.org/grpc/version.go +++ b/vendor/google.golang.org/grpc/version.go @@ -19,4 +19,4 @@ package grpc // Version is the current grpc version. -const Version = "1.76.0" +const Version = "1.79.3" diff --git a/vendor/modules.txt b/vendor/modules.txt index 44e5ca281..d12136ecc 100644 --- a/vendor/modules.txt +++ b/vendor/modules.txt @@ -2732,7 +2732,7 @@ golang.org/x/net/internal/timeseries golang.org/x/net/proxy golang.org/x/net/publicsuffix golang.org/x/net/trace -# golang.org/x/oauth2 v0.33.0 +# golang.org/x/oauth2 v0.34.0 ## explicit; go 1.24.0 golang.org/x/oauth2 golang.org/x/oauth2/clientcredentials @@ -2808,16 +2808,16 @@ gomodules.xyz/wait # gomodules.xyz/x v0.0.17 ## explicit; go 1.22.0 gomodules.xyz/x/filepath -# google.golang.org/genproto/googleapis/api v0.0.0-20250818200422-3122310a409c -## explicit; go 1.23.0 +# google.golang.org/genproto/googleapis/api v0.0.0-20251202230838-ff82c1b0f217 +## explicit; go 1.24.0 google.golang.org/genproto/googleapis/api google.golang.org/genproto/googleapis/api/annotations google.golang.org/genproto/googleapis/api/httpbody -# google.golang.org/genproto/googleapis/rpc v0.0.0-20251029180050-ab9386a59fda +# google.golang.org/genproto/googleapis/rpc v0.0.0-20251202230838-ff82c1b0f217 ## explicit; go 1.24.0 google.golang.org/genproto/googleapis/rpc/errdetails google.golang.org/genproto/googleapis/rpc/status -# google.golang.org/grpc v1.76.0 +# google.golang.org/grpc v1.79.3 ## explicit; go 1.24.0 google.golang.org/grpc google.golang.org/grpc/attributes @@ -2828,7 +2828,6 @@ google.golang.org/grpc/balancer/endpointsharding google.golang.org/grpc/balancer/grpclb/state google.golang.org/grpc/balancer/pickfirst google.golang.org/grpc/balancer/pickfirst/internal -google.golang.org/grpc/balancer/pickfirst/pickfirstleaf google.golang.org/grpc/balancer/roundrobin google.golang.org/grpc/binarylog/grpc_binarylog_v1 google.golang.org/grpc/channelz @@ -2838,6 +2837,7 @@ google.golang.org/grpc/credentials google.golang.org/grpc/credentials/insecure google.golang.org/grpc/encoding google.golang.org/grpc/encoding/gzip +google.golang.org/grpc/encoding/internal google.golang.org/grpc/encoding/proto google.golang.org/grpc/experimental/stats google.golang.org/grpc/grpclog @@ -2847,6 +2847,7 @@ google.golang.org/grpc/health/grpc_health_v1 google.golang.org/grpc/internal google.golang.org/grpc/internal/backoff google.golang.org/grpc/internal/balancer/gracefulswitch +google.golang.org/grpc/internal/balancer/weight google.golang.org/grpc/internal/balancerload google.golang.org/grpc/internal/binarylog google.golang.org/grpc/internal/buffer