Skip to content
68 changes: 68 additions & 0 deletions spanner/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,32 @@ func NewClientWithConfig(ctx context.Context, database string, config ClientConf
return newClientWithConfig(ctx, database, config, nil, opts...)
}

type fallbackWrapper struct {
*grpcgcp.GCPFallback
primaryConn gtransport.ConnPool
fallbackConn gtransport.ConnPool
}

// Conn returns nil because GCPFallback hides the underlying ClientConn.
// The Spanner client handles this by using the interface methods (Invoke/NewStream).
func (fw *fallbackWrapper) Conn() *grpc.ClientConn {
return nil
}

func (fw *fallbackWrapper) Num() int {
return fw.primaryConn.Num()
}

func (fw *fallbackWrapper) Close() error {
fw.GCPFallback.Close()
err1 := fw.primaryConn.Close()
err2 := fw.fallbackConn.Close()
if err1 != nil {
return err1
}
return err2
}

func newClientWithConfig(ctx context.Context, database string, config ClientConfig, gme *grpcgcp.GCPMultiEndpoint, opts ...option.ClientOption) (c *Client, err error) {
// Validate database path.
if err := validDatabaseName(database); err != nil {
Expand Down Expand Up @@ -496,6 +522,48 @@ func newClientWithConfig(ctx context.Context, database string, config ClientConf
// Use GCPMultiEndpoint if provided.
pool = &gmeWrapper{gme}
endpointClientOpts = append(endpointClientOpts, opts...)
} else if isFallbackEnabled, _ := strconv.ParseBool(os.Getenv("GOOGLE_SPANNER_ENABLE_GCP_FALLBACK")); isFallbackEnabled && isDirectPathEnabled {
var primaryConn gtransport.ConnPool
var fallbackConn gtransport.ConnPool
reqIDInjector := new(requestIDHeaderInjector)
opts = append(opts,
option.WithGRPCDialOption(grpc.WithChainStreamInterceptor(reqIDInjector.interceptStream)),
option.WithGRPCDialOption(grpc.WithChainUnaryInterceptor(reqIDInjector.interceptUnary)),
)
allOpts := allClientOpts(config.NumChannels, config.Compression, opts...)
endpointClientOpts = append(endpointClientOpts, allOpts...)
primaryConn, err = gtransport.DialPool(ctx, allOpts...)
if err != nil {
return nil, err
}

fallbackConnOpts := append(allOpts, internaloption.EnableDirectPath(false))
fallbackConn, err = gtransport.DialPool(ctx, fallbackConnOpts...)
if err != nil {
primaryConn.Close()
return nil, err
}

if hasNumChannelsConfig && ((primaryConn.Num() != config.NumChannels) || (fallbackConn.Num() != config.NumChannels)) {
primaryConn.Close()
fallbackConn.Close()
return nil, spannerErrorf(codes.InvalidArgument, "Connection pool mismatch: NumChannels=%v, primaryConn.Num()=%v, fallbackConn.Num()=%v", config.NumChannels, primaryConn.Num(), fallbackConn.Num())
}

fbOpts := grpcgcp.NewGCPFallbackOptions()
fbOpts.EnableFallback = true
fbOpts.ErrorRateThreshold = 1
fbOpts.MinFailedCalls = 1
fbOpts.MeterProvider = config.OpenTelemetryMeterProvider

gcpFallback, err := grpcgcp.NewGCPFallback(ctx, primaryConn, fallbackConn, fbOpts)
if err != nil {
primaryConn.Close()
fallbackConn.Close()
return nil, err
}

pool = &fallbackWrapper{gcpFallback, primaryConn, fallbackConn}
} else {
// Create gtransport ConnPool as usual if MultiEndpoint is not used.
// gRPC options.
Expand Down
2 changes: 2 additions & 0 deletions spanner/sessionclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ func (sc *sessionClient) nextClient() (spannerClient, error) {
if _, ok := sc.connPool.(*gmeWrapper); ok {
// Pass GCPMultiEndpoint as a pool.
clientOpt = gtransport.WithConnPool(sc.connPool)
} else if _, ok := sc.connPool.(*fallbackWrapper); ok {
clientOpt = gtransport.WithConnPool(sc.connPool)
} else {
// Pick a grpc.ClientConn from a regular pool.
conn := sc.connPool.Conn()
Expand Down
Loading