diff --git a/CHANGELOG.md b/CHANGELOG.md index cf22ffb82..45c1c4628 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ - feat: support type aliases and redefinitions [#721](https://github.com/hypermodeinc/modus/pull/721) - feat: support MySQL database connections [#722](https://github.com/hypermodeinc/modus/pull/722) +- chore: refactoring / tests [#723](https://github.com/hypermodeinc/modus/pull/723) ## 2025-01-09 - CLI 0.16.6 diff --git a/runtime/app/app.go b/runtime/app/app.go index 6ef64522d..1c893ccf3 100644 --- a/runtime/app/app.go +++ b/runtime/app/app.go @@ -15,34 +15,75 @@ import ( "path/filepath" "runtime" "sync" - "time" -) -// ShutdownTimeout is the time to wait for the server to shutdown gracefully. -const ShutdownTimeout = 5 * time.Second + "github.com/fatih/color" +) var mu = &sync.RWMutex{} +var config *AppConfig var shuttingDown = false +func init() { + // Set the global color mode + SetOutputColorMode() + + // Create the the app configuration + mu.Lock() + defer mu.Unlock() + config = CreateAppConfig() +} + +// SetOutputColorMode applies the FORCE_COLOR environment variable to override the default color mode. +func SetOutputColorMode() { + forceColor := os.Getenv("FORCE_COLOR") + if forceColor != "" && forceColor != "0" { + color.NoColor = false + } +} + +// Config returns the global app configuration. +func Config() *AppConfig { + mu.RLock() + defer mu.RUnlock() + return config +} + +// SetConfig sets the global app configuration. +// This is typically only called in tests. +func SetConfig(c *AppConfig) { + mu.Lock() + defer mu.Unlock() + config = c +} + +// IsDevEnvironment returns true if the application is running in a development environment. +func IsDevEnvironment() bool { + return Config().IsDevEnvironment() +} + +// IsShuttingDown returns true if the application is in the process of a graceful shutdown. func IsShuttingDown() bool { mu.RLock() defer mu.RUnlock() return shuttingDown } +// SetShuttingDown sets the application to a shutting down state during a graceful shutdown. func SetShuttingDown() { mu.Lock() defer mu.Unlock() shuttingDown = true } +// GetRootSourcePath returns the root path of the source code. +// It is used to trim the paths in stack traces when included in telemetry. func GetRootSourcePath() string { _, filename, _, ok := runtime.Caller(0) if !ok { return "" } - return path.Join(path.Dir(filename), "../") + "/" + return path.Dir(path.Dir(filename)) + "/" } func ModusHomeDir() string { diff --git a/runtime/app/app_test.go b/runtime/app/app_test.go new file mode 100644 index 000000000..16610fd7c --- /dev/null +++ b/runtime/app/app_test.go @@ -0,0 +1,67 @@ +/* + * Copyright 2025 Hypermode Inc. + * Licensed under the terms of the Apache License, Version 2.0 + * See the LICENSE file that accompanied this code for further details. + * + * SPDX-FileCopyrightText: 2025 Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package app_test + +import ( + "os" + "path" + "testing" + + "github.com/fatih/color" + "github.com/hypermodeinc/modus/runtime/app" +) + +func TestGetRootSourcePath(t *testing.T) { + cwd, _ := os.Getwd() + expectedPath := path.Dir(cwd) + "/" + actualPath := app.GetRootSourcePath() + + if actualPath != expectedPath { + t.Errorf("Expected path: %s, but got: %s", expectedPath, actualPath) + } +} +func TestIsShuttingDown(t *testing.T) { + if app.IsShuttingDown() { + t.Errorf("Expected initial state to be not shutting down") + } + + app.SetShuttingDown() + + if !app.IsShuttingDown() { + t.Errorf("Expected state to be shutting down") + } +} + +func TestSetConfig(t *testing.T) { + initialConfig := app.Config() + if initialConfig == nil { + t.Errorf("Expected initial config to be non-nil") + } + + newConfig := &app.AppConfig{} + app.SetConfig(newConfig) + + if app.Config() != newConfig { + t.Errorf("Expected config to be updated") + } +} + +func TestForceColor(t *testing.T) { + if !color.NoColor { + t.Errorf("Expected NoColor to be true") + } + + os.Setenv("FORCE_COLOR", "1") + app.SetOutputColorMode() + + if color.NoColor { + t.Errorf("Expected NoColor to be false") + } +} diff --git a/runtime/app/config.go b/runtime/app/config.go new file mode 100644 index 000000000..4430d9ae7 --- /dev/null +++ b/runtime/app/config.go @@ -0,0 +1,159 @@ +/* + * Copyright 2025 Hypermode Inc. + * Licensed under the terms of the Apache License, Version 2.0 + * See the LICENSE file that accompanied this code for further details. + * + * SPDX-FileCopyrightText: 2025 Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package app + +import ( + "flag" + "fmt" + "os" + "strings" + "time" +) + +type AppConfig struct { + environment string + port int + appPath string + useAwsStorage bool + s3Bucket string + s3Path string + refreshInterval time.Duration + useJsonLogging bool +} + +func (c *AppConfig) Environment() string { + return c.environment +} + +func (c *AppConfig) Port() int { + return c.port +} + +func (c *AppConfig) AppPath() string { + return c.appPath +} + +func (c *AppConfig) UseAwsStorage() bool { + return c.useAwsStorage +} + +func (c *AppConfig) S3Bucket() string { + return c.s3Bucket +} + +func (c *AppConfig) S3Path() string { + return c.s3Path +} + +func (c *AppConfig) RefreshInterval() time.Duration { + return c.refreshInterval +} + +func (c *AppConfig) UseJsonLogging() bool { + return c.useJsonLogging +} + +func (c *AppConfig) IsDevEnvironment() bool { + // support either name (but prefer "dev") + return c.environment == "dev" || c.environment == "development" +} + +func (c *AppConfig) WithEnvironment(environment string) *AppConfig { + cfg := *c + cfg.environment = environment + return &cfg +} + +func (c *AppConfig) WithPort(port int) *AppConfig { + cfg := *c + cfg.port = port + return &cfg +} + +func (c *AppConfig) WithAppPath(appPath string) *AppConfig { + cfg := *c + cfg.appPath = appPath + return &cfg +} + +func (c *AppConfig) WithS3Storage(s3Bucket, s3Path string) *AppConfig { + cfg := *c + cfg.useAwsStorage = true + cfg.s3Bucket = s3Bucket + cfg.s3Path = s3Path + return &cfg +} + +func (c *AppConfig) WithRefreshInterval(interval time.Duration) *AppConfig { + cfg := *c + cfg.refreshInterval = interval + return &cfg +} + +func (c *AppConfig) WithJsonLogging() *AppConfig { + cfg := *c + cfg.useJsonLogging = true + return &cfg +} + +// Creates a new AppConfig instance with default values. +func NewAppConfig() *AppConfig { + return &AppConfig{ + port: 8686, + environment: "prod", + refreshInterval: time.Second * 5, + } +} + +// Creates the app configuration from the command line flags and environment variables. +func CreateAppConfig() *AppConfig { + + cfg := NewAppConfig() + + fs := flag.NewFlagSet("", flag.ContinueOnError) + + fs.StringVar(&cfg.appPath, "appPath", cfg.appPath, "REQUIRED - The path to the Modus app to load and run.") + fs.IntVar(&cfg.port, "port", cfg.port, "The HTTP port to listen on.") + + fs.BoolVar(&cfg.useAwsStorage, "useAwsStorage", cfg.useAwsStorage, "Use AWS S3 for storage instead of the local filesystem.") + fs.StringVar(&cfg.s3Bucket, "s3bucket", cfg.s3Bucket, "The S3 bucket to use, if using AWS storage.") + fs.StringVar(&cfg.s3Path, "s3path", cfg.s3Path, "The path within the S3 bucket to use, if using AWS storage.") + + fs.DurationVar(&cfg.refreshInterval, "refresh", cfg.refreshInterval, "The refresh interval to reload any changes.") + fs.BoolVar(&cfg.useJsonLogging, "jsonlogs", cfg.useJsonLogging, "Use JSON format for logging.") + + var showVersion bool + const versionUsage = "Show the Runtime version number and exit." + fs.BoolVar(&showVersion, "version", false, versionUsage) + fs.BoolVar(&showVersion, "v", false, versionUsage+" (shorthand)") + + args := make([]string, 0, len(os.Args)) + for i := 1; i < len(os.Args); i++ { + if !strings.HasPrefix(os.Args[i], "-test.") { + args = append(args, os.Args[i]) + } + } + + if err := fs.Parse(args); err != nil { + fmt.Fprintf(os.Stderr, "Error parsing command line flags: %v\n", err) + os.Exit(1) + } + + if showVersion { + fmt.Println(ProductVersion()) + os.Exit(0) + } + + if env := os.Getenv("MODUS_ENV"); env != "" { + cfg.environment = env + } + + return cfg +} diff --git a/runtime/app/config_test.go b/runtime/app/config_test.go new file mode 100644 index 000000000..ffa3b794f --- /dev/null +++ b/runtime/app/config_test.go @@ -0,0 +1,135 @@ +/* + * Copyright 2024 Hypermode Inc. + * Licensed under the terms of the Apache License, Version 2.0 + * See the LICENSE file that accompanied this code for further details. + * + * SPDX-FileCopyrightText: 2024 Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package app_test + +import ( + "os" + "reflect" + "testing" + "time" + + "github.com/hypermodeinc/modus/runtime/app" +) + +func TestIsDevEnvironment(t *testing.T) { + tests := []struct { + name string + environment string + expected bool + }{ + {"Environment is dev", "dev", true}, + {"Environment is development", "development", true}, + {"Environment is prod", "prod", false}, + {"Environment is test", "test", false}, + {"Environment is empty", "", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + cfg := app.NewAppConfig().WithEnvironment(tt.environment) + if got := cfg.IsDevEnvironment(); got != tt.expected { + t.Errorf("IsDevEnvironment() = %v, want %v", got, tt.expected) + } + }) + } +} + +func TestNewAppConfig(t *testing.T) { + cfg := app.NewAppConfig() + + if cfg.Port() != 8686 { + t.Errorf("Expected port to be 8686, but got %d", cfg.Port()) + } + + if cfg.RefreshInterval() != time.Second*5 { + t.Errorf("Expected refresh interval to be 5s, but got %s", cfg.RefreshInterval()) + } + + if cfg.Environment() != "prod" { + t.Errorf("Expected environment to be prod, but got %s", cfg.Environment()) + } + + if cfg.IsDevEnvironment() { + t.Errorf("Expected IsDevEnvironment to be false") + } +} + +func TestCreateAppConfig(t *testing.T) { + tests := []struct { + name string + vars map[string]string + args []string + expected *app.AppConfig + }{ + { + name: "default values", + expected: app.NewAppConfig(), + }, + { + name: "custom values", + vars: map[string]string{ + "MODUS_ENV": "dev", + }, + args: []string{ + "-appPath=/path/to/app", + "-port=9090", + "-useAwsStorage=true", + "-s3bucket=my-bucket", + "-s3path=my-path", + "-refresh=10s", + "-jsonlogs=true", + }, + expected: app.NewAppConfig(). + WithEnvironment("dev"). + WithPort(9090). + WithAppPath("/path/to/app"). + WithS3Storage("my-bucket", "my-path"). + WithRefreshInterval(10 * time.Second). + WithJsonLogging(), + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + closer := applySettings(tt.vars, tt.args) + t.Cleanup(closer) + + cfg := app.CreateAppConfig() + + if !reflect.DeepEqual(cfg, tt.expected) { + t.Errorf("Expected config to be %v, but got %v", tt.expected, cfg) + } + }) + } +} + +func applySettings(vars map[string]string, args []string) func() { + originalVars := map[string]string{} + for name, value := range vars { + if originalValue, ok := os.LookupEnv(name); ok { + originalVars[name] = originalValue + } + os.Setenv(name, value) + } + + originalArgs := os.Args + os.Args = append([]string{os.Args[0]}, args...) + + return func() { + for name := range vars { + if value, ok := originalVars[name]; ok { + os.Setenv(name, value) + } else { + os.Unsetenv(name) + } + } + os.Args = originalArgs + } +} diff --git a/runtime/config/version.go b/runtime/app/version.go similarity index 51% rename from runtime/config/version.go rename to runtime/app/version.go index eda8eb55e..01a0d6244 100644 --- a/runtime/config/version.go +++ b/runtime/app/version.go @@ -7,7 +7,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package config +package app import ( "os/exec" @@ -17,6 +17,10 @@ import ( var version string func init() { + adjustVersion() +} + +func adjustVersion() { // The "version" variable is set by the makefile using -ldflags when using "make build" or goreleaser. // If it is not set, then we are running in development mode with "go run" or "go build" without the makefile, // so we will describe the version from git at run time. @@ -27,18 +31,31 @@ func init() { } } -func GetVersionNumber() string { +func VersionNumber() string { return version } -func GetProductVersion() string { - return "Modus Runtime " + GetVersionNumber() +func ProductVersion() string { + return "Modus Runtime " + VersionNumber() } func describeVersion() string { - result, err := exec.Command("git", "describe", "--tags", "--always", "--match", "runtime/*").Output() - if err != nil { - return "(unknown)" + if isShallowGit() { + result, err := exec.Command("git", "rev-parse", "--short", "HEAD").Output() + if err != nil { + return "(unknown)" + } + return "v0.0.0-?-g" + strings.TrimSpace(string(result)) + } else { + result, err := exec.Command("git", "describe", "--tags", "--always", "--match", "runtime/*").Output() + if err != nil { + return "(unknown)" + } + return strings.TrimPrefix(strings.TrimSpace(string(result)), "runtime/") } - return strings.TrimPrefix(strings.TrimSpace(string(result)), "runtime/") +} + +func isShallowGit() bool { + result, err := exec.Command("git", "rev-parse", "--is-shallow-repository").Output() + return err == nil && strings.TrimSpace(string(result)) == "true" } diff --git a/runtime/app/version_test.go b/runtime/app/version_test.go new file mode 100644 index 000000000..0816f3d0e --- /dev/null +++ b/runtime/app/version_test.go @@ -0,0 +1,69 @@ +/* + * Copyright 2025 Hypermode Inc. + * Licensed under the terms of the Apache License, Version 2.0 + * See the LICENSE file that accompanied this code for further details. + * + * SPDX-FileCopyrightText: 2025 Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package app + +import ( + "strings" + "testing" +) + +func TestVersionNumber(t *testing.T) { + original := version + t.Cleanup(func() { + version = original + }) + + tests := []struct { + name string + version string + }{ + {"Empty version", ""}, + {"Non-empty version", "1.2.3"}, + {"Non-empty version with 'v'", "v1.2.3"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + + version = tt.version + adjustVersion() + + v := VersionNumber() + if v == "" { + t.Errorf("Expected version number to be non-empty") + } + + if v[0] != 'v' { + t.Errorf("Expected version number to start with 'v'") + } + + if len(v) < 6 { + t.Errorf("Expected version number to be at least 6 characters long") + } + + if strings.Count(v, ".") < 2 { + t.Errorf("Expected version number to have at least two dots") + } + }) + } +} + +func TestProductVersion(t *testing.T) { + v := VersionNumber() + pv := ProductVersion() + + if !strings.HasSuffix(pv, " "+v) { + t.Errorf("Expected product version to end with version number") + } + + if len(pv) < len(v)+2 { + t.Errorf("Expected product version to be at least %d characters long", len(v)+2) + } +} diff --git a/runtime/aws/config.go b/runtime/aws/config.go deleted file mode 100644 index 40d9cfc60..000000000 --- a/runtime/aws/config.go +++ /dev/null @@ -1,66 +0,0 @@ -/* - * Copyright 2024 Hypermode Inc. - * Licensed under the terms of the Apache License, Version 2.0 - * See the LICENSE file that accompanied this code for further details. - * - * SPDX-FileCopyrightText: 2024 Hypermode Inc. - * SPDX-License-Identifier: Apache-2.0 - */ - -package aws - -import ( - "context" - "fmt" - - hmConfig "github.com/hypermodeinc/modus/runtime/config" - "github.com/hypermodeinc/modus/runtime/logger" - "github.com/hypermodeinc/modus/runtime/utils" - - "github.com/aws/aws-sdk-go-v2/aws" - "github.com/aws/aws-sdk-go-v2/config" - "github.com/aws/aws-sdk-go-v2/service/sts" -) - -var awsConfig aws.Config - -func GetAwsConfig() aws.Config { - return awsConfig -} - -func Initialize(ctx context.Context) { - if !(hmConfig.UseAwsStorage) { - return - } - - err := initialize(ctx) - if err != nil { - logger.Fatal(ctx).Err(err).Msg("Failed to initialize AWS. Exiting.") - } -} - -func initialize(ctx context.Context) error { - span, ctx := utils.NewSentrySpanForCurrentFunc(ctx) - defer span.Finish() - - cfg, err := config.LoadDefaultConfig(ctx) - if err != nil { - return fmt.Errorf("error loading AWS configuration: %w", err) - } - - client := sts.NewFromConfig(cfg) - identity, err := client.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) - if err != nil { - return fmt.Errorf("error getting AWS caller identity: %w", err) - } - - awsConfig = cfg - - logger.Info(ctx). - Str("region", awsConfig.Region). - Str("account", *identity.Account). - Str("userid", *identity.UserId). - Msg("AWS configuration loaded.") - - return nil -} diff --git a/runtime/config/commandline.go b/runtime/config/commandline.go deleted file mode 100644 index c93d4e396..000000000 --- a/runtime/config/commandline.go +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright 2024 Hypermode Inc. - * Licensed under the terms of the Apache License, Version 2.0 - * See the LICENSE file that accompanied this code for further details. - * - * SPDX-FileCopyrightText: 2024 Hypermode Inc. - * SPDX-License-Identifier: Apache-2.0 - */ - -package config - -import ( - "flag" - "fmt" - "os" - "time" -) - -var Port int -var AppPath string -var UseAwsStorage bool -var S3Bucket string -var S3Path string -var RefreshInterval time.Duration -var UseJsonLogging bool - -func parseCommandLineFlags() { - flag.StringVar(&AppPath, "appPath", "", "REQUIRED - The path to the Modus app to load and run.") - flag.IntVar(&Port, "port", 8686, "The HTTP port to listen on.") - - flag.BoolVar(&UseAwsStorage, "useAwsStorage", false, "Use AWS S3 for storage instead of the local filesystem.") - flag.StringVar(&S3Bucket, "s3bucket", "", "The S3 bucket to use, if using AWS storage.") - flag.StringVar(&S3Path, "s3path", "", "The path within the S3 bucket to use, if using AWS storage.") - - flag.DurationVar(&RefreshInterval, "refresh", time.Second*5, "The refresh interval to reload any changes.") - flag.BoolVar(&UseJsonLogging, "jsonlogs", false, "Use JSON format for logging.") - - var showVersion bool - const versionUsage = "Show the Runtime version number and exit." - flag.BoolVar(&showVersion, "version", false, versionUsage) - // flag.BoolVar(&showVersion, "v", false, versionUsage+" (shorthand)") - - flag.Parse() - - if showVersion { - fmt.Println(GetProductVersion()) - os.Exit(0) - } -} diff --git a/runtime/config/commandline_test.go b/runtime/config/commandline_test.go deleted file mode 100644 index 5d4c24370..000000000 --- a/runtime/config/commandline_test.go +++ /dev/null @@ -1,105 +0,0 @@ -/* - * Copyright 2024 Hypermode Inc. - * Licensed under the terms of the Apache License, Version 2.0 - * See the LICENSE file that accompanied this code for further details. - * - * SPDX-FileCopyrightText: 2024 Hypermode Inc. - * SPDX-License-Identifier: Apache-2.0 - */ - -package config - -import ( - "flag" - "os" - "testing" - "time" -) - -func TestParseCommandLineFlags(t *testing.T) { - tests := []struct { - name string - args []string - expectedPort int - expectedAppPath string - expectedUseAwsStorage bool - expectedS3Bucket string - expectedS3Path string - expectedRefreshInterval time.Duration - expectedUseJsonLogging bool - }{ - { - name: "default values", - args: []string{}, - expectedPort: 8686, - expectedAppPath: "", - expectedUseAwsStorage: false, - expectedS3Bucket: "", - expectedS3Path: "", - expectedRefreshInterval: time.Second * 5, - expectedUseJsonLogging: false, - }, - { - name: "custom values", - args: []string{ - "-appPath=/path/to/app", - "-port=9090", - "-useAwsStorage=true", - "-s3bucket=my-bucket", - "-s3path=my-path", - "-refresh=10s", - "-jsonlogs=true", - }, - expectedPort: 9090, - expectedAppPath: "/path/to/app", - expectedUseAwsStorage: true, - expectedS3Bucket: "my-bucket", - expectedS3Path: "my-path", - expectedRefreshInterval: 10 * time.Second, - expectedUseJsonLogging: true, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - // Reset flags and variables - flag.CommandLine = flag.NewFlagSet(os.Args[0], flag.ExitOnError) - Port = 0 - AppPath = "" - UseAwsStorage = false - S3Bucket = "" - S3Path = "" - RefreshInterval = 0 - UseJsonLogging = false - - // Set command line arguments - os.Args = append([]string{os.Args[0]}, tt.args...) - - // Parse flags - parseCommandLineFlags() - - // Check values - if Port != tt.expectedPort { - t.Errorf("expected Port %d, got %d", tt.expectedPort, Port) - } - if AppPath != tt.expectedAppPath { - t.Errorf("expected AppPath %s, got %s", tt.expectedAppPath, AppPath) - } - if UseAwsStorage != tt.expectedUseAwsStorage { - t.Errorf("expected UseAwsStorage %v, got %v", tt.expectedUseAwsStorage, UseAwsStorage) - } - if S3Bucket != tt.expectedS3Bucket { - t.Errorf("expected S3Bucket %s, got %s", tt.expectedS3Bucket, S3Bucket) - } - if S3Path != tt.expectedS3Path { - t.Errorf("expected S3Path %s, got %s", tt.expectedS3Path, S3Path) - } - if RefreshInterval != tt.expectedRefreshInterval { - t.Errorf("expected RefreshInterval %v, got %v", tt.expectedRefreshInterval, RefreshInterval) - } - if UseJsonLogging != tt.expectedUseJsonLogging { - t.Errorf("expected UseJsonLogging %v, got %v", tt.expectedUseJsonLogging, UseJsonLogging) - } - }) - } -} diff --git a/runtime/config/config.go b/runtime/config/config.go deleted file mode 100644 index f81ef673c..000000000 --- a/runtime/config/config.go +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright 2024 Hypermode Inc. - * Licensed under the terms of the Apache License, Version 2.0 - * See the LICENSE file that accompanied this code for further details. - * - * SPDX-FileCopyrightText: 2024 Hypermode Inc. - * SPDX-License-Identifier: Apache-2.0 - */ - -package config - -import ( - "os" - - "github.com/fatih/color" -) - -func Initialize() { - forceColor := os.Getenv("FORCE_COLOR") - if forceColor != "" && forceColor != "0" { - color.NoColor = false - } - - parseCommandLineFlags() - readEnvironmentVariables() -} diff --git a/runtime/config/environment.go b/runtime/config/environment.go deleted file mode 100644 index 07c5a0bf5..000000000 --- a/runtime/config/environment.go +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2024 Hypermode Inc. - * Licensed under the terms of the Apache License, Version 2.0 - * See the LICENSE file that accompanied this code for further details. - * - * SPDX-FileCopyrightText: 2024 Hypermode Inc. - * SPDX-License-Identifier: Apache-2.0 - */ - -package config - -import ( - "os" -) - -/* - -DESIGN NOTES: - -- The MODUS_ENV environment variable is used to determine the environment name. -- We prefer to use short names, "prod", "stage", "dev", etc, but the actual value is arbitrary, so you can use longer names if you prefer. -- If it is not set, the default environment name is "prod". This is a safe-by-default approach. -- It is preferable to actually set the MODUS_ENV to the appropriate environment when running the application. -- During development, the Modus CLI will set the MODUS_ENV to "dev" automatically. -- The "dev" environment is special in several ways, such as relaxed security requirements, and omitting certain telemetry. -- There is nothing special about "prod", other than it is the default. -- You can also use "stage", "test", etc, as needed - but they will behave like "prod". The only difference is the name returned by the health endpoint, logs, and telemetry. - -*/ - -var environment string -var namespace string - -func GetEnvironmentName() string { - return environment -} - -func GetNamespace() string { - return namespace -} - -func readEnvironmentVariables() { - environment = os.Getenv("MODUS_ENV") - - // default to prod - if environment == "" { - environment = "prod" - } - - // If running in Kubernetes, also capture the namespace environment variable. - namespace = os.Getenv("NAMESPACE") -} - -func IsDevEnvironment() bool { - // support either name (but prefer "dev") - return environment == "dev" || environment == "development" -} diff --git a/runtime/config/environment_test.go b/runtime/config/environment_test.go deleted file mode 100644 index 6bc7f26bb..000000000 --- a/runtime/config/environment_test.go +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright 2024 Hypermode Inc. - * Licensed under the terms of the Apache License, Version 2.0 - * See the LICENSE file that accompanied this code for further details. - * - * SPDX-FileCopyrightText: 2024 Hypermode Inc. - * SPDX-License-Identifier: Apache-2.0 - */ - -package config - -import ( - "os" - "testing" -) - -func TestEnvironmentNames(t *testing.T) { - tests := []struct { - name string - envValue string - expectedResult string - isDev bool - }{ - { - name: "Environment variable not set", - envValue: "", - expectedResult: "prod", - isDev: false, - }, - { - name: "Environment variable set to dev", - envValue: "dev", - expectedResult: "dev", - isDev: true, - }, - { - name: "Environment variable set to development", - envValue: "development", - expectedResult: "development", - isDev: true, - }, - { - name: "Environment variable set to stage", - envValue: "stage", - expectedResult: "stage", - isDev: false, - }, - { - name: "Environment variable set to test", - envValue: "test", - expectedResult: "test", - isDev: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - os.Setenv("MODUS_ENV", tt.envValue) - readEnvironmentVariables() - result := GetEnvironmentName() - if result != tt.expectedResult { - t.Errorf("Expected environment to be %s, but got %s", tt.expectedResult, result) - } - if IsDevEnvironment() != tt.isDev { - t.Errorf("Expected IsDevEnvironment to be %v, but got %v", tt.isDev, IsDevEnvironment()) - } - }) - } -} diff --git a/runtime/db/db.go b/runtime/db/db.go index 7644a2c6b..23dfd71ee 100644 --- a/runtime/db/db.go +++ b/runtime/db/db.go @@ -16,7 +16,7 @@ import ( "sync" "time" - "github.com/hypermodeinc/modus/runtime/config" + "github.com/hypermodeinc/modus/runtime/app" "github.com/hypermodeinc/modus/runtime/logger" "github.com/hypermodeinc/modus/runtime/utils" @@ -55,7 +55,7 @@ func logDbWarningOrError(ctx context.Context, err error, msg string) { if _, ok := err.(*pgconn.ConnectError); ok { logger.Warn(ctx).Err(err).Msgf("Database connection error. %s", msg) } else if errors.Is(err, errDbNotConfigured) { - if !config.IsDevEnvironment() { + if !app.IsDevEnvironment() { logger.Warn(ctx).Msgf("Database has not been configured. %s", msg) } } else { @@ -393,7 +393,7 @@ func QueryCollectionVectorsFromCheckpoint(ctx context.Context, collectionName, s func Initialize(ctx context.Context) { // this will initialize the pool and start the worker _, err := globalRuntimePostgresWriter.GetPool(ctx) - if err != nil && !config.IsDevEnvironment() { + if err != nil && !app.IsDevEnvironment() { logger.Warn(ctx).Err(err).Msg("Metadata database is not available.") } go globalRuntimePostgresWriter.worker(ctx) diff --git a/runtime/db/inferencehistory.go b/runtime/db/inferencehistory.go index af84e1210..c38fcd82f 100644 --- a/runtime/db/inferencehistory.go +++ b/runtime/db/inferencehistory.go @@ -15,7 +15,7 @@ import ( "time" "github.com/hypermodeinc/modus/lib/manifest" - "github.com/hypermodeinc/modus/runtime/config" + "github.com/hypermodeinc/modus/runtime/app" "github.com/hypermodeinc/modus/runtime/metrics" "github.com/hypermodeinc/modus/runtime/plugins" "github.com/hypermodeinc/modus/runtime/secrets" @@ -184,7 +184,7 @@ func getInferenceDataJson(val any) ([]byte, error) { func WritePluginInfo(ctx context.Context, plugin *plugins.Plugin) { - if config.IsDevEnvironment() { + if app.IsDevEnvironment() { err := writePluginInfoToModusdb(plugin) if err != nil { logDbWarningOrError(ctx, err, "Plugin info not written to ModusDB.") @@ -306,7 +306,7 @@ func WriteInferenceHistoryToDB(ctx context.Context, batch []inferenceHistory) { return } - if config.IsDevEnvironment() { + if app.IsDevEnvironment() { err := writeInferenceHistoryToModusDb(batch) if err != nil { logDbWarningOrError(ctx, err, "Inference history not written to ModusDB.") diff --git a/runtime/db/modusdb.go b/runtime/db/modusdb.go index 1e38e96bf..eb12ccb8f 100644 --- a/runtime/db/modusdb.go +++ b/runtime/db/modusdb.go @@ -15,7 +15,6 @@ import ( "runtime" "github.com/hypermodeinc/modus/runtime/app" - "github.com/hypermodeinc/modus/runtime/config" "github.com/hypermodeinc/modus/runtime/logger" "github.com/hypermodeinc/modusdb" ) @@ -23,7 +22,7 @@ import ( var GlobalModusDbEngine *modusdb.Engine func InitModusDb(ctx context.Context) { - if config.IsDevEnvironment() && runtime.GOOS != "windows" { + if app.IsDevEnvironment() && runtime.GOOS != "windows" { dataDir := filepath.Join(app.ModusHomeDir(), "data") if eng, err := modusdb.NewEngine(modusdb.NewDefaultConfig(dataDir)); err != nil { logger.Fatal(ctx).Err(err).Msg("Failed to initialize modusdb.") diff --git a/runtime/envfiles/envfiles.go b/runtime/envfiles/envfiles.go index 6ade216d1..b735c9e97 100644 --- a/runtime/envfiles/envfiles.go +++ b/runtime/envfiles/envfiles.go @@ -16,7 +16,7 @@ import ( "strings" "sync" - "github.com/hypermodeinc/modus/runtime/config" + "github.com/hypermodeinc/modus/runtime/app" "github.com/hypermodeinc/modus/runtime/logger" "github.com/joho/godotenv" @@ -28,7 +28,7 @@ var originalProcessEnvironmentVariables = os.Environ() // Allow the env files to use short or long names. func getSupportedEnvironmentNames() []string { - environment := config.GetEnvironmentName() + environment := app.Config().Environment() switch strings.ToLower(environment) { case "dev", "development": return []string{"dev", "development"} @@ -70,7 +70,7 @@ func LoadEnvFiles(ctx context.Context) error { files = append(files, ".env") for _, file := range files { - path := filepath.Join(config.AppPath, file) + path := filepath.Join(app.Config().AppPath(), file) if _, err := os.Stat(path); err == nil { if err := godotenv.Load(path); err != nil { logger.Warn(ctx).Err(err).Msgf("Failed to load %s file.", file) diff --git a/runtime/go.mod b/runtime/go.mod index 448422111..f1ba66fe4 100644 --- a/runtime/go.mod +++ b/runtime/go.mod @@ -13,7 +13,6 @@ require ( require ( github.com/OneOfOne/xxhash v1.2.8 github.com/archdx/zerolog-sentry v1.8.5 - github.com/aws/aws-sdk-go-v2 v1.33.0 github.com/aws/aws-sdk-go-v2/config v1.29.1 github.com/aws/aws-sdk-go-v2/service/s3 v1.73.2 github.com/aws/aws-sdk-go-v2/service/sts v1.33.9 @@ -79,6 +78,7 @@ require ( github.com/IBM/sarama v1.45.0 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect github.com/agnivade/levenshtein v1.1.1 // indirect + github.com/aws/aws-sdk-go-v2 v1.33.0 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.7 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.54 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.24 // indirect diff --git a/runtime/graphql/engine/engine.go b/runtime/graphql/engine/engine.go index ea6367a2b..4c870238b 100644 --- a/runtime/graphql/engine/engine.go +++ b/runtime/graphql/engine/engine.go @@ -19,7 +19,7 @@ import ( "github.com/fatih/color" "github.com/hypermodeinc/modus/lib/metadata" - "github.com/hypermodeinc/modus/runtime/config" + "github.com/hypermodeinc/modus/runtime/app" "github.com/hypermodeinc/modus/runtime/graphql/datasource" "github.com/hypermodeinc/modus/runtime/graphql/schemagen" "github.com/hypermodeinc/modus/runtime/logger" @@ -81,7 +81,7 @@ func generateSchema(ctx context.Context, md *metadata.Metadata) (*gql.Schema, *d } if utils.DebugModeEnabled() { - if config.UseJsonLogging { + if app.Config().UseJsonLogging() { logger.Debug(ctx).Str("schema", generated.Schema).Msg("Generated schema") } else { fmt.Fprintf(os.Stderr, "\n%s\n", color.BlueString(generated.Schema)) diff --git a/runtime/graphql/graphql.go b/runtime/graphql/graphql.go index 2d0b685b3..4f0a89f73 100644 --- a/runtime/graphql/graphql.go +++ b/runtime/graphql/graphql.go @@ -16,7 +16,7 @@ import ( "strconv" "strings" - "github.com/hypermodeinc/modus/runtime/config" + "github.com/hypermodeinc/modus/runtime/app" "github.com/hypermodeinc/modus/runtime/graphql/engine" "github.com/hypermodeinc/modus/runtime/logger" "github.com/hypermodeinc/modus/runtime/manifestdata" @@ -61,7 +61,7 @@ func Initialize() { func handleGraphQLRequest(w http.ResponseWriter, r *http.Request) { // In dev, redirect non-GraphQL requests to the explorer - if config.IsDevEnvironment() && + if app.IsDevEnvironment() && r.Method == http.MethodGet && !strings.Contains(r.Header.Get("Accept"), "application/json") { http.Redirect(w, r, "/explorer", http.StatusTemporaryRedirect) @@ -78,7 +78,7 @@ func handleGraphQLRequest(w http.ResponseWriter, r *http.Request) { http.Error(w, msg, http.StatusBadRequest) // NOTE: We only log these in dev, to avoid a bad actor spamming the logs in prod. - if config.IsDevEnvironment() { + if app.IsDevEnvironment() { logger.Warn(ctx).Err(err).Msg(msg) } return @@ -143,7 +143,7 @@ func handleGraphQLRequest(w http.ResponseWriter, r *http.Request) { _, _ = requestErrors.WriteResponse(w) // NOTE: We only log these in dev, to avoid a bad actor spamming the logs in prod. - if config.IsDevEnvironment() { + if app.IsDevEnvironment() { // cleanup empty arrays from error message before logging errMsg := strings.Replace(err.Error(), ", locations: []", "", 1) errMsg = strings.Replace(errMsg, ", path: []", "", 1) diff --git a/runtime/httpserver/health.go b/runtime/httpserver/health.go index 001b5287d..b7d7ad28d 100644 --- a/runtime/httpserver/health.go +++ b/runtime/httpserver/health.go @@ -12,13 +12,13 @@ package httpserver import ( "net/http" - "github.com/hypermodeinc/modus/runtime/config" + "github.com/hypermodeinc/modus/runtime/app" "github.com/hypermodeinc/modus/runtime/utils" ) var healthHandler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - env := config.GetEnvironmentName() - ver := config.GetVersionNumber() + env := app.Config().Environment() + ver := app.VersionNumber() w.WriteHeader(http.StatusOK) utils.WriteJsonContentHeader(w) _, _ = w.Write([]byte(`{"status":"ok","environment":"` + env + `","version":"` + ver + `"}`)) diff --git a/runtime/httpserver/server.go b/runtime/httpserver/server.go index 968e5bf33..de5c56cab 100644 --- a/runtime/httpserver/server.go +++ b/runtime/httpserver/server.go @@ -19,10 +19,10 @@ import ( "os" "os/signal" "syscall" + "time" "github.com/hypermodeinc/modus/lib/manifest" "github.com/hypermodeinc/modus/runtime/app" - "github.com/hypermodeinc/modus/runtime/config" "github.com/hypermodeinc/modus/runtime/explorer" "github.com/hypermodeinc/modus/runtime/graphql" "github.com/hypermodeinc/modus/runtime/logger" @@ -40,20 +40,25 @@ var urlColor = color.New(color.FgHiCyan) var noticeColor = color.New(color.FgGreen, color.Italic) var warningColor = color.New(color.FgYellow) +// ShutdownTimeout is the time to wait for the server to shutdown gracefully. +const shutdownTimeout = 5 * time.Second + func Start(ctx context.Context, local bool) { + port := app.Config().Port() + if local { // If we are running locally, only listen on localhost. // This prevents getting nagged for firewall permissions each launch. // Listen on IPv4, and also on IPv6 if available. - addresses := []string{fmt.Sprintf("127.0.0.1:%d", config.Port)} + addresses := []string{fmt.Sprintf("127.0.0.1:%d", port)} if isIPv6Available() { - addresses = append(addresses, fmt.Sprintf("[::1]:%d", config.Port)) + addresses = append(addresses, fmt.Sprintf("[::1]:%d", port)) } startHttpServer(ctx, addresses...) } else { // Otherwise, listen on all interfaces. - addr := fmt.Sprintf(":%d", config.Port) + addr := fmt.Sprintf(":%d", port) startHttpServer(ctx, addr) } } @@ -105,7 +110,7 @@ func startHttpServer(ctx context.Context, addresses ...string) { // Shutdown all servers gracefully. for _, server := range servers { - shutdownCtx, shutdownRelease := context.WithTimeout(ctx, app.ShutdownTimeout) + shutdownCtx, shutdownRelease := context.WithTimeout(ctx, shutdownTimeout) defer shutdownRelease() if err := server.Shutdown(shutdownCtx); err != nil { logger.Fatal(ctx).Err(err).Msg("HTTP server shutdown error.") @@ -134,10 +139,12 @@ func GetMainHandler(options ...func(map[string]http.Handler)) http.Handler { "/metrics": metrics.MetricsHandler, } - if config.IsDevEnvironment() { + cfg := app.Config() + if cfg.IsDevEnvironment() { defaultRoutes["/explorer/"] = explorer.ExplorerHandler defaultRoutes["/"] = http.RedirectHandler("/explorer/", http.StatusSeeOther) } + port := cfg.Port() for _, opt := range options { opt(defaultRoutes) @@ -177,7 +184,7 @@ func GetMainHandler(options ...func(map[string]http.Handler)) http.Handler { routes[info.Path] = metrics.InstrumentHandler(handler, name) - url := fmt.Sprintf("http://localhost:%d%s", config.Port, info.Path) + url := fmt.Sprintf("http://localhost:%d%s", port, info.Path) logger.Info(ctx).Str("url", url).Msg("Registered GraphQL endpoint.") endpoints = append(endpoints, endpoint{"GraphQL", name, url}) @@ -188,7 +195,7 @@ func GetMainHandler(options ...func(map[string]http.Handler)) http.Handler { mux.ReplaceRoutes(routes) - if config.IsDevEnvironment() { + if app.IsDevEnvironment() { fmt.Fprintln(os.Stderr) switch len(endpoints) { @@ -201,7 +208,7 @@ func GetMainHandler(options ...func(map[string]http.Handler)) http.Handler { itemColor.Fprintf(os.Stderr, "• %s (%s): ", ep.apiType, ep.name) urlColor.Fprintln(os.Stderr, ep.url) - explorerURL := fmt.Sprintf("http://localhost:%d/explorer", config.Port) + explorerURL := fmt.Sprintf("http://localhost:%d/explorer", port) titleColor.Fprintf(os.Stderr, "\nView endpoint: ") urlColor.Fprintln(os.Stderr, explorerURL) @@ -212,7 +219,7 @@ func GetMainHandler(options ...func(map[string]http.Handler)) http.Handler { urlColor.Fprintln(os.Stderr, ep.url) } - explorerURL := fmt.Sprintf("http://localhost:%d/explorer", config.Port) + explorerURL := fmt.Sprintf("http://localhost:%d/explorer", port) titleColor.Fprintf(os.Stderr, "\nView your endpoints at: ") urlColor.Fprintln(os.Stderr, explorerURL) } diff --git a/runtime/integration_tests/postgresql_integration_test.go b/runtime/integration_tests/postgresql_integration_test.go index fd9567a67..0ad6772f2 100644 --- a/runtime/integration_tests/postgresql_integration_test.go +++ b/runtime/integration_tests/postgresql_integration_test.go @@ -25,7 +25,7 @@ import ( "testing" "time" - "github.com/hypermodeinc/modus/runtime/config" + "github.com/hypermodeinc/modus/runtime/app" "github.com/hypermodeinc/modus/runtime/httpserver" "github.com/hypermodeinc/modus/runtime/services" @@ -108,9 +108,11 @@ func updateManifest(t *testing.T, jsonManifest []byte) func() { func TestMain(m *testing.M) { // setup config - config.AppPath = testPluginsPath - config.RefreshInterval = refreshPluginInterval - config.Port = httpListenPort + cfg := app.NewAppConfig(). + WithAppPath(testPluginsPath). + WithRefreshInterval(refreshPluginInterval). + WithPort(httpListenPort) + app.SetConfig(cfg) // Create the main background context ctx := context.Background() diff --git a/runtime/logger/logger.go b/runtime/logger/logger.go index 11257ebc3..18a94c205 100644 --- a/runtime/logger/logger.go +++ b/runtime/logger/logger.go @@ -16,7 +16,7 @@ import ( "sync" "time" - "github.com/hypermodeinc/modus/runtime/config" + "github.com/hypermodeinc/modus/runtime/app" "github.com/hypermodeinc/modus/runtime/utils" zls "github.com/archdx/zerolog-sentry" @@ -29,7 +29,7 @@ var zlsCloser io.Closer func Initialize() *zerolog.Logger { var writer io.Writer - if config.UseJsonLogging { + if app.Config().UseJsonLogging() { // In JSON mode, we'll log UTC with millisecond precision. // Note that Go uses this specific value for its formatting exemplars. zerolog.TimeFieldFormat = utils.TimeFormat @@ -42,7 +42,7 @@ func Initialize() *zerolog.Logger { // We'll still log with millisecond precision. zerolog.TimeFieldFormat = zerolog.TimeFormatUnixMs consoleWriter := zerolog.ConsoleWriter{Out: os.Stderr} - if config.IsDevEnvironment() { + if app.IsDevEnvironment() { consoleWriter.TimeFormat = "15:04:05.000" consoleWriter.FieldsExclude = []string{ "build_id", @@ -66,9 +66,9 @@ func Initialize() *zerolog.Logger { } // Log the runtime version to every log line, except in development. - if !config.IsDevEnvironment() { + if !app.IsDevEnvironment() { log.Logger = log.Logger.With(). - Str("runtime_version", config.GetVersionNumber()). + Str("runtime_version", app.VersionNumber()). Logger() } diff --git a/runtime/main.go b/runtime/main.go index e168bdf66..9d1af6a98 100644 --- a/runtime/main.go +++ b/runtime/main.go @@ -13,7 +13,6 @@ import ( "context" "github.com/hypermodeinc/modus/runtime/app" - "github.com/hypermodeinc/modus/runtime/config" "github.com/hypermodeinc/modus/runtime/envfiles" "github.com/hypermodeinc/modus/runtime/httpserver" "github.com/hypermodeinc/modus/runtime/logger" @@ -23,17 +22,14 @@ import ( func main() { - // Initialize the configuration - config.Initialize() - // Create the main background context ctx := context.Background() // Initialize the logger log := logger.Initialize() log.Info(). - Str("version", config.GetVersionNumber()). - Str("environment", config.GetEnvironmentName()). + Str("version", app.VersionNumber()). + Str("environment", app.Config().Environment()). Msg("Starting Modus Runtime.") err := envfiles.LoadEnvFiles(ctx) @@ -42,8 +38,7 @@ func main() { } // Initialize Sentry (if enabled) - rootSourcePath := app.GetRootSourcePath() - utils.InitSentry(rootSourcePath) + utils.InitSentry() defer utils.FlushSentryEvents() // Start the background services @@ -51,7 +46,7 @@ func main() { defer services.Stop(ctx) // Set local mode in development - local := config.IsDevEnvironment() + local := app.IsDevEnvironment() // Start the HTTP server to listen for requests. // Note, this function blocks, and handles shutdown gracefully. diff --git a/runtime/middleware/jwt.go b/runtime/middleware/jwt.go index 3c8246a82..f4ce84b28 100644 --- a/runtime/middleware/jwt.go +++ b/runtime/middleware/jwt.go @@ -19,7 +19,7 @@ import ( "os" "strings" - "github.com/hypermodeinc/modus/runtime/config" + "github.com/hypermodeinc/modus/runtime/app" "github.com/hypermodeinc/modus/runtime/envfiles" "github.com/hypermodeinc/modus/runtime/logger" "github.com/hypermodeinc/modus/runtime/utils" @@ -49,7 +49,7 @@ func initKeys(ctx context.Context) { if publicPemKeysJson != "" { keys, err := publicPemKeysJsonToKeys(publicPemKeysJson) if err != nil { - if config.IsDevEnvironment() { + if app.IsDevEnvironment() { logger.Fatal(ctx).Err(err).Msg("Auth PEM public keys deserializing error") } logger.Error(ctx).Err(err).Msg("Auth PEM public keys deserializing error") @@ -60,7 +60,7 @@ func initKeys(ctx context.Context) { if jwksEndpointsJson != "" { keys, err := jwksEndpointsJsonToKeys(ctx, jwksEndpointsJson) if err != nil { - if config.IsDevEnvironment() { + if app.IsDevEnvironment() { logger.Fatal(ctx).Err(err).Msg("Auth JWKS public keys deserializing error") } logger.Error(ctx).Err(err).Msg("Auth JWKS public keys deserializing error") @@ -91,7 +91,7 @@ func HandleJWT(next http.Handler) http.Handler { } if len(globalAuthKeys.getPemPublicKeys()) == 0 && len(globalAuthKeys.getJwksPublicKeys()) == 0 { - if config.IsDevEnvironment() { + if app.IsDevEnvironment() { if tokenStr == "" { next.ServeHTTP(w, r) return diff --git a/runtime/models/hypermode.go b/runtime/models/hypermode.go index d33cf2204..91de74136 100644 --- a/runtime/models/hypermode.go +++ b/runtime/models/hypermode.go @@ -17,7 +17,7 @@ import ( "strings" "github.com/hypermodeinc/modus/lib/manifest" - "github.com/hypermodeinc/modus/runtime/config" + "github.com/hypermodeinc/modus/runtime/app" "github.com/hypermodeinc/modus/runtime/secrets" ) @@ -26,7 +26,7 @@ var _hypermodeModelHost string func getHypermodeModelEndpointUrl(model *manifest.ModelInfo) (string, error) { // In development, use the shared Hypermode model server. // Note: Authentication via the Hypermode CLI is required. - if config.IsDevEnvironment() { + if app.IsDevEnvironment() { endpoint := fmt.Sprintf("https://models.hypermode.host/%s", strings.ToLower(model.SourceModel)) return endpoint, nil } @@ -45,7 +45,7 @@ func getHypermodeModelEndpointUrl(model *manifest.ModelInfo) (string, error) { func authenticateHypermodeModelRequest(ctx context.Context, req *http.Request, connection *manifest.HTTPConnectionInfo) error { // In development, Hypermode models require authentication. - if config.IsDevEnvironment() { + if app.IsDevEnvironment() { return secrets.ApplyAuthToLocalHypermodeModelRequest(ctx, connection, req) } diff --git a/runtime/models/models.go b/runtime/models/models.go index a87e4d0ae..a28bb7985 100644 --- a/runtime/models/models.go +++ b/runtime/models/models.go @@ -17,7 +17,7 @@ import ( "strings" "github.com/hypermodeinc/modus/lib/manifest" - "github.com/hypermodeinc/modus/runtime/config" + "github.com/hypermodeinc/modus/runtime/app" "github.com/hypermodeinc/modus/runtime/db" "github.com/hypermodeinc/modus/runtime/httpclient" "github.com/hypermodeinc/modus/runtime/manifestdata" @@ -93,7 +93,7 @@ func PostToModelEndpoint[TResult any](ctx context.Context, model *manifest.Model var empty TResult var httpe *utils.HttpError if errors.As(err, &httpe) { - if config.IsDevEnvironment() && httpe.StatusCode == http.StatusNotFound { + if app.IsDevEnvironment() && httpe.StatusCode == http.StatusNotFound { return empty, fmt.Errorf("model %s is not available in the local dev environment", model.SourceModel) } } diff --git a/runtime/services/services.go b/runtime/services/services.go index 61a71866f..bbcc5529f 100644 --- a/runtime/services/services.go +++ b/runtime/services/services.go @@ -12,7 +12,6 @@ package services import ( "context" - "github.com/hypermodeinc/modus/runtime/aws" "github.com/hypermodeinc/modus/runtime/collections" "github.com/hypermodeinc/modus/runtime/db" "github.com/hypermodeinc/modus/runtime/dgraphclient" @@ -50,7 +49,6 @@ func Start(ctx context.Context) context.Context { sqlclient.Initialize() dgraphclient.Initialize() neo4jclient.Initialize() - aws.Initialize(ctx) secrets.Initialize(ctx) storage.Initialize(ctx) db.Initialize(ctx) diff --git a/runtime/storage/awsstorage.go b/runtime/storage/awsstorage.go index 540c4c975..4013898ab 100644 --- a/runtime/storage/awsstorage.go +++ b/runtime/storage/awsstorage.go @@ -15,34 +15,57 @@ import ( "io" "path" - "github.com/hypermodeinc/modus/runtime/aws" - "github.com/hypermodeinc/modus/runtime/config" + "github.com/hypermodeinc/modus/runtime/app" "github.com/hypermodeinc/modus/runtime/logger" + "github.com/hypermodeinc/modus/runtime/utils" + "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/s3" + "github.com/aws/aws-sdk-go-v2/service/sts" ) type awsStorageProvider struct { s3Client *s3.Client + s3Bucket string + s3Path string } func (stg *awsStorageProvider) initialize(ctx context.Context) { - if config.S3Bucket == "" { + span, ctx := utils.NewSentrySpanForCurrentFunc(ctx) + defer span.Finish() + + appConfig := app.Config() + stg.s3Bucket = appConfig.S3Bucket() + stg.s3Path = appConfig.S3Path() + + if stg.s3Bucket == "" { logger.Fatal(ctx).Msg("An S3 bucket is required when using AWS storage. Exiting.") } - // Initialize the S3 service client. - // This is safe to hold onto for the lifetime of the application. - // See https://github.com/aws/aws-sdk-go-v2/discussions/2566 - cfg := aws.GetAwsConfig() + cfg, err := config.LoadDefaultConfig(ctx) + if err != nil { + logger.Fatal(ctx).Err(err).Msg("Failed to load AWS configuration. Exiting.") + } + + client := sts.NewFromConfig(cfg) + identity, err := client.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) + if err != nil { + logger.Fatal(ctx).Err(err).Msg("Failed to get AWS caller identity. Exiting.") + } + stg.s3Client = s3.NewFromConfig(cfg) + + logger.Info(ctx). + Str("region", cfg.Region). + Str("account", *identity.Account). + Str("userid", *identity.UserId). + Msg("AWS configuration loaded.") } func (stg *awsStorageProvider) listFiles(ctx context.Context, patterns ...string) ([]FileInfo, error) { - input := &s3.ListObjectsV2Input{ - Bucket: &config.S3Bucket, - Prefix: &config.S3Path, + Bucket: &stg.s3Bucket, + Prefix: &stg.s3Path, } result, err := stg.s3Client.ListObjectsV2(ctx, input) @@ -77,9 +100,9 @@ func (stg *awsStorageProvider) listFiles(ctx context.Context, patterns ...string } func (stg *awsStorageProvider) getFileContents(ctx context.Context, name string) ([]byte, error) { - key := path.Join(config.S3Path, name) + key := path.Join(stg.s3Path, name) input := &s3.GetObjectInput{ - Bucket: &config.S3Bucket, + Bucket: &stg.s3Bucket, Key: &key, } diff --git a/runtime/storage/localstorage.go b/runtime/storage/localstorage.go index 458e596f9..a9d55378f 100644 --- a/runtime/storage/localstorage.go +++ b/runtime/storage/localstorage.go @@ -17,38 +17,41 @@ import ( "path/filepath" "time" - "github.com/hypermodeinc/modus/runtime/config" + "github.com/hypermodeinc/modus/runtime/app" "github.com/hypermodeinc/modus/runtime/logger" "github.com/gofrs/flock" ) type localStorageProvider struct { + appPath string } func (stg *localStorageProvider) initialize(ctx context.Context) { - if config.AppPath == "" { + + stg.appPath = app.Config().AppPath() + if stg.appPath == "" { logger.Fatal(ctx).Msg("The -appPath command line argument is required. Exiting.") } - if _, err := os.Stat(config.AppPath); os.IsNotExist(err) { + if _, err := os.Stat(stg.appPath); os.IsNotExist(err) { logger.Info(ctx). - Str("path", config.AppPath). + Str("path", stg.appPath). Msg("Creating app directory.") - err := os.MkdirAll(config.AppPath, 0755) + err := os.MkdirAll(stg.appPath, 0755) if err != nil { logger.Fatal(ctx).Err(err). Msg("Failed to create local app directory. Exiting.") } } else { logger.Info(ctx). - Str("path", config.AppPath). + Str("path", stg.appPath). Msg("Using local app directory.") } } func (stg *localStorageProvider) listFiles(ctx context.Context, patterns ...string) ([]FileInfo, error) { - entries, err := os.ReadDir(config.AppPath) + entries, err := os.ReadDir(stg.appPath) if err != nil { return nil, fmt.Errorf("failed to list files in storage directory: %w", err) } @@ -86,7 +89,7 @@ func (stg *localStorageProvider) listFiles(ctx context.Context, patterns ...stri } func (stg *localStorageProvider) getFileContents(ctx context.Context, name string) (content []byte, err error) { - path := filepath.Join(config.AppPath, name) + path := filepath.Join(stg.appPath, name) // Acquire a read lock on the file to prevent reading a file that is still being written to. // For example, this can easily happen when using `modus dev` and the user is editing the manifest file. diff --git a/runtime/storage/storage.go b/runtime/storage/storage.go index 7722af831..61a324649 100644 --- a/runtime/storage/storage.go +++ b/runtime/storage/storage.go @@ -13,8 +13,7 @@ import ( "context" "time" - "github.com/hypermodeinc/modus/runtime/config" - + "github.com/hypermodeinc/modus/runtime/app" "github.com/hypermodeinc/modus/runtime/utils" ) @@ -36,7 +35,7 @@ func Initialize(ctx context.Context) { span, ctx := utils.NewSentrySpanForCurrentFunc(ctx) defer span.Finish() - if config.UseAwsStorage { + if app.Config().UseAwsStorage() { provider = &awsStorageProvider{} } else { provider = &localStorageProvider{} diff --git a/runtime/storage/storagemonitor.go b/runtime/storage/storagemonitor.go index 090b12503..728d4423a 100644 --- a/runtime/storage/storagemonitor.go +++ b/runtime/storage/storagemonitor.go @@ -13,7 +13,7 @@ import ( "context" "time" - "github.com/hypermodeinc/modus/runtime/config" + "github.com/hypermodeinc/modus/runtime/app" "github.com/hypermodeinc/modus/runtime/logger" ) @@ -44,7 +44,7 @@ func NewStorageMonitor(patterns ...string) *StorageMonitor { func (sm *StorageMonitor) Start(ctx context.Context) { go func() { - ticker := time.NewTicker(config.RefreshInterval) + ticker := time.NewTicker(app.Config().RefreshInterval()) defer ticker.Stop() var loggedError = false diff --git a/runtime/utils/buffers_test.go b/runtime/utils/buffers_test.go new file mode 100644 index 000000000..31846a5c5 --- /dev/null +++ b/runtime/utils/buffers_test.go @@ -0,0 +1,41 @@ +/* + * Copyright 2025 Hypermode Inc. + * Licensed under the terms of the Apache License, Version 2.0 + * See the LICENSE file that accompanied this code for further details. + * + * SPDX-FileCopyrightText: 2025 Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package utils_test + +import ( + "os" + "testing" + + "github.com/hypermodeinc/modus/runtime/utils" + "github.com/stretchr/testify/assert" +) + +func TestNewOutputBuffers(t *testing.T) { + buffers := utils.NewOutputBuffers() + + outStream := buffers.StdOut() + errStream := buffers.StdErr() + + assert.NotNil(t, outStream) + assert.NotNil(t, errStream) + + assert.NotSame(t, outStream, errStream) + assert.NotSame(t, os.Stdout, outStream) + assert.NotSame(t, os.Stderr, errStream) + + assert.Equal(t, 0, outStream.Cap()) + assert.Equal(t, 0, errStream.Cap()) + + outStream.Write([]byte("Hello, World!")) + errStream.Write([]byte("Hello, Error!")) + + assert.Equal(t, "Hello, World!", outStream.String()) + assert.Equal(t, "Hello, Error!", errStream.String()) +} diff --git a/runtime/utils/cast.go b/runtime/utils/cast.go index 2eabddd8f..6f64da57e 100644 --- a/runtime/utils/cast.go +++ b/runtime/utils/cast.go @@ -106,7 +106,7 @@ func Cast[T any](obj any) (T, error) { if e != nil { return result, e } - result = any(v).(T) + result = any(uintptr(v)).(T) default: return result, fmt.Errorf("unsupported type: %T", obj) } diff --git a/runtime/utils/cast_test.go b/runtime/utils/cast_test.go new file mode 100644 index 000000000..8effe0de9 --- /dev/null +++ b/runtime/utils/cast_test.go @@ -0,0 +1,169 @@ +/* + * Copyright 2025 Hypermode Inc. + * Licensed under the terms of the Apache License, Version 2.0 + * See the LICENSE file that accompanied this code for further details. + * + * SPDX-FileCopyrightText: 2025 Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package utils_test + +import ( + "encoding/json" + "testing" + + "github.com/hypermodeinc/modus/runtime/utils" + "github.com/stretchr/testify/assert" +) + +func testCast[T any](t *testing.T, expected T, inputs ...any) { + for _, input := range inputs { + actual, e := utils.Cast[T](input) + assert.Nil(t, e) + assert.Equal(t, expected, actual, "input: %v (%[1]T)", input) + } +} + +func TestCast_Int(t *testing.T) { + testCast(t, int(42), + 42, + 42.0, + "42", + json.Number("42"), + func() *int { x := 42; return &x }(), + ) +} + +func TestCast_Int8(t *testing.T) { + testCast(t, int8(42), + 42, + 42.0, + "42", + json.Number("42"), + func() *int { x := 42; return &x }(), + ) +} + +func TestCast_Int16(t *testing.T) { + testCast(t, int16(42), + 42, + 42.0, + "42", + json.Number("42"), + func() *int { x := 42; return &x }(), + ) +} + +func TestCast_Int32(t *testing.T) { + testCast(t, int32(42), + 42, + 42.0, + "42", + json.Number("42"), + func() *int { x := 42; return &x }(), + ) +} + +func TestCast_Int64(t *testing.T) { + testCast(t, int64(42), + 42, + 42.0, + "42", + json.Number("42"), + func() *int { x := 42; return &x }(), + ) +} + +func TestCast_Uint(t *testing.T) { + testCast(t, uint(42), + 42, + 42.0, + "42", + json.Number("42"), + func() *int { x := 42; return &x }(), + ) +} + +func TestCast_Uint8(t *testing.T) { + testCast(t, uint8(42), + 42, + 42.0, + "42", + json.Number("42"), + func() *int { x := 42; return &x }(), + ) +} + +func TestCast_Uint16(t *testing.T) { + testCast(t, uint16(42), + 42, + 42.0, + "42", + json.Number("42"), + func() *int { x := 42; return &x }(), + ) +} + +func TestCast_Uint32(t *testing.T) { + testCast(t, uint32(42), + 42, + 42.0, + "42", + json.Number("42"), + func() *int { x := 42; return &x }(), + ) +} + +func TestCast_Uint64(t *testing.T) { + testCast(t, uint64(42), + 42, + 42.0, + "42", + json.Number("42"), + func() *int { x := 42; return &x }(), + ) +} + +func TestCast_Float32(t *testing.T) { + testCast(t, float32(42), + 42, + 42.0, + "42", + json.Number("42"), + func() *int { x := 42; return &x }(), + ) +} + +func TestCast_Float64(t *testing.T) { + testCast(t, float64(42), + 42, + 42.0, + "42", + json.Number("42"), + func() *int { x := 42; return &x }(), + ) +} + +func TestCast_Uintptr(t *testing.T) { + testCast(t, uintptr(42), + 42, + 42.0, + "42", + json.Number("42"), + uintptr(42), + uint32(42), + func() *uint32 { x := uint32(42); return &x }(), + // TODO: *uintptr fails + ) +} + +func TestCast_Bool(t *testing.T) { + testCast(t, true, + 1, + 1.0, + "true", + json.Number("1"), + func() *bool { x := true; return &x }(), + ) +} diff --git a/runtime/utils/cleaner.go b/runtime/utils/cleaner.go index a5422e70c..97bc47ea5 100644 --- a/runtime/utils/cleaner.go +++ b/runtime/utils/cleaner.go @@ -15,6 +15,8 @@ import ( ) type Cleaner interface { + Len() int + Cap() int Clean() error AddCleanup(fn func() error) AddCleaner(c Cleaner) @@ -34,6 +36,14 @@ func NewCleanerN(capacity int) Cleaner { } } +func (c *cleaner) Len() int { + return len(c.cleanupFuncs) +} + +func (c *cleaner) Cap() int { + return cap(c.cleanupFuncs) +} + func (c *cleaner) AddCleanup(fn func() error) { c.cleanupFuncs = append(c.cleanupFuncs, fn) } diff --git a/runtime/utils/cleaner_test.go b/runtime/utils/cleaner_test.go new file mode 100644 index 000000000..834e9074a --- /dev/null +++ b/runtime/utils/cleaner_test.go @@ -0,0 +1,92 @@ +/* + * Copyright 2025 Hypermode Inc. + * Licensed under the terms of the Apache License, Version 2.0 + * See the LICENSE file that accompanied this code for further details. + * + * SPDX-FileCopyrightText: 2025 Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package utils_test + +import ( + "errors" + "testing" + + "github.com/hypermodeinc/modus/runtime/utils" + "github.com/stretchr/testify/assert" +) + +func TestNewCleaner(t *testing.T) { + c := utils.NewCleaner() + + assert.NotNil(t, c) + assert.Equal(t, 0, c.Len()) + assert.Equal(t, 0, c.Cap()) +} + +func TestNewCleanerN(t *testing.T) { + c := utils.NewCleanerN(10) + + assert.NotNil(t, c) + assert.Equal(t, 0, c.Len()) + assert.Equal(t, 10, c.Cap()) +} + +func TestCleaner_AddCleanup(t *testing.T) { + c := utils.NewCleaner() + c.AddCleanup(func() error { return nil }) + c.AddCleanup(func() error { return nil }) + + assert.Equal(t, 2, c.Len()) +} + +func TestCleaner_AddCleaner(t *testing.T) { + c1 := utils.NewCleaner() + c1.AddCleanup(func() error { return nil }) + c1.AddCleanup(func() error { return nil }) + + c2 := utils.NewCleaner() + c2.AddCleanup(func() error { return nil }) + + c1.AddCleaner(c2) + assert.Equal(t, 3, c1.Len()) + + c1.AddCleaner(nil) + assert.Equal(t, 3, c1.Len()) + + c3 := utils.NewCleaner() + c3.AddCleaner(c2) + assert.Equal(t, 1, c3.Len()) +} + +func TestCleaner_Clean(t *testing.T) { + c := utils.NewCleaner() + assert.Nil(t, c.Clean()) + + a := false + c.AddCleanup(func() error { a = true; return nil }) + assert.Nil(t, c.Clean()) + assert.True(t, a) + + a = false + b := false + c.AddCleanup(func() error { b = true; return nil }) + assert.Nil(t, c.Clean()) + assert.True(t, a) + assert.True(t, b) + + c = utils.NewCleaner() + c.AddCleanup(func() error { return errors.New("error") }) + assert.NotNil(t, c.Clean()) + + c = utils.NewCleaner() + c.AddCleanup(func() error { return nil }) + c.AddCleanup(func() error { return errors.New("error") }) + assert.NotNil(t, c.Clean()) + + c = utils.NewCleaner() + c.AddCleanup(func() error { return errors.New("error") }) + c.AddCleanup(func() error { return errors.New("error") }) + assert.NotNil(t, c.Clean()) +} diff --git a/runtime/utils/console_test.go b/runtime/utils/console_test.go new file mode 100644 index 000000000..5316b9580 --- /dev/null +++ b/runtime/utils/console_test.go @@ -0,0 +1,114 @@ +/* + * Copyright 2025 Hypermode Inc. + * Licensed under the terms of the Apache License, Version 2.0 + * See the LICENSE file that accompanied this code for further details. + * + * SPDX-FileCopyrightText: 2025 Hypermode Inc. + * SPDX-License-Identifier: Apache-2.0 + */ + +package utils_test + +import ( + "bytes" + "testing" + + "github.com/hypermodeinc/modus/runtime/utils" + "github.com/stretchr/testify/assert" +) + +func TestIsError(t *testing.T) { + tests := []struct { + logMessage utils.LogMessage + expected bool + }{ + {utils.LogMessage{Level: "debug", Message: "This is a debug message"}, false}, + {utils.LogMessage{Level: "info", Message: "This is an info message"}, false}, + {utils.LogMessage{Level: "warning", Message: "This is a warning message"}, false}, + {utils.LogMessage{Level: "error", Message: "This is an error message"}, true}, + {utils.LogMessage{Level: "fatal", Message: "This is a fatal message"}, true}, + } + + for _, test := range tests { + result := test.logMessage.IsError() + if result != test.expected { + t.Errorf("For log message '%v', expected IsError to be '%v', but got '%v'", test.logMessage, test.expected, result) + } + } +} + +func TestSplitConsoleOutputLine(t *testing.T) { + tests := []struct { + input string + expectedLevel string + expectedMessage string + }{ + {"Debug: This is a debug message", "debug", "This is a debug message"}, + {"Info: This is an info message", "info", "This is an info message"}, + {"Warning: This is a warning message", "warning", "This is a warning message"}, + {"Error: This is an error message", "error", "This is an error message"}, + {"abort: This is a fatal message", "fatal", "This is a fatal message"}, + {"panic: This is another fatal message", "fatal", "This is another fatal message"}, + {"This is a message without level", "", "This is a message without level"}, + } + + for _, test := range tests { + level, message := utils.SplitConsoleOutputLine(test.input) + if level != test.expectedLevel || message != test.expectedMessage { + t.Errorf("For input '%s', expected level '%s' and message '%s', but got level '%s' and message '%s'", test.input, test.expectedLevel, test.expectedMessage, level, message) + } + } +} + +type mockOutputBuffers struct { + stdOut *bytes.Buffer + stdErr *bytes.Buffer +} + +func (m mockOutputBuffers) StdOut() *bytes.Buffer { + return m.stdOut +} + +func (m mockOutputBuffers) StdErr() *bytes.Buffer { + return m.stdErr +} + +func TestTransformConsoleOutput(t *testing.T) { + tests := []struct { + stdOut string + stdErr string + expected []utils.LogMessage + }{ + { + stdOut: "", + stdErr: "", + expected: []utils.LogMessage{}, + }, + { + stdOut: "Info: This is an info message\n", + stdErr: "Error: This is an error message\n", + expected: []utils.LogMessage{ + {Level: "info", Message: "This is an info message"}, + {Level: "error", Message: "This is an error message"}, + }, + }, + { + stdOut: "Debug: This is a debug message\nWarning: This is a warning message\n", + stdErr: "panic: This is a fatal message\n", + expected: []utils.LogMessage{ + {Level: "debug", Message: "This is a debug message"}, + {Level: "warning", Message: "This is a warning message"}, + {Level: "fatal", Message: "This is a fatal message"}, + }, + }, + } + + for _, test := range tests { + buffers := mockOutputBuffers{ + stdOut: bytes.NewBufferString(test.stdOut), + stdErr: bytes.NewBufferString(test.stdErr), + } + result := utils.TransformConsoleOutput(buffers) + assert.Equal(t, test.expected, result) + } +} diff --git a/runtime/utils/sentry.go b/runtime/utils/sentry.go index 47e3c67fa..9cf4a04ed 100644 --- a/runtime/utils/sentry.go +++ b/runtime/utils/sentry.go @@ -17,15 +17,16 @@ import ( "strings" "time" - "github.com/hypermodeinc/modus/runtime/config" + "github.com/hypermodeinc/modus/runtime/app" "github.com/getsentry/sentry-go" ) -var rootSourcePath string var sentryInitialized bool -func InitSentry(rootPath string) { +var rootSourcePath = app.GetRootSourcePath() + +func InitSentry() { // Don't initialize Sentry when running in debug mode. if DebugModeEnabled() { @@ -42,17 +43,16 @@ func InitSentry(rootPath string) { // but default to the environment name from MODUS_ENV. environment := os.Getenv("SENTRY_ENVIRONMENT") if environment == "" { - environment = config.GetEnvironmentName() + environment = app.Config().Environment() } // Allow the Sentry release to be overridden by the SENTRY_RELEASE environment variable, // but default to the Modus version number. release := os.Getenv("SENTRY_RELEASE") if release == "" { - release = config.GetVersionNumber() + release = app.VersionNumber() } - rootSourcePath = rootPath err := sentry.Init(sentry.ClientOptions{ Dsn: dsn, Environment: environment, @@ -155,7 +155,8 @@ func sentryAddExtras(event *sentry.Event) { event.Extra = make(map[string]interface{}) } - ns := config.GetNamespace() + // Capture the k8s namespace environment variable. + ns := os.Getenv("NAMESPACE") if ns != "" { event.Extra["namespace"] = ns }