diff --git a/cmd/dev_server/flags.go b/cmd/dev_server/flags.go index 16b0edc6..01ba78d8 100644 --- a/cmd/dev_server/flags.go +++ b/cmd/dev_server/flags.go @@ -2,5 +2,6 @@ package dev_server const ( ContextFlag = "context" + OverrideFlag = "override" SourceEnvironmentFlag = "source" ) diff --git a/cmd/dev_server/start_server.go b/cmd/dev_server/start_server.go index 0f5d881b..f990cba6 100644 --- a/cmd/dev_server/start_server.go +++ b/cmd/dev_server/start_server.go @@ -2,6 +2,7 @@ package dev_server import ( "context" + "encoding/json" "errors" "log" "os/exec" @@ -10,10 +11,12 @@ import ( "github.com/spf13/cobra" "github.com/spf13/viper" + "github.com/launchdarkly/go-sdk-common/v3/ldcontext" "github.com/launchdarkly/ldcli/cmd/cliflags" resourcescmd "github.com/launchdarkly/ldcli/cmd/resources" "github.com/launchdarkly/ldcli/cmd/validators" "github.com/launchdarkly/ldcli/internal/dev_server" + "github.com/launchdarkly/ldcli/internal/dev_server/model" ) func NewStartServerCmd(client dev_server.Client) *cobra.Command { @@ -28,17 +31,61 @@ func NewStartServerCmd(client dev_server.Client) *cobra.Command { cmd.SetUsageTemplate(resourcescmd.SubcommandUsageTemplate()) + cmd.Flags().String(cliflags.ProjectFlag, "", "The project key") + _ = viper.BindPFlag(cliflags.ProjectFlag, cmd.Flags().Lookup(cliflags.ProjectFlag)) + + cmd.Flags().String(SourceEnvironmentFlag, "", "environment to copy flag values from") + _ = viper.BindPFlag(SourceEnvironmentFlag, cmd.Flags().Lookup(SourceEnvironmentFlag)) + + cmd.Flags().String(ContextFlag, "", `Stringified JSON representation of your context object ex. {"kind": "multi", "user": { "email": "test@gmail.com", "username": "foo", "key": "bar"}`) + _ = viper.BindPFlag(ContextFlag, cmd.Flags().Lookup(ContextFlag)) + + cmd.Flags().String(OverrideFlag, "", `Stringified JSON representation of flag overrides ex. {"flagName": true, "stringFlagName": "test" }`) + _ = viper.BindPFlag(OverrideFlag, cmd.Flags().Lookup(OverrideFlag)) + return cmd } func startServer(client dev_server.Client) func(*cobra.Command, []string) error { return func(cmd *cobra.Command, args []string) error { ctx := context.Background() + + var initialSetting model.InitialProjectSettings + + if viper.IsSet(cliflags.ProjectFlag) && viper.IsSet(SourceEnvironmentFlag) { + + initialSetting = model.InitialProjectSettings{ + Enabled: true, + ProjectKey: viper.GetString(cliflags.ProjectFlag), + EnvKey: viper.GetString(SourceEnvironmentFlag), + } + if viper.IsSet(ContextFlag) { + var c ldcontext.Context + contextString := viper.GetString(ContextFlag) + err := c.UnmarshalJSON([]byte(contextString)) + if err != nil { + return err + } + initialSetting.Context = &c + } + + if viper.IsSet(OverrideFlag) { + var override map[string]model.FlagValue + overrideString := viper.GetString(OverrideFlag) + err := json.Unmarshal([]byte(overrideString), &override) + if err != nil { + return err + } + initialSetting.Overrides = override + } + } + params := dev_server.ServerParams{ - AccessToken: viper.GetString(cliflags.AccessTokenFlag), - BaseURI: viper.GetString(cliflags.BaseURIFlag), - DevStreamURI: viper.GetString(cliflags.DevStreamURIFlag), - Port: viper.GetString(cliflags.PortFlag), + AccessToken: viper.GetString(cliflags.AccessTokenFlag), + BaseURI: viper.GetString(cliflags.BaseURIFlag), + DevStreamURI: viper.GetString(cliflags.DevStreamURIFlag), + Port: viper.GetString(cliflags.PortFlag), + InitialProjectSettings: initialSetting, } client.RunServer(ctx, params) diff --git a/internal/dev_server/adapters/context.go b/internal/dev_server/adapters/context.go new file mode 100644 index 00000000..a0d59611 --- /dev/null +++ b/internal/dev_server/adapters/context.go @@ -0,0 +1,12 @@ +package adapters + +import ( + "context" + ldapi "github.com/launchdarkly/api-client-go/v14" +) + +func WithApiAndSdk(ctx context.Context, client ldapi.APIClient, streamingUrl string) context.Context { + ctx = WithSdk(ctx, newSdk(streamingUrl)) + ctx = WithApi(ctx, NewApi(client)) + return ctx +} diff --git a/internal/dev_server/adapters/middleware.go b/internal/dev_server/adapters/middleware.go index 3f2d0a07..f1a66064 100644 --- a/internal/dev_server/adapters/middleware.go +++ b/internal/dev_server/adapters/middleware.go @@ -13,8 +13,7 @@ func Middleware(client ldapi.APIClient, streamingUrl string) func(handler http.H return func(handler http.Handler) http.Handler { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { ctx := request.Context() - ctx = WithSdk(ctx, newSdk(streamingUrl)) - ctx = WithApi(ctx, NewApi(client)) + ctx = WithApiAndSdk(ctx, client, streamingUrl) request = request.WithContext(ctx) handler.ServeHTTP(writer, request) }) diff --git a/internal/dev_server/dev_server.go b/internal/dev_server/dev_server.go index 7166ec57..0e8712dd 100644 --- a/internal/dev_server/dev_server.go +++ b/internal/dev_server/dev_server.go @@ -25,10 +25,11 @@ type Client interface { } type ServerParams struct { - AccessToken string - BaseURI string - DevStreamURI string - Port string + AccessToken string + BaseURI string + DevStreamURI string + Port string + InitialProjectSettings model.InitialProjectSettings } type LDClient struct { @@ -49,6 +50,7 @@ func (c LDClient) RunServer(ctx context.Context, serverParams ServerParams) { if err != nil { log.Fatal(err) } + observers := model.NewObservers() ss := api.NewStrictServer() apiServer := api.NewStrictHandlerWithOptions(ss, nil, api.StrictHTTPServerOptions{ RequestErrorHandlerFunc: api.RequestErrorHandler, @@ -57,7 +59,7 @@ func (c LDClient) RunServer(ctx context.Context, serverParams ServerParams) { r := mux.NewRouter() r.Use(adapters.Middleware(*ldClient, serverParams.DevStreamURI)) r.Use(model.StoreMiddleware(sqlStore)) - r.Use(model.ObserversMiddleware(model.NewObservers())) + r.Use(model.ObserversMiddleware(observers)) r.Handle("/", http.RedirectHandler("/ui/", http.StatusFound)) r.Handle("/ui", http.RedirectHandler("/ui/", http.StatusMovedPermanently)) r.PathPrefix("/ui/").Handler(http.StripPrefix("/ui/", ui.AssetHandler)) @@ -66,9 +68,18 @@ func (c LDClient) RunServer(ctx context.Context, serverParams ServerParams) { handler = handlers.CombinedLoggingHandler(os.Stdout, handler) handler = handlers.RecoveryHandler(handlers.PrintRecoveryStack(true))(handler) + ctx = adapters.WithApiAndSdk(ctx, *ldClient, serverParams.DevStreamURI) + ctx = model.SetObserversOnContext(ctx, observers) + ctx = model.ContextWithStore(ctx, sqlStore) + syncErr := model.CreateOrSyncProject(ctx, serverParams.InitialProjectSettings) + if syncErr != nil { + log.Fatal(syncErr) + } + addr := fmt.Sprintf("0.0.0.0:%s", serverParams.Port) log.Printf("Server running on %s", addr) log.Printf("Access the UI for toggling overrides at http://localhost:%s/ui or by running `ldcli dev-server ui`", serverParams.Port) + server := http.Server{ Addr: addr, Handler: handler, diff --git a/internal/dev_server/model/sync.go b/internal/dev_server/model/sync.go new file mode 100644 index 00000000..dd0f323a --- /dev/null +++ b/internal/dev_server/model/sync.go @@ -0,0 +1,53 @@ +package model + +import ( + "context" + "log" + + "github.com/pkg/errors" + + "github.com/launchdarkly/go-sdk-common/v3/ldcontext" + "github.com/launchdarkly/go-sdk-common/v3/ldvalue" +) + +type FlagValue = ldvalue.Value + +type InitialProjectSettings struct { + Enabled bool + ProjectKey string + EnvKey string + Context *ldcontext.Context `json:"context,omitempty"` + Overrides map[string]FlagValue `json:"overrides,omitempty"` +} + +func CreateOrSyncProject(ctx context.Context, settings InitialProjectSettings) error { + if !settings.Enabled { + return nil + } + + log.Printf("Initial project [%s] with env [%s]", settings.ProjectKey, settings.EnvKey) + var project Project + project, createError := CreateProject(ctx, settings.ProjectKey, settings.EnvKey, settings.Context) + if createError != nil { + if errors.Is(createError, ErrAlreadyExists) { + log.Printf("Project [%s] exists, refreshing data", settings.ProjectKey) + var updateErr error + project, updateErr = UpdateProject(ctx, settings.ProjectKey, settings.Context, &settings.EnvKey) + if updateErr != nil { + return updateErr + } + + } else { + return createError + } + } + for flagKey, val := range settings.Overrides { + _, err := UpsertOverride(ctx, settings.ProjectKey, flagKey, val) + if err != nil { + return err + } + } + + log.Printf("Successfully synced Initial project [%s]", project.Key) + return nil +} diff --git a/internal/dev_server/model/sync_test.go b/internal/dev_server/model/sync_test.go new file mode 100644 index 00000000..68c49f44 --- /dev/null +++ b/internal/dev_server/model/sync_test.go @@ -0,0 +1,173 @@ +package model_test + +import ( + "context" + "testing" + + "github.com/pkg/errors" + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + ldapi "github.com/launchdarkly/api-client-go/v14" + "github.com/launchdarkly/go-sdk-common/v3/ldcontext" + "github.com/launchdarkly/go-sdk-common/v3/ldvalue" + "github.com/launchdarkly/go-server-sdk/v7/interfaces/flagstate" + adapters_mocks "github.com/launchdarkly/ldcli/internal/dev_server/adapters/mocks" + "github.com/launchdarkly/ldcli/internal/dev_server/model" + "github.com/launchdarkly/ldcli/internal/dev_server/model/mocks" +) + +func TestInitialSync(t *testing.T) { + + ctx := context.Background() + mockController := gomock.NewController(t) + observers := model.NewObservers() + ctx, api, sdk := adapters_mocks.WithMockApiAndSdk(ctx, mockController) + store := mocks.NewMockStore(mockController) + ctx = model.ContextWithStore(ctx, store) + ctx = model.SetObserversOnContext(ctx, observers) + projKey := "proj" + sourceEnvKey := "env" + sdkKey := "thing" + + allFlagsState := flagstate.NewAllFlagsBuilder(). + AddFlag("boolFlag", flagstate.FlagState{Value: ldvalue.Bool(true)}). + Build() + + trueVariationId, falseVariationId := "true", "false" + allFlags := []ldapi.FeatureFlag{{ + Name: "bool flag", + Kind: "bool", + Key: "boolFlag", + Variations: []ldapi.Variation{ + { + Id: &trueVariationId, + Value: true, + }, + { + Id: &falseVariationId, + Value: false, + }, + }, + }} + + t.Run("Returns no error if disabled", func(t *testing.T) { + input := model.InitialProjectSettings{ + Enabled: false, + ProjectKey: projKey, + EnvKey: sourceEnvKey, + Context: nil, + Overrides: nil, + } + err := model.CreateOrSyncProject(ctx, input) + assert.NoError(t, err) + }) + + t.Run("Returns error if it cant fetch flag state", func(t *testing.T) { + api.EXPECT().GetSdkKey(gomock.Any(), projKey, sourceEnvKey).Return("", errors.New("fetch flag state fails")) + input := model.InitialProjectSettings{ + Enabled: true, + ProjectKey: projKey, + EnvKey: sourceEnvKey, + Context: nil, + Overrides: nil, + } + err := model.CreateOrSyncProject(ctx, input) + assert.NotNil(t, err) + assert.Equal(t, "fetch flag state fails", err.Error()) + }) + + t.Run("Returns error if it can't fetch flags", func(t *testing.T) { + api.EXPECT().GetSdkKey(gomock.Any(), projKey, sourceEnvKey).Return(sdkKey, nil) + sdk.EXPECT().GetAllFlagsState(gomock.Any(), gomock.Any(), sdkKey).Return(allFlagsState, nil) + api.EXPECT().GetAllFlags(gomock.Any(), projKey).Return(nil, errors.New("fetch flags failed")) + input := model.InitialProjectSettings{ + Enabled: true, + ProjectKey: projKey, + EnvKey: sourceEnvKey, + Context: nil, + Overrides: nil, + } + err := model.CreateOrSyncProject(ctx, input) + assert.NotNil(t, err) + assert.Equal(t, "fetch flags failed", err.Error()) + }) + + t.Run("Returns error if it fails to insert the project", func(t *testing.T) { + api.EXPECT().GetSdkKey(gomock.Any(), projKey, sourceEnvKey).Return(sdkKey, nil) + sdk.EXPECT().GetAllFlagsState(gomock.Any(), gomock.Any(), sdkKey).Return(allFlagsState, nil) + api.EXPECT().GetAllFlags(gomock.Any(), projKey).Return(allFlags, nil) + store.EXPECT().InsertProject(gomock.Any(), gomock.Any()).Return(errors.New("insert fails")) + + input := model.InitialProjectSettings{ + Enabled: true, + ProjectKey: projKey, + EnvKey: sourceEnvKey, + Context: nil, + Overrides: nil, + } + err := model.CreateOrSyncProject(ctx, input) + assert.NotNil(t, err) + assert.Equal(t, "insert fails", err.Error()) + }) + + t.Run("Successfully creates project", func(t *testing.T) { + api.EXPECT().GetSdkKey(gomock.Any(), projKey, sourceEnvKey).Return(sdkKey, nil) + sdk.EXPECT().GetAllFlagsState(gomock.Any(), gomock.Any(), sdkKey).Return(allFlagsState, nil) + api.EXPECT().GetAllFlags(gomock.Any(), projKey).Return(allFlags, nil) + store.EXPECT().InsertProject(gomock.Any(), gomock.Any()).Return(nil) + + input := model.InitialProjectSettings{ + Enabled: true, + ProjectKey: projKey, + EnvKey: sourceEnvKey, + Context: nil, + Overrides: nil, + } + err := model.CreateOrSyncProject(ctx, input) + + assert.NoError(t, err) + }) + t.Run("Successfully creates project with override", func(t *testing.T) { + override := model.Override{ + ProjectKey: projKey, + FlagKey: "boolFlag", + Value: ldvalue.Bool(true), + Active: true, + Version: 1, + } + + proj := model.Project{ + Key: projKey, + SourceEnvironmentKey: sourceEnvKey, + Context: ldcontext.New(t.Name()), + AllFlagsState: map[string]model.FlagState{ + "boolFlag": { + Version: 0, + Value: ldvalue.Bool(false), + }, + }, + } + + api.EXPECT().GetSdkKey(gomock.Any(), projKey, sourceEnvKey).Return(sdkKey, nil) + sdk.EXPECT().GetAllFlagsState(gomock.Any(), gomock.Any(), sdkKey).Return(allFlagsState, nil) + api.EXPECT().GetAllFlags(gomock.Any(), projKey).Return(allFlags, nil) + store.EXPECT().InsertProject(gomock.Any(), gomock.Any()).Return(nil) + store.EXPECT().UpsertOverride(gomock.Any(), override).Return(override, nil) + store.EXPECT().GetDevProject(gomock.Any(), projKey).Return(&proj, nil) + + input := model.InitialProjectSettings{ + Enabled: true, + ProjectKey: projKey, + EnvKey: sourceEnvKey, + Context: nil, + Overrides: map[string]model.FlagValue{ + "boolFlag": ldvalue.Bool(true), + }, + } + err := model.CreateOrSyncProject(ctx, input) + + assert.NoError(t, err) + }) + +}