diff --git a/internal/devconfig/configfile/ast.go b/internal/devconfig/configfile/ast.go index b11df73347a..f42ab82ccb4 100644 --- a/internal/devconfig/configfile/ast.go +++ b/internal/devconfig/configfile/ast.go @@ -425,3 +425,32 @@ func (c *configAST) beforeComment(path ...any) []byte { }), ) } + +func (c *configAST) createMemberIfMissing(key string) *hujson.ObjectMember { + i := c.memberIndex(c.root.Value.(*hujson.Object), key) + if i == -1 { + c.root.Value.(*hujson.Object).Members = append(c.root.Value.(*hujson.Object).Members, hujson.ObjectMember{ + Name: hujson.Value{Value: hujson.String(key)}, + }) + i = len(c.root.Value.(*hujson.Object).Members) - 1 + } + return &c.root.Value.(*hujson.Object).Members[i] +} + +func mapToObjectMembers(env map[string]string) []hujson.ObjectMember { + members := make([]hujson.ObjectMember, 0, len(env)) + for k, v := range env { + members = append(members, hujson.ObjectMember{ + Name: hujson.Value{Value: hujson.String(k)}, + Value: hujson.Value{Value: hujson.String(v)}, + }) + } + return members +} + +func (c *configAST) setEnv(env map[string]string) { + c.createMemberIfMissing("env").Value.Value = &hujson.Object{ + Members: mapToObjectMembers(env), + } + c.root.Format() +} diff --git a/internal/devconfig/configfile/ast_test.go b/internal/devconfig/configfile/ast_test.go new file mode 100644 index 00000000000..ac63b9094d4 --- /dev/null +++ b/internal/devconfig/configfile/ast_test.go @@ -0,0 +1,69 @@ +package configfile + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/tailscale/hujson" +) + +func TestSetEnv(t *testing.T) { + tests := []struct { + name string + initial string + env map[string]string + expected string + }{ + { + name: "add env to empty config", + initial: "{}", + env: map[string]string{ + "FOO": "bar", + "BAZ": "qux", + }, + expected: `{"env": {"FOO": "bar", "BAZ": "qux"}} +`, + }, + { + name: "update existing env", + initial: `{ + "env": { + "EXISTING": "value" + } +}`, + env: map[string]string{ + "FOO": "bar", + }, + expected: `{ + "env": {"FOO": "bar"} +} +`, + }, + { + name: "clear env with empty map", + initial: `{ + "env": { + "EXISTING": "value" + } +}`, + env: map[string]string{}, + expected: `{ + "env": {} +} +`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + val, err := hujson.Parse([]byte(tt.initial)) + assert.NoError(t, err) + + ast := &configAST{root: val} + ast.setEnv(tt.env) + + actual := string(ast.root.Pack()) + assert.Equal(t, tt.expected, actual) + }) + } +} diff --git a/internal/devconfig/configfile/env.go b/internal/devconfig/configfile/env.go index 83161aed658..5effb5b933f 100644 --- a/internal/devconfig/configfile/env.go +++ b/internal/devconfig/configfile/env.go @@ -44,3 +44,8 @@ func (c *ConfigFile) ParseEnvsFromDotEnv() (map[string]string, error) { return envMap, nil } + +func (c *ConfigFile) SetEnv(env map[string]string) { + c.Env = env + c.ast.setEnv(env) +} diff --git a/pkg/autodetect/autodetect.go b/pkg/autodetect/autodetect.go index a37ee460100..64b114d4d09 100644 --- a/pkg/autodetect/autodetect.go +++ b/pkg/autodetect/autodetect.go @@ -36,6 +36,11 @@ func populateConfig(ctx context.Context, path string, config *devconfig.Config) for _, pkg := range pkgs { config.PackageMutator().Add(pkg) } + env, err := env(ctx, path) + if err != nil { + return err + } + config.Root.SetEnv(env) return nil } @@ -57,6 +62,14 @@ func packages(ctx context.Context, path string) ([]string, error) { return mostRelevantDetector.Packages(ctx) } +func env(ctx context.Context, path string) (map[string]string, error) { + mostRelevantDetector, err := relevantDetector(path) + if err != nil || mostRelevantDetector == nil { + return nil, err + } + return mostRelevantDetector.Env(ctx) +} + // relevantDetector returns the most relevant detector for the given path. // We could modify this to return a list of detectors and their scores or // possibly grouped detectors by category (e.g. python, server, etc.) diff --git a/pkg/autodetect/detector/detector.go b/pkg/autodetect/detector/detector.go index 6388cc9b1e7..bcdbc78928f 100644 --- a/pkg/autodetect/detector/detector.go +++ b/pkg/autodetect/detector/detector.go @@ -5,4 +5,5 @@ import "context" type Detector interface { Relevance(path string) (float64, error) Packages(ctx context.Context) ([]string, error) + Env(ctx context.Context) (map[string]string, error) } diff --git a/pkg/autodetect/detector/go.go b/pkg/autodetect/detector/go.go index d7c44f007ca..cdd49ffc9ec 100644 --- a/pkg/autodetect/detector/go.go +++ b/pkg/autodetect/detector/go.go @@ -38,6 +38,10 @@ func (d *GoDetector) Packages(ctx context.Context) ([]string, error) { return []string{"go@" + goVersion}, nil } +func (d *GoDetector) Env(ctx context.Context) (map[string]string, error) { + return map[string]string{}, nil +} + func parseGoVersion(goModContent string) string { // Use a regular expression to find the Go version directive re := regexp.MustCompile(`(?m)^go\s+(\d+\.\d+(\.\d+)?)`) diff --git a/pkg/autodetect/detector/nodejs.go b/pkg/autodetect/detector/nodejs.go index 785f6fa2cfa..5becaa4169b 100644 --- a/pkg/autodetect/detector/nodejs.go +++ b/pkg/autodetect/detector/nodejs.go @@ -41,6 +41,10 @@ func (d *NodeJSDetector) Packages(ctx context.Context) ([]string, error) { return []string{"nodejs@" + d.nodeVersion(ctx)}, nil } +func (d *NodeJSDetector) Env(ctx context.Context) (map[string]string, error) { + return map[string]string{"DEVBOX_COREPACK_ENABLED": "1"}, nil +} + func (d *NodeJSDetector) nodeVersion(ctx context.Context) string { if d.packageJSON == nil || d.packageJSON.Engines.Node == "" { return "latest" // Default to latest if not specified diff --git a/pkg/autodetect/detector/nodejs_test.go b/pkg/autodetect/detector/nodejs_test.go index 9d63e2e75d6..7b84e8dbecb 100644 --- a/pkg/autodetect/detector/nodejs_test.go +++ b/pkg/autodetect/detector/nodejs_test.go @@ -17,6 +17,7 @@ func TestNodeJSDetector_Relevance(t *testing.T) { fs fstest.MapFS expected float64 expectedPackages []string + expectedEnv map[string]string }{ { name: "package.json in root", @@ -27,6 +28,7 @@ func TestNodeJSDetector_Relevance(t *testing.T) { }, expected: 1, expectedPackages: []string{"nodejs@latest"}, + expectedEnv: map[string]string{"DEVBOX_COREPACK_ENABLED": "1"}, }, { name: "package.json with node version", @@ -41,6 +43,7 @@ func TestNodeJSDetector_Relevance(t *testing.T) { }, expected: 1, expectedPackages: []string{"nodejs@18.0.0"}, + expectedEnv: map[string]string{"DEVBOX_COREPACK_ENABLED": "1"}, }, { name: "no nodejs files", @@ -54,12 +57,14 @@ func TestNodeJSDetector_Relevance(t *testing.T) { }, expected: 0, expectedPackages: []string{}, + expectedEnv: map[string]string{}, }, { name: "empty directory", fs: fstest.MapFS{}, expected: 0, expectedPackages: []string{}, + expectedEnv: map[string]string{}, }, } @@ -74,17 +79,21 @@ func TestNodeJSDetector_Relevance(t *testing.T) { require.NoError(t, err) } - d := &NodeJSDetector{Root: dir} - err := d.Init() + detector := &NodeJSDetector{Root: dir} + err := detector.Init() require.NoError(t, err) - score, err := d.Relevance(dir) + score, err := detector.Relevance(dir) require.NoError(t, err) assert.Equal(t, curTest.expected, score) if score > 0 { - packages, err := d.Packages(context.Background()) + packages, err := detector.Packages(context.Background()) require.NoError(t, err) assert.Equal(t, curTest.expectedPackages, packages) + + env, err := detector.Env(context.Background()) + require.NoError(t, err) + assert.Equal(t, curTest.expectedEnv, env) } }) } @@ -96,3 +105,10 @@ func TestNodeJSDetector_Packages(t *testing.T) { require.NoError(t, err) assert.Equal(t, []string{"nodejs@latest"}, packages) } + +func TestNodeJSDetector_Env(t *testing.T) { + d := &NodeJSDetector{} + env, err := d.Env(context.Background()) + require.NoError(t, err) + assert.Equal(t, map[string]string{"DEVBOX_COREPACK_ENABLED": "1"}, env) +} diff --git a/pkg/autodetect/detector/php.go b/pkg/autodetect/detector/php.go index ed01079f50e..fc1a0ee5416 100644 --- a/pkg/autodetect/detector/php.go +++ b/pkg/autodetect/detector/php.go @@ -49,6 +49,10 @@ func (d *PHPDetector) Packages(ctx context.Context) ([]string, error) { return packages, nil } +func (d *PHPDetector) Env(ctx context.Context) (map[string]string, error) { + return map[string]string{}, nil +} + func (d *PHPDetector) phpVersion(ctx context.Context) string { require := d.composerJSON.Require diff --git a/pkg/autodetect/detector/poetry.go b/pkg/autodetect/detector/poetry.go index 65d73854e0e..1533684d551 100644 --- a/pkg/autodetect/detector/poetry.go +++ b/pkg/autodetect/detector/poetry.go @@ -58,6 +58,10 @@ func (d *PoetryDetector) Packages(ctx context.Context) ([]string, error) { return []string{"python@" + pythonVersion, "poetry@" + poetryVersion}, nil } +func (d *PoetryDetector) Env(ctx context.Context) (map[string]string, error) { + return d.PythonDetector.Env(ctx) +} + func determineBestVersion(ctx context.Context, pkg, version string) string { if version == "" { return "latest" diff --git a/pkg/autodetect/detector/python.go b/pkg/autodetect/detector/python.go index a542ab1b4d8..5acbb3d3d56 100644 --- a/pkg/autodetect/detector/python.go +++ b/pkg/autodetect/detector/python.go @@ -28,6 +28,10 @@ func (d *PythonDetector) Packages(ctx context.Context) ([]string, error) { return []string{"python@latest"}, nil } +func (d *PythonDetector) Env(ctx context.Context) (map[string]string, error) { + return map[string]string{}, nil +} + func (d *PythonDetector) maxRelevance() float64 { return 1.0 }