diff --git a/packages/orchestrator/pkg/sandbox/fc/config.go b/packages/orchestrator/pkg/sandbox/fc/config.go index 23cca45da5..c30150bc0e 100644 --- a/packages/orchestrator/pkg/sandbox/fc/config.go +++ b/packages/orchestrator/pkg/sandbox/fc/config.go @@ -1,9 +1,12 @@ package fc import ( + "errors" + "os" "path/filepath" "github.com/e2b-dev/infra/packages/orchestrator/pkg/cfg" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) const ( @@ -31,10 +34,33 @@ func (t Config) SandboxKernelDir() string { } func (t Config) HostKernelPath(config cfg.BuilderConfig) string { + // Prefer arch-prefixed path ({version}/{arch}/vmlinux.bin) for multi-arch support. + // Fall back to legacy flat path ({version}/vmlinux.bin) for existing production nodes. + archPath := filepath.Join(config.HostKernelsDir, t.KernelVersion, utils.TargetArch(), SandboxKernelFile) + if _, err := os.Stat(archPath); err == nil { + return archPath + } else if !errors.Is(err, os.ErrNotExist) { + // Non-existence errors (e.g. permission denied) should not silently fall back + // to the legacy path, as that could use the wrong binary. + return archPath + } + return filepath.Join(config.HostKernelsDir, t.KernelVersion, SandboxKernelFile) } func (t Config) FirecrackerPath(config cfg.BuilderConfig) string { + // Prefer arch-prefixed path ({version}/{arch}/firecracker) for multi-arch support. + // Fall back to legacy flat path ({version}/firecracker) for existing production nodes + // that haven't migrated to the arch-prefixed layout yet. + archPath := filepath.Join(config.FirecrackerVersionsDir, t.FirecrackerVersion, utils.TargetArch(), FirecrackerBinaryName) + if _, err := os.Stat(archPath); err == nil { + return archPath + } else if !errors.Is(err, os.ErrNotExist) { + // Non-existence errors (e.g. permission denied) should not silently fall back + // to the legacy path, as that could use the wrong binary. + return archPath + } + return filepath.Join(config.FirecrackerVersionsDir, t.FirecrackerVersion, FirecrackerBinaryName) } diff --git a/packages/orchestrator/pkg/sandbox/fc/config_test.go b/packages/orchestrator/pkg/sandbox/fc/config_test.go new file mode 100644 index 0000000000..0dff191793 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/fc/config_test.go @@ -0,0 +1,113 @@ +package fc + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/orchestrator/pkg/cfg" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" +) + +func TestFirecrackerPath_ArchPrefixed(t *testing.T) { + t.Parallel() + dir := t.TempDir() + arch := utils.TargetArch() + + // Create the arch-prefixed binary + archDir := filepath.Join(dir, "v1.12.0", arch) + require.NoError(t, os.MkdirAll(archDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(archDir, "firecracker"), []byte("binary"), 0o755)) + + config := cfg.BuilderConfig{FirecrackerVersionsDir: dir} + fc := Config{FirecrackerVersion: "v1.12.0"} + + result := fc.FirecrackerPath(config) + + assert.Equal(t, filepath.Join(dir, "v1.12.0", arch, "firecracker"), result) +} + +func TestFirecrackerPath_LegacyFallback(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + // Only create the legacy flat binary (no arch subdirectory) + require.NoError(t, os.MkdirAll(filepath.Join(dir, "v1.12.0"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "v1.12.0", "firecracker"), []byte("binary"), 0o755)) + + config := cfg.BuilderConfig{FirecrackerVersionsDir: dir} + fc := Config{FirecrackerVersion: "v1.12.0"} + + result := fc.FirecrackerPath(config) + + assert.Equal(t, filepath.Join(dir, "v1.12.0", "firecracker"), result) +} + +func TestFirecrackerPath_NeitherExists(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + // No binary at all — should return legacy flat path + config := cfg.BuilderConfig{FirecrackerVersionsDir: dir} + fc := Config{FirecrackerVersion: "v1.12.0"} + + result := fc.FirecrackerPath(config) + + assert.Equal(t, filepath.Join(dir, "v1.12.0", "firecracker"), result) +} + +func TestHostKernelPath_ArchPrefixed(t *testing.T) { + t.Parallel() + dir := t.TempDir() + arch := utils.TargetArch() + + // Create the arch-prefixed kernel + archDir := filepath.Join(dir, "vmlinux-6.1.102", arch) + require.NoError(t, os.MkdirAll(archDir, 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(archDir, "vmlinux.bin"), []byte("kernel"), 0o644)) + + config := cfg.BuilderConfig{HostKernelsDir: dir} + fc := Config{KernelVersion: "vmlinux-6.1.102"} + + result := fc.HostKernelPath(config) + + assert.Equal(t, filepath.Join(dir, "vmlinux-6.1.102", arch, "vmlinux.bin"), result) +} + +func TestHostKernelPath_LegacyFallback(t *testing.T) { + t.Parallel() + dir := t.TempDir() + + // Only create the legacy flat kernel + require.NoError(t, os.MkdirAll(filepath.Join(dir, "vmlinux-6.1.102"), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "vmlinux-6.1.102", "vmlinux.bin"), []byte("kernel"), 0o644)) + + config := cfg.BuilderConfig{HostKernelsDir: dir} + fc := Config{KernelVersion: "vmlinux-6.1.102"} + + result := fc.HostKernelPath(config) + + assert.Equal(t, filepath.Join(dir, "vmlinux-6.1.102", "vmlinux.bin"), result) +} + +func TestHostKernelPath_PrefersArchOverLegacy(t *testing.T) { + t.Parallel() + dir := t.TempDir() + arch := utils.TargetArch() + + // Create BOTH arch-prefixed and legacy flat kernels + require.NoError(t, os.MkdirAll(filepath.Join(dir, "vmlinux-6.1.102", arch), 0o755)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "vmlinux-6.1.102", arch, "vmlinux.bin"), []byte("arch-kernel"), 0o644)) + require.NoError(t, os.WriteFile(filepath.Join(dir, "vmlinux-6.1.102", "vmlinux.bin"), []byte("legacy-kernel"), 0o644)) + + config := cfg.BuilderConfig{HostKernelsDir: dir} + fc := Config{KernelVersion: "vmlinux-6.1.102"} + + result := fc.HostKernelPath(config) + + // Should prefer the arch-prefixed path + assert.Equal(t, filepath.Join(dir, "vmlinux-6.1.102", arch, "vmlinux.bin"), result) +} diff --git a/packages/orchestrator/pkg/template/build/core/oci/oci.go b/packages/orchestrator/pkg/template/build/core/oci/oci.go index 34f8aa81b5..dc57314c62 100644 --- a/packages/orchestrator/pkg/template/build/core/oci/oci.go +++ b/packages/orchestrator/pkg/template/build/core/oci/oci.go @@ -56,9 +56,12 @@ func (e *ImageTooLargeError) Error() string { ) } -var DefaultPlatform = containerregistry.Platform{ - OS: "linux", - Architecture: "amd64", +// DefaultPlatform returns the OCI platform for image pulls, respecting TARGET_ARCH. +func DefaultPlatform() containerregistry.Platform { + return containerregistry.Platform{ + OS: "linux", + Architecture: utils.TargetArch(), + } } // wrapImagePullError converts technical Docker registry errors into user-friendly messages. @@ -96,7 +99,7 @@ func GetPublicImage(ctx context.Context, dockerhubRepository dockerhub.RemoteRep return nil, fmt.Errorf("invalid image reference '%s': %w", tag, err) } - platform := DefaultPlatform + platform := DefaultPlatform() // When no auth provider is provided and the image is from the default registry // use docker remote repository proxy with cached images @@ -149,7 +152,7 @@ func GetImage(ctx context.Context, artifactRegistry artifactsregistry.ArtifactsR childCtx, childSpan := tracer.Start(ctx, "pull-docker-image") defer childSpan.End() - platform := DefaultPlatform + platform := DefaultPlatform() img, err := artifactRegistry.GetImage(childCtx, templateId, buildId, platform) if err != nil { @@ -469,7 +472,7 @@ func verifyImagePlatform(img containerregistry.Image, platform containerregistry return fmt.Errorf("error getting image config file: %w", err) } if config.Architecture != platform.Architecture { - return fmt.Errorf("image is not %s", platform.Architecture) + return fmt.Errorf("image architecture %q does not match expected %q", config.Architecture, platform.Architecture) } return nil diff --git a/packages/orchestrator/pkg/template/build/core/oci/oci_test.go b/packages/orchestrator/pkg/template/build/core/oci/oci_test.go index aa0fa421cd..5b774f8682 100644 --- a/packages/orchestrator/pkg/template/build/core/oci/oci_test.go +++ b/packages/orchestrator/pkg/template/build/core/oci/oci_test.go @@ -26,6 +26,7 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/dockerhub" templatemanager "github.com/e2b-dev/infra/packages/shared/pkg/grpc/template-manager" "github.com/e2b-dev/infra/packages/shared/pkg/logger" + "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) func createFileTar(t *testing.T, fileName string) *bytes.Buffer { @@ -213,7 +214,7 @@ func TestGetPublicImageWithGeneralAuth(t *testing.T) { // Set the config to include the proper platform configFile, err := testImage.ConfigFile() require.NoError(t, err) - configFile.Architecture = "amd64" + configFile.Architecture = utils.TargetArch() configFile.OS = "linux" testImage, err = mutate.ConfigFile(testImage, configFile) require.NoError(t, err) diff --git a/packages/shared/pkg/utils/env.go b/packages/shared/pkg/utils/env.go index 82305689d7..b5563cbada 100644 --- a/packages/shared/pkg/utils/env.go +++ b/packages/shared/pkg/utils/env.go @@ -3,9 +3,40 @@ package utils import ( "fmt" "os" + "runtime" "strings" + "sync" ) +// archAliases normalizes common architecture names to Go convention. +var archAliases = map[string]string{ + "amd64": "amd64", + "x86_64": "amd64", + "arm64": "arm64", + "aarch64": "arm64", +} + +var archWarningOnce sync.Once + +// TargetArch returns the target architecture for binary paths and OCI platform. +// If TARGET_ARCH is set, it is normalized to Go convention ("amd64" or "arm64"); +// otherwise defaults to the host architecture (runtime.GOARCH). +func TargetArch() string { + if arch := os.Getenv("TARGET_ARCH"); arch != "" { + if normalized, ok := archAliases[arch]; ok { + return normalized + } + + archWarningOnce.Do(func() { + fmt.Fprintf(os.Stderr, "WARNING: unrecognized TARGET_ARCH=%q, falling back to %s\n", arch, runtime.GOARCH) + }) + + return runtime.GOARCH + } + + return runtime.GOARCH +} + // RequiredEnv returns the value of the environment variable for key if it is set, non-empty and not only whitespace. // It panics otherwise. // diff --git a/packages/shared/pkg/utils/env_test.go b/packages/shared/pkg/utils/env_test.go new file mode 100644 index 0000000000..9684ab83eb --- /dev/null +++ b/packages/shared/pkg/utils/env_test.go @@ -0,0 +1,66 @@ +package utils + +import ( + "runtime" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTargetArch_DefaultsToHostArch(t *testing.T) { + t.Setenv("TARGET_ARCH", "") + + result := TargetArch() + + assert.Equal(t, runtime.GOARCH, result) +} + +func TestTargetArch_RespectsValidOverride(t *testing.T) { + tests := []struct { + name string + arch string + expected string + }{ + {name: "amd64", arch: "amd64", expected: "amd64"}, + {name: "arm64", arch: "arm64", expected: "arm64"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv("TARGET_ARCH", tt.arch) + + result := TargetArch() + + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestTargetArch_NormalizesAliases(t *testing.T) { + tests := []struct { + name string + arch string + expected string + }{ + {name: "x86_64 → amd64", arch: "x86_64", expected: "amd64"}, + {name: "aarch64 → arm64", arch: "aarch64", expected: "arm64"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Setenv("TARGET_ARCH", tt.arch) + + result := TargetArch() + + assert.Equal(t, tt.expected, result) + }) + } +} + +func TestTargetArch_FallsBackOnUnknown(t *testing.T) { + t.Setenv("TARGET_ARCH", "mips") + + result := TargetArch() + + assert.Equal(t, runtime.GOARCH, result) +}