Skip to content

Commit 8ff44da

Browse files
fix: return fatal on certain error codes during first stream cycle
Signed-off-by: Alexandra Oberaigner <[email protected]>
1 parent 20b0ccd commit 8ff44da

File tree

13 files changed

+259
-83
lines changed

13 files changed

+259
-83
lines changed

providers/flagd/e2e/inprocess_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func TestInProcessProviderE2E(t *testing.T) {
2626
}
2727

2828
// Run tests with in-process specific tags
29-
tags := "@in-process && ~@unixsocket && ~@metadata && ~@customCert && ~@contextEnrichment && ~@sync-payload"
29+
tags := "@in-process && ~@unixsocket && ~@metadata && ~@customCert && ~@contextEnrichment && ~@sync-payload && ~@sync-port"
3030

3131
if err := runner.RunGherkinTestsWithSubtests(t, featurePaths, tags); err != nil {
3232
t.Fatalf("Gherkin tests failed: %v", err)

providers/flagd/e2e/rpc_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ func TestRPCProviderE2E(t *testing.T) {
2626
}
2727

2828
// Run tests with RPC-specific tags - exclude unimplemented scenarios
29-
tags := "@rpc && ~@unixsocket && ~@targetURI && ~@sync && ~@metadata && ~@grace && ~@customCert && ~@caching"
29+
tags := "@rpc && ~@unixsocket && ~@targetURI && ~@sync && ~@metadata && ~@grace && ~@customCert && ~@caching && ~@forbidden"
3030

3131
if err := runner.RunGherkinTestsWithSubtests(t, featurePaths, tags); err != nil {
3232
t.Fatalf("Gherkin tests failed: %v", err)

providers/flagd/pkg/configuration.go

Lines changed: 45 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,15 @@ package flagd
33
import (
44
"errors"
55
"fmt"
6+
"os"
7+
"strconv"
8+
"strings"
9+
610
"github.com/go-logr/logr"
711
"github.com/open-feature/flagd/core/pkg/sync"
812
"github.com/open-feature/go-sdk-contrib/providers/flagd/internal/cache"
913
"github.com/open-feature/go-sdk-contrib/providers/flagd/internal/logger"
1014
"google.golang.org/grpc"
11-
"os"
12-
"strconv"
13-
"strings"
1415
)
1516

1617
type ResolverType string
@@ -26,6 +27,7 @@ const (
2627
defaultHost = "localhost"
2728
defaultResolver = rpc
2829
defaultGracePeriod = 5
30+
defaultFatalStatusCodes = ""
2931

3032
rpc ResolverType = "rpc"
3133
inProcess ResolverType = "in-process"
@@ -45,6 +47,7 @@ const (
4547
flagdOfflinePathEnvironmentVariableName = "FLAGD_OFFLINE_FLAG_SOURCE_PATH"
4648
flagdTargetUriEnvironmentVariableName = "FLAGD_TARGET_URI"
4749
flagdGracePeriodVariableName = "FLAGD_RETRY_GRACE_PERIOD"
50+
flagdFatalStatusCodesVariableName = "FLAGD_FATAL_STATUS_CODES"
4851
)
4952

5053
type ProviderConfiguration struct {
@@ -66,6 +69,7 @@ type ProviderConfiguration struct {
6669
CustomSyncProviderUri string
6770
GrpcDialOptionsOverride []grpc.DialOption
6871
RetryGracePeriod int
72+
FatalStatusCodes []string
6973

7074
log logr.Logger
7175
}
@@ -80,6 +84,7 @@ func newDefaultConfiguration(log logr.Logger) *ProviderConfiguration {
8084
Resolver: defaultResolver,
8185
Tls: defaultTLS,
8286
RetryGracePeriod: defaultGracePeriod,
87+
FatalStatusCodes: strings.Split(defaultFatalStatusCodes, ","),
8388
}
8489

8590
p.updateFromEnvVar()
@@ -130,6 +135,7 @@ func validateProviderConfiguration(p *ProviderConfiguration) error {
130135

131136
// updateFromEnvVar is a utility to update configurations based on current environment variables
132137
func (cfg *ProviderConfiguration) updateFromEnvVar() {
138+
133139
portS := os.Getenv(flagdPortEnvironmentVariableName)
134140
if portS != "" {
135141
port, err := strconv.Atoi(portS)
@@ -159,17 +165,7 @@ func (cfg *ProviderConfiguration) updateFromEnvVar() {
159165
cfg.CertPath = certificatePath
160166
}
161167

162-
if maxCacheSizeS := os.Getenv(flagdMaxCacheSizeEnvironmentVariableName); maxCacheSizeS != "" {
163-
maxCacheSizeFromEnv, err := strconv.Atoi(maxCacheSizeS)
164-
if err != nil {
165-
cfg.log.Error(err,
166-
fmt.Sprintf("invalid env config for %s provided, using default value: %d",
167-
flagdMaxCacheSizeEnvironmentVariableName, defaultMaxCacheSize,
168-
))
169-
} else {
170-
cfg.MaxCacheSize = maxCacheSizeFromEnv
171-
}
172-
}
168+
cfg.MaxCacheSize = getIntFromEnvVarOrDefault(flagdMaxCacheSizeEnvironmentVariableName, defaultMaxCacheSize, cfg.log)
173169

174170
if cacheValue := os.Getenv(flagdCacheEnvironmentVariableName); cacheValue != "" {
175171
switch cache.Type(cacheValue) {
@@ -185,18 +181,8 @@ func (cfg *ProviderConfiguration) updateFromEnvVar() {
185181
}
186182
}
187183

188-
if maxEventStreamRetriesS := os.Getenv(
189-
flagdMaxEventStreamRetriesEnvironmentVariableName); maxEventStreamRetriesS != "" {
190-
191-
maxEventStreamRetries, err := strconv.Atoi(maxEventStreamRetriesS)
192-
if err != nil {
193-
cfg.log.Error(err,
194-
fmt.Sprintf("invalid env config for %s provided, using default value: %d",
195-
flagdMaxEventStreamRetriesEnvironmentVariableName, defaultMaxEventStreamRetries))
196-
} else {
197-
cfg.EventStreamConnectionMaxAttempts = maxEventStreamRetries
198-
}
199-
}
184+
cfg.EventStreamConnectionMaxAttempts = getIntFromEnvVarOrDefault(
185+
flagdMaxEventStreamRetriesEnvironmentVariableName, defaultMaxEventStreamRetries, cfg.log)
200186

201187
if resolver := os.Getenv(flagdResolverEnvironmentVariableName); resolver != "" {
202188
switch strings.ToLower(resolver) {
@@ -230,12 +216,34 @@ func (cfg *ProviderConfiguration) updateFromEnvVar() {
230216
if gracePeriod := os.Getenv(flagdGracePeriodVariableName); gracePeriod != "" {
231217
if seconds, err := strconv.Atoi(gracePeriod); err == nil {
232218
cfg.RetryGracePeriod = seconds
233-
} else {
234-
// Handle parsing error
235-
cfg.log.Error(err, fmt.Sprintf("invalid grace period '%s'", gracePeriod))
219+
cfg.RetryGracePeriod = getIntFromEnvVarOrDefault(flagdGracePeriodVariableName, defaultGracePeriod, cfg.log)
236220
}
237221
}
238222

223+
if fatalStatusCodes := os.Getenv(flagdFatalStatusCodesVariableName); fatalStatusCodes != "" {
224+
fatalStatusCodesArr := strings.Split(fatalStatusCodes, ",")
225+
for i, fatalStatusCode := range fatalStatusCodesArr {
226+
fatalStatusCodesArr[i] = strings.TrimSpace(fatalStatusCode)
227+
}
228+
cfg.FatalStatusCodes = fatalStatusCodesArr
229+
}
230+
}
231+
232+
// Helper
233+
234+
func getIntFromEnvVarOrDefault(envVarName string, defaultValue int, log logr.Logger) int {
235+
if valueFromEnv := os.Getenv(envVarName); valueFromEnv != "" {
236+
intValue, err := strconv.Atoi(valueFromEnv)
237+
if err != nil {
238+
log.Error(err,
239+
fmt.Sprintf("invalid env config for %s provided, using default value: %d",
240+
envVarName, defaultValue,
241+
))
242+
} else {
243+
return intValue
244+
}
245+
}
246+
return defaultValue
239247
}
240248

241249
// ProviderOptions
@@ -415,3 +423,11 @@ func WithRetryGracePeriod(gracePeriod int) ProviderOption {
415423
p.RetryGracePeriod = gracePeriod
416424
}
417425
}
426+
427+
// WithFatalStatusCodes allows to set a list of gRPC status codes, which will cause streams to give up
428+
// and put the provider in a PROVIDER_FATAL state
429+
func WithFatalStatusCodes(fatalStatusCodes []string) ProviderOption {
430+
return func(p *ProviderConfiguration) {
431+
p.FatalStatusCodes = fatalStatusCodes
432+
}
433+
}

providers/flagd/pkg/provider.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ func NewProvider(opts ...ProviderOption) (*Provider, error) {
7474
CustomSyncProviderUri: provider.providerConfiguration.CustomSyncProviderUri,
7575
GrpcDialOptionsOverride: provider.providerConfiguration.GrpcDialOptionsOverride,
7676
RetryGracePeriod: provider.providerConfiguration.RetryGracePeriod,
77+
FatalStatusCodes: provider.providerConfiguration.FatalStatusCodes,
7778
})
7879
default:
7980
service = process.NewInProcessService(process.Configuration{
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
package process
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"google.golang.org/grpc/codes"
7+
"time"
8+
)
9+
10+
const (
11+
// Default timeouts and retry intervals
12+
defaultKeepaliveTime = 30 * time.Second
13+
defaultKeepaliveTimeout = 5 * time.Second
14+
)
15+
16+
type RetryPolicy struct {
17+
MaxAttempts int `json:"MaxAttempts"`
18+
InitialBackoff string `json:"InitialBackoff"`
19+
MaxBackoff string `json:"MaxBackoff"`
20+
BackoffMultiplier float64 `json:"BackoffMultiplier"`
21+
RetryableStatusCodes []string `json:"RetryableStatusCodes"`
22+
}
23+
24+
func (g *Sync) buildRetryPolicy() string {
25+
var policy = map[string]interface{}{
26+
"methodConfig": []map[string]interface{}{
27+
{
28+
"name": []map[string]string{
29+
{"service": "flagd.sync.v1.FlagSyncService"},
30+
},
31+
"retryPolicy": RetryPolicy{
32+
MaxAttempts: 3,
33+
InitialBackoff: "1s",
34+
MaxBackoff: "5s",
35+
BackoffMultiplier: 2.0,
36+
RetryableStatusCodes: []string{"UNKNOWN", "UNAVAILABLE"},
37+
},
38+
},
39+
},
40+
}
41+
retryPolicyBytes, _ := json.Marshal(policy)
42+
retryPolicy := string(retryPolicyBytes)
43+
44+
return retryPolicy
45+
}
46+
47+
// Set of non-retryable gRPC status codes for faster lookup
48+
var nonRetryableCodes map[codes.Code]struct{}
49+
50+
// initNonRetryableStatusCodesSet initializes the set of non-retryable gRPC status codes for quick lookup
51+
func (g *Sync) initNonRetryableStatusCodesSet() {
52+
nonRetryableCodes = make(map[codes.Code]struct{})
53+
54+
for _, codeStr := range g.FatalStatusCodes {
55+
// Wrap the string in quotes to match the expected JSON format
56+
jsonStr := fmt.Sprintf(`"%s"`, codeStr)
57+
58+
var code codes.Code
59+
if err := code.UnmarshalJSON([]byte(jsonStr)); err != nil {
60+
g.Logger.Warn(fmt.Sprintf("unknown status code: %s, error: %v", codeStr, err))
61+
continue
62+
}
63+
64+
nonRetryableCodes[code] = struct{}{}
65+
}
66+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
package process
2+
3+
import (
4+
"github.com/open-feature/flagd/core/pkg/logger"
5+
"go.uber.org/zap"
6+
"google.golang.org/grpc/codes"
7+
"testing"
8+
)
9+
10+
func TestSync_initNonRetryableStatusCodesSet(t *testing.T) {
11+
tests := []struct {
12+
name string
13+
fatalStatusCodes []string
14+
expectedCodes []codes.Code
15+
notExpectedCodes []codes.Code
16+
}{
17+
{
18+
name: "valid status codes",
19+
fatalStatusCodes: []string{"UNAVAILABLE", "INTERNAL", "DEADLINE_EXCEEDED"},
20+
expectedCodes: []codes.Code{codes.Unavailable, codes.Internal, codes.DeadlineExceeded},
21+
notExpectedCodes: []codes.Code{codes.OK, codes.Unknown},
22+
},
23+
{
24+
name: "empty array",
25+
fatalStatusCodes: []string{},
26+
expectedCodes: []codes.Code{},
27+
notExpectedCodes: []codes.Code{codes.Unavailable, codes.Internal},
28+
},
29+
{
30+
name: "invalid status codes",
31+
fatalStatusCodes: []string{"INVALID_CODE", "UNKNOWN_STATUS"},
32+
expectedCodes: []codes.Code{},
33+
notExpectedCodes: []codes.Code{codes.Unavailable, codes.Internal},
34+
},
35+
{
36+
name: "mixed valid and invalid codes",
37+
fatalStatusCodes: []string{"UNAVAILABLE", "INVALID_CODE", "INTERNAL"},
38+
expectedCodes: []codes.Code{codes.Unavailable, codes.Internal},
39+
notExpectedCodes: []codes.Code{codes.OK, codes.Unknown},
40+
},
41+
}
42+
43+
for _, tt := range tests {
44+
t.Run(tt.name, func(t *testing.T) {
45+
// Reset the global map before each test
46+
nonRetryableCodes = nil
47+
48+
s := &Sync{
49+
FatalStatusCodes: tt.fatalStatusCodes,
50+
Logger: &logger.Logger{
51+
Logger: zap.NewNop(),
52+
},
53+
}
54+
55+
s.initNonRetryableStatusCodesSet()
56+
57+
// Verify expected codes are present
58+
for _, code := range tt.expectedCodes {
59+
if _, exists := nonRetryableCodes[code]; !exists {
60+
t.Errorf("expected code %v to be in nonRetryableCodes, but it was not found", code)
61+
}
62+
}
63+
64+
// Verify not expected codes are absent
65+
for _, code := range tt.notExpectedCodes {
66+
if _, exists := nonRetryableCodes[code]; exists {
67+
t.Errorf("did not expect code %v to be in nonRetryableCodes, but it was found", code)
68+
}
69+
}
70+
71+
// Verify the map size matches expected
72+
if len(nonRetryableCodes) != len(tt.expectedCodes) {
73+
t.Errorf("expected map size %d, got %d", len(tt.expectedCodes), len(nonRetryableCodes))
74+
}
75+
})
76+
}
77+
}

0 commit comments

Comments
 (0)