diff --git a/packages/orchestrator/pkg/sandbox/block/cache_test.go b/packages/orchestrator/pkg/sandbox/block/cache_test.go index 63ebbdda07..56557afafd 100644 --- a/packages/orchestrator/pkg/sandbox/block/cache_test.go +++ b/packages/orchestrator/pkg/sandbox/block/cache_test.go @@ -4,12 +4,14 @@ import ( "bytes" "crypto/rand" "fmt" + "io" "math" "os" "syscall" "testing" "unsafe" + "github.com/google/uuid" "github.com/stretchr/testify/require" "golang.org/x/sys/unix" @@ -218,6 +220,183 @@ func TestEmptyRanges(t *testing.T) { }) } +func TestCacheExportToDiff_ZeroDirtyBlockEmittedAsDirtyPayload(t *testing.T) { + t.Parallel() + + const blockSize = header.RootfsBlockSize + cache, err := NewCache(blockSize, blockSize, t.TempDir()+"/cache", false) + require.NoError(t, err) + + t.Cleanup(func() { + require.NoError(t, cache.Close()) + }) + + zeroBlock := make([]byte, blockSize) + n, err := cache.WriteAt(zeroBlock, 0) + require.NoError(t, err) + require.Equal(t, int(blockSize), n) + + out, err := os.CreateTemp(t.TempDir(), "diff-*") + require.NoError(t, err) + defer out.Close() + + diffMetadata, err := cache.ExportToDiff(t.Context(), out) + require.NoError(t, err) + + require.EqualValues(t, 1, diffMetadata.Dirty.Count(), "zero-filled dirty block should be emitted as dirty payload") + require.EqualValues(t, 0, diffMetadata.Empty.Count(), "zero-filled dirty block should not be tracked in empty metadata") + + stat, err := out.Stat() + require.NoError(t, err) + require.EqualValues(t, blockSize, stat.Size(), "zero-filled dirty block should write block payload bytes") +} + +func TestCacheExportToDiff_ZeroDirtyBlockMapsToSnapshotBuild(t *testing.T) { + t.Parallel() + + const blockSize = header.RootfsBlockSize + cache, err := NewCache(blockSize, blockSize, t.TempDir()+"/cache", false) + require.NoError(t, err) + + t.Cleanup(func() { + require.NoError(t, cache.Close()) + }) + + zeroBlock := make([]byte, blockSize) + n, err := cache.WriteAt(zeroBlock, 0) + require.NoError(t, err) + require.Equal(t, int(blockSize), n) + + out, err := os.CreateTemp(t.TempDir(), "diff-*") + require.NoError(t, err) + defer out.Close() + + diffMetadata, err := cache.ExportToDiff(t.Context(), out) + require.NoError(t, err) + + baseBuildID := uuid.New() + originalHeader, err := header.NewHeader( + header.NewTemplateMetadata(baseBuildID, uint64(blockSize), uint64(blockSize)), + nil, + ) + require.NoError(t, err) + + snapshotBuildID := uuid.New() + diffHeader, err := diffMetadata.ToDiffHeader(t.Context(), originalHeader, snapshotBuildID) + require.NoError(t, err) + + _, _, mappedBuildID, err := diffHeader.GetShiftedMapping(t.Context(), 0) + require.NoError(t, err) + + require.NotNil(t, mappedBuildID) + require.Equal(t, snapshotBuildID, *mappedBuildID, "zero-filled dirty block should map to the snapshot diff when empty detection is skipped") + require.NotEqual(t, uuid.Nil, *mappedBuildID, "zero-filled dirty block should no longer be represented as an empty mapping") +} + +func TestCacheExportToDiff_MixedDirtyBlocksKeepsZeroBlockInDiff(t *testing.T) { + t.Parallel() + + const blockSize = header.RootfsBlockSize + const size = blockSize * 3 + + cache, err := NewCache(size, blockSize, t.TempDir()+"/cache", false) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, cache.Close()) + }) + + zeroBlock := make([]byte, blockSize) + nonZeroBlock := bytes.Repeat([]byte{0xAB}, int(blockSize)) + + _, err = cache.WriteAt(zeroBlock, 0) + require.NoError(t, err) + + _, err = cache.WriteAt(nonZeroBlock, blockSize) + require.NoError(t, err) + + out, err := os.CreateTemp(t.TempDir(), "diff-*") + require.NoError(t, err) + defer out.Close() + + diffMetadata, err := cache.ExportToDiff(t.Context(), out) + require.NoError(t, err) + + require.EqualValues(t, 2, diffMetadata.Dirty.Count()) + require.EqualValues(t, 0, diffMetadata.Empty.Count(), "mixed export should still skip empty tracking for zero-filled dirty blocks") + + _, err = out.Seek(0, io.SeekStart) + require.NoError(t, err) + exported, err := io.ReadAll(out) + require.NoError(t, err) + expected := make([]byte, 0, len(zeroBlock)+len(nonZeroBlock)) + expected = append(expected, zeroBlock...) + expected = append(expected, nonZeroBlock...) + require.Equal(t, expected, exported) + + baseBuildID := uuid.New() + originalHeader, err := header.NewHeader( + header.NewTemplateMetadata(baseBuildID, uint64(blockSize), uint64(size)), + nil, + ) + require.NoError(t, err) + + snapshotBuildID := uuid.New() + diffHeader, err := diffMetadata.ToDiffHeader(t.Context(), originalHeader, snapshotBuildID) + require.NoError(t, err) + + _, _, firstBlockBuildID, err := diffHeader.GetShiftedMapping(t.Context(), 0) + require.NoError(t, err) + require.Equal(t, snapshotBuildID, *firstBlockBuildID, "zero-filled dirty block should still map to the snapshot diff") + + _, _, secondBlockBuildID, err := diffHeader.GetShiftedMapping(t.Context(), blockSize) + require.NoError(t, err) + require.Equal(t, snapshotBuildID, *secondBlockBuildID) + + _, _, thirdBlockBuildID, err := diffHeader.GetShiftedMapping(t.Context(), 2*blockSize) + require.NoError(t, err) + require.Equal(t, baseBuildID, *thirdBlockBuildID, "clean blocks should keep the base mapping") +} + +func TestCacheExportToDiff_NonContiguousDirtyBlocksPreserveRangeOrder(t *testing.T) { + t.Parallel() + + const blockSize = header.RootfsBlockSize + const size = blockSize * 5 + + cache, err := NewCache(size, blockSize, t.TempDir()+"/cache", false) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, cache.Close()) + }) + + firstBlock := bytes.Repeat([]byte{0x11}, int(blockSize)) + secondBlock := bytes.Repeat([]byte{0x22}, int(blockSize)) + + _, err = cache.WriteAt(firstBlock, 0) + require.NoError(t, err) + + _, err = cache.WriteAt(secondBlock, 3*blockSize) + require.NoError(t, err) + + out, err := os.CreateTemp(t.TempDir(), "diff-*") + require.NoError(t, err) + defer out.Close() + + diffMetadata, err := cache.ExportToDiff(t.Context(), out) + require.NoError(t, err) + + require.EqualValues(t, 2, diffMetadata.Dirty.Count()) + require.True(t, diffMetadata.Dirty.Test(0)) + require.True(t, diffMetadata.Dirty.Test(3)) + require.EqualValues(t, 0, diffMetadata.Empty.Count()) + + _, err = out.Seek(0, io.SeekStart) + require.NoError(t, err) + exported, err := io.ReadAll(out) + require.NoError(t, err) + require.Equal(t, append(firstBlock, secondBlock...), exported) +} + func compareData(readBytes []byte, expectedBytes []byte) error { // The bytes.Equal is the first place in this flow that actually touches the uffd managed memory and triggers the pagefault, so any deadlocks will manifest here. if !bytes.Equal(readBytes, expectedBytes) { diff --git a/tests/integration/internal/tests/orchestrator/filesystem_pause_resume_integrity_test.go b/tests/integration/internal/tests/orchestrator/filesystem_pause_resume_integrity_test.go new file mode 100644 index 0000000000..4be2314e3a --- /dev/null +++ b/tests/integration/internal/tests/orchestrator/filesystem_pause_resume_integrity_test.go @@ -0,0 +1,263 @@ +package orchestrator + +import ( + "fmt" + "net/http" + "os" + "strconv" + "strings" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/tests/integration/internal/api" + "github.com/e2b-dev/infra/tests/integration/internal/setup" + "github.com/e2b-dev/infra/tests/integration/internal/utils" +) + +type fsWritePattern string + +const ( + fsWriteContiguous fsWritePattern = "contiguous" + fsWriteScattered fsWritePattern = "scattered" +) + +func TestSandboxFilesystemPauseResumeIntegrity(t *testing.T) { + t.Parallel() + + t.Run("contiguous write hash survives pause", func(t *testing.T) { + t.Parallel() + runFilesystemPauseResumeIntegrityCase(t, filesystemPauseResumeCase{ + name: "contiguous", + filePath: "/home/user/fs-integrity-contiguous.bin", + writeGiB: 0, + writeBytes: 64 * 1024 * 1024, + cycles: 1, + pattern: fsWriteContiguous, + }) + }) + + t.Run("scattered write hash survives pause", func(t *testing.T) { + t.Parallel() + runFilesystemPauseResumeIntegrityCase(t, filesystemPauseResumeCase{ + name: "scattered", + filePath: "/home/user/fs-integrity-scattered.bin", + writeGiB: 0, + writeBytes: 32 * 1024 * 1024, + cycles: 1, + pattern: fsWriteScattered, + }) + }) + + t.Run("zeroed ranges and truncate survive pause", func(t *testing.T) { + t.Parallel() + runFilesystemPauseResumeTruncateCase(t) + }) +} + +func TestSandboxFilesystemPauseResumeIntegrityStress(t *testing.T) { + t.Parallel() + + if os.Getenv("TESTS_FS_INTEGRITY_STRESS") != "1" { + t.Skip("set TESTS_FS_INTEGRITY_STRESS=1 to run the large pause/resume filesystem stress test") + } + + writeGiB := getenvInt(t, "TESTS_FS_INTEGRITY_GIB", 2) + cycles := getenvInt(t, "TESTS_FS_INTEGRITY_CYCLES", 10) + + t.Run("contiguous", func(t *testing.T) { + t.Parallel() + runFilesystemPauseResumeIntegrityCase(t, filesystemPauseResumeCase{ + name: fmt.Sprintf("stress-contiguous-%dGiB", writeGiB), + filePath: fmt.Sprintf("/home/user/fs-integrity-stress-contiguous-%dGiB.bin", writeGiB), + writeGiB: writeGiB, + writeBytes: int64(writeGiB) * 1024 * 1024 * 1024, + cycles: cycles, + pattern: fsWriteContiguous, + }) + }) + + t.Run("scattered", func(t *testing.T) { + t.Parallel() + runFilesystemPauseResumeIntegrityCase(t, filesystemPauseResumeCase{ + name: fmt.Sprintf("stress-scattered-%dGiB", writeGiB), + filePath: fmt.Sprintf("/home/user/fs-integrity-stress-scattered-%dGiB.bin", writeGiB), + writeGiB: writeGiB, + writeBytes: int64(writeGiB) * 1024 * 1024 * 1024, + cycles: cycles, + pattern: fsWriteScattered, + }) + }) +} + +type filesystemPauseResumeCase struct { + name string + filePath string + writeGiB int + writeBytes int64 + cycles int + pattern fsWritePattern +} + +func runFilesystemPauseResumeIntegrityCase(t *testing.T, tc filesystemPauseResumeCase) { + t.Helper() + + c := setup.GetAPIClient() + sbx := utils.SetupSandboxWithCleanup(t, c, utils.WithAutoPause(false)) + sbxID := sbx.SandboxID + envdClient := setup.GetEnvdClient(t, t.Context()) + + exec := func(script string) string { + t.Helper() + out, err := utils.ExecCommandAsRootWithOutput(t, t.Context(), sbx, envdClient, "bash", "-lc", script) + require.NoError(t, err) + + return strings.TrimSpace(out) + } + pause := func() { + t.Helper() + resp, err := c.PostSandboxesSandboxIDPauseWithResponse(t.Context(), sbxID, setup.WithAPIKey()) + require.NoError(t, err) + require.Equal(t, http.StatusNoContent, resp.StatusCode()) + } + resume := func() { + t.Helper() + resp, err := c.PostSandboxesSandboxIDResumeWithResponse(t.Context(), sbxID, api.PostSandboxesSandboxIDResumeJSONRequestBody{}, setup.WithAPIKey()) + require.NoError(t, err) + require.Equal(t, http.StatusCreated, resp.StatusCode()) + } + + exec(buildWriteScript(tc.filePath, tc.writeBytes, tc.pattern)) + expectedHash := exec(`sha256sum "` + tc.filePath + `" | awk '{print $1}'`) + expectedSize := exec(`stat -c %s "` + tc.filePath + `"`) + + for i := range tc.cycles { + pause() + resume() + + assert.Equal(t, expectedHash, exec(`sha256sum "`+tc.filePath+`" | awk '{print $1}'`), "hash mismatch after cycle %d", i+1) + assert.Equal(t, expectedSize, exec(`stat -c %s "`+tc.filePath+`"`), "size mismatch after cycle %d", i+1) + } +} + +func runFilesystemPauseResumeTruncateCase(t *testing.T) { + t.Helper() + + c := setup.GetAPIClient() + sbx := utils.SetupSandboxWithCleanup(t, c, utils.WithAutoPause(false)) + sbxID := sbx.SandboxID + envdClient := setup.GetEnvdClient(t, t.Context()) + + exec := func(script string) string { + t.Helper() + out, err := utils.ExecCommandAsRootWithOutput(t, t.Context(), sbx, envdClient, "bash", "-lc", script) + require.NoError(t, err) + + return strings.TrimSpace(out) + } + pause := func() { + t.Helper() + resp, err := c.PostSandboxesSandboxIDPauseWithResponse(t.Context(), sbxID, setup.WithAPIKey()) + require.NoError(t, err) + require.Equal(t, http.StatusNoContent, resp.StatusCode()) + } + resume := func() { + t.Helper() + resp, err := c.PostSandboxesSandboxIDResumeWithResponse(t.Context(), sbxID, api.PostSandboxesSandboxIDResumeJSONRequestBody{}, setup.WithAPIKey()) + require.NoError(t, err) + require.Equal(t, http.StatusCreated, resp.StatusCode()) + } + + const filePath = "/home/user/fs-integrity-zero-truncate.bin" + + exec(fmt.Sprintf(` +python3 - <<'PY' +import os +path = %q +chunk = 1024 * 1024 +with open(path, "wb", buffering=0) as f: + f.write(b"\x07" * (16 * chunk)) + f.flush() + os.fsync(f.fileno()) +with open(path, "r+b", buffering=0) as f: + zero = b"\x00" * chunk + for off in (2 * chunk, 6 * chunk, 10 * chunk): + f.seek(off) + f.write(zero) + f.truncate(12 * chunk) + f.flush() + os.fsync(f.fileno()) +PY`, filePath)) + expectedHash := exec(`sha256sum "` + filePath + `" | awk '{print $1}'`) + expectedSize := exec(`stat -c %s "` + filePath + `"`) + + cycles := getenvInt(t, "TESTS_FS_INTEGRITY_TRUNCATE_CYCLES", 1) + + for i := range cycles { + pause() + resume() + + assert.Equal(t, expectedHash, exec(`sha256sum "`+filePath+`" | awk '{print $1}'`), "hash mismatch after cycle %d", i+1) + assert.Equal(t, expectedSize, exec(`stat -c %s "`+filePath+`"`), "size mismatch after cycle %d", i+1) + } +} + +func buildWriteScript(filePath string, writeBytes int64, pattern fsWritePattern) string { + strideBytes := 8 * 1024 * 1024 + + switch pattern { + case fsWriteContiguous: + return fmt.Sprintf(` +python3 - <<'PY' +import os +path = %q +remaining = %d +chunk = 1024 * 1024 +buf = b"\x01" * chunk +with open(path, "wb", buffering=0) as f: + while remaining > 0: + n = min(chunk, remaining) + f.write(buf[:n]) + remaining -= n + f.flush() + os.fsync(f.fileno()) +PY`, filePath, writeBytes) + case fsWriteScattered: + return fmt.Sprintf(` +python3 - <<'PY' +import os +path = %q +remaining = %d +chunk = 1024 * 1024 +stride = %d +buf = b"\x01" * chunk +offset = 0 +with open(path, "wb", buffering=0) as f: + while remaining > 0: + n = min(chunk, remaining) + f.seek(offset) + f.write(buf[:n]) + remaining -= n + offset += stride + f.flush() + os.fsync(f.fileno()) +PY`, filePath, writeBytes, strideBytes) + default: + panic(fmt.Sprintf("unsupported pattern: %s", pattern)) + } +} + +func getenvInt(t *testing.T, key string, fallback int) int { + t.Helper() + + value := os.Getenv(key) + if value == "" { + return fallback + } + + parsed, err := strconv.Atoi(value) + require.NoErrorf(t, err, "invalid integer value for %s", key) + + return parsed +}