diff --git a/.github/actions/start-services/action.yml b/.github/actions/start-services/action.yml index 4cc7c396ab..001b44af9c 100644 --- a/.github/actions/start-services/action.yml +++ b/.github/actions/start-services/action.yml @@ -1,6 +1,24 @@ name: "Start Services" description: "Sets up and starts the required services, including PostgreSQL." +inputs: + compress_enabled: + description: "Enable compression (true/false)" + required: false + default: "false" + compress_type: + description: "Compression type (zstd, lz4)" + required: false + default: "" + compress_level: + description: "Compression level (zstd: 1=fastest, 2=default; lz4: 0)" + required: false + default: "" + compress_workers: + description: "Number of frame encode workers" + required: false + default: "" + runs: using: "composite" steps: @@ -107,6 +125,10 @@ runs: API_GRPC_ADDRESS: "localhost:5009" DEFAULT_PERSISTENT_VOLUME_TYPE: "test-volume-type" SANDBOX_STORAGE_BACKEND: "redis" + COMPRESS_ENABLED: ${{ inputs.compress_enabled }} + COMPRESS_TYPE: ${{ inputs.compress_type }} + COMPRESS_LEVEL: ${{ inputs.compress_level }} + COMPRESS_FRAME_ENCODE_WORKERS: ${{ inputs.compress_workers }} run: | mkdir -p $SHARED_CHUNK_CACHE_PATH mkdir -p ~/logs diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index af27a2ae71..d14ca2b23a 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -30,6 +30,11 @@ jobs: - name: Start Services uses: ./.github/actions/start-services + with: + compress_enabled: "true" + compress_type: "lz4" + compress_level: "0" + compress_workers: "8" - name: Run Integration Tests env: @@ -38,9 +43,7 @@ jobs: TESTS_ORCHESTRATOR_HOST: "localhost:5008" TESTS_ENVD_PROXY: "http://localhost:3002" TESTS_CLIENT_PROXY: "http://localhost:3002" - run: | - # Run the integration tests - make test-integration + run: make test-integration - name: Check for Data Races in Service Logs if: always() @@ -75,12 +78,12 @@ jobs: if: ${{ always() && inputs.publish == true }} uses: actions/upload-artifact@v6 with: - name: Integration Tests Results - path: ./tests/integration/test-results.xml + name: ${{ inputs.compression && 'Compressed ' || '' }}Integration Tests Results + path: ./tests/integration/test-results*.xml - name: Upload Service Logs if: ${{ always() && inputs.publish == true }} uses: actions/upload-artifact@v6 with: - name: Service Logs + name: ${{ inputs.compression && 'Compressed ' || '' }}Service Logs path: ~/logs/*.log diff --git a/.mockery.yaml b/.mockery.yaml index 7cb277c42a..20c5b979d7 100644 --- a/.mockery.yaml +++ b/.mockery.yaml @@ -39,34 +39,38 @@ packages: interfaces: featureFlagsClient: config: - dir: packages/shared/pkg/storage/mocks - filename: mockfeatureflagsclient.go - pkgname: storagemocks + dir: packages/shared/pkg/storage + filename: mock_featureflagsclient.go + pkgname: storage + inpackage: true structname: MockFeatureFlagsClient Blob: config: - dir: packages/shared/pkg/storage/mocks - filename: mockobjectprovider.go - pkgname: storagemocks + dir: packages/shared/pkg/storage + filename: mock_blob.go + pkgname: storage + inpackage: true Seekable: config: - dir: packages/shared/pkg/storage/mocks - filename: mockseekableobjectprovider.go - pkgname: storagemocks + dir: packages/shared/pkg/storage + filename: mock_seekable.go + pkgname: storage + inpackage: true StorageProvider: config: - dir: packages/shared/pkg/storage/mocks/provider - filename: mockstorageprovider.go - pkgname: providermocks - + dir: packages/shared/pkg/storage + filename: mock_storageprovider.go + pkgname: storage + inpackage: true io: interfaces: Reader: config: - dir: packages/shared/pkg/storage/mocks - filename: mockioreader.go - pkgname: storagemocks + dir: packages/shared/pkg/storage + filename: mock_ioreader.go + pkgname: storage + inpackage: true github.com/e2b-dev/infra/packages/shared/pkg/utils: interfaces: @@ -76,6 +80,14 @@ packages: filename: mocks_test.go pkgname: utils + github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/build: + interfaces: + Diff: + config: + dir: packages/orchestrator/pkg/sandbox/build/mocks + filename: mockdiff.go + pkgname: buildmocks + github.com/e2b-dev/infra/packages/api/internal/handlers: interfaces: featureFlagsClient: diff --git a/packages/orchestrator/chunks.proto b/packages/orchestrator/chunks.proto index 89993d17f2..55a1a539db 100644 --- a/packages/orchestrator/chunks.proto +++ b/packages/orchestrator/chunks.proto @@ -14,6 +14,12 @@ message PeerAvailability { // use_storage is true when the GCS upload has completed and the caller // should switch to reading from GCS/NFS directly instead of this peer. bool use_storage = 2; + // memfile_header contains the serialized V4 header (with FrameTables) + // for the memfile, included when use_storage is true and the upload was compressed. + bytes memfile_header = 3; + // rootfs_header contains the serialized V4 header (with FrameTables) + // for the rootfs, included when use_storage is true and the upload was compressed. + bytes rootfs_header = 4; } message GetBuildFileSizeRequest { diff --git a/packages/orchestrator/cmd/show-build-diff/main.go b/packages/orchestrator/cmd/show-build-diff/main.go index 33c53fe9b5..d8aadcec69 100644 --- a/packages/orchestrator/cmd/show-build-diff/main.go +++ b/packages/orchestrator/cmd/show-build-diff/main.go @@ -142,7 +142,10 @@ func main() { ) } - mergedHeader := header.MergeMappings(baseHeader.Mapping, onlyDiffMappings) + mergedHeader, err := header.MergeMappings(baseHeader.Mapping, onlyDiffMappings) + if err != nil { + log.Fatalf("merge mappings: %v", err) + } fmt.Printf("\n\nMERGED METADATA\n") fmt.Printf("========\n") diff --git a/packages/orchestrator/pkg/cfg/model.go b/packages/orchestrator/pkg/cfg/model.go index 9484515a81..75303d878b 100644 --- a/packages/orchestrator/pkg/cfg/model.go +++ b/packages/orchestrator/pkg/cfg/model.go @@ -29,8 +29,9 @@ type BuilderConfig struct { DefaultCacheDir string `env:"DEFAULT_CACHE_DIR,expand" envDefault:"${ORCHESTRATOR_BASE_PATH}/build"` - StorageConfig storage.Config - NetworkConfig network.Config + StorageConfig storage.Config + CompressConfig storage.CompressConfig + NetworkConfig network.Config } func makePathsAbsolute(c *BuilderConfig) error { diff --git a/packages/orchestrator/pkg/sandbox/block/cache_test.go b/packages/orchestrator/pkg/sandbox/block/cache_test.go index 56557afafd..49dcbb170e 100644 --- a/packages/orchestrator/pkg/sandbox/block/cache_test.go +++ b/packages/orchestrator/pkg/sandbox/block/cache_test.go @@ -285,12 +285,11 @@ func TestCacheExportToDiff_ZeroDirtyBlockMapsToSnapshotBuild(t *testing.T) { diffHeader, err := diffMetadata.ToDiffHeader(t.Context(), originalHeader, snapshotBuildID) require.NoError(t, err) - _, _, mappedBuildID, err := diffHeader.GetShiftedMapping(t.Context(), 0) + mapped, err := diffHeader.GetShiftedMapping(t.Context(), 0) require.NoError(t, err) - require.NotNil(t, mappedBuildID) - require.Equal(t, snapshotBuildID, *mappedBuildID, "zero-filled dirty block should map to the snapshot diff when empty detection is skipped") - require.NotEqual(t, uuid.Nil, *mappedBuildID, "zero-filled dirty block should no longer be represented as an empty mapping") + require.Equal(t, snapshotBuildID, mapped.BuildId, "zero-filled dirty block should map to the snapshot diff when empty detection is skipped") + require.NotEqual(t, uuid.Nil, mapped.BuildId, "zero-filled dirty block should no longer be represented as an empty mapping") } func TestCacheExportToDiff_MixedDirtyBlocksKeepsZeroBlockInDiff(t *testing.T) { @@ -344,17 +343,17 @@ func TestCacheExportToDiff_MixedDirtyBlocksKeepsZeroBlockInDiff(t *testing.T) { diffHeader, err := diffMetadata.ToDiffHeader(t.Context(), originalHeader, snapshotBuildID) require.NoError(t, err) - _, _, firstBlockBuildID, err := diffHeader.GetShiftedMapping(t.Context(), 0) + firstBlock, err := diffHeader.GetShiftedMapping(t.Context(), 0) require.NoError(t, err) - require.Equal(t, snapshotBuildID, *firstBlockBuildID, "zero-filled dirty block should still map to the snapshot diff") + require.Equal(t, snapshotBuildID, firstBlock.BuildId, "zero-filled dirty block should still map to the snapshot diff") - _, _, secondBlockBuildID, err := diffHeader.GetShiftedMapping(t.Context(), blockSize) + secondBlock, err := diffHeader.GetShiftedMapping(t.Context(), blockSize) require.NoError(t, err) - require.Equal(t, snapshotBuildID, *secondBlockBuildID) + require.Equal(t, snapshotBuildID, secondBlock.BuildId) - _, _, thirdBlockBuildID, err := diffHeader.GetShiftedMapping(t.Context(), 2*blockSize) + thirdBlock, err := diffHeader.GetShiftedMapping(t.Context(), 2*blockSize) require.NoError(t, err) - require.Equal(t, baseBuildID, *thirdBlockBuildID, "clean blocks should keep the base mapping") + require.Equal(t, baseBuildID, thirdBlock.BuildId, "clean blocks should keep the base mapping") } func TestCacheExportToDiff_NonContiguousDirtyBlocksPreserveRangeOrder(t *testing.T) { diff --git a/packages/orchestrator/pkg/sandbox/block/chunk.go b/packages/orchestrator/pkg/sandbox/block/chunk.go deleted file mode 100644 index ad2017d2aa..0000000000 --- a/packages/orchestrator/pkg/sandbox/block/chunk.go +++ /dev/null @@ -1,301 +0,0 @@ -package block - -import ( - "context" - "errors" - "fmt" - "io" - "strconv" - - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/metric" - "go.uber.org/zap" - "golang.org/x/sync/errgroup" - "golang.org/x/sync/singleflight" - - "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block/metrics" - "github.com/e2b-dev/infra/packages/shared/pkg/featureflags" - "github.com/e2b-dev/infra/packages/shared/pkg/logger" - "github.com/e2b-dev/infra/packages/shared/pkg/storage" - "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" - "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" -) - -const ( - pullType = "pull-type" - pullTypeLocal = "local" - pullTypeRemote = "remote" - - failureReason = "failure-reason" - - failureTypeLocalRead = "local-read" - failureTypeLocalReadAgain = "local-read-again" - failureTypeRemoteRead = "remote-read" - failureTypeCacheFetch = "cache-fetch" -) - -type precomputedAttrs struct { - successFromCache metric.MeasurementOption - successFromRemote metric.MeasurementOption - - failCacheRead metric.MeasurementOption - failRemoteFetch metric.MeasurementOption - failLocalReadAgain metric.MeasurementOption - - // RemoteReads timer (runFetch) - remoteSuccess metric.MeasurementOption - remoteFailure metric.MeasurementOption -} - -var chunkerAttrs = precomputedAttrs{ - successFromCache: telemetry.PrecomputeAttrs( - telemetry.Success, - attribute.String(pullType, pullTypeLocal)), - - successFromRemote: telemetry.PrecomputeAttrs( - telemetry.Success, - attribute.String(pullType, pullTypeRemote)), - - failCacheRead: telemetry.PrecomputeAttrs( - telemetry.Failure, - attribute.String(pullType, pullTypeLocal), - attribute.String(failureReason, failureTypeLocalRead)), - - failRemoteFetch: telemetry.PrecomputeAttrs( - telemetry.Failure, - attribute.String(pullType, pullTypeRemote), - attribute.String(failureReason, failureTypeCacheFetch)), - - failLocalReadAgain: telemetry.PrecomputeAttrs( - telemetry.Failure, - attribute.String(pullType, pullTypeLocal), - attribute.String(failureReason, failureTypeLocalReadAgain)), - - remoteSuccess: telemetry.PrecomputeAttrs( - telemetry.Success), - - remoteFailure: telemetry.PrecomputeAttrs( - telemetry.Failure, - attribute.String(failureReason, failureTypeRemoteRead)), -} - -// Chunker is the interface satisfied by both FullFetchChunker and StreamingChunker. -type Chunker interface { - Slice(ctx context.Context, off, length int64) ([]byte, error) - ReadAt(ctx context.Context, b []byte, off int64) (int, error) - WriteTo(ctx context.Context, w io.Writer) (int64, error) - Close() error - FileSize() (int64, error) -} - -// NewChunker creates a Chunker based on the chunker-config feature flag. -// It reads the flag internally so callers don't need to parse flag values. -func NewChunker( - ctx context.Context, - featureFlags *featureflags.Client, - size, blockSize int64, - upstream storage.Seekable, - cachePath string, - metrics metrics.Metrics, -) (Chunker, error) { - useStreaming, minReadBatchSizeKB := getChunkerConfig(ctx, featureFlags) - - if useStreaming { - return NewStreamingChunker(size, blockSize, upstream, cachePath, metrics, int64(minReadBatchSizeKB)*1024, featureFlags) - } - - return NewFullFetchChunker(size, blockSize, upstream, cachePath, metrics) -} - -// getChunkerConfig fetches the chunker-config feature flag and returns the parsed values. -func getChunkerConfig(ctx context.Context, ff *featureflags.Client) (useStreaming bool, minReadBatchSizeKB int) { - value := ff.JSONFlag(ctx, featureflags.ChunkerConfigFlag) - - if v := value.GetByKey("useStreaming"); v.IsDefined() { - useStreaming = v.BoolValue() - } - - if v := value.GetByKey("minReadBatchSizeKB"); v.IsDefined() { - minReadBatchSizeKB = v.IntValue() - } - - return useStreaming, minReadBatchSizeKB -} - -type FullFetchChunker struct { - base storage.SeekableReader - cache *Cache - metrics metrics.Metrics - - size int64 - - fetchers singleflight.Group -} - -func NewFullFetchChunker( - size, blockSize int64, - base storage.SeekableReader, - cachePath string, - metrics metrics.Metrics, -) (*FullFetchChunker, error) { - cache, err := NewCache(size, blockSize, cachePath, false) - if err != nil { - return nil, fmt.Errorf("failed to create file cache: %w", err) - } - - chunker := &FullFetchChunker{ - size: size, - base: base, - cache: cache, - metrics: metrics, - } - - return chunker, nil -} - -func (c *FullFetchChunker) ReadAt(ctx context.Context, b []byte, off int64) (int, error) { - slice, err := c.Slice(ctx, off, int64(len(b))) - if err != nil { - return 0, fmt.Errorf("failed to slice cache at %d-%d: %w", off, off+int64(len(b)), err) - } - - return copy(b, slice), nil -} - -func (c *FullFetchChunker) WriteTo(ctx context.Context, w io.Writer) (int64, error) { - for i := int64(0); i < c.size; i += storage.MemoryChunkSize { - chunk := make([]byte, storage.MemoryChunkSize) - - n, err := c.ReadAt(ctx, chunk, i) - if err != nil { - return 0, fmt.Errorf("failed to slice cache at %d-%d: %w", i, i+storage.MemoryChunkSize, err) - } - - _, err = w.Write(chunk[:n]) - if err != nil { - return 0, fmt.Errorf("failed to write chunk %d to writer: %w", i, err) - } - } - - return c.size, nil -} - -func (c *FullFetchChunker) Slice(ctx context.Context, off, length int64) ([]byte, error) { - timer := c.metrics.SlicesTimerFactory.Begin() - - b, err := c.cache.Slice(off, length) - if err == nil { - timer.RecordRaw(ctx, length, chunkerAttrs.successFromCache) - - return b, nil - } - - if !errors.As(err, &BytesNotAvailableError{}) { - timer.RecordRaw(ctx, length, chunkerAttrs.failCacheRead) - - return nil, fmt.Errorf("failed read from cache at offset %d: %w", off, err) - } - - chunkErr := c.fetchToCache(ctx, off, length) - if chunkErr != nil { - timer.RecordRaw(ctx, length, chunkerAttrs.failRemoteFetch) - - return nil, fmt.Errorf("failed to ensure data at %d-%d: %w", off, off+length, chunkErr) - } - - b, cacheErr := c.cache.Slice(off, length) - if cacheErr != nil { - timer.RecordRaw(ctx, length, chunkerAttrs.failLocalReadAgain) - - return nil, fmt.Errorf("failed to read from cache after ensuring data at %d-%d: %w", off, off+length, cacheErr) - } - - timer.RecordRaw(ctx, length, chunkerAttrs.successFromRemote) - - return b, nil -} - -// fetchToCache ensures that the data at the given offset and length is available in the cache. -func (c *FullFetchChunker) fetchToCache(ctx context.Context, off, length int64) error { - var eg errgroup.Group - - chunks := header.BlocksOffsets(length, storage.MemoryChunkSize) - - startingChunk := header.BlockIdx(off, storage.MemoryChunkSize) - startingChunkOffset := header.BlockOffset(startingChunk, storage.MemoryChunkSize) - - for _, chunkOff := range chunks { - // Ensure the closure captures the correct block offset. - fetchOff := startingChunkOffset + chunkOff - - eg.Go(func() (err error) { - defer func() { - if r := recover(); r != nil { - logger.L().Error(ctx, "recovered from panic in the fetch handler", zap.Any("error", r)) - err = fmt.Errorf("recovered from panic in the fetch handler: %v", r) - } - }() - - key := strconv.FormatInt(fetchOff, 10) - - _, err, _ = c.fetchers.Do(key, func() (any, error) { - // Check early to prevent overwriting data, Slice requires thread safety - if c.cache.isCached(fetchOff, storage.MemoryChunkSize) { - return nil, nil - } - - select { - case <-ctx.Done(): - return nil, fmt.Errorf("error fetching range %d-%d: %w", fetchOff, fetchOff+storage.MemoryChunkSize, ctx.Err()) - default: - } - - // The size of the buffer is adjusted if the last chunk is not a multiple of the block size. - b, releaseCacheCloseLock, err := c.cache.addressBytes(fetchOff, storage.MemoryChunkSize) - if err != nil { - return nil, err - } - - defer releaseCacheCloseLock() - - fetchSW := c.metrics.RemoteReadsTimerFactory.Begin() - - readBytes, err := c.base.ReadAt(ctx, b, fetchOff) - if err != nil { - fetchSW.RecordRaw(ctx, int64(readBytes), chunkerAttrs.remoteFailure) - - return nil, fmt.Errorf("failed to read chunk from base %d: %w", fetchOff, err) - } - - if readBytes != len(b) { - fetchSW.RecordRaw(ctx, int64(readBytes), chunkerAttrs.remoteFailure) - - return nil, fmt.Errorf("failed to read chunk from base %d: expected %d bytes, got %d bytes", fetchOff, len(b), readBytes) - } - - c.cache.setIsCached(fetchOff, int64(readBytes)) - - fetchSW.RecordRaw(ctx, int64(readBytes), chunkerAttrs.remoteSuccess) - - return nil, nil - }) - - return err - }) - } - - err := eg.Wait() - if err != nil { - return fmt.Errorf("failed to ensure data at %d-%d: %w", off, off+length, err) - } - - return nil -} - -func (c *FullFetchChunker) Close() error { - return c.cache.Close() -} - -func (c *FullFetchChunker) FileSize() (int64, error) { - return c.cache.FileSize() -} diff --git a/packages/orchestrator/pkg/sandbox/block/chunk_bench_test.go b/packages/orchestrator/pkg/sandbox/block/chunk_bench_test.go deleted file mode 100644 index 93534b4d3b..0000000000 --- a/packages/orchestrator/pkg/sandbox/block/chunk_bench_test.go +++ /dev/null @@ -1,60 +0,0 @@ -package block - -import ( - "context" - "path/filepath" - "testing" - - sdkmetric "go.opentelemetry.io/otel/sdk/metric" - - blockmetrics "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block/metrics" -) - -const ( - cbBlockSize int64 = 4096 - cbNumBlocks int64 = 16384 // 64 MiB - cbCacheSize int64 = cbNumBlocks * cbBlockSize - cbChunkSize int64 = 4 * 1024 * 1024 // 4 MiB — MemoryChunkSize - cbChunkCount int64 = cbCacheSize / cbChunkSize -) - -// BenchmarkChunkerSlice_CacheHit benchmarks the full FullFetchChunker.Slice -// hot path on a cache hit: bitmap check + mmap slice return + OTEL -// timer.Success with attribute construction. -func BenchmarkChunkerSlice_CacheHit(b *testing.B) { - provider := sdkmetric.NewMeterProvider() - b.Cleanup(func() { provider.Shutdown(context.Background()) }) - - m, err := blockmetrics.NewMetrics(provider) - if err != nil { - b.Fatal(err) - } - - chunker, err := NewFullFetchChunker( - cbCacheSize, cbBlockSize, - nil, // base is never called on cache hit - filepath.Join(b.TempDir(), "cache"), - m, - ) - if err != nil { - b.Fatal(err) - } - b.Cleanup(func() { chunker.Close() }) - - // Pre-populate the cache so every Slice hits. - chunker.cache.setIsCached(0, cbCacheSize) - - ctx := context.Background() - - b.ResetTimer() - for i := range b.N { - off := int64(i%int(cbChunkCount)) * cbChunkSize - s, sliceErr := chunker.Slice(ctx, off, cbChunkSize) - if sliceErr != nil { - b.Fatal(sliceErr) - } - if len(s) == 0 { - b.Fatal("empty slice") - } - } -} diff --git a/packages/orchestrator/pkg/sandbox/block/chunk_test.go b/packages/orchestrator/pkg/sandbox/block/chunk_test.go deleted file mode 100644 index c9350a80f5..0000000000 --- a/packages/orchestrator/pkg/sandbox/block/chunk_test.go +++ /dev/null @@ -1,165 +0,0 @@ -package block - -import ( - "context" - "errors" - "fmt" - "sync/atomic" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "golang.org/x/sync/errgroup" - - "github.com/e2b-dev/infra/packages/shared/pkg/storage" -) - -// failingUpstream returns an error on ReadAt for specific offsets. -type failingUpstream struct { - data []byte - failCount atomic.Int32 // incremented on each failed ReadAt - failErr error -} - -func (u *failingUpstream) ReadAt(_ context.Context, buffer []byte, off int64) (int, error) { - if u.failErr != nil { - u.failCount.Add(1) - - return 0, u.failErr - } - - end := min(off+int64(len(buffer)), int64(len(u.data))) - n := copy(buffer, u.data[off:end]) - - return n, nil -} - -func (u *failingUpstream) Size(_ context.Context) (int64, error) { - return int64(len(u.data)), nil -} - -func TestFullFetchChunker_BasicSlice(t *testing.T) { - t.Parallel() - - data := makeTestData(t, storage.MemoryChunkSize) - upstream := &fastUpstream{data: data, blockSize: testBlockSize} - - chunker, err := NewFullFetchChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - ) - require.NoError(t, err) - defer chunker.Close() - - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) - require.NoError(t, err) - assert.Equal(t, data[:testBlockSize], slice) -} - -func TestFullFetchChunker_RetryAfterError(t *testing.T) { - t.Parallel() - - data := makeTestData(t, storage.MemoryChunkSize) - - upstream := &failingUpstream{ - data: data, - failErr: errors.New("connection pool exhausted"), - } - - chunker, err := NewFullFetchChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - ) - require.NoError(t, err) - defer chunker.Close() - - // First call fails - _, err = chunker.Slice(t.Context(), 0, testBlockSize) - require.Error(t, err) - - firstFailCount := upstream.failCount.Load() - require.Positive(t, firstFailCount) - - // Clear the error — simulate saturation passing - upstream.failErr = nil - - // Retry should succeed — singleflight does not cache errors - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) - require.NoError(t, err) - assert.Equal(t, data[:testBlockSize], slice) -} - -func TestFullFetchChunker_ConcurrentSameChunk(t *testing.T) { - t.Parallel() - - data := makeTestData(t, storage.MemoryChunkSize) - readCount := atomic.Int64{} - - upstream := &countingUpstream{ - inner: &fastUpstream{data: data, blockSize: testBlockSize}, - readCount: &readCount, - } - - chunker, err := NewFullFetchChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - ) - require.NoError(t, err) - defer chunker.Close() - - numGoroutines := 10 - results := make([][]byte, numGoroutines) - - var eg errgroup.Group - - for i := range numGoroutines { - eg.Go(func() error { - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) - if err != nil { - return fmt.Errorf("goroutine %d failed: %w", i, err) - } - - results[i] = make([]byte, len(slice)) - copy(results[i], slice) - - return nil - }) - } - - require.NoError(t, eg.Wait()) - - for i := range numGoroutines { - assert.Equal(t, data[:testBlockSize], results[i], "goroutine %d got wrong data", i) - } -} - -func TestFullFetchChunker_DifferentChunksIndependent(t *testing.T) { - t.Parallel() - - // Two 4MB chunks - size := storage.MemoryChunkSize * 2 - data := makeTestData(t, size) - upstream := &fastUpstream{data: data, blockSize: testBlockSize} - - chunker, err := NewFullFetchChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - ) - require.NoError(t, err) - defer chunker.Close() - - // Read from chunk 0 - slice0, err := chunker.Slice(t.Context(), 0, testBlockSize) - require.NoError(t, err) - assert.Equal(t, data[:testBlockSize], slice0) - - // Read from chunk 1 - off1 := int64(storage.MemoryChunkSize) - slice1, err := chunker.Slice(t.Context(), off1, testBlockSize) - require.NoError(t, err) - assert.Equal(t, data[off1:off1+testBlockSize], slice1) -} diff --git a/packages/orchestrator/pkg/sandbox/block/device.go b/packages/orchestrator/pkg/sandbox/block/device.go index 39a1cae845..5cd6c0ba79 100644 --- a/packages/orchestrator/pkg/sandbox/block/device.go +++ b/packages/orchestrator/pkg/sandbox/block/device.go @@ -8,19 +8,30 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" ) +// BytesNotAvailableError indicates the requested range is not yet cached. type BytesNotAvailableError struct{} func (BytesNotAvailableError) Error() string { return "The requested bytes are not available on the device" } +type FramedReader interface { + ReadAt(ctx context.Context, p []byte, off int64, ft *storage.FrameTable) (int, error) +} + +type FramedSlicer interface { + Slice(ctx context.Context, off, length int64, ft *storage.FrameTable) ([]byte, error) +} + +// Slicer provides plain block reads (no FrameTable). Used by UFFD/NBD. type Slicer interface { Slice(ctx context.Context, off, length int64) ([]byte, error) BlockSize() int64 } type ReadonlyDevice interface { - storage.SeekableReader + ReadAt(ctx context.Context, p []byte, off int64) (int, error) + Size(ctx context.Context) (int64, error) io.Closer Slicer BlockSize() int64 diff --git a/packages/orchestrator/pkg/sandbox/block/fetch_session.go b/packages/orchestrator/pkg/sandbox/block/fetch_session.go new file mode 100644 index 0000000000..6ad34c475d --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/block/fetch_session.go @@ -0,0 +1,141 @@ +package block + +import ( + "context" + "fmt" + "sync" + "sync/atomic" +) + +type fetchSession struct { + // chunk is what we are fetching, can be >= 1 block. chunkOff/chunkLen are absolute offsets in U-space. + chunkOff int64 + chunkLen int64 + cache *Cache + + mu sync.Mutex + fetchErr error + signal chan struct{} // closed on each advance; nil when terminated + + // bytesReady is the byte count (from chunkOff) up to which all blocks + // are fully written and marked cached. Atomic so registerAndWait can + // do a lock-free fast-path check: bytesReady only increases. + bytesReady atomic.Int64 +} + +// terminated reports whether the session reached a terminal state. +// Must be called with mu held. +func (s *fetchSession) terminated() bool { + return s.fetchErr != nil || s.bytesReady.Load() == s.chunkLen +} + +func newFetchSession(chunkOff, chunkLen int64, cache *Cache) *fetchSession { + return &fetchSession{ + chunkOff: chunkOff, + chunkLen: chunkLen, + cache: cache, + signal: make(chan struct{}), + } +} + +// registerAndWait blocks until the block at blockOff is cached, the session +// terminates, or ctx is cancelled. Each caller requests exactly one block. +func (s *fetchSession) registerAndWait(ctx context.Context, blockOff int64) error { + blockSize := s.cache.blockSize + + // endByte is the byte offset (relative to chunkOff) that must be ready + // for our block to be fully written. + relEnd := blockOff + blockSize - s.chunkOff + endByte := min(relEnd, s.chunkLen) + + for { + // Lock-free fast path: bytesReady only increases, so >= endByte + // guarantees data is available. + if s.bytesReady.Load() >= endByte { + return nil + } + + s.mu.Lock() + + // Re-check under lock. + if s.bytesReady.Load() >= endByte { + s.mu.Unlock() + + return nil + } + + // Terminal but block not covered — only happens on error. + // setDone sets bytesReady=chunkLen, so terminated() with bytesReady < endByte + // means fetchErr != nil. Check cache in case a prior session already fetched this block. + if s.terminated() { + fetchErr := s.fetchErr + s.mu.Unlock() + + if s.cache.isCached(blockOff, blockSize) { + return nil + } + + if fetchErr == nil { + return fmt.Errorf("fetch session terminated without error but block %d not cached (bytesReady=%d, endByte=%d)", + blockOff/blockSize, s.bytesReady.Load(), endByte) + } + + return fmt.Errorf("fetch failed: %w", fetchErr) + } + + ch := s.signal + s.mu.Unlock() + + select { + case <-ch: + continue + case <-ctx.Done(): + return ctx.Err() + } + } +} + +// advance updates progress and wakes all waiters by closing the current +// broadcast channel and replacing it with a fresh one. +func (s *fetchSession) advance(bytesReady int64) { + s.mu.Lock() + s.bytesReady.Store(bytesReady) + old := s.signal + s.signal = make(chan struct{}) + s.mu.Unlock() + + close(old) +} + +// setDone marks the session as successfully completed and wakes all waiters. +func (s *fetchSession) setDone() { + s.mu.Lock() + s.bytesReady.Store(s.chunkLen) + old := s.signal + s.signal = nil + s.mu.Unlock() + + close(old) +} + +// setError records the error and wakes all waiters. +// When onlyIfRunning is true, it is a no-op if the session already +// terminated (used for panic recovery to avoid overriding a successful +// completion or double-closing the broadcast channel). +func (s *fetchSession) setError(err error, onlyIfRunning bool) { + s.mu.Lock() + if onlyIfRunning && s.terminated() { + s.mu.Unlock() + + return + } + + s.fetchErr = err + old := s.signal + s.signal = nil + s.mu.Unlock() + + if old != nil { + close(old) + } +} diff --git a/packages/orchestrator/pkg/sandbox/block/mocks/mockreadonlydevice.go b/packages/orchestrator/pkg/sandbox/block/mocks/mockreadonlydevice.go index d0c464b661..8f0a2e5717 100644 --- a/packages/orchestrator/pkg/sandbox/block/mocks/mockreadonlydevice.go +++ b/packages/orchestrator/pkg/sandbox/block/mocks/mockreadonlydevice.go @@ -173,8 +173,8 @@ func (_c *MockReadonlyDevice_Header_Call) RunAndReturn(run func() *header.Header } // ReadAt provides a mock function for the type MockReadonlyDevice -func (_mock *MockReadonlyDevice) ReadAt(ctx context.Context, buffer []byte, off int64) (int, error) { - ret := _mock.Called(ctx, buffer, off) +func (_mock *MockReadonlyDevice) ReadAt(ctx context.Context, p []byte, off int64) (int, error) { + ret := _mock.Called(ctx, p, off) if len(ret) == 0 { panic("no return value specified for ReadAt") @@ -183,15 +183,15 @@ func (_mock *MockReadonlyDevice) ReadAt(ctx context.Context, buffer []byte, off var r0 int var r1 error if returnFunc, ok := ret.Get(0).(func(context.Context, []byte, int64) (int, error)); ok { - return returnFunc(ctx, buffer, off) + return returnFunc(ctx, p, off) } if returnFunc, ok := ret.Get(0).(func(context.Context, []byte, int64) int); ok { - r0 = returnFunc(ctx, buffer, off) + r0 = returnFunc(ctx, p, off) } else { r0 = ret.Get(0).(int) } if returnFunc, ok := ret.Get(1).(func(context.Context, []byte, int64) error); ok { - r1 = returnFunc(ctx, buffer, off) + r1 = returnFunc(ctx, p, off) } else { r1 = ret.Error(1) } @@ -205,13 +205,13 @@ type MockReadonlyDevice_ReadAt_Call struct { // ReadAt is a helper method to define mock.On call // - ctx context.Context -// - buffer []byte +// - p []byte // - off int64 -func (_e *MockReadonlyDevice_Expecter) ReadAt(ctx interface{}, buffer interface{}, off interface{}) *MockReadonlyDevice_ReadAt_Call { - return &MockReadonlyDevice_ReadAt_Call{Call: _e.mock.On("ReadAt", ctx, buffer, off)} +func (_e *MockReadonlyDevice_Expecter) ReadAt(ctx interface{}, p interface{}, off interface{}) *MockReadonlyDevice_ReadAt_Call { + return &MockReadonlyDevice_ReadAt_Call{Call: _e.mock.On("ReadAt", ctx, p, off)} } -func (_c *MockReadonlyDevice_ReadAt_Call) Run(run func(ctx context.Context, buffer []byte, off int64)) *MockReadonlyDevice_ReadAt_Call { +func (_c *MockReadonlyDevice_ReadAt_Call) Run(run func(ctx context.Context, p []byte, off int64)) *MockReadonlyDevice_ReadAt_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -239,7 +239,7 @@ func (_c *MockReadonlyDevice_ReadAt_Call) Return(n int, err error) *MockReadonly return _c } -func (_c *MockReadonlyDevice_ReadAt_Call) RunAndReturn(run func(ctx context.Context, buffer []byte, off int64) (int, error)) *MockReadonlyDevice_ReadAt_Call { +func (_c *MockReadonlyDevice_ReadAt_Call) RunAndReturn(run func(ctx context.Context, p []byte, off int64) (int, error)) *MockReadonlyDevice_ReadAt_Call { _c.Call.Return(run) return _c } diff --git a/packages/orchestrator/pkg/sandbox/block/streaming_chunk.go b/packages/orchestrator/pkg/sandbox/block/streaming_chunk.go index 7e40b35c4e..d5e1aedd4a 100644 --- a/packages/orchestrator/pkg/sandbox/block/streaming_chunk.go +++ b/packages/orchestrator/pkg/sandbox/block/streaming_chunk.go @@ -1,22 +1,20 @@ package block import ( - "cmp" "context" "errors" "fmt" "io" - "slices" "sync" - "sync/atomic" "time" - "golang.org/x/sync/errgroup" + "go.opentelemetry.io/otel/attribute" + "go.opentelemetry.io/otel/metric" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block/metrics" "github.com/e2b-dev/infra/packages/shared/pkg/featureflags" "github.com/e2b-dev/infra/packages/shared/pkg/storage" - "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" + "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" ) const ( @@ -27,174 +25,52 @@ const ( // defaultMinReadBatchSize is the floor for the read batch size when blockSize // is very small (e.g. 4KB rootfs). The actual batch is max(blockSize, minReadBatchSize). defaultMinReadBatchSize = 16 * 1024 // 16 KB -) - -type rangeWaiter struct { - // endByte is the byte offset (relative to chunkOff) at which this waiter's - // entire requested range is cached. Equal to the end of the last block - // overlapping the requested range. Always a multiple of blockSize. - endByte int64 - ch chan error // buffered cap 1 -} - -type fetchSession struct { - mu sync.Mutex - chunkOff int64 - chunkLen int64 - cache *Cache - waiters []*rangeWaiter // sorted by endByte ascending - fetchErr error - - // bytesReady is the byte count (from chunkOff) up to which all blocks are - // fully written to mmap and marked cached. Always a multiple of blockSize - // during progressive reads. Used to cheaply determine which sorted waiters - // are satisfied without calling isCached. - // - // Atomic so registerAndWait can do a lock-free fast-path check: - // bytesReady only increases, so a Load() >= endByte guarantees data - // availability without taking the mutex. - bytesReady atomic.Int64 -} - -// terminated reports whether the fetch session has reached a terminal state -// (done or errored). Must be called with s.mu held. -func (s *fetchSession) terminated() bool { - return s.fetchErr != nil || s.bytesReady.Load() == s.chunkLen -} - -// registerAndWait adds a waiter for the given range and blocks until the range -// is cached or the context is cancelled. Returns nil if the range was already -// cached before registering. -func (s *fetchSession) registerAndWait(ctx context.Context, off, length int64) error { - blockSize := s.cache.BlockSize() - lastBlockIdx := (off + length - 1 - s.chunkOff) / blockSize - endByte := (lastBlockIdx + 1) * blockSize - - // Lock-free fast path: bytesReady only increases, so >= endByte - // guarantees data is available without taking the lock. - if s.bytesReady.Load() >= endByte { - return nil - } - - s.mu.Lock() - - // Re-check under lock. - if endByte <= s.bytesReady.Load() { - s.mu.Unlock() - - return nil - } - - // Terminal but range not covered — only happens on error - // (Done sets bytesReady=chunkLen). Check cache for prior session data. - if s.terminated() { - fetchErr := s.fetchErr - s.mu.Unlock() - if s.cache.isCached(off, length) { - return nil - } - - if fetchErr != nil { - return fmt.Errorf("fetch failed: %w", fetchErr) - } - - return fmt.Errorf("fetch completed but range %d-%d not cached", off, off+length) - } - - // Fetch in progress — register waiter. - w := &rangeWaiter{endByte: endByte, ch: make(chan error, 1)} - idx, _ := slices.BinarySearchFunc(s.waiters, endByte, func(w *rangeWaiter, target int64) int { - return cmp.Compare(w.endByte, target) - }) - s.waiters = slices.Insert(s.waiters, idx, w) - s.mu.Unlock() - - select { - case err := <-w.ch: - return err - case <-ctx.Done(): - return ctx.Err() - } -} - -// notifyWaiters notifies waiters whose ranges are satisfied. -// -// Because waiters are sorted by endByte and the fetch fills the chunk -// sequentially, we only need to walk from the front until we hit a waiter -// whose endByte exceeds bytesReady — all subsequent waiters are unsatisfied. -// -// In terminal states (done/errored) all remaining waiters are notified. -// Must be called with s.mu held. -func (s *fetchSession) notifyWaiters(sendErr error) { - ready := s.bytesReady.Load() - - // Terminal: notify every remaining waiter. - if s.terminated() { - for _, w := range s.waiters { - if sendErr != nil && w.endByte > ready { - w.ch <- sendErr - } - close(w.ch) - } - s.waiters = nil - return - } - - // Progress: pop satisfied waiters from the sorted front. - i := 0 - for i < len(s.waiters) && s.waiters[i].endByte <= ready { - close(s.waiters[i].ch) - i++ - } - s.waiters = s.waiters[i:] -} +) -type StreamingChunker struct { - upstream storage.StreamingReader - cache *Cache - metrics metrics.Metrics - fetchTimeout time.Duration - featureFlags *featureflags.Client - minReadBatchSize int64 +type Chunker struct { + upstream storage.StreamingReader + cache *Cache + metrics metrics.Metrics + fetchTimeout time.Duration + featureFlags *featureflags.Client size int64 - fetchMu sync.Mutex - fetchMap map[int64]*fetchSession + fetchMu sync.Mutex + fetchSessions []*fetchSession } -func NewStreamingChunker( +var ( + _ FramedReader = (*Chunker)(nil) + _ FramedSlicer = (*Chunker)(nil) +) + +func NewChunker( + _ context.Context, + ff *featureflags.Client, size, blockSize int64, upstream storage.StreamingReader, cachePath string, metrics metrics.Metrics, - minReadBatchSize int64, - ff *featureflags.Client, -) (*StreamingChunker, error) { +) (*Chunker, error) { cache, err := NewCache(size, blockSize, cachePath, false) if err != nil { return nil, fmt.Errorf("failed to create file cache: %w", err) } - if minReadBatchSize <= 0 { - minReadBatchSize = defaultMinReadBatchSize - } - - return &StreamingChunker{ - size: size, - upstream: upstream, - cache: cache, - metrics: metrics, - featureFlags: ff, - fetchTimeout: defaultFetchTimeout, - minReadBatchSize: minReadBatchSize, - fetchMap: make(map[int64]*fetchSession), + return &Chunker{ + size: size, + upstream: upstream, + cache: cache, + metrics: metrics, + featureFlags: ff, + fetchTimeout: defaultFetchTimeout, }, nil } -func (c *StreamingChunker) ReadAt(ctx context.Context, b []byte, off int64) (int, error) { - slice, err := c.Slice(ctx, off, int64(len(b))) +func (c *Chunker) ReadAt(ctx context.Context, b []byte, off int64, ft *storage.FrameTable) (int, error) { + slice, err := c.Slice(ctx, off, int64(len(b)), ft) if err != nil { return 0, fmt.Errorf("failed to slice cache at %d-%d: %w", off, off+int64(len(b)), err) } @@ -202,158 +78,112 @@ func (c *StreamingChunker) ReadAt(ctx context.Context, b []byte, off int64) (int return copy(b, slice), nil } -func (c *StreamingChunker) WriteTo(ctx context.Context, w io.Writer) (int64, error) { - chunk := make([]byte, storage.MemoryChunkSize) - - for i := int64(0); i < c.size; i += storage.MemoryChunkSize { - n, err := c.ReadAt(ctx, chunk, i) - if err != nil { - return 0, fmt.Errorf("failed to slice cache at %d-%d: %w", i, i+storage.MemoryChunkSize, err) - } - - _, err = w.Write(chunk[:n]) - if err != nil { - return 0, fmt.Errorf("failed to write chunk %d to writer: %w", i, err) - } +func (c *Chunker) Slice(ctx context.Context, off, length int64, ft *storage.FrameTable) ([]byte, error) { + attrs := chunkerAttrs + if ft.IsCompressed() { + attrs = chunkerAttrsCompressed } - - return c.size, nil -} - -func (c *StreamingChunker) Slice(ctx context.Context, off, length int64) ([]byte, error) { timer := c.metrics.SlicesTimerFactory.Begin() // Fast path: already cached b, err := c.cache.Slice(off, length) if err == nil { - timer.RecordRaw(ctx, length, chunkerAttrs.successFromCache) + timer.RecordRaw(ctx, length, attrs.successFromCache) return b, nil } if !errors.As(err, &BytesNotAvailableError{}) { - timer.RecordRaw(ctx, length, chunkerAttrs.failCacheRead) + timer.RecordRaw(ctx, length, attrs.failCacheRead) return nil, fmt.Errorf("failed read from cache at offset %d: %w", off, err) } - // Compute which 4MB chunks overlap with the requested range - firstChunkOff := header.BlockOffset(header.BlockIdx(off, storage.MemoryChunkSize), storage.MemoryChunkSize) - lastChunkOff := header.BlockOffset(header.BlockIdx(off+length-1, storage.MemoryChunkSize), storage.MemoryChunkSize) - - var eg errgroup.Group - - for fetchOff := firstChunkOff; fetchOff <= lastChunkOff; fetchOff += storage.MemoryChunkSize { - eg.Go(func() error { - // Clip request to this chunk's boundaries - chunkEnd := fetchOff + storage.MemoryChunkSize - clippedOff := max(off, fetchOff) - clippedEnd := min(off+length, chunkEnd, c.size) - clippedLen := clippedEnd - clippedOff - - if clippedLen <= 0 { - return nil - } - - session, justGotCached := c.getOrCreateSession(ctx, fetchOff) - if justGotCached { - return nil - } - - return session.registerAndWait(ctx, clippedOff, clippedLen) - }) - } - - if err := eg.Wait(); err != nil { - timer.RecordRaw(ctx, length, chunkerAttrs.failRemoteFetch) + if err := c.fetch(ctx, off, ft); err != nil { + timer.RecordRaw(ctx, length, attrs.failRemoteFetch) return nil, fmt.Errorf("failed to ensure data at %d-%d: %w", off, off+length, err) } b, cacheErr := c.cache.Slice(off, length) if cacheErr != nil { - timer.RecordRaw(ctx, length, chunkerAttrs.failLocalReadAgain) + timer.RecordRaw(ctx, length, attrs.failLocalReadAgain) return nil, fmt.Errorf("failed to read from cache after ensuring data at %d-%d: %w", off, off+length, cacheErr) } - timer.RecordRaw(ctx, length, chunkerAttrs.successFromRemote) + timer.RecordRaw(ctx, length, attrs.successFromRemote) return b, nil } -// getOrCreateSession returns a fetch session for the chunk at fetchOff, or -// (nil, true) if the data is already fully cached. -// -// Slice() checks isCached() before calling this method as a lock-free fast -// path. A TOCTOU race exists between that check and the fetchMap lookup: -// a fetch can finish (writing the dirty bitmap) and delete itself from -// fetchMap in between, so the caller misses both. To close this we re-check -// isCached under fetchMu. This is safe because runFetch calls setIsCached -// before acquiring fetchMu to delete, so the lock provides a happens-before -// guarantee that the bitmap writes are visible here. -func (c *StreamingChunker) getOrCreateSession(ctx context.Context, fetchOff int64) (_ *fetchSession, cached bool) { - chunkLen := min(int64(storage.MemoryChunkSize), c.size-fetchOff) - +// getOrCreateSession returns a fetch session for the chunk at [off, off+length), +// or (nil, true) if the data is already fully cached. +func (c *Chunker) getOrCreateSession(ctx context.Context, off, length int64, ft *storage.FrameTable) (_ *fetchSession, cached bool) { c.fetchMu.Lock() - if existing, ok := c.fetchMap[fetchOff]; ok { - c.fetchMu.Unlock() + for _, s := range c.fetchSessions { + if s.chunkOff <= off && s.chunkOff+s.chunkLen >= off+length { + c.fetchMu.Unlock() - return existing, false + return s, false + } } - if c.cache.isCached(fetchOff, chunkLen) { + // Re-check cache under fetchMu. A fetch can finish (marking blocks + // cached via setIsCached) and remove itself from sessions between + // the lock-free Slice() and the session scan above. The lock + // provides a happens-before guarantee that the bitmap writes are visible. + if c.cache.isCached(off, length) { c.fetchMu.Unlock() return nil, true } - s := &fetchSession{ - chunkOff: fetchOff, - chunkLen: chunkLen, - cache: c.cache, - } - c.fetchMap[fetchOff] = s + s := newFetchSession(off, length, c.cache) + c.fetchSessions = append(c.fetchSessions, s) c.fetchMu.Unlock() // Detach from the caller's cancel signal so the shared fetch goroutine // continues even if the first caller's context is cancelled. Trace/value // context is preserved for metrics. - go c.runFetch(context.WithoutCancel(ctx), s) + go c.runFetch(context.WithoutCancel(ctx), s, off, ft) return s, false } -func (s *fetchSession) setDone() { - s.mu.Lock() - defer s.mu.Unlock() - - s.bytesReady.Store(s.chunkLen) - s.notifyWaiters(nil) -} +// fetch ensures the frame/chunk covering off is fetched into the mmap cache, +// then waits until the block at off is available. Deduplicates concurrent +// requests for the same region via the session list. +func (c *Chunker) fetch(ctx context.Context, off int64, ft *storage.FrameTable) error { + var chunkOff, chunkLen int64 + if ft.IsCompressed() { + frameStarts, frameSize, err := ft.FrameFor(off) + if err != nil { + return fmt.Errorf("failed to get frame for offset %d: %w", off, err) + } -func (s *fetchSession) setError(err error, onlyIfRunning bool) { - s.mu.Lock() - defer s.mu.Unlock() + chunkOff = frameStarts.U + chunkLen = int64(frameSize.U) + } else { + chunkOff = (off / storage.MemoryChunkSize) * storage.MemoryChunkSize + chunkLen = min(int64(storage.MemoryChunkSize), c.size-chunkOff) + } - if onlyIfRunning && s.terminated() { - return + session, justGotCached := c.getOrCreateSession(ctx, chunkOff, chunkLen, ft) + if justGotCached { + return nil } - s.fetchErr = err - s.notifyWaiters(err) + return session.registerAndWait(ctx, off) } -func (c *StreamingChunker) runFetch(ctx context.Context, s *fetchSession) { +// runFetch fetches data from storage into the mmap cache. Runs in a background goroutine. +func (c *Chunker) runFetch(ctx context.Context, s *fetchSession, offsetU int64, ft *storage.FrameTable) { ctx, cancel := context.WithTimeout(ctx, c.fetchTimeout) defer cancel() - defer func() { - c.fetchMu.Lock() - delete(c.fetchMap, s.chunkOff) - c.fetchMu.Unlock() - }() + defer c.releaseSession(s) // Panic recovery: ensure waiters are always notified even if the fetch // goroutine panics (e.g. nil pointer in upstream reader, mmap fault). @@ -373,87 +203,183 @@ func (c *StreamingChunker) runFetch(ctx context.Context, s *fetchSession) { } defer releaseLock() + attrs := chunkerAttrs + if ft.IsCompressed() { + attrs = chunkerAttrsCompressed + } fetchTimer := c.metrics.RemoteReadsTimerFactory.Begin() - err = c.progressiveRead(ctx, s, mmapSlice) + readBytes, err := c.progressiveRead(ctx, s, mmapSlice, offsetU, ft) if err != nil { - fetchTimer.RecordRaw(ctx, s.chunkLen, chunkerAttrs.remoteFailure) + fetchTimer.RecordRaw(ctx, readBytes, attrs.remoteFailure) s.setError(err, false) return } - fetchTimer.RecordRaw(ctx, s.chunkLen, chunkerAttrs.remoteSuccess) + fetchTimer.RecordRaw(ctx, readBytes, attrs.remoteSuccess) s.setDone() } -func (c *StreamingChunker) progressiveRead(ctx context.Context, s *fetchSession, mmapSlice []byte) error { - reader, err := c.upstream.OpenRangeReader(ctx, s.chunkOff, s.chunkLen) +func (c *Chunker) progressiveRead(ctx context.Context, s *fetchSession, mmapSlice []byte, offsetU int64, ft *storage.FrameTable) (int64, error) { + reader, err := c.upstream.OpenRangeReader(ctx, offsetU, s.chunkLen, ft) if err != nil { - return fmt.Errorf("failed to open range reader at %d: %w", s.chunkOff, err) + return 0, fmt.Errorf("failed to open range reader at %d: %w", offsetU, err) } defer reader.Close() blockSize := c.cache.BlockSize() readBatch := max(blockSize, c.getMinReadBatchSize(ctx)) var totalRead int64 - var prevCompleted int64 for totalRead < s.chunkLen { - // Read in batches of max(blockSize, 16KB) to align notification + // Read in batches of max(blockSize, minReadBatchSize) to align notification // granularity with the read size and minimize lock/notify overhead. readEnd := min(totalRead+readBatch, s.chunkLen) - n, readErr := reader.Read(mmapSlice[totalRead:readEnd]) + n, readErr := io.ReadFull(reader, mmapSlice[totalRead:readEnd]) totalRead += int64(n) - completedBlocks := totalRead / blockSize - if completedBlocks > prevCompleted { - newBytes := (completedBlocks - prevCompleted) * blockSize - c.cache.setIsCached(s.chunkOff+prevCompleted*blockSize, newBytes) - prevCompleted = completedBlocks - - s.mu.Lock() - s.bytesReady.Store(completedBlocks * blockSize) - s.notifyWaiters(nil) - s.mu.Unlock() + if n > 0 { + c.cache.setIsCached(s.chunkOff+totalRead-int64(n), int64(n)) + s.advance(totalRead) } - if errors.Is(readErr, io.EOF) { - // Mark final partial block if any - if totalRead > prevCompleted*blockSize { - c.cache.setIsCached(s.chunkOff+prevCompleted*blockSize, totalRead-prevCompleted*blockSize) + if readErr != nil { + if totalRead >= s.chunkLen { + break // all bytes received; trailing EOF is expected } - // Remaining waiters are notified in runFetch via the Done state. - break - } - if readErr != nil { - return fmt.Errorf("failed reading at offset %d after %d bytes: %w", s.chunkOff, totalRead, readErr) + return totalRead, fmt.Errorf("failed reading at offset %d after %d bytes: %w", offsetU, totalRead, readErr) } } - return nil + return totalRead, nil +} + +// releaseSession removes s from the active list (swap-delete). +func (c *Chunker) releaseSession(s *fetchSession) { + c.fetchMu.Lock() + defer c.fetchMu.Unlock() + + for i, a := range c.fetchSessions { + if a == s { + c.fetchSessions[i] = c.fetchSessions[len(c.fetchSessions)-1] + c.fetchSessions[len(c.fetchSessions)-1] = nil + c.fetchSessions = c.fetchSessions[:len(c.fetchSessions)-1] + + return + } + } } -// getMinReadBatchSize returns the effective min read batch size. When a feature -// flags client is available, the value is read just-in-time from the flag so -// it can be tuned without restarting the service. -func (c *StreamingChunker) getMinReadBatchSize(ctx context.Context) int64 { +// getMinReadBatchSize returns the effective min read batch size. +// Queried per-fetch so it can be tuned via feature flags without a restart. +func (c *Chunker) getMinReadBatchSize(ctx context.Context) int64 { if c.featureFlags != nil { - _, minKB := getChunkerConfig(ctx, c.featureFlags) - if minKB > 0 { - return int64(minKB) * 1024 + if v := c.featureFlags.IntFlag(ctx, featureflags.MinChunkerReadSizeKB); v > 0 { + return int64(v) * 1024 } } - return c.minReadBatchSize + return defaultMinReadBatchSize } -func (c *StreamingChunker) Close() error { +func (c *Chunker) Close() error { return c.cache.Close() } -func (c *StreamingChunker) FileSize() (int64, error) { +func (c *Chunker) FileSize() (int64, error) { return c.cache.FileSize() } + +const ( + compressedAttr = "compressed" + pullType = "pull-type" + pullTypeLocal = "local" + pullTypeRemote = "remote" + + failureReason = "failure-reason" + + failureTypeLocalRead = "local-read" + failureTypeLocalReadAgain = "local-read-again" + failureTypeRemoteRead = "remote-read" + failureTypeCacheFetch = "cache-fetch" +) + +type precomputedAttrs struct { + successFromCache metric.MeasurementOption + successFromRemote metric.MeasurementOption + + failCacheRead metric.MeasurementOption + failRemoteFetch metric.MeasurementOption + failLocalReadAgain metric.MeasurementOption + + // RemoteReads timer (runFetch) + remoteSuccess metric.MeasurementOption + remoteFailure metric.MeasurementOption +} + +var chunkerAttrs = precomputedAttrs{ + successFromCache: telemetry.PrecomputeAttrs( + telemetry.Success, + attribute.String(pullType, pullTypeLocal)), + + successFromRemote: telemetry.PrecomputeAttrs( + telemetry.Success, + attribute.String(pullType, pullTypeRemote)), + + failCacheRead: telemetry.PrecomputeAttrs( + telemetry.Failure, + attribute.String(pullType, pullTypeLocal), + attribute.String(failureReason, failureTypeLocalRead)), + + failRemoteFetch: telemetry.PrecomputeAttrs( + telemetry.Failure, + attribute.String(pullType, pullTypeRemote), + attribute.String(failureReason, failureTypeCacheFetch)), + + failLocalReadAgain: telemetry.PrecomputeAttrs( + telemetry.Failure, + attribute.String(pullType, pullTypeLocal), + attribute.String(failureReason, failureTypeLocalReadAgain)), + + remoteSuccess: telemetry.PrecomputeAttrs( + telemetry.Success), + + remoteFailure: telemetry.PrecomputeAttrs( + telemetry.Failure, + attribute.String(failureReason, failureTypeRemoteRead)), +} + +var chunkerAttrsCompressed = precomputedAttrs{ + successFromCache: telemetry.PrecomputeAttrs( + telemetry.Success, attribute.Bool(compressedAttr, true), + attribute.String(pullType, pullTypeLocal)), + + successFromRemote: telemetry.PrecomputeAttrs( + telemetry.Success, attribute.Bool(compressedAttr, true), + attribute.String(pullType, pullTypeRemote)), + + failCacheRead: telemetry.PrecomputeAttrs( + telemetry.Failure, attribute.Bool(compressedAttr, true), + attribute.String(pullType, pullTypeLocal), + attribute.String(failureReason, failureTypeLocalRead)), + + failRemoteFetch: telemetry.PrecomputeAttrs( + telemetry.Failure, attribute.Bool(compressedAttr, true), + attribute.String(pullType, pullTypeRemote), + attribute.String(failureReason, failureTypeCacheFetch)), + + failLocalReadAgain: telemetry.PrecomputeAttrs( + telemetry.Failure, attribute.Bool(compressedAttr, true), + attribute.String(pullType, pullTypeLocal), + attribute.String(failureReason, failureTypeLocalReadAgain)), + + remoteSuccess: telemetry.PrecomputeAttrs( + telemetry.Success, attribute.Bool(compressedAttr, true)), + + remoteFailure: telemetry.PrecomputeAttrs( + telemetry.Failure, attribute.Bool(compressedAttr, true), + attribute.String(failureReason, failureTypeRemoteRead)), +} diff --git a/packages/orchestrator/pkg/sandbox/block/streaming_chunk_test.go b/packages/orchestrator/pkg/sandbox/block/streaming_chunk_test.go index f91c99e0e5..c288302ea0 100644 --- a/packages/orchestrator/pkg/sandbox/block/streaming_chunk_test.go +++ b/packages/orchestrator/pkg/sandbox/block/streaming_chunk_test.go @@ -3,15 +3,12 @@ package block import ( "bytes" "context" - "crypto/rand" "fmt" "io" - mathrand "math/rand/v2" + "math/rand/v2" "sync/atomic" "testing" - "time" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/metric/noop" "golang.org/x/sync/errgroup" @@ -23,571 +20,420 @@ import ( const ( testBlockSize = header.PageSize // 4KB + testFrameSize = 256 * 1024 // 256 KB per frame for fast tests + testFileSize = testFrameSize * 4 ) -// slowUpstream simulates GCS: implements both SeekableReader and StreamingReader. -// OpenRangeReader returns a reader that yields blockSize bytes per Read() call -// with a configurable delay between calls. -type slowUpstream struct { - data []byte - blockSize int64 - delay time.Duration -} - -var ( - _ storage.SeekableReader = (*slowUpstream)(nil) - _ storage.StreamingReader = (*slowUpstream)(nil) -) +func newTestMetrics(tb testing.TB) metrics.Metrics { + tb.Helper() -func (s *slowUpstream) ReadAt(_ context.Context, buffer []byte, off int64) (int, error) { - end := min(off+int64(len(buffer)), int64(len(s.data))) - n := copy(buffer, s.data[off:end]) - - return n, nil -} + m, err := metrics.NewMetrics(noop.NewMeterProvider()) + require.NoError(tb, err) -func (s *slowUpstream) Size(_ context.Context) (int64, error) { - return int64(len(s.data)), nil + return m } -func (s *slowUpstream) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { - end := min(off+length, int64(len(s.data))) +func makeTestData(size int) []byte { + rng := rand.New(rand.NewPCG(42, 0)) //nolint:gosec // deterministic test data + data := make([]byte, size) + for i := range data { + data[i] = byte(rng.IntN(256)) + } - return &slowReader{ - data: s.data[off:end], - blockSize: int(s.blockSize), - delay: s.delay, - }, nil + return data } -type slowReader struct { - data []byte - pos int - blockSize int - delay time.Duration +// fakeSeekable implements storage.Seekable backed by in-memory data. +// When ctrl is non-nil, reads are gated through its channels for concurrency tests. +type fakeSeekable struct { + data []byte + failAfter int64 // >0: truncate reads at this offset; 0 = disabled + fetchCount atomic.Int64 + ctrl *testControl // nil = ungated immediate reads } -func (r *slowReader) Read(p []byte) (int, error) { - if r.pos >= len(r.data) { - return 0, io.EOF - } - - if r.delay > 0 { - time.Sleep(r.delay) - } - - end := min(r.pos+r.blockSize, len(r.data)) - - n := copy(p, r.data[r.pos:end]) - r.pos += n - - if r.pos >= len(r.data) { - return n, io.EOF - } - - return n, nil -} +var _ storage.Seekable = (*fakeSeekable)(nil) -func (r *slowReader) Close() error { - return nil +// testControl provides channel-based flow control for fakeSeekable. +type testControl struct { + advance chan struct{} // close to release reads + consumed chan struct{} // receives after each read step + opened chan struct{} // receives when OpenRangeReader is called + closed chan struct{} // receives when reader is closed (fetch done) + onOpen func() // optional callback on OpenRangeReader } -// fastUpstream simulates NFS: same interfaces but no delay. -type fastUpstream = slowUpstream +func newTestChunker(t *testing.T, file storage.Seekable, size int64) *Chunker { + t.Helper() + c, err := NewChunker(context.Background(), nil, size, testBlockSize, file, t.TempDir()+"/cache", newTestMetrics(t)) + require.NoError(t, err) -// streamingFunc adapts a function into a StreamingReader. -type streamingFunc func(ctx context.Context, off, length int64) (io.ReadCloser, error) + return c +} -func (f streamingFunc) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { - return f(ctx, off, length) +func (s *fakeSeekable) Size(_ context.Context) (int64, error) { + return int64(len(s.data)), nil } -// errorAfterNUpstream fails after reading n bytes. -type errorAfterNUpstream struct { - data []byte - failAfter int64 - blockSize int64 +func (s *fakeSeekable) StoreFile(context.Context, string, *storage.CompressConfig) (*storage.FrameTable, [32]byte, error) { + panic("not used") } -var _ storage.StreamingReader = (*errorAfterNUpstream)(nil) +func (s *fakeSeekable) OpenRangeReader(_ context.Context, offsetU int64, length int64, frameTable *storage.FrameTable) (io.ReadCloser, error) { + s.fetchCount.Add(1) -func (u *errorAfterNUpstream) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { - end := min(off+length, int64(len(u.data))) + if s.ctrl != nil { + if s.ctrl.onOpen != nil { + s.ctrl.onOpen() + } - return &errorAfterNReader{ - data: u.data[off:end], - blockSize: int(u.blockSize), - failAfter: int(u.failAfter - off), - }, nil -} + select { + case s.ctrl.opened <- struct{}{}: + default: + } -type errorAfterNReader struct { - data []byte - pos int - blockSize int - failAfter int -} + end := min(offsetU+length, int64(len(s.data))) -func (r *errorAfterNReader) Read(p []byte) (int, error) { - if r.pos >= len(r.data) { - return 0, io.EOF + return &controlledReader{ + data: s.data[offsetU:end], + step: max(defaultMinReadBatchSize, testBlockSize), + advance: s.ctrl.advance, + consumed: s.ctrl.consumed, + closed: s.ctrl.closed, + }, nil } - if r.pos >= r.failAfter { - return 0, fmt.Errorf("simulated upstream error") - } + var fetchOff, fetchLen int64 + if frameTable.IsCompressed() { + frameStart, frameSize, err := frameTable.FrameFor(offsetU) + if err != nil { + return nil, fmt.Errorf("frame lookup: %w", err) + } - end := min(r.pos+r.blockSize, len(r.data)) + fetchOff = frameStart.C + fetchLen = int64(frameSize.C) + } else { + fetchOff = offsetU + fetchLen = length + } - n := copy(p, r.data[r.pos:end]) - r.pos += n + end := min(fetchOff+fetchLen, int64(len(s.data))) + if s.failAfter > 0 { + end = min(end, s.failAfter) + } - if r.pos >= len(r.data) { - return n, io.EOF + r := io.Reader(bytes.NewReader(s.data[fetchOff:end])) + if frameTable.IsCompressed() { + return storage.NewDecompressingReader(r, frameTable.CompressionType()) } - return n, nil + return io.NopCloser(r), nil } -func (r *errorAfterNReader) Close() error { - return nil -} +func makeCompressedTestData(tb testing.TB, data []byte) (*storage.FrameTable, *fakeSeekable) { + tb.Helper() -func newTestMetrics(t *testing.T) metrics.Metrics { - t.Helper() + ft, compressed, _, err := storage.CompressBytes(context.Background(), data, &storage.CompressConfig{ + Enabled: true, + Type: "lz4", + EncoderConcurrency: 1, + FrameEncodeWorkers: 1, + FrameSizeKB: testFrameSize / 1024, + TargetPartSizeMB: 50, + }) + require.NoError(tb, err) - m, err := metrics.NewMetrics(noop.NewMeterProvider()) - require.NoError(t, err) + return ft, &fakeSeekable{data: compressed} +} - return m +type chunkerTestCase struct { + name string + newChunker func(t *testing.T, data []byte) (*Chunker, *storage.FrameTable) } -func makeTestData(t *testing.T, size int) []byte { - t.Helper() +var allChunkerTestCases = []chunkerTestCase{ + { + name: "Compressed", + newChunker: func(t *testing.T, data []byte) (*Chunker, *storage.FrameTable) { + t.Helper() + ft, getter := makeCompressedTestData(t, data) - data := make([]byte, size) - _, err := rand.Read(data) - require.NoError(t, err) + return newTestChunker(t, getter, int64(len(data))), ft + }, + }, + { + name: "Uncompressed", + newChunker: func(t *testing.T, data []byte) (*Chunker, *storage.FrameTable) { + t.Helper() - return data + return newTestChunker(t, &fakeSeekable{data: data}, int64(len(data))), nil + }, + }, } -func TestStreamingChunker_BasicSlice(t *testing.T) { +func TestChunker_BasicSlice(t *testing.T) { t.Parallel() - data := makeTestData(t, storage.MemoryChunkSize) - upstream := &fastUpstream{data: data, blockSize: testBlockSize} + for _, tc := range allChunkerTestCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() + data := makeTestData(testFileSize) + chunker, ft := tc.newChunker(t, data) + defer chunker.Close() - // Read first page - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) - require.NoError(t, err) - require.Equal(t, data[:testBlockSize], slice) + slice, err := chunker.Slice(t.Context(), 0, testBlockSize, ft) + require.NoError(t, err) + require.Equal(t, data[:testBlockSize], slice) + }) + } } -func TestStreamingChunker_CacheHit(t *testing.T) { +// TestChunker_CacheHit verifies that a second read of the same block +// is served from cache without an additional upstream fetch. +func TestChunker_CacheHit(t *testing.T) { t.Parallel() - data := makeTestData(t, storage.MemoryChunkSize) - readCount := atomic.Int64{} + data := makeTestData(testFileSize) - upstream := &countingUpstream{ - inner: &fastUpstream{data: data, blockSize: testBlockSize}, - readCount: &readCount, - } - - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) + // Uncompressed only — we need direct access to the fakeSeekable to count fetches. + file := &fakeSeekable{data: data} + chunker := newTestChunker(t, file, int64(len(data))) defer chunker.Close() - // First read: triggers fetch - _, err = chunker.Slice(t.Context(), 0, testBlockSize) + // First read triggers a fetch. + slice1, err := chunker.Slice(t.Context(), 0, testBlockSize, nil) require.NoError(t, err) + require.Equal(t, data[:testBlockSize], slice1) - // Wait for the full chunk to be fetched - time.Sleep(50 * time.Millisecond) - - firstCount := readCount.Load() - require.Positive(t, firstCount) + firstFetches := file.fetchCount.Load() + require.Positive(t, firstFetches) - // Second read: should hit cache - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) + // Second read of the same block — should hit cache. + slice2, err := chunker.Slice(t.Context(), 0, testBlockSize, nil) require.NoError(t, err) - require.Equal(t, data[:testBlockSize], slice) - - // No additional reads should have happened - assert.Equal(t, firstCount, readCount.Load()) -} - -type countingUpstream struct { - inner *fastUpstream - readCount *atomic.Int64 + require.Equal(t, data[:testBlockSize], slice2) + require.Equal(t, firstFetches, file.fetchCount.Load(), "expected no additional upstream fetch") } -var ( - _ storage.SeekableReader = (*countingUpstream)(nil) - _ storage.StreamingReader = (*countingUpstream)(nil) -) - -func (c *countingUpstream) ReadAt(ctx context.Context, buffer []byte, off int64) (int, error) { - c.readCount.Add(1) - - return c.inner.ReadAt(ctx, buffer, off) -} - -func (c *countingUpstream) Size(ctx context.Context) (int64, error) { - return c.inner.Size(ctx) -} - -func (c *countingUpstream) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { - c.readCount.Add(1) - - return c.inner.OpenRangeReader(ctx, off, length) -} - -func TestStreamingChunker_FullChunkCachedAfterPartialRequest(t *testing.T) { +// TestChunker_FullChunkCachedAfterPartialRequest verifies that requesting the +// first block triggers a full background fetch of the entire chunk/frame, so +// the last block becomes available without additional upstream fetches. +func TestChunker_FullChunkCachedAfterPartialRequest(t *testing.T) { t.Parallel() - data := makeTestData(t, storage.MemoryChunkSize) - openCount := atomic.Int64{} - - upstream := &countingUpstream{ - inner: &fastUpstream{data: data, blockSize: testBlockSize}, - readCount: &openCount, - } - - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() - - // Request only the FIRST block of the 4MB chunk. - _, err = chunker.Slice(t.Context(), 0, testBlockSize) - require.NoError(t, err) + for _, tc := range allChunkerTestCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - // The background goroutine should continue fetching the remaining data. - // Use a blocking Slice call (with timeout) instead of require.Eventually - // to avoid racing condition goroutines against defer chunker.Close(). - lastOff := int64(storage.MemoryChunkSize) - testBlockSize - ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) - defer cancel() + data := makeTestData(testFileSize) + chunker, ft := tc.newChunker(t, data) + defer chunker.Close() - slice, err := chunker.Slice(ctx, lastOff, testBlockSize) - require.NoError(t, err) - require.True(t, bytes.Equal(data[lastOff:], slice)) + _, err := chunker.Slice(t.Context(), 0, testBlockSize, ft) + require.NoError(t, err) - // Exactly one OpenRangeReader call should have been made for the entire - // chunk, not one per requested block. - assert.Equal(t, int64(1), openCount.Load(), - "expected 1 OpenRangeReader call (full chunk fetched in background), got %d", openCount.Load()) + // The second Slice joins the in-flight session (or hits + // cache if the fetch already completed). Either way it blocks + // until the data is available — no polling needed. + lastOff := int64(testFileSize) - testBlockSize + slice, err := chunker.Slice(t.Context(), lastOff, testBlockSize, ft) + require.NoError(t, err) + require.Equal(t, data[lastOff:lastOff+testBlockSize], slice) + }) + } } -func TestStreamingChunker_ConcurrentSameChunk(t *testing.T) { +// TestChunker_ConcurrentSameChunk verifies that concurrent requests for the same +// chunk don't cause duplicate upstream fetches. +func TestChunker_ConcurrentSameChunk(t *testing.T) { t.Parallel() - data := makeTestData(t, storage.MemoryChunkSize) - // Use a slow upstream so requests will overlap - upstream := &slowUpstream{ - data: data, - blockSize: testBlockSize, - delay: 50 * time.Microsecond, - } + data := makeTestData(testFileSize) - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) + var fetchCount atomic.Int64 + chunker := newControlledChunker(t, data) + chunker.onOpen = func() { fetchCount.Add(1) } defer chunker.Close() - numGoroutines := 10 - offsets := make([]int64, numGoroutines) - for i := range numGoroutines { - offsets[i] = int64(i) * testBlockSize - } - - results := make([][]byte, numGoroutines) + const numGoroutines = 10 var eg errgroup.Group - - for i := range numGoroutines { + started := make(chan struct{}) + for range numGoroutines { eg.Go(func() error { - slice, err := chunker.Slice(t.Context(), offsets[i], testBlockSize) - if err != nil { - return fmt.Errorf("goroutine %d failed: %w", i, err) - } - results[i] = make([]byte, len(slice)) - copy(results[i], slice) + <-started + _, sliceErr := chunker.Slice(t.Context(), 0, testBlockSize, nil) - return nil + return sliceErr }) } + // Release goroutines, wait for the fetch to start (blocked on advance), + // then release data. + close(started) + <-chunker.opened + close(chunker.advance) + require.NoError(t, eg.Wait()) - for i := range numGoroutines { - require.Equal(t, data[offsets[i]:offsets[i]+testBlockSize], results[i], - "goroutine %d got wrong data", i) - } + require.Equal(t, int64(1), fetchCount.Load(), + "expected 1 fetch (dedup), got %d", fetchCount.Load()) } -func TestStreamingChunker_EarlyReturn(t *testing.T) { +func TestChunker_EarlyReturn(t *testing.T) { t.Parallel() - type testCase struct { - name string - blockSize int64 - delay time.Duration - // blockIndices are block indices within the chunk, listed in the - // expected completion order (earlier blocks are notified first). - blockIndices []int - } - - cases := []testCase{ - { - name: "hugepage", - blockSize: header.HugepageSize, // 2MB → 2 blocks per 4MB chunk - delay: 50 * time.Millisecond, - blockIndices: []int{0, 1}, - }, - { - name: "4K", - blockSize: header.PageSize, // 4KB → 1024 blocks per 4MB chunk - delay: 100 * time.Microsecond, - blockIndices: []int{1, 512, 1022}, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - t.Parallel() - - data := makeTestData(t, storage.MemoryChunkSize) - - gate := make(chan struct{}) - upstream := streamingFunc(func(_ context.Context, off, length int64) (io.ReadCloser, error) { - <-gate - end := min(off+length, int64(len(data))) - - return &slowReader{ - data: data[off:end], - blockSize: int(tc.blockSize), - delay: tc.delay, - }, nil - }) - - chunker, err := NewStreamingChunker( - int64(len(data)), tc.blockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() - - n := len(tc.blockIndices) - completionOrder := make(chan int, n) - - var eg errgroup.Group - for i, blockIdx := range tc.blockIndices { - off := int64(blockIdx) * tc.blockSize - eg.Go(func() error { - _, err := chunker.Slice(t.Context(), off, tc.blockSize) - if err != nil { - return fmt.Errorf("request %d (block %d) failed: %w", i, blockIdx, err) - } - completionOrder <- i - - return nil - }) - } - - // Let all goroutines register as waiters before the fetch begins. - time.Sleep(10 * time.Millisecond) - close(gate) + data := makeTestData(testFileSize) + chunker := newControlledChunker(t, data) + defer chunker.Close() - require.NoError(t, eg.Wait()) - close(completionOrder) + lastOff := int64(len(data)) - testBlockSize - got := make([]int, 0, n) - for idx := range completionOrder { - got = append(got, idx) - } - - expected := make([]int, n) - for i := range expected { - expected[i] = i - } + type result struct { + data []byte + err error + } - assert.Equal(t, expected, got, - "requests should complete in offset order (earlier blocks first)") - }) + earlyDone := make(chan result, 1) + lateDone := make(chan result, 1) + + go func() { + slice, sliceErr := chunker.Slice(t.Context(), 0, testBlockSize, nil) + earlyDone <- result{data: bytes.Clone(slice), err: sliceErr} // clone: slice backed by mutable mmap + }() + go func() { + slice, sliceErr := chunker.Slice(t.Context(), lastOff, testBlockSize, nil) + lateDone <- result{data: bytes.Clone(slice), err: sliceErr} + }() + + // Release reads, wait for one block to be consumed. + close(chunker.advance) + <-chunker.consumed + + // Offset 0 is within the first readSize — should be available now. + r := <-earlyDone + require.NoError(t, r.err) + require.Equal(t, data[:testBlockSize], r.data) + + // Last offset hasn't been reached yet. + select { + case <-lateDone: + t.Fatal("late reader completed before its data was delivered") + default: } + + // Fetch completes (advance is closed), late reader unblocks. + r = <-lateDone + require.NoError(t, r.err) + require.Equal(t, data[lastOff:lastOff+testBlockSize], r.data) } -func TestStreamingChunker_ErrorKeepsPartialData(t *testing.T) { +// TestChunker_ErrorKeepsPartialData verifies that an upstream error at the +// midpoint of a chunk still allows data before the error to be served. +func TestChunker_ErrorKeepsPartialData(t *testing.T) { t.Parallel() - chunkSize := storage.MemoryChunkSize - data := makeTestData(t, chunkSize) - failAfter := int64(chunkSize / 2) // Fail at 2MB - - upstream := &errorAfterNUpstream{ - data: data, - failAfter: failAfter, - blockSize: testBlockSize, - } + data := makeTestData(testFileSize) - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) + chunker := newTestChunker(t, &fakeSeekable{data: data, failAfter: int64(testFileSize / 2)}, int64(len(data))) defer chunker.Close() - // Request the last page — this should fail because upstream dies at 2MB - lastOff := int64(chunkSize) - testBlockSize - _, err = chunker.Slice(t.Context(), lastOff, testBlockSize) + lastOff := int64(testFileSize) - testBlockSize + _, err := chunker.Slice(t.Context(), lastOff, testBlockSize, nil) require.Error(t, err) - // But first page (within first 2MB) should still be cached and servable - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) + slice, err := chunker.Slice(t.Context(), 0, testBlockSize, nil) require.NoError(t, err) require.Equal(t, data[:testBlockSize], slice) } -func TestStreamingChunker_ContextCancellation(t *testing.T) { +// TestChunker_ContextCancellation verifies that a cancelled caller context +// doesn't kill the background fetch — another caller can still get data. +func TestChunker_ContextCancellation(t *testing.T) { t.Parallel() - data := makeTestData(t, storage.MemoryChunkSize) - upstream := &slowUpstream{ - data: data, - blockSize: testBlockSize, - delay: 1 * time.Millisecond, - } - - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) + data := makeTestData(testFileSize) + chunker := newControlledChunker(t, data) defer chunker.Close() - // Request with a context that we'll cancel quickly - ctx, cancel := context.WithTimeout(t.Context(), 1*time.Millisecond) - defer cancel() - - lastOff := int64(storage.MemoryChunkSize) - testBlockSize - _, err = chunker.Slice(ctx, lastOff, testBlockSize) - // This should fail with context cancellation - require.Error(t, err) + ctx, cancel := context.WithCancel(t.Context()) - // But another caller with a valid context should still get the data - // because the fetch goroutine uses background context - time.Sleep(200 * time.Millisecond) // Wait for fetch to complete - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) - require.NoError(t, err) - require.Equal(t, data[:testBlockSize], slice) -} + done := make(chan error, 1) + go func() { + _, sliceErr := chunker.Slice(ctx, 0, testBlockSize, nil) + done <- sliceErr + }() -func TestStreamingChunker_LastBlockPartial(t *testing.T) { - t.Parallel() + // Wait for the fetch goroutine to be blocked on the reader, then cancel. + <-chunker.opened + cancel() - // File size not aligned to blockSize - size := storage.MemoryChunkSize - 100 - data := makeTestData(t, size) - upstream := &fastUpstream{data: data, blockSize: testBlockSize} - - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() + require.Error(t, <-done) - // Read the last partial block - lastBlockOff := (int64(size) / testBlockSize) * testBlockSize - remaining := int64(size) - lastBlockOff + // Release the fetch — it runs with context.WithoutCancel so it continues. + close(chunker.advance) + <-chunker.closed - slice, err := chunker.Slice(t.Context(), lastBlockOff, remaining) + // Fetch completed — data is now cached. + slice, err := chunker.Slice(t.Context(), 0, testBlockSize, nil) require.NoError(t, err) - require.Equal(t, data[lastBlockOff:], slice) + require.Equal(t, data[:testBlockSize], slice) } -func TestStreamingChunker_MultiChunkSlice(t *testing.T) { +// TestChunker_LastBlockPartial verifies correct handling of a file whose size +// is not aligned to blockSize — the final block is shorter than blockSize. +func TestChunker_LastBlockPartial(t *testing.T) { t.Parallel() - // Two 4MB chunks - size := storage.MemoryChunkSize * 2 - data := makeTestData(t, size) - upstream := &fastUpstream{data: data, blockSize: testBlockSize} - - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() + size := testFileSize - 100 + data := makeTestData(size) - // Request spanning two chunks: last page of chunk 0 + first page of chunk 1 - off := int64(storage.MemoryChunkSize) - testBlockSize - length := testBlockSize * 2 + for _, tc := range allChunkerTestCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - slice, err := chunker.Slice(t.Context(), off, int64(length)) - require.NoError(t, err) - require.Equal(t, data[off:off+int64(length)], slice) + chunker, ft := tc.newChunker(t, data) + defer chunker.Close() + + lastBlockOff := (int64(size) / testBlockSize) * testBlockSize + remaining := int64(size) - lastBlockOff + + slice, err := chunker.Slice(t.Context(), lastBlockOff, remaining, ft) + require.NoError(t, err) + require.Equal(t, data[lastBlockOff:], slice) + }) + } } -// panicUpstream panics during Read after delivering a configurable number of bytes. -type panicUpstream struct { +// panicSeekable panics during Read after delivering panicAfter bytes. +type panicSeekable struct { data []byte - blockSize int64 - panicAfter int64 // byte offset at which to panic (0 = panic immediately) + panicAfter int64 +} + +var _ storage.Seekable = (*panicSeekable)(nil) + +func (s *panicSeekable) Size(_ context.Context) (int64, error) { + return int64(len(s.data)), nil } -var _ storage.StreamingReader = (*panicUpstream)(nil) +func (s *panicSeekable) StoreFile(context.Context, string, *storage.CompressConfig) (*storage.FrameTable, [32]byte, error) { + panic("not used") +} -func (u *panicUpstream) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { - end := min(off+length, int64(len(u.data))) +func (s *panicSeekable) OpenRangeReader(_ context.Context, off int64, length int64, _ *storage.FrameTable) (io.ReadCloser, error) { + end := min(off+length, int64(len(s.data))) return &panicReader{ - data: u.data[off:end], - blockSize: int(u.blockSize), - panicAfter: int(u.panicAfter - off), + data: s.data[off:end], + panicAfter: int(s.panicAfter - off), }, nil } type panicReader struct { data []byte pos int - blockSize int panicAfter int } @@ -600,7 +446,7 @@ func (r *panicReader) Read(p []byte) (int, error) { return 0, io.EOF } - end := min(r.pos+r.blockSize, len(r.data)) + end := min(r.pos+len(p), len(r.data)) n := copy(p, r.data[r.pos:end]) r.pos += n @@ -611,340 +457,125 @@ func (r *panicReader) Close() error { return nil } -func TestStreamingChunker_PanicRecovery(t *testing.T) { +func TestChunker_PanicRecovery(t *testing.T) { t.Parallel() - data := makeTestData(t, storage.MemoryChunkSize) - panicAt := int64(storage.MemoryChunkSize / 2) // Panic at 2MB - - upstream := &panicUpstream{ - data: data, - blockSize: testBlockSize, - panicAfter: panicAt, - } + data := makeTestData(testFileSize) + panicAt := int64(testFileSize / 2) - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) + chunker := newTestChunker(t, &panicSeekable{data: data, panicAfter: panicAt}, int64(len(data))) defer chunker.Close() // Request data past the panic point — should get an error, not hang or crash - lastOff := int64(storage.MemoryChunkSize) - testBlockSize - _, err = chunker.Slice(t.Context(), lastOff, testBlockSize) + lastOff := int64(testFileSize) - testBlockSize + _, err := chunker.Slice(t.Context(), lastOff, testBlockSize, nil) require.Error(t, err) - assert.Contains(t, err.Error(), "panicked") // Data before the panic point should still be cached - slice, err := chunker.Slice(t.Context(), 0, testBlockSize) + slice, err := chunker.Slice(t.Context(), 0, testBlockSize, nil) require.NoError(t, err) require.Equal(t, data[:testBlockSize], slice) } -func TestStreamingChunker_ConcurrentSameChunk_SharedSession(t *testing.T) { +func TestChunker_ConcurrentStress(t *testing.T) { t.Parallel() - data := makeTestData(t, storage.MemoryChunkSize) - - gate := make(chan struct{}) - openCount := atomic.Int64{} - - // OpenRangeReader blocks on the gate, keeping the session in fetchMap - // until both callers have entered. This removes the scheduling-dependent - // race in the old slow-upstream version of this test. - upstream := streamingFunc(func(_ context.Context, off, length int64) (io.ReadCloser, error) { - openCount.Add(1) - <-gate - - end := min(off+length, int64(len(data))) - - return io.NopCloser(bytes.NewReader(data[off:end])), nil - }) - - chunker, err := NewStreamingChunker( - int64(len(data)), testBlockSize, - upstream, t.TempDir()+"/cache", - newTestMetrics(t), - 0, nil, - ) - require.NoError(t, err) - defer chunker.Close() - - // Two different ranges inside the same 4MB chunk. - offA := int64(0) - offB := int64(storage.MemoryChunkSize) - testBlockSize // last block - - var eg errgroup.Group - var sliceA, sliceB []byte - - eg.Go(func() error { - s, err := chunker.Slice(t.Context(), offA, testBlockSize) - if err != nil { - return err - } - sliceA = make([]byte, len(s)) - copy(sliceA, s) - - return nil - }) - eg.Go(func() error { - s, err := chunker.Slice(t.Context(), offB, testBlockSize) - if err != nil { - return err - } - sliceB = make([]byte, len(s)) - copy(sliceB, s) + for _, tc := range allChunkerTestCases { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() - return nil - }) + data := makeTestData(testFileSize) + chunker, ft := tc.newChunker(t, data) + defer chunker.Close() - // Let both goroutines enter getOrCreateSession, then release the fetch. - time.Sleep(10 * time.Millisecond) - close(gate) + const numGoroutines = 50 + const opsPerGoroutine = 5 + readLen := int64(testBlockSize) - require.NoError(t, eg.Wait()) + var eg errgroup.Group - assert.Equal(t, data[offA:offA+testBlockSize], sliceA) - assert.Equal(t, data[offB:offB+testBlockSize], sliceB) - assert.Equal(t, int64(1), openCount.Load(), - "expected exactly 1 OpenRangeReader call (shared session), got %d", openCount.Load()) -} - -// --- Benchmarks --- -// -// Uses a bandwidth-limited upstream with real time.Sleep to simulate GCS and -// NFS backends. Measures actual wall-clock latency per caller. -// -// Backend parameters (tuned to match observed production latencies): -// GCS: 20ms TTFB + 100 MB/s → 4MB chunk ≈ 62ms (observed ~60ms) -// NFS: 1ms TTFB + 500 MB/s → 4MB chunk ≈ 9ms (observed ~9-10ms) -// -// All sub-benchmarks share a pre-generated offset sequence so results are -// directly comparable across chunker types and backends. -// -// Recommended invocation (~1 minute): -// go test -bench BenchmarkRandomAccess -benchtime 150x -count=3 -run '^$' ./... - -func newBenchmarkMetrics(b *testing.B) metrics.Metrics { - b.Helper() + for i := range numGoroutines { + eg.Go(func() error { + for j := range opsPerGoroutine { + off := int64(((i*opsPerGoroutine)+j)%(len(data)/int(readLen))) * readLen + slice, err := chunker.Slice(t.Context(), off, readLen, ft) + if err != nil { + return fmt.Errorf("goroutine %d op %d: %w", i, j, err) + } + if !bytes.Equal(data[off:off+readLen], slice) { + return fmt.Errorf("goroutine %d op %d: data mismatch at off=%d", i, j, off) + } + } - m, err := metrics.NewMetrics(noop.NewMeterProvider()) - require.NoError(b, err) + return nil + }) + } - return m + require.NoError(t, eg.Wait()) + }) + } } -// realisticUpstream simulates a storage backend with configurable time-to-first-byte -// and bandwidth. ReadAt blocks for the full transfer duration (bulk fetch model). -// OpenRangeReader returns a bandwidth-limited progressive reader. -type realisticUpstream struct { - data []byte - blockSize int64 - ttfb time.Duration - bytesPerSec float64 +// controlledChunker wraps a Chunker with channel-based flow control for tests. +// advance gates reads; opened/consumed/closed signal fetch lifecycle events. +type controlledChunker struct { + *Chunker + *testControl } -var ( - _ storage.SeekableReader = (*realisticUpstream)(nil) - _ storage.StreamingReader = (*realisticUpstream)(nil) -) - -func (u *realisticUpstream) ReadAt(_ context.Context, buffer []byte, off int64) (int, error) { - transferTime := time.Duration(float64(len(buffer)) / u.bytesPerSec * float64(time.Second)) - time.Sleep(u.ttfb + transferTime) +func newControlledChunker(t *testing.T, data []byte) *controlledChunker { + t.Helper() - end := min(off+int64(len(buffer)), int64(len(u.data))) - n := copy(buffer, u.data[off:end]) + ctrl := &testControl{ + advance: make(chan struct{}), + consumed: make(chan struct{}, 10), + opened: make(chan struct{}, 10), + closed: make(chan struct{}, 10), + } - return n, nil -} + file := &fakeSeekable{data: data, ctrl: ctrl} -func (u *realisticUpstream) Size(_ context.Context) (int64, error) { - return int64(len(u.data)), nil + return &controlledChunker{ + Chunker: newTestChunker(t, file, int64(len(data))), + testControl: ctrl, + } } -func (u *realisticUpstream) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { - end := min(off+length, int64(len(u.data))) - - return &bandwidthReader{ - data: u.data[off:end], - blockSize: int(u.blockSize), - ttfb: u.ttfb, - bytesPerSec: u.bytesPerSec, - }, nil +// controlledReader yields data in fixed-size steps, blocking on advance +// before each Read. After advance is closed, reads proceed immediately. +type controlledReader struct { + data []byte + pos int + step int + advance chan struct{} + consumed chan struct{} + closed chan struct{} } -// bandwidthReader delivers data at a steady rate after an initial TTFB delay. -// Uses cumulative timing (time since first byte) so OS scheduling jitter does -// not compound across blocks. -type bandwidthReader struct { - data []byte - pos int - blockSize int - ttfb time.Duration - bytesPerSec float64 - startTime time.Time - started bool -} - -func (r *bandwidthReader) Read(p []byte) (int, error) { - if !r.started { - r.started = true - time.Sleep(r.ttfb) - r.startTime = time.Now() - } - +func (r *controlledReader) Read(p []byte) (int, error) { if r.pos >= len(r.data) { return 0, io.EOF } - end := min(r.pos+r.blockSize, len(r.data)) + <-r.advance + + end := min(r.pos+min(len(p), r.step), len(r.data)) n := copy(p, r.data[r.pos:end]) r.pos += n - // Enforce bandwidth: sleep until this many bytes should have arrived. - expectedArrival := r.startTime.Add(time.Duration(float64(r.pos) / r.bytesPerSec * float64(time.Second))) - if wait := time.Until(expectedArrival); wait > 0 { - time.Sleep(wait) - } - - if r.pos >= len(r.data) { - return n, io.EOF + select { + case r.consumed <- struct{}{}: + default: } return n, nil } -func (r *bandwidthReader) Close() error { - return nil -} - -type benchChunker interface { - Slice(ctx context.Context, off, length int64) ([]byte, error) - Close() error -} - -func BenchmarkRandomAccess(b *testing.B) { - size := int64(storage.MemoryChunkSize) - data := make([]byte, size) - - backends := []struct { - name string - upstream *realisticUpstream - }{ - { - name: "GCS", - upstream: &realisticUpstream{ - data: data, - blockSize: testBlockSize, - ttfb: 20 * time.Millisecond, - bytesPerSec: 100e6, // 100 MB/s — full 4MB chunk ≈ 62ms (observed ~60ms) - }, - }, - { - name: "NFS", - upstream: &realisticUpstream{ - data: data, - blockSize: testBlockSize, - ttfb: 1 * time.Millisecond, - bytesPerSec: 500e6, // 500 MB/s — full 4MB chunk ≈ 9ms (observed ~9-10ms) - }, - }, - } - - chunkerTypes := []struct { - name string - newChunker func(b *testing.B, m metrics.Metrics, upstream *realisticUpstream) benchChunker - }{ - { - name: "StreamingChunker", - newChunker: func(b *testing.B, m metrics.Metrics, upstream *realisticUpstream) benchChunker { - b.Helper() - c, err := NewStreamingChunker(size, testBlockSize, upstream, b.TempDir()+"/cache", m, 0, nil) - require.NoError(b, err) - - return c - }, - }, - { - name: "FullFetchChunker", - newChunker: func(b *testing.B, m metrics.Metrics, upstream *realisticUpstream) benchChunker { - b.Helper() - c, err := NewFullFetchChunker(size, testBlockSize, upstream, b.TempDir()+"/cache", m) - require.NoError(b, err) - - return c - }, - }, - } - - // Realistic concurrency: UFFD faults are limited by vCPU count (typically - // 1-2 for Firecracker VMs) and NBD requests are largely sequential. - const numCallers = 3 - - // Pre-generate a fixed sequence of random offsets so all sub-benchmarks - // use identical access patterns, making results directly comparable. - const maxIters = 500 - numBlocks := size / testBlockSize - rng := mathrand.New(mathrand.NewPCG(42, 0)) - - allOffsets := make([][]int64, maxIters) - for i := range allOffsets { - offsets := make([]int64, numCallers) - for j := range offsets { - offsets[j] = rng.Int64N(numBlocks) * testBlockSize - } - allOffsets[i] = offsets +func (r *controlledReader) Close() error { + select { + case r.closed <- struct{}{}: + default: } - for _, backend := range backends { - for _, ct := range chunkerTypes { - b.Run(backend.name+"/"+ct.name, func(b *testing.B) { - m := newBenchmarkMetrics(b) - - b.ReportMetric(0, "ns/op") - - var sumAvg, sumMax float64 - - for i := range b.N { - offsets := allOffsets[i%maxIters] - - chunker := ct.newChunker(b, m, backend.upstream) - - latencies := make([]time.Duration, numCallers) - - var eg errgroup.Group - for ci, off := range offsets { - eg.Go(func() error { - start := time.Now() - _, err := chunker.Slice(context.Background(), off, testBlockSize) - latencies[ci] = time.Since(start) - - return err - }) - } - require.NoError(b, eg.Wait()) - - var totalLatency time.Duration - var maxLatency time.Duration - for _, l := range latencies { - totalLatency += l - maxLatency = max(maxLatency, l) - } - - avgUs := float64(totalLatency.Microseconds()) / float64(numCallers) - sumAvg += avgUs - sumMax = max(sumMax, float64(maxLatency.Microseconds())) - - chunker.Close() - } - - b.ReportMetric(sumAvg/float64(b.N), "avg-us/caller") - b.ReportMetric(sumMax, "worst-us/caller") - }) - } - } + return nil } diff --git a/packages/orchestrator/pkg/sandbox/build/build.go b/packages/orchestrator/pkg/sandbox/build/build.go index 149caab3d6..6fa33a3579 100644 --- a/packages/orchestrator/pkg/sandbox/build/build.go +++ b/packages/orchestrator/pkg/sandbox/build/build.go @@ -2,8 +2,10 @@ package build import ( "context" + "errors" "fmt" "io" + "sync/atomic" "github.com/google/uuid" @@ -14,7 +16,8 @@ import ( ) type File struct { - header *header.Header + header atomic.Pointer[header.Header] + swapFailed atomic.Bool // set if header deserialization fails during P2P transition store *DiffStore fileType DiffType persistence storage.StorageProvider @@ -28,25 +31,34 @@ func NewFile( persistence storage.StorageProvider, metrics blockmetrics.Metrics, ) *File { - return &File{ - header: header, + f := &File{ store: store, fileType: fileType, persistence: persistence, metrics: metrics, } + f.header.Store(header) + + return f +} + +// Header returns the current header. After a peer transition the header may +// have been atomically swapped to a V4 header containing FrameTables. +func (b *File) Header() *header.Header { + return b.header.Load() } func (b *File) ReadAt(ctx context.Context, p []byte, off int64) (n int, err error) { for n < len(p) { - mappedOffset, mappedLength, buildID, err := b.header.GetShiftedMapping(ctx, off+int64(n)) + h := b.header.Load() + + mappedToBuild, err := h.GetShiftedMapping(ctx, off+int64(n)) if err != nil { return 0, fmt.Errorf("failed to get mapping: %w", err) } remainingReadLength := int64(len(p)) - int64(n) - - readLength := min(mappedLength, remainingReadLength) + readLength := min(int64(mappedToBuild.Length), remainingReadLength) if readLength <= 0 { logger.L().Error(ctx, fmt.Sprintf( @@ -54,13 +66,13 @@ func (b *File) ReadAt(ctx context.Context, p []byte, off int64) (n int, err erro len(p)-n, off, readLength, - buildID, + mappedToBuild.BuildId, b.fileType, - mappedOffset, + mappedToBuild.Offset, n, int64(n)+readLength, n, - mappedLength, + mappedToBuild.Length, remainingReadLength, )) @@ -70,22 +82,33 @@ func (b *File) ReadAt(ctx context.Context, p []byte, off int64) (n int, err erro // Skip reading when the uuid is nil. // We will use this to handle base builds that are already diffs. // The passed slice p must start as empty, otherwise we would need to copy the empty values there. - if *buildID == uuid.Nil { + if mappedToBuild.BuildId == uuid.Nil { n += int(readLength) continue } - mappedBuild, err := b.getBuild(ctx, buildID) + size := b.buildFileSize(h, mappedToBuild.BuildId) + mappedBuild, err := b.getBuild(ctx, mappedToBuild.BuildId, size, mappedToBuild.FrameTable.CompressionType()) if err != nil { return 0, fmt.Errorf("failed to get build: %w", err) } buildN, err := mappedBuild.ReadAt(ctx, p[n:int64(n)+readLength], - mappedOffset, + int64(mappedToBuild.Offset), + mappedToBuild.FrameTable, ) if err != nil { + var transErr *storage.PeerTransitionedError + if errors.As(err, &transErr) && !b.swapFailed.Load() { + if swapErr := b.swapHeader(transErr); swapErr != nil { + return 0, fmt.Errorf("failed to swap header: %w", swapErr) + } + + continue // retry with the new header + } + return 0, fmt.Errorf("failed to read from source: %w", err) } @@ -97,32 +120,94 @@ func (b *File) ReadAt(ctx context.Context, p []byte, off int64) (n int, err erro // The slice access must be in the predefined blocksize of the build. func (b *File) Slice(ctx context.Context, off, _ int64) ([]byte, error) { - mappedOffset, _, buildID, err := b.header.GetShiftedMapping(ctx, off) - if err != nil { - return nil, fmt.Errorf("failed to get mapping: %w", err) + for { + h := b.header.Load() + + mappedBuild, err := h.GetShiftedMapping(ctx, off) + if err != nil { + return nil, fmt.Errorf("failed to get mapping: %w", err) + } + + // Pass empty huge page when the build id is nil. + if mappedBuild.BuildId == uuid.Nil { + return header.EmptyHugePage, nil + } + + size := b.buildFileSize(h, mappedBuild.BuildId) + diff, err := b.getBuild(ctx, mappedBuild.BuildId, size, mappedBuild.FrameTable.CompressionType()) + if err != nil { + return nil, fmt.Errorf("failed to get build: %w", err) + } + + result, err := diff.Slice(ctx, int64(mappedBuild.Offset), int64(h.Metadata.BlockSize), mappedBuild.FrameTable) + if err != nil { + var transErr *storage.PeerTransitionedError + if errors.As(err, &transErr) && !b.swapFailed.Load() { + if swapErr := b.swapHeader(transErr); swapErr != nil { + return nil, fmt.Errorf("failed to swap header: %w", swapErr) + } + + continue // retry with the new header + } + + return nil, err + } + + return result, nil } +} - // Pass empty huge page when the build id is nil. - if *buildID == uuid.Nil { - return header.EmptyHugePage, nil +// swapHeader atomically replaces the header when the peer signals upload +// completion. Only the first goroutine to CAS succeeds; others just retry +// with the already-swapped header. On deserialization failure, marks the +// swap as failed so the ReadAt/Slice loop doesn't retry indefinitely. +func (b *File) swapHeader(transErr *storage.PeerTransitionedError) error { + var headerBytes []byte + + switch b.fileType { + case Memfile: + headerBytes = transErr.MemfileHeader + case Rootfs: + headerBytes = transErr.RootfsHeader } - build, err := b.getBuild(ctx, buildID) + if len(headerBytes) == 0 { + return fmt.Errorf("no header bytes available") + } + + newH, err := header.DeserializeBytes(headerBytes) if err != nil { - return nil, fmt.Errorf("failed to get build: %w", err) + b.swapFailed.Store(true) + + return fmt.Errorf("failed to swap header: %w", err) + } + + old := b.header.Load() + b.header.CompareAndSwap(old, newH) + + return nil +} + +// buildFileSize returns the uncompressed file size for buildID from the +// header's BuildFiles map. Returns 0 for V3 headers (no BuildFiles), which +// signals the read path to fall back to a Size() RPC. +func (b *File) buildFileSize(h *header.Header, buildID uuid.UUID) int64 { + if info, ok := h.BuildFiles[buildID]; ok { + return info.Size } - return build.Slice(ctx, mappedOffset, int64(b.header.Metadata.BlockSize)) + return 0 } -func (b *File) getBuild(ctx context.Context, buildID *uuid.UUID) (Diff, error) { +func (b *File) getBuild(ctx context.Context, buildID uuid.UUID, uncompressedSize int64, ct storage.CompressionType) (Diff, error) { storageDiff, err := newStorageDiff( b.store.cachePath, buildID.String(), b.fileType, - int64(b.header.Metadata.BlockSize), + int64(b.Header().Metadata.BlockSize), b.metrics, b.persistence, + uncompressedSize, ct, b.store.flags, ) if err != nil { diff --git a/packages/orchestrator/pkg/sandbox/build/diff.go b/packages/orchestrator/pkg/sandbox/build/diff.go index b817235aa9..a895ae7bea 100644 --- a/packages/orchestrator/pkg/sandbox/build/diff.go +++ b/packages/orchestrator/pkg/sandbox/build/diff.go @@ -27,10 +27,11 @@ const ( type Diff interface { io.Closer storage.SeekableReader - block.Slicer + block.FramedSlicer CacheKey() DiffStoreKey CachePath() (string, error) FileSize() (int64, error) + BlockSize() int64 Init(ctx context.Context) error } @@ -42,7 +43,7 @@ func (n *NoDiff) CachePath() (string, error) { return "", NoDiffError{} } -func (n *NoDiff) Slice(_ context.Context, _, _ int64) ([]byte, error) { +func (n *NoDiff) Slice(_ context.Context, _, _ int64, _ *storage.FrameTable) ([]byte, error) { return nil, NoDiffError{} } @@ -50,7 +51,7 @@ func (n *NoDiff) Close() error { return nil } -func (n *NoDiff) ReadAt(_ context.Context, _ []byte, _ int64) (int, error) { +func (n *NoDiff) ReadAt(_ context.Context, _ []byte, _ int64, _ *storage.FrameTable) (int, error) { return 0, NoDiffError{} } diff --git a/packages/orchestrator/pkg/sandbox/build/local_diff.go b/packages/orchestrator/pkg/sandbox/build/local_diff.go index df5fec4ea7..117d5ebf2a 100644 --- a/packages/orchestrator/pkg/sandbox/build/local_diff.go +++ b/packages/orchestrator/pkg/sandbox/build/local_diff.go @@ -6,6 +6,7 @@ import ( "os" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) type LocalDiffFile struct { @@ -114,16 +115,16 @@ func (b *localDiff) Close() error { return b.cache.Close() } -func (b *localDiff) ReadAt(_ context.Context, p []byte, off int64) (int, error) { +func (b *localDiff) ReadAt(_ context.Context, p []byte, off int64, _ *storage.FrameTable) (int, error) { return b.cache.ReadAt(p, off) } -func (b *localDiff) Slice(_ context.Context, off, length int64) ([]byte, error) { +func (b *localDiff) Slice(_ context.Context, off, length int64, _ *storage.FrameTable) ([]byte, error) { return b.cache.Slice(off, length) } func (b *localDiff) Size(_ context.Context) (int64, error) { - return b.cache.Size() + return b.FileSize() } func (b *localDiff) FileSize() (int64, error) { diff --git a/packages/orchestrator/pkg/sandbox/build/mocks/mockdiff.go b/packages/orchestrator/pkg/sandbox/build/mocks/mockdiff.go index ea61e38b25..b52ed79aad 100644 --- a/packages/orchestrator/pkg/sandbox/build/mocks/mockdiff.go +++ b/packages/orchestrator/pkg/sandbox/build/mocks/mockdiff.go @@ -8,6 +8,7 @@ import ( "context" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/build" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" mock "github.com/stretchr/testify/mock" ) @@ -328,8 +329,8 @@ func (_c *MockDiff_Init_Call) RunAndReturn(run func(ctx context.Context) error) } // ReadAt provides a mock function for the type MockDiff -func (_mock *MockDiff) ReadAt(ctx context.Context, buffer []byte, off int64) (int, error) { - ret := _mock.Called(ctx, buffer, off) +func (_mock *MockDiff) ReadAt(ctx context.Context, buffer []byte, off int64, ft *storage.FrameTable) (int, error) { + ret := _mock.Called(ctx, buffer, off, ft) if len(ret) == 0 { panic("no return value specified for ReadAt") @@ -337,16 +338,16 @@ func (_mock *MockDiff) ReadAt(ctx context.Context, buffer []byte, off int64) (in var r0 int var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, []byte, int64) (int, error)); ok { - return returnFunc(ctx, buffer, off) + if returnFunc, ok := ret.Get(0).(func(context.Context, []byte, int64, *storage.FrameTable) (int, error)); ok { + return returnFunc(ctx, buffer, off, ft) } - if returnFunc, ok := ret.Get(0).(func(context.Context, []byte, int64) int); ok { - r0 = returnFunc(ctx, buffer, off) + if returnFunc, ok := ret.Get(0).(func(context.Context, []byte, int64, *storage.FrameTable) int); ok { + r0 = returnFunc(ctx, buffer, off, ft) } else { r0 = ret.Get(0).(int) } - if returnFunc, ok := ret.Get(1).(func(context.Context, []byte, int64) error); ok { - r1 = returnFunc(ctx, buffer, off) + if returnFunc, ok := ret.Get(1).(func(context.Context, []byte, int64, *storage.FrameTable) error); ok { + r1 = returnFunc(ctx, buffer, off, ft) } else { r1 = ret.Error(1) } @@ -362,11 +363,12 @@ type MockDiff_ReadAt_Call struct { // - ctx context.Context // - buffer []byte // - off int64 -func (_e *MockDiff_Expecter) ReadAt(ctx interface{}, buffer interface{}, off interface{}) *MockDiff_ReadAt_Call { - return &MockDiff_ReadAt_Call{Call: _e.mock.On("ReadAt", ctx, buffer, off)} +// - ft *storage.FrameTable +func (_e *MockDiff_Expecter) ReadAt(ctx interface{}, buffer interface{}, off interface{}, ft interface{}) *MockDiff_ReadAt_Call { + return &MockDiff_ReadAt_Call{Call: _e.mock.On("ReadAt", ctx, buffer, off, ft)} } -func (_c *MockDiff_ReadAt_Call) Run(run func(ctx context.Context, buffer []byte, off int64)) *MockDiff_ReadAt_Call { +func (_c *MockDiff_ReadAt_Call) Run(run func(ctx context.Context, buffer []byte, off int64, ft *storage.FrameTable)) *MockDiff_ReadAt_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -380,10 +382,15 @@ func (_c *MockDiff_ReadAt_Call) Run(run func(ctx context.Context, buffer []byte, if args[2] != nil { arg2 = args[2].(int64) } + var arg3 *storage.FrameTable + if args[3] != nil { + arg3 = args[3].(*storage.FrameTable) + } run( arg0, arg1, arg2, + arg3, ) }) return _c @@ -394,7 +401,7 @@ func (_c *MockDiff_ReadAt_Call) Return(n int, err error) *MockDiff_ReadAt_Call { return _c } -func (_c *MockDiff_ReadAt_Call) RunAndReturn(run func(ctx context.Context, buffer []byte, off int64) (int, error)) *MockDiff_ReadAt_Call { +func (_c *MockDiff_ReadAt_Call) RunAndReturn(run func(ctx context.Context, buffer []byte, off int64, ft *storage.FrameTable) (int, error)) *MockDiff_ReadAt_Call { _c.Call.Return(run) return _c } @@ -460,8 +467,8 @@ func (_c *MockDiff_Size_Call) RunAndReturn(run func(ctx context.Context) (int64, } // Slice provides a mock function for the type MockDiff -func (_mock *MockDiff) Slice(ctx context.Context, off int64, length int64) ([]byte, error) { - ret := _mock.Called(ctx, off, length) +func (_mock *MockDiff) Slice(ctx context.Context, off int64, length int64, ft *storage.FrameTable) ([]byte, error) { + ret := _mock.Called(ctx, off, length, ft) if len(ret) == 0 { panic("no return value specified for Slice") @@ -469,18 +476,18 @@ func (_mock *MockDiff) Slice(ctx context.Context, off int64, length int64) ([]by var r0 []byte var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, int64, int64) ([]byte, error)); ok { - return returnFunc(ctx, off, length) + if returnFunc, ok := ret.Get(0).(func(context.Context, int64, int64, *storage.FrameTable) ([]byte, error)); ok { + return returnFunc(ctx, off, length, ft) } - if returnFunc, ok := ret.Get(0).(func(context.Context, int64, int64) []byte); ok { - r0 = returnFunc(ctx, off, length) + if returnFunc, ok := ret.Get(0).(func(context.Context, int64, int64, *storage.FrameTable) []byte); ok { + r0 = returnFunc(ctx, off, length, ft) } else { if ret.Get(0) != nil { r0 = ret.Get(0).([]byte) } } - if returnFunc, ok := ret.Get(1).(func(context.Context, int64, int64) error); ok { - r1 = returnFunc(ctx, off, length) + if returnFunc, ok := ret.Get(1).(func(context.Context, int64, int64, *storage.FrameTable) error); ok { + r1 = returnFunc(ctx, off, length, ft) } else { r1 = ret.Error(1) } @@ -496,11 +503,12 @@ type MockDiff_Slice_Call struct { // - ctx context.Context // - off int64 // - length int64 -func (_e *MockDiff_Expecter) Slice(ctx interface{}, off interface{}, length interface{}) *MockDiff_Slice_Call { - return &MockDiff_Slice_Call{Call: _e.mock.On("Slice", ctx, off, length)} +// - ft *storage.FrameTable +func (_e *MockDiff_Expecter) Slice(ctx interface{}, off interface{}, length interface{}, ft interface{}) *MockDiff_Slice_Call { + return &MockDiff_Slice_Call{Call: _e.mock.On("Slice", ctx, off, length, ft)} } -func (_c *MockDiff_Slice_Call) Run(run func(ctx context.Context, off int64, length int64)) *MockDiff_Slice_Call { +func (_c *MockDiff_Slice_Call) Run(run func(ctx context.Context, off int64, length int64, ft *storage.FrameTable)) *MockDiff_Slice_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -514,10 +522,15 @@ func (_c *MockDiff_Slice_Call) Run(run func(ctx context.Context, off int64, leng if args[2] != nil { arg2 = args[2].(int64) } + var arg3 *storage.FrameTable + if args[3] != nil { + arg3 = args[3].(*storage.FrameTable) + } run( arg0, arg1, arg2, + arg3, ) }) return _c @@ -528,7 +541,7 @@ func (_c *MockDiff_Slice_Call) Return(bytes []byte, err error) *MockDiff_Slice_C return _c } -func (_c *MockDiff_Slice_Call) RunAndReturn(run func(ctx context.Context, off int64, length int64) ([]byte, error)) *MockDiff_Slice_Call { +func (_c *MockDiff_Slice_Call) RunAndReturn(run func(ctx context.Context, off int64, length int64, ft *storage.FrameTable) ([]byte, error)) *MockDiff_Slice_Call { _c.Call.Return(run) return _c } diff --git a/packages/orchestrator/pkg/sandbox/build/storage_diff.go b/packages/orchestrator/pkg/sandbox/build/storage_diff.go index eca9b11bb8..55942436e8 100644 --- a/packages/orchestrator/pkg/sandbox/build/storage_diff.go +++ b/packages/orchestrator/pkg/sandbox/build/storage_diff.go @@ -3,7 +3,6 @@ package build import ( "context" "fmt" - "io" "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block" blockmetrics "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/block/metrics" @@ -12,21 +11,18 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) -func storagePath(buildId string, diffType DiffType) string { - return fmt.Sprintf("%s/%s", buildId, diffType) -} - type StorageDiff struct { - chunker *utils.SetOnce[block.Chunker] + chunker *utils.SetOnce[*block.Chunker] cachePath string cacheKey DiffStoreKey storagePath string storageObjectType storage.SeekableObjectType - blockSize int64 - metrics blockmetrics.Metrics - persistence storage.StorageProvider - featureFlags *featureflags.Client + blockSize int64 + metrics blockmetrics.Metrics + persistence storage.StorageProvider + featureFlags *featureflags.Client + uncompressedSize int64 // 0 means unknown (fall back to Size() call) } var _ Diff = (*StorageDiff)(nil) @@ -46,9 +42,10 @@ func newStorageDiff( blockSize int64, metrics blockmetrics.Metrics, persistence storage.StorageProvider, - featureFlags *featureflags.Client, + uncompressedSize int64, + ct storage.CompressionType, + ff *featureflags.Client, ) (*StorageDiff, error) { - storagePath := storagePath(buildId, diffType) storageObjectType, ok := storageObjectType(diffType) if !ok { return nil, UnknownDiffTypeError{diffType} @@ -57,14 +54,15 @@ func newStorageDiff( cachePath := GenerateDiffCachePath(basePath, buildId, diffType) return &StorageDiff{ - storagePath: storagePath, + storagePath: storage.Paths{BuildID: buildId}.DataFile(string(diffType), ct), storageObjectType: storageObjectType, cachePath: cachePath, - chunker: utils.NewSetOnce[block.Chunker](), + chunker: utils.NewSetOnce[*block.Chunker](), blockSize: blockSize, metrics: metrics, persistence: persistence, - featureFlags: featureFlags, + featureFlags: ff, + uncompressedSize: uncompressedSize, cacheKey: GetDiffStoreKey(buildId, diffType), }, nil } @@ -90,12 +88,15 @@ func (b *StorageDiff) Init(ctx context.Context) error { return err } - size, err := obj.Size(ctx) - if err != nil { - errMsg := fmt.Errorf("failed to get object size: %w", err) - b.chunker.SetError(errMsg) + size := b.uncompressedSize + if size == 0 { + size, err = obj.Size(ctx) + if err != nil { + errMsg := fmt.Errorf("failed to get object size: %w", err) + b.chunker.SetError(errMsg) - return errMsg + return errMsg + } } c, err := block.NewChunker(ctx, b.featureFlags, size, b.blockSize, obj, b.cachePath, b.metrics) @@ -118,31 +119,22 @@ func (b *StorageDiff) Close() error { return c.Close() } -func (b *StorageDiff) ReadAt(ctx context.Context, p []byte, off int64) (int, error) { +func (b *StorageDiff) ReadAt(ctx context.Context, p []byte, off int64, ft *storage.FrameTable) (int, error) { c, err := b.chunker.Wait() if err != nil { return 0, err } - return c.ReadAt(ctx, p, off) + return c.ReadAt(ctx, p, off, ft) } -func (b *StorageDiff) Slice(ctx context.Context, off, length int64) ([]byte, error) { +func (b *StorageDiff) Slice(ctx context.Context, off, length int64, ft *storage.FrameTable) ([]byte, error) { c, err := b.chunker.Wait() if err != nil { return nil, err } - return c.Slice(ctx, off, length) -} - -func (b *StorageDiff) WriteTo(ctx context.Context, w io.Writer) (int64, error) { - c, err := b.chunker.Wait() - if err != nil { - return 0, err - } - - return c.WriteTo(ctx, w) + return c.Slice(ctx, off, length, ft) } // The local file might not be synced. diff --git a/packages/orchestrator/pkg/sandbox/build_upload.go b/packages/orchestrator/pkg/sandbox/build_upload.go new file mode 100644 index 0000000000..542164ac4e --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/build_upload.go @@ -0,0 +1,211 @@ +package sandbox + +import ( + "context" + "fmt" + "sync" + + "github.com/google/uuid" + "golang.org/x/sync/errgroup" + + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox/build" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" + headers "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +// BuildUploader uploads a paused snapshot's files to storage. +type BuildUploader interface { + // UploadData uploads data files, snapfile, and metadata. + UploadData(ctx context.Context) error + // FinalizeHeaders uploads final headers after all upstream layers are done. + // Returns serialized V4 header bytes for peer transition (nil for uncompressed). + FinalizeHeaders(ctx context.Context) (memfileHeader, rootfsHeader []byte, err error) +} + +// NewBuildUploader creates a BuildUploader for the given snapshot. +// If cfg is non-nil, compression is used (V4 headers). Otherwise, uncompressed (V3 headers). +// pending is shared across layers for multi-layer builds; nil is fine for single-layer. +func NewBuildUploader(snapshot *Snapshot, persistence storage.StorageProvider, paths storage.Paths, cfg *storage.CompressConfig, pending *PendingBuildInfo) BuildUploader { + base := buildUploader{ + paths: paths, + persistence: persistence, + snapshot: snapshot, + } + + if cfg != nil { + if pending == nil { + pending = &PendingBuildInfo{} + } + + return &compressedUploader{ + buildUploader: base, + pending: pending, + cfg: cfg, + } + } + + return &uncompressedUploader{buildUploader: base} +} + +// buildUploader contains fields and helpers shared by both implementations. +type buildUploader struct { + paths storage.Paths + persistence storage.StorageProvider + snapshot *Snapshot +} + +// diffPath returns the cache path for a diff, or nil if the diff is NoDiff. +func diffPath(d build.Diff) (*string, error) { + if _, ok := d.(*build.NoDiff); ok { + return nil, nil + } + + p, err := d.CachePath() + if err != nil { + return nil, err + } + + return &p, nil +} + +func (b *buildUploader) uploadUncompressedFile(ctx context.Context, local, remote string, objType storage.SeekableObjectType) error { + object, err := b.persistence.OpenSeekable(ctx, remote, objType) + if err != nil { + return err + } + + if _, _, err := object.StoreFile(ctx, local, nil); err != nil { + return fmt.Errorf("error when uploading %s: %w", remote, err) + } + + return nil +} + +// Snap-file is small enough so we don't use composite upload. +func (b *buildUploader) uploadSnapfile(ctx context.Context, path string) error { + object, err := b.persistence.OpenBlob(ctx, b.paths.Snapfile(), storage.SnapfileObjectType) + if err != nil { + return err + } + + if err = uploadFileAsBlob(ctx, object, path); err != nil { + return fmt.Errorf("error when uploading snapfile: %w", err) + } + + return nil +} + +// Metadata is small enough so we don't use composite upload. +func (b *buildUploader) uploadMetadata(ctx context.Context, path string) error { + object, err := b.persistence.OpenBlob(ctx, b.paths.Metadata(), storage.MetadataObjectType) + if err != nil { + return err + } + + if err := uploadFileAsBlob(ctx, object, path); err != nil { + return fmt.Errorf("error when uploading metadata: %w", err) + } + + return nil +} + +func (b *buildUploader) uploadCompressedFile(ctx context.Context, local, remote string, objType storage.SeekableObjectType, cfg *storage.CompressConfig) (*storage.FrameTable, [32]byte, error) { + object, err := b.persistence.OpenSeekable(ctx, remote, objType) + if err != nil { + return nil, [32]byte{}, fmt.Errorf("error opening framed file for %s: %w", remote, err) + } + + ft, checksum, err := object.StoreFile(ctx, local, cfg) + if err != nil { + return nil, [32]byte{}, fmt.Errorf("error compressing %s to %s: %w", local, remote, err) + } + + return ft, checksum, nil +} + +func (b *buildUploader) scheduleAlwaysUploads(eg *errgroup.Group, ctx context.Context) { + eg.Go(func() error { + return b.uploadSnapfile(ctx, b.snapshot.Snapfile.Path()) + }) + + eg.Go(func() error { + return b.uploadMetadata(ctx, b.snapshot.Metafile.Path()) + }) +} + +// pendingBuildInfo pairs a FrameTable with the uncompressed file size and +// uncompressed-data checksum so all can be stored in the header after uploads complete. +type pendingBuildInfo struct { + ft *storage.FrameTable + fileSize int64 + checksum [32]byte +} + +// PendingBuildInfo collects FrameTables and file sizes from compressed data +// uploads across all layers. After all data files are uploaded, the collected +// tables are applied to headers before the compressed headers are serialized +// and uploaded. +type PendingBuildInfo sync.Map + +func pendingBuildInfoKey(buildID, fileType string) string { + return buildID + "/" + fileType +} + +func (p *PendingBuildInfo) add(key string, ft *storage.FrameTable, fileSize int64, checksum [32]byte) { + if ft == nil { + return + } + + (*sync.Map)(p).Store(key, pendingBuildInfo{ft: ft, fileSize: fileSize, checksum: checksum}) +} + +func (p *PendingBuildInfo) get(key string) *pendingBuildInfo { + v, ok := (*sync.Map)(p).Load(key) + if !ok { + return nil + } + + info, ok := v.(pendingBuildInfo) + if !ok { + return nil + } + + return &info +} + +func (p *PendingBuildInfo) applyToHeader(h *headers.Header, fileType string) error { + if h == nil { + return nil + } + + // Track frame cursor per build to avoid O(N²) rescanning. + cursors := make(map[string]int) + + for _, mapping := range h.Mapping { + key := pendingBuildInfoKey(mapping.BuildId.String(), fileType) + info := p.get(key) + + if info == nil { + continue + } + + cursor := cursors[key] + next, err := mapping.SetFramesFrom(info.ft, cursor) + if err != nil { + return fmt.Errorf("apply frames to mapping at offset %d for build %s: %w", + mapping.Offset, mapping.BuildId.String(), err) + } + cursors[key] = next + + // Populate BuildFiles with size and checksum for this build. + if h.BuildFiles == nil { + h.BuildFiles = make(map[uuid.UUID]headers.BuildFileInfo) + } + h.BuildFiles[mapping.BuildId] = headers.BuildFileInfo{ + Size: info.fileSize, + Checksum: info.checksum, + } + } + + return nil +} diff --git a/packages/orchestrator/pkg/sandbox/build_upload_v3.go b/packages/orchestrator/pkg/sandbox/build_upload_v3.go new file mode 100644 index 0000000000..06a4768a7e --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/build_upload_v3.go @@ -0,0 +1,79 @@ +package sandbox + +import ( + "context" + "fmt" + + "golang.org/x/sync/errgroup" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage" + headers "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +// uncompressedUploader implements BuildUploader for V3 (uncompressed) builds. +type uncompressedUploader struct { + buildUploader +} + +func (u *uncompressedUploader) UploadData(ctx context.Context) error { + memfilePath, err := diffPath(u.snapshot.MemfileDiff) + if err != nil { + return fmt.Errorf("error getting memfile diff path: %w", err) + } + + rootfsPath, err := diffPath(u.snapshot.RootfsDiff) + if err != nil { + return fmt.Errorf("error getting rootfs diff path: %w", err) + } + + eg, ctx := errgroup.WithContext(ctx) + + // V3 headers + eg.Go(func() error { + if u.snapshot.MemfileDiffHeader == nil { + return nil + } + + _, err := headers.StoreHeader(ctx, u.persistence, u.paths.MemfileHeader(), u.snapshot.MemfileDiffHeader) + + return err + }) + + eg.Go(func() error { + if u.snapshot.RootfsDiffHeader == nil { + return nil + } + + _, err := headers.StoreHeader(ctx, u.persistence, u.paths.RootfsHeader(), u.snapshot.RootfsDiffHeader) + + return err + }) + + // Uncompressed data + eg.Go(func() error { + if memfilePath == nil { + return nil + } + + return u.uploadUncompressedFile(ctx, *memfilePath, u.paths.Memfile(), storage.MemfileObjectType) + }) + + eg.Go(func() error { + if rootfsPath == nil { + return nil + } + + return u.uploadUncompressedFile(ctx, *rootfsPath, u.paths.Rootfs(), storage.RootFSObjectType) + }) + + u.scheduleAlwaysUploads(eg, ctx) + + return eg.Wait() +} + +func (u *uncompressedUploader) FinalizeHeaders(context.Context) ([]byte, []byte, error) { + return nil, nil, nil +} + +// Ensure uncompressedUploader implements BuildUploader. +var _ BuildUploader = (*uncompressedUploader)(nil) diff --git a/packages/orchestrator/pkg/sandbox/build_upload_v4.go b/packages/orchestrator/pkg/sandbox/build_upload_v4.go new file mode 100644 index 0000000000..4b5f776334 --- /dev/null +++ b/packages/orchestrator/pkg/sandbox/build_upload_v4.go @@ -0,0 +1,130 @@ +package sandbox + +import ( + "context" + "fmt" + + "golang.org/x/sync/errgroup" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage" + headers "github.com/e2b-dev/infra/packages/shared/pkg/storage/header" +) + +// compressedUploader implements BuildUploader for V4 (compressed) builds. +type compressedUploader struct { + buildUploader + + pending *PendingBuildInfo + cfg *storage.CompressConfig +} + +func (c *compressedUploader) UploadData(ctx context.Context) error { + memfilePath, err := diffPath(c.snapshot.MemfileDiff) + if err != nil { + return fmt.Errorf("error getting memfile diff path: %w", err) + } + + rootfsPath, err := diffPath(c.snapshot.RootfsDiff) + if err != nil { + return fmt.Errorf("error getting rootfs diff path: %w", err) + } + + eg, ctx := errgroup.WithContext(ctx) + + if memfilePath != nil { + localPath := *memfilePath + eg.Go(func() error { + ft, checksum, err := c.uploadCompressedFile(ctx, localPath, c.paths.MemfileCompressed(c.cfg.CompressionType()), storage.MemfileObjectType, c.cfg) + if err != nil { + return fmt.Errorf("compressed memfile upload: %w", err) + } + + uncompressedSize, _ := ft.Size() + c.pending.add(pendingBuildInfoKey(c.paths.BuildID, storage.MemfileName), ft, uncompressedSize, checksum) + + return nil + }) + } + + if rootfsPath != nil { + localPath := *rootfsPath + eg.Go(func() error { + ft, checksum, err := c.uploadCompressedFile(ctx, localPath, c.paths.RootfsCompressed(c.cfg.CompressionType()), storage.RootFSObjectType, c.cfg) + if err != nil { + return fmt.Errorf("compressed rootfs upload: %w", err) + } + if ft == nil { + return fmt.Errorf("compressed rootfs upload returned nil FrameTable") + } + + uncompressedSize, _ := ft.Size() + c.pending.add(pendingBuildInfoKey(c.paths.BuildID, storage.RootfsName), ft, uncompressedSize, checksum) + + return nil + }) + } + + c.scheduleAlwaysUploads(eg, ctx) + + return eg.Wait() +} + +// FinalizeHeaders applies pending frame tables to headers and uploads them as V4 format. +// +// The snapshot headers are cloned before mutation because the originals may be +// concurrently read by sandboxes resumed from the template cache (e.g. the +// optimize phase's UFFD handlers). +func (c *compressedUploader) FinalizeHeaders(ctx context.Context) (memfileHeader, rootfsHeader []byte, err error) { + eg, ctx := errgroup.WithContext(ctx) + + if c.snapshot.MemfileDiffHeader != nil { + eg.Go(func() error { + h := c.snapshot.MemfileDiffHeader.CloneForUpload() + + if err := c.pending.applyToHeader(h, storage.MemfileName); err != nil { + return fmt.Errorf("apply frames to memfile header: %w", err) + } + + h.Metadata.Version = headers.MetadataVersionCompressed + + data, err := headers.StoreHeader(ctx, c.persistence, c.paths.MemfileHeader(), h) + if err != nil { + return err + } + + memfileHeader = data + + return nil + }) + } + + if c.snapshot.RootfsDiffHeader != nil { + eg.Go(func() error { + h := c.snapshot.RootfsDiffHeader.CloneForUpload() + + if err := c.pending.applyToHeader(h, storage.RootfsName); err != nil { + return fmt.Errorf("apply frames to rootfs header: %w", err) + } + + h.Metadata.Version = headers.MetadataVersionCompressed + + data, err := headers.StoreHeader(ctx, c.persistence, c.paths.RootfsHeader(), h) + if err != nil { + return err + } + + rootfsHeader = data + + return nil + }) + } + + if err = eg.Wait(); err != nil { + return nil, nil, err + } + + return memfileHeader, rootfsHeader, nil +} + +// Ensure compressedUploader implements BuildUploader. +var _ BuildUploader = (*compressedUploader)(nil) diff --git a/packages/orchestrator/pkg/sandbox/nbd/dispatch.go b/packages/orchestrator/pkg/sandbox/nbd/dispatch.go index 3a40e79c71..ad051e3f64 100644 --- a/packages/orchestrator/pkg/sandbox/nbd/dispatch.go +++ b/packages/orchestrator/pkg/sandbox/nbd/dispatch.go @@ -11,13 +11,13 @@ import ( "go.uber.org/zap" "github.com/e2b-dev/infra/packages/shared/pkg/logger" - "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) var ErrShuttingDown = errors.New("shutting down. Cannot serve any new requests") type Provider interface { - storage.SeekableReader + ReadAt(ctx context.Context, p []byte, off int64) (int, error) + Size(ctx context.Context) (int64, error) io.WriterAt } diff --git a/packages/orchestrator/pkg/sandbox/template/peerclient/blob.go b/packages/orchestrator/pkg/sandbox/template/peerclient/blob.go index c566158f74..d1d5dcf474 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerclient/blob.go +++ b/packages/orchestrator/pkg/sandbox/template/peerclient/blob.go @@ -85,7 +85,7 @@ func openPeerBlobStream( ctx context.Context, client orchestrator.ChunkServiceClient, req *orchestrator.GetBuildBlobRequest, - uploaded *atomic.Bool, + uploaded *atomic.Pointer[UploadedHeaders], ) (func() ([]byte, error), error) { stream, err := client.GetBuildBlob(ctx, req) if err != nil { diff --git a/packages/orchestrator/pkg/sandbox/template/peerclient/blob_test.go b/packages/orchestrator/pkg/sandbox/template/peerclient/blob_test.go index 350d0d972d..b8587ca582 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerclient/blob_test.go +++ b/packages/orchestrator/pkg/sandbox/template/peerclient/blob_test.go @@ -15,8 +15,6 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/grpc/orchestrator" orchestratormocks "github.com/e2b-dev/infra/packages/shared/pkg/grpc/orchestrator/mocks" "github.com/e2b-dev/infra/packages/shared/pkg/storage" - storagemocks "github.com/e2b-dev/infra/packages/shared/pkg/storage/mocks" - providermocks "github.com/e2b-dev/infra/packages/shared/pkg/storage/mocks/provider" ) func TestPeerBlob_WriteTo_PeerSucceeds(t *testing.T) { @@ -36,7 +34,7 @@ func TestPeerBlob_WriteTo_PeerSucceeds(t *testing.T) { client: client, buildID: "build-1", fileName: "snapfile", - uploaded: &atomic.Bool{}, + uploaded: &atomic.Pointer[UploadedHeaders]{}, }} var buf bytes.Buffer @@ -55,21 +53,20 @@ func TestPeerBlob_WriteTo_PeerNotAvailable_FallsBackToBase(t *testing.T) { client := orchestratormocks.NewMockChunkServiceClient(t) client.EXPECT().GetBuildBlob(mock.Anything, mock.Anything).Return(stream, nil) - baseBlob := storagemocks.NewMockBlob(t) + baseBlob := storage.NewMockBlob(t) baseBlob.EXPECT().WriteTo(mock.Anything, mock.Anything).RunAndReturn(func(_ context.Context, dst io.Writer) (int64, error) { n, err := dst.Write([]byte("from gcs")) return int64(n), err }) - - base := providermocks.NewMockStorageProvider(t) + base := storage.NewMockStorageProvider(t) base.EXPECT().OpenBlob(mock.Anything, "build-1/snapfile", storage.SnapfileObjectType).Return(baseBlob, nil) blob := &peerBlob{peerHandle: peerHandle[storage.Blob]{ client: client, buildID: "build-1", fileName: "snapfile", - uploaded: &atomic.Bool{}, + uploaded: &atomic.Pointer[UploadedHeaders]{}, openFn: func(ctx context.Context) (storage.Blob, error) { return base.OpenBlob(ctx, "build-1/snapfile", storage.SnapfileObjectType) }, @@ -88,21 +85,20 @@ func TestPeerBlob_WriteTo_PeerError_FallsBackToBase(t *testing.T) { client := orchestratormocks.NewMockChunkServiceClient(t) client.EXPECT().GetBuildBlob(mock.Anything, mock.Anything).Return(nil, errors.New("connection refused")) - baseBlob := storagemocks.NewMockBlob(t) + baseBlob := storage.NewMockBlob(t) baseBlob.EXPECT().WriteTo(mock.Anything, mock.Anything).RunAndReturn(func(_ context.Context, dst io.Writer) (int64, error) { n, err := dst.Write([]byte("from gcs")) return int64(n), err }) - - base := providermocks.NewMockStorageProvider(t) + base := storage.NewMockStorageProvider(t) base.EXPECT().OpenBlob(mock.Anything, "build-1/snapfile", storage.SnapfileObjectType).Return(baseBlob, nil) blob := &peerBlob{peerHandle: peerHandle[storage.Blob]{ client: client, buildID: "build-1", fileName: "snapfile", - uploaded: &atomic.Bool{}, + uploaded: &atomic.Pointer[UploadedHeaders]{}, openFn: func(ctx context.Context) (storage.Blob, error) { return base.OpenBlob(ctx, "build-1/snapfile", storage.SnapfileObjectType) }, @@ -117,14 +113,14 @@ func TestPeerBlob_WriteTo_PeerError_FallsBackToBase(t *testing.T) { func TestPeerBlob_WriteTo_UploadedSetMidStream_CompletesFromPeerThenFallsBack(t *testing.T) { t.Parallel() - uploaded := &atomic.Bool{} + uploaded := &atomic.Pointer[UploadedHeaders]{} // Peer streams three chunks; the second Recv sets uploaded=true // (simulating a concurrent operation receiving UseStorage). stream := orchestratormocks.NewMockChunkService_GetBuildBlobClient(t) stream.EXPECT().Recv().Return(&orchestrator.GetBuildBlobResponse{Data: []byte("aaa")}, nil).Once() stream.EXPECT().Recv().RunAndReturn(func() (*orchestrator.GetBuildBlobResponse, error) { - uploaded.Store(true) + uploaded.Store(&UploadedHeaders{}) return &orchestrator.GetBuildBlobResponse{Data: []byte("bbb")}, nil }).Once() @@ -134,14 +130,13 @@ func TestPeerBlob_WriteTo_UploadedSetMidStream_CompletesFromPeerThenFallsBack(t client := orchestratormocks.NewMockChunkServiceClient(t) client.EXPECT().GetBuildBlob(mock.Anything, mock.Anything).Return(stream, nil).Once() - baseBlob := storagemocks.NewMockBlob(t) + baseBlob := storage.NewMockBlob(t) baseBlob.EXPECT().WriteTo(mock.Anything, mock.Anything).RunAndReturn(func(_ context.Context, dst io.Writer) (int64, error) { n, err := dst.Write([]byte("from storage")) return int64(n), err }) - - base := providermocks.NewMockStorageProvider(t) + base := storage.NewMockStorageProvider(t) base.EXPECT().OpenBlob(mock.Anything, "build-1/snapfile", storage.SnapfileObjectType).Return(baseBlob, nil) blob := &peerBlob{peerHandle: peerHandle[storage.Blob]{ @@ -160,7 +155,7 @@ func TestPeerBlob_WriteTo_UploadedSetMidStream_CompletesFromPeerThenFallsBack(t require.NoError(t, err) assert.Equal(t, int64(9), n1) assert.Equal(t, "aaabbbccc", buf1.String()) - assert.True(t, uploaded.Load()) + assert.NotNil(t, uploaded.Load()) // Second download: uploaded is now true, skips peer and goes to base storage. var buf2 bytes.Buffer @@ -178,7 +173,7 @@ func TestPeerBlob_Exists_PeerHasFile(t *testing.T) { return req.GetBuildId() == "build-1" && req.GetFileName() == "snapfile" })).Return(&orchestrator.GetBuildFileExistsResponse{}, nil) - blob := &peerBlob{peerHandle: peerHandle[storage.Blob]{client: client, buildID: "build-1", fileName: "snapfile", uploaded: &atomic.Bool{}}} + blob := &peerBlob{peerHandle: peerHandle[storage.Blob]{client: client, buildID: "build-1", fileName: "snapfile", uploaded: &atomic.Pointer[UploadedHeaders]{}}} ok, err := blob.Exists(t.Context()) require.NoError(t, err) assert.True(t, ok) @@ -190,17 +185,16 @@ func TestPeerBlob_Exists_PeerNotAvailable_FallsBackToBase(t *testing.T) { client := orchestratormocks.NewMockChunkServiceClient(t) client.EXPECT().GetBuildFileExists(mock.Anything, mock.Anything).Return(&orchestrator.GetBuildFileExistsResponse{Availability: &orchestrator.PeerAvailability{NotAvailable: true}}, nil) - baseBlob := storagemocks.NewMockBlob(t) + baseBlob := storage.NewMockBlob(t) baseBlob.EXPECT().Exists(mock.Anything).Return(true, nil) - - base := providermocks.NewMockStorageProvider(t) + base := storage.NewMockStorageProvider(t) base.EXPECT().OpenBlob(mock.Anything, "build-1/snapfile", storage.SnapfileObjectType).Return(baseBlob, nil) blob := &peerBlob{peerHandle: peerHandle[storage.Blob]{ client: client, buildID: "build-1", fileName: "snapfile", - uploaded: &atomic.Bool{}, + uploaded: &atomic.Pointer[UploadedHeaders]{}, openFn: func(ctx context.Context) (storage.Blob, error) { return base.OpenBlob(ctx, "build-1/snapfile", storage.SnapfileObjectType) }, @@ -217,13 +211,12 @@ func TestPeerBlob_Exists_UseStorage_FallsBackToBase(t *testing.T) { client := orchestratormocks.NewMockChunkServiceClient(t) client.EXPECT().GetBuildFileExists(mock.Anything, mock.Anything).Return(&orchestrator.GetBuildFileExistsResponse{Availability: &orchestrator.PeerAvailability{UseStorage: true}}, nil) - baseBlob := storagemocks.NewMockBlob(t) + baseBlob := storage.NewMockBlob(t) baseBlob.EXPECT().Exists(mock.Anything).Return(true, nil) - - base := providermocks.NewMockStorageProvider(t) + base := storage.NewMockStorageProvider(t) base.EXPECT().OpenBlob(mock.Anything, "build-1/snapfile", storage.SnapfileObjectType).Return(baseBlob, nil) - uploaded := &atomic.Bool{} + uploaded := &atomic.Pointer[UploadedHeaders]{} blob := &peerBlob{peerHandle: peerHandle[storage.Blob]{ client: client, buildID: "build-1", @@ -237,5 +230,5 @@ func TestPeerBlob_Exists_UseStorage_FallsBackToBase(t *testing.T) { ok, err := blob.Exists(t.Context()) require.NoError(t, err) assert.True(t, ok) - assert.True(t, uploaded.Load(), "uploaded flag should be set after UseStorage response") + assert.NotNil(t, uploaded.Load(), "uploaded flag should be set after UseStorage response") } diff --git a/packages/orchestrator/pkg/sandbox/template/peerclient/resolver.go b/packages/orchestrator/pkg/sandbox/template/peerclient/resolver.go index 49bd708bd7..999ccf5c2a 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerclient/resolver.go +++ b/packages/orchestrator/pkg/sandbox/template/peerclient/resolver.go @@ -30,9 +30,17 @@ type Resolver interface { Close() } +// UploadedHeaders holds the serialized V4 headers received from the peer's +// use_storage response. These are used by build.File to atomically swap headers +// when transitioning from P2P to compressed GCS reads. +type UploadedHeaders struct { + MemfileHeader []byte + RootfsHeader []byte +} + type resolveResult struct { client orchestrator.ChunkServiceClient - uploaded *atomic.Bool + uploaded *atomic.Pointer[UploadedHeaders] addr string } @@ -49,11 +57,11 @@ func (nopResolver) Close() {} // peerResolver is the real implementation that looks up peers via the Registry. type peerResolver struct { - registry Registry - selfAddress string - peerConns sync.Map // address → *grpc.ClientConn - uploadedBuilds sync.Map // buildID → *atomic.Bool - dialGroup singleflight.Group + registry Registry + selfAddress string + peerConns sync.Map // address → *grpc.ClientConn + uploaded sync.Map // buildID → *atomic.Pointer[UploadedHeaders] + dialGroup singleflight.Group } func NewResolver(registry Registry, selfAddress string) Resolver { @@ -104,32 +112,33 @@ func (r *peerResolver) isSelfAddress(address string) bool { return address == r.selfAddress } -// uploadedFlag returns a shared atomic flag for the given build ID. -// Once any reader sets the flag (via use_storage), all subsequent opens for -// that build skip the peer. -func (r *peerResolver) uploadedFlag(buildID string) *atomic.Bool { - if v, ok := r.uploadedBuilds.Load(buildID); ok { - return v.(*atomic.Bool) +// uploadedPtr returns a shared atomic pointer for the given build ID. +// Non-nil value means the build is uploaded (use_storage). The UploadedHeaders +// may contain serialized V4 headers for the peer transition protocol, or be +// empty (for uncompressed builds). +func (r *peerResolver) uploadedPtr(buildID string) *atomic.Pointer[UploadedHeaders] { + if v, ok := r.uploaded.Load(buildID); ok { + return v.(*atomic.Pointer[UploadedHeaders]) } - flag := &atomic.Bool{} - actual, _ := r.uploadedBuilds.LoadOrStore(buildID, flag) + ptr := &atomic.Pointer[UploadedHeaders]{} + actual, _ := r.uploaded.LoadOrStore(buildID, ptr) - return actual.(*atomic.Bool) + return actual.(*atomic.Pointer[UploadedHeaders]) } // Purge removes the uploaded state for a build, called on template // cache eviction so the entry doesn't accumulate forever. func (r *peerResolver) Purge(buildID string) { - r.uploadedBuilds.Delete(buildID) + r.uploaded.Delete(buildID) } // resolve looks up the peer for the given build and returns a gRPC client if // a remote peer is found. Returns a nil client when the base provider should // be used instead (uploaded, no peer, self, or error). func (r *peerResolver) resolve(ctx context.Context, buildID string) (attribute.KeyValue, resolveResult) { - uploaded := r.uploadedFlag(buildID) - if uploaded.Load() { + hdrs := r.uploadedPtr(buildID) + if hdrs.Load() != nil { return attrResolveUploaded, resolveResult{} } @@ -153,7 +162,7 @@ func (r *peerResolver) resolve(ctx context.Context, buildID string) (attribute.K return attrResolvePeer, resolveResult{ client: orchestrator.NewChunkServiceClient(conn), - uploaded: uploaded, + uploaded: hdrs, addr: addr, } } diff --git a/packages/orchestrator/pkg/sandbox/template/peerclient/seekable.go b/packages/orchestrator/pkg/sandbox/template/peerclient/seekable.go index df71917a00..5de4e6d4f5 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerclient/seekable.go +++ b/packages/orchestrator/pkg/sandbox/template/peerclient/seekable.go @@ -83,12 +83,18 @@ func (s *peerSeekable) ReadAt(ctx context.Context, buf []byte, off int64) (int, return peerAttempt[int]{value: n, bytes: int64(n), hit: true}, nil }, func(ctx context.Context, base storage.Seekable) (int, error) { - return base.ReadAt(ctx, buf, off) + rc, err := base.OpenRangeReader(ctx, off, int64(len(buf)), nil) + if err != nil { + return 0, err + } + defer rc.Close() + + return io.ReadFull(rc, buf) }, ) } -func (s *peerSeekable) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { +func (s *peerSeekable) OpenRangeReader(ctx context.Context, off int64, length int64, frameTable *storage.FrameTable) (io.ReadCloser, error) { return withPeerFallback(ctx, &s.peerHandle, "peer-seekable-open-range-reader", attrOpRangeReader, func(ctx context.Context) (peerAttempt[io.ReadCloser], error) { streamCtx, cancel := context.WithCancel(ctx) @@ -112,19 +118,29 @@ func (s *peerSeekable) OpenRangeReader(ctx context.Context, off, length int64) ( }, nil }, func(ctx context.Context, base storage.Seekable) (io.ReadCloser, error) { - return base.OpenRangeReader(ctx, off, length) + // Signal the caller to swap to V4 headers if compressed headers are available. + if s.uploaded != nil { + if hdrs := s.uploaded.Load(); hdrs != nil && (len(hdrs.MemfileHeader) > 0 || len(hdrs.RootfsHeader) > 0) { + return nil, &storage.PeerTransitionedError{ + MemfileHeader: hdrs.MemfileHeader, + RootfsHeader: hdrs.RootfsHeader, + } + } + } + + return base.OpenRangeReader(ctx, off, length, frameTable) }, ) } -func (s *peerSeekable) StoreFile(ctx context.Context, path string) error { +func (s *peerSeekable) StoreFile(ctx context.Context, path string, cfg *storage.CompressConfig) (*storage.FrameTable, [32]byte, error) { // Writes always go to the base provider (GCS/S3); the peer is read-only. fallback, err := s.getOrOpenBase(ctx) if err != nil { - return err + return nil, [32]byte{}, err } - return fallback.StoreFile(ctx, path) + return fallback.StoreFile(ctx, path, cfg) } // openPeerSeekableStream opens a ReadAtBuildSeekable stream, checks peer availability, @@ -133,7 +149,7 @@ func openPeerSeekableStream( ctx context.Context, client orchestrator.ChunkServiceClient, req *orchestrator.ReadAtBuildSeekableRequest, - uploaded *atomic.Bool, + uploaded *atomic.Pointer[UploadedHeaders], ) (func() ([]byte, error), error) { stream, err := client.ReadAtBuildSeekable(ctx, req) if err != nil { diff --git a/packages/orchestrator/pkg/sandbox/template/peerclient/seekable_test.go b/packages/orchestrator/pkg/sandbox/template/peerclient/seekable_test.go index 2c0c913a07..995eaf32d5 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerclient/seekable_test.go +++ b/packages/orchestrator/pkg/sandbox/template/peerclient/seekable_test.go @@ -15,8 +15,6 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/grpc/orchestrator" orchestratormocks "github.com/e2b-dev/infra/packages/shared/pkg/grpc/orchestrator/mocks" "github.com/e2b-dev/infra/packages/shared/pkg/storage" - storagemocks "github.com/e2b-dev/infra/packages/shared/pkg/storage/mocks" - providermocks "github.com/e2b-dev/infra/packages/shared/pkg/storage/mocks/provider" ) func TestPeerSeekable_Size_PeerSucceeds(t *testing.T) { @@ -27,7 +25,7 @@ func TestPeerSeekable_Size_PeerSucceeds(t *testing.T) { return req.GetBuildId() == "build-1" && req.GetFileName() == storage.MemfileName })).Return(&orchestrator.GetBuildFileSizeResponse{TotalSize: 4096}, nil) - s := &peerSeekable{peerHandle: peerHandle[storage.Seekable]{client: client, buildID: "build-1", fileName: storage.MemfileName, uploaded: &atomic.Bool{}}} + s := &peerSeekable{peerHandle: peerHandle[storage.Seekable]{client: client, buildID: "build-1", fileName: storage.MemfileName, uploaded: &atomic.Pointer[UploadedHeaders]{}}} size, err := s.Size(t.Context()) require.NoError(t, err) assert.Equal(t, int64(4096), size) @@ -39,17 +37,17 @@ func TestPeerSeekable_Size_PeerNotAvailable_FallsBackToBase(t *testing.T) { client := orchestratormocks.NewMockChunkServiceClient(t) client.EXPECT().GetBuildFileSize(mock.Anything, mock.Anything).Return(&orchestrator.GetBuildFileSizeResponse{Availability: &orchestrator.PeerAvailability{NotAvailable: true}}, nil) - baseSeekable := storagemocks.NewMockSeekable(t) + baseSeekable := storage.NewMockSeekable(t) baseSeekable.EXPECT().Size(mock.Anything).Return(int64(8192), nil) - base := providermocks.NewMockStorageProvider(t) + base := storage.NewMockStorageProvider(t) base.EXPECT().OpenSeekable(mock.Anything, "build-1/memfile", storage.MemfileObjectType).Return(baseSeekable, nil) s := &peerSeekable{peerHandle: peerHandle[storage.Seekable]{ client: client, buildID: "build-1", fileName: storage.MemfileName, - uploaded: &atomic.Bool{}, + uploaded: &atomic.Pointer[UploadedHeaders]{}, openFn: func(ctx context.Context) (storage.Seekable, error) { return base.OpenSeekable(ctx, "build-1/memfile", storage.MemfileObjectType) }, @@ -72,7 +70,7 @@ func TestPeerSeekable_ReadAt_PeerSucceeds(t *testing.T) { return req.GetOffset() == 0 && req.GetLength() == int64(len(data)) })).Return(stream, nil) - s := &peerSeekable{peerHandle: peerHandle[storage.Seekable]{client: client, buildID: "build-1", fileName: storage.MemfileName, uploaded: &atomic.Bool{}}} + s := &peerSeekable{peerHandle: peerHandle[storage.Seekable]{client: client, buildID: "build-1", fileName: storage.MemfileName, uploaded: &atomic.Pointer[UploadedHeaders]{}}} buf := make([]byte, len(data)) n, err := s.ReadAt(t.Context(), buf, 0) require.NoError(t, err) @@ -90,21 +88,18 @@ func TestPeerSeekable_ReadAt_PeerNotAvailable_FallsBackToBase(t *testing.T) { client := orchestratormocks.NewMockChunkServiceClient(t) client.EXPECT().ReadAtBuildSeekable(mock.Anything, mock.Anything).Return(stream, nil) - baseSeekable := storagemocks.NewMockSeekable(t) - baseSeekable.EXPECT().ReadAt(mock.Anything, mock.Anything, int64(0)).RunAndReturn(func(_ context.Context, buf []byte, _ int64) (int, error) { - n := copy(buf, baseData) + baseSeekable := storage.NewMockSeekable(t) + baseSeekable.EXPECT().OpenRangeReader(mock.Anything, int64(0), int64(len(baseData)), (*storage.FrameTable)(nil)). + Return(io.NopCloser(bytes.NewReader(baseData)), nil) - return n, nil - }) - - base := providermocks.NewMockStorageProvider(t) + base := storage.NewMockStorageProvider(t) base.EXPECT().OpenSeekable(mock.Anything, "build-1/memfile", storage.MemfileObjectType).Return(baseSeekable, nil) s := &peerSeekable{peerHandle: peerHandle[storage.Seekable]{ client: client, buildID: "build-1", fileName: storage.MemfileName, - uploaded: &atomic.Bool{}, + uploaded: &atomic.Pointer[UploadedHeaders]{}, openFn: func(ctx context.Context) (storage.Seekable, error) { return base.OpenSeekable(ctx, "build-1/memfile", storage.MemfileObjectType) }, @@ -130,8 +125,8 @@ func TestPeerSeekable_OpenRangeReader_PeerSucceeds(t *testing.T) { return req.GetOffset() == 10 && req.GetLength() == int64(len(data)) })).Return(stream, nil) - s := &peerSeekable{peerHandle: peerHandle[storage.Seekable]{client: client, buildID: "build-1", fileName: storage.MemfileName, uploaded: &atomic.Bool{}}} - rc, err := s.OpenRangeReader(t.Context(), 10, int64(len(data))) + s := &peerSeekable{peerHandle: peerHandle[storage.Seekable]{client: client, buildID: "build-1", fileName: storage.MemfileName, uploaded: &atomic.Pointer[UploadedHeaders]{}}} + rc, err := s.OpenRangeReader(t.Context(), 10, int64(len(data)), nil) require.NoError(t, err) defer rc.Close() @@ -147,22 +142,94 @@ func TestPeerSeekable_OpenRangeReader_PeerError_FallsBackToBase(t *testing.T) { client := orchestratormocks.NewMockChunkServiceClient(t) client.EXPECT().ReadAtBuildSeekable(mock.Anything, mock.Anything).Return(nil, errors.New("peer unavailable")) - baseSeekable := storagemocks.NewMockSeekable(t) - baseSeekable.EXPECT().OpenRangeReader(mock.Anything, int64(0), int64(len(baseData))).Return(io.NopCloser(bytes.NewReader(baseData)), nil) + baseSeekable := storage.NewMockSeekable(t) + baseSeekable.EXPECT().OpenRangeReader(mock.Anything, int64(0), int64(len(baseData)), (*storage.FrameTable)(nil)).Return(io.NopCloser(bytes.NewReader(baseData)), nil) + + base := storage.NewMockStorageProvider(t) + base.EXPECT().OpenSeekable(mock.Anything, "build-1/memfile", storage.MemfileObjectType).Return(baseSeekable, nil) + + s := &peerSeekable{peerHandle: peerHandle[storage.Seekable]{ + client: client, + buildID: "build-1", + fileName: storage.MemfileName, + uploaded: &atomic.Pointer[UploadedHeaders]{}, + openFn: func(ctx context.Context) (storage.Seekable, error) { + return base.OpenSeekable(ctx, "build-1/memfile", storage.MemfileObjectType) + }, + }} + rc, err := s.OpenRangeReader(t.Context(), 0, int64(len(baseData)), nil) + require.NoError(t, err) + defer rc.Close() + + got, err := io.ReadAll(rc) + require.NoError(t, err) + assert.Equal(t, baseData, got) +} + +func TestPeerSeekable_OpenRangeReader_UploadedHeaders_ReturnsPeerTransitionedError(t *testing.T) { + t.Parallel() + + memHeader := []byte("mem-header-v4") + rootHeader := []byte("root-header-v4") + + client := orchestratormocks.NewMockChunkServiceClient(t) + + uploaded := &atomic.Pointer[UploadedHeaders]{} + uploaded.Store(&UploadedHeaders{ + MemfileHeader: memHeader, + RootfsHeader: rootHeader, + }) + + baseSeekable := storage.NewMockSeekable(t) + base := storage.NewMockStorageProvider(t) + base.EXPECT().OpenSeekable(mock.Anything, "build-1/memfile", storage.MemfileObjectType).Return(baseSeekable, nil) + + s := &peerSeekable{peerHandle: peerHandle[storage.Seekable]{ + client: client, + buildID: "build-1", + fileName: storage.MemfileName, + uploaded: uploaded, + openFn: func(ctx context.Context) (storage.Seekable, error) { + return base.OpenSeekable(ctx, "build-1/memfile", storage.MemfileObjectType) + }, + }} + + // frameTable=nil triggers the transition header check in the fallback path + _, err := s.OpenRangeReader(t.Context(), 0, 100, nil) + require.Error(t, err) + + var transErr *storage.PeerTransitionedError + require.ErrorAs(t, err, &transErr) + assert.Equal(t, memHeader, transErr.MemfileHeader) + assert.Equal(t, rootHeader, transErr.RootfsHeader) +} + +func TestPeerSeekable_OpenRangeReader_UploadedSkipsPeer(t *testing.T) { + t.Parallel() + + client := orchestratormocks.NewMockChunkServiceClient(t) - base := providermocks.NewMockStorageProvider(t) + uploaded := &atomic.Pointer[UploadedHeaders]{} + uploaded.Store(&UploadedHeaders{}) + + baseData := []byte("from gcs") + baseSeekable := storage.NewMockSeekable(t) + baseSeekable.EXPECT().OpenRangeReader(mock.Anything, int64(0), int64(len(baseData)), (*storage.FrameTable)(nil)).Return(io.NopCloser(bytes.NewReader(baseData)), nil) + + base := storage.NewMockStorageProvider(t) base.EXPECT().OpenSeekable(mock.Anything, "build-1/memfile", storage.MemfileObjectType).Return(baseSeekable, nil) s := &peerSeekable{peerHandle: peerHandle[storage.Seekable]{ client: client, buildID: "build-1", fileName: storage.MemfileName, - uploaded: &atomic.Bool{}, + uploaded: uploaded, openFn: func(ctx context.Context) (storage.Seekable, error) { return base.OpenSeekable(ctx, "build-1/memfile", storage.MemfileObjectType) }, }} - rc, err := s.OpenRangeReader(t.Context(), 0, int64(len(baseData))) + + rc, err := s.OpenRangeReader(t.Context(), 0, int64(len(baseData)), nil) require.NoError(t, err) defer rc.Close() diff --git a/packages/orchestrator/pkg/sandbox/template/peerclient/storage.go b/packages/orchestrator/pkg/sandbox/template/peerclient/storage.go index f683b4f2f6..f408f98a5a 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerclient/storage.go +++ b/packages/orchestrator/pkg/sandbox/template/peerclient/storage.go @@ -111,15 +111,15 @@ var _ storage.StorageProvider = (*peerStorageProvider)(nil) type peerStorageProvider struct { base storage.StorageProvider peerClient orchestrator.ChunkServiceClient - // uploaded is set to true when the peer signals that GCS upload is complete - // (use_storage=true). Once set, all subsequent reads skip the peer and go to base. - uploaded *atomic.Bool + // uploaded is set when the peer signals GCS upload is complete (use_storage=true). + // Once non-nil, all subsequent reads skip the peer and go to base. + uploaded *atomic.Pointer[UploadedHeaders] } func newPeerStorageProvider( base storage.StorageProvider, peerClient orchestrator.ChunkServiceClient, - uploaded *atomic.Bool, + uploaded *atomic.Pointer[UploadedHeaders], ) storage.StorageProvider { return &peerStorageProvider{ base: base, @@ -168,14 +168,18 @@ func (p *peerStorageProvider) GetDetails() string { return p.base.GetDetails() } -// checkPeerAvailability also marks the uploaded flag when UseStorage is set. -func checkPeerAvailability(avail *orchestrator.PeerAvailability, uploaded *atomic.Bool) bool { +// checkPeerAvailability marks the build as uploaded when UseStorage is set. +func checkPeerAvailability(avail *orchestrator.PeerAvailability, uploaded *atomic.Pointer[UploadedHeaders]) bool { if avail.GetNotAvailable() { return false } if avail.GetUseStorage() { - uploaded.Store(true) + hdrs := &UploadedHeaders{ + MemfileHeader: avail.GetMemfileHeader(), + RootfsHeader: avail.GetRootfsHeader(), + } + uploaded.Store(hdrs) return false } @@ -187,7 +191,7 @@ type peerHandle[Base any] struct { client orchestrator.ChunkServiceClient buildID string fileName string - uploaded *atomic.Bool + uploaded *atomic.Pointer[UploadedHeaders] mu sync.Mutex base Base @@ -238,7 +242,7 @@ func withPeerFallback[Base, T any]( )) defer span.End() - if !h.uploaded.Load() { + if h.uploaded.Load() == nil { timer := peerReadTimerFactory.Begin(opAttr) res, err := peerFn(ctx) diff --git a/packages/orchestrator/pkg/sandbox/template/peerclient/storage_test.go b/packages/orchestrator/pkg/sandbox/template/peerclient/storage_test.go index ca8ea8106d..8ec3f79c70 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerclient/storage_test.go +++ b/packages/orchestrator/pkg/sandbox/template/peerclient/storage_test.go @@ -13,7 +13,6 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/grpc/orchestrator" orchestratormocks "github.com/e2b-dev/infra/packages/shared/pkg/grpc/orchestrator/mocks" "github.com/e2b-dev/infra/packages/shared/pkg/storage" - providermocks "github.com/e2b-dev/infra/packages/shared/pkg/storage/mocks/provider" ) func TestPeerStorageProvider_OpenBlob_ExtractsFileName(t *testing.T) { @@ -28,9 +27,9 @@ func TestPeerStorageProvider_OpenBlob_ExtractsFileName(t *testing.T) { return req.GetBuildId() == "build-1" && req.GetFileName() == "snapfile" })).Return(stream, nil) - base := providermocks.NewMockStorageProvider(t) + base := storage.NewMockStorageProvider(t) - p := newPeerStorageProvider(base, client, &atomic.Bool{}) + p := newPeerStorageProvider(base, client, &atomic.Pointer[UploadedHeaders]{}) blob, err := p.OpenBlob(t.Context(), "build-1/snapfile", storage.SnapfileObjectType) require.NoError(t, err) @@ -48,13 +47,13 @@ func TestPeerStorageProvider_OpenSeekable_ExtractsFileName(t *testing.T) { return req.GetBuildId() == "build-1" && req.GetFileName() == "memfile" })).Return(&orchestrator.GetBuildFileSizeResponse{TotalSize: 512}, nil) - base := providermocks.NewMockStorageProvider(t) + base := storage.NewMockStorageProvider(t) - p := newPeerStorageProvider(base, client, &atomic.Bool{}) - seekable, err := p.OpenSeekable(t.Context(), "build-1/memfile", storage.MemfileObjectType) + p := newPeerStorageProvider(base, client, &atomic.Pointer[UploadedHeaders]{}) + ff, err := p.OpenSeekable(t.Context(), "build-1/memfile", storage.MemfileObjectType) require.NoError(t, err) - size, err := seekable.Size(t.Context()) + size, err := ff.Size(t.Context()) require.NoError(t, err) assert.Equal(t, int64(512), size) } diff --git a/packages/orchestrator/pkg/sandbox/template/peerserver/header.go b/packages/orchestrator/pkg/sandbox/template/peerserver/header.go index 9691e217c0..44de5c56bd 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerserver/header.go +++ b/packages/orchestrator/pkg/sandbox/template/peerserver/header.go @@ -35,7 +35,7 @@ func (f *headerSource) Stream(ctx context.Context, sender Sender) error { return ErrNotAvailable } - data, err := header.Serialize(h.Metadata, h.Mapping) + data, err := header.SerializeHeader(h) if err != nil { span.RecordError(err) diff --git a/packages/orchestrator/pkg/sandbox/template/peerserver/resolve.go b/packages/orchestrator/pkg/sandbox/template/peerserver/resolve.go index aa3604225b..39c1040fe9 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerserver/resolve.go +++ b/packages/orchestrator/pkg/sandbox/template/peerserver/resolve.go @@ -17,7 +17,7 @@ var ErrUnknownFile = fmt.Errorf("unknown file") // Returns ErrNotAvailable when the build is not in the local cache. // Returns ErrUnknownFile for unrecognised file names. func ResolveSeekable(cache Cache, buildID, fileName string) (SeekableSource, error) { - switch fileName { + switch storage.StripCompression(fileName) { case storage.MemfileName, storage.RootfsName: diff, ok := cache.LookupDiff(buildID, build.DiffType(fileName)) if !ok { diff --git a/packages/orchestrator/pkg/sandbox/template/peerserver/seekable.go b/packages/orchestrator/pkg/sandbox/template/peerserver/seekable.go index 319b9d3c99..a69d563912 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerserver/seekable.go +++ b/packages/orchestrator/pkg/sandbox/template/peerserver/seekable.go @@ -18,8 +18,8 @@ type seekableSource struct { diff build.Diff } -func (f *seekableSource) Size(ctx context.Context) (int64, error) { - return f.diff.Size(ctx) +func (f *seekableSource) Size(_ context.Context) (int64, error) { + return f.diff.FileSize() } func (f *seekableSource) Exists(_ context.Context) (bool, error) { @@ -33,7 +33,8 @@ func (f *seekableSource) Stream(ctx context.Context, offset, length int64, sende )) defer span.End() - data, err := f.diff.Slice(ctx, offset, length) + // P2P always serves uncompressed bytes — pass nil FrameTable. + data, err := f.diff.Slice(ctx, offset, length, nil) if err != nil { span.RecordError(err) diff --git a/packages/orchestrator/pkg/sandbox/template/peerserver/seekable_test.go b/packages/orchestrator/pkg/sandbox/template/peerserver/seekable_test.go index 66724591bd..9ddf805ac3 100644 --- a/packages/orchestrator/pkg/sandbox/template/peerserver/seekable_test.go +++ b/packages/orchestrator/pkg/sandbox/template/peerserver/seekable_test.go @@ -17,7 +17,7 @@ func TestSeekableSource_Size(t *testing.T) { t.Parallel() diff := buildmocks.NewMockDiff(t) - diff.EXPECT().Size(mock.Anything).Return(int64(1234), nil) + diff.EXPECT().FileSize().Return(int64(1234), nil) cache := peerservermocks.NewMockCache(t) cache.EXPECT().LookupDiff("build-1", build.DiffType(storage.MemfileName)).Return(diff, true) @@ -36,7 +36,7 @@ func TestSeekableSource_Stream(t *testing.T) { data := []byte("diff bytes") diff := buildmocks.NewMockDiff(t) - diff.EXPECT().Slice(mock.Anything, int64(0), int64(len(data))).Return(data, nil) + diff.EXPECT().Slice(mock.Anything, int64(0), int64(len(data)), (*storage.FrameTable)(nil)).Return(data, nil) diff.EXPECT().BlockSize().Return(int64(len(data))) cache := peerservermocks.NewMockCache(t) diff --git a/packages/orchestrator/pkg/sandbox/template/storage.go b/packages/orchestrator/pkg/sandbox/template/storage.go index 9a3c0852e0..22190ce4e7 100644 --- a/packages/orchestrator/pkg/sandbox/template/storage.go +++ b/packages/orchestrator/pkg/sandbox/template/storage.go @@ -19,7 +19,6 @@ const ( ) type Storage struct { - header *header.Header source *build.File } @@ -58,7 +57,7 @@ func NewStorage( if h == nil { var hdrPath string - headerObjectType, ok := storageHeaderObjectType(fileType) + _, ok := storageHeaderObjectType(fileType) if !ok { return nil, build.UnknownDiffTypeError{DiffType: fileType} } @@ -70,20 +69,10 @@ func NewStorage( hdrPath = paths.RootfsHeader() } - headerObject, err := persistence.OpenBlob(ctx, hdrPath, headerObjectType) - if err != nil { - return nil, err - } - - diffHeader, err := header.Deserialize(ctx, headerObject) - - // If we can't find the diff header in storage, we switch to templates without a headers + var err error + h, err = header.LoadHeader(ctx, persistence, hdrPath) if err != nil && !errors.Is(err, storage.ErrObjectNotExist) { - return nil, fmt.Errorf("failed to deserialize header: %w", err) - } - - if err == nil { - h = diffHeader + return nil, err } } @@ -147,7 +136,6 @@ func NewStorage( return &Storage{ source: b, - header: h, }, nil } @@ -156,11 +144,11 @@ func (d *Storage) ReadAt(ctx context.Context, p []byte, off int64) (int, error) } func (d *Storage) Size(_ context.Context) (int64, error) { - return int64(d.header.Metadata.Size), nil + return int64(d.source.Header().Metadata.Size), nil } func (d *Storage) BlockSize() int64 { - return int64(d.header.Metadata.BlockSize) + return int64(d.source.Header().Metadata.BlockSize) } func (d *Storage) Slice(ctx context.Context, off, length int64) ([]byte, error) { @@ -168,7 +156,7 @@ func (d *Storage) Slice(ctx context.Context, off, length int64) ([]byte, error) } func (d *Storage) Header() *header.Header { - return d.header + return d.source.Header() } func (d *Storage) Close() error { diff --git a/packages/orchestrator/pkg/sandbox/template_build.go b/packages/orchestrator/pkg/sandbox/template_build.go index 374d39fba4..db17c1cee6 100644 --- a/packages/orchestrator/pkg/sandbox/template_build.go +++ b/packages/orchestrator/pkg/sandbox/template_build.go @@ -45,7 +45,7 @@ func (t *TemplateBuild) uploadMemfileHeader(ctx context.Context, h *headers.Head return err } - serialized, err := headers.Serialize(h.Metadata, h.Mapping) + serialized, err := headers.SerializeHeader(h) if err != nil { return fmt.Errorf("error when serializing memfile header: %w", err) } @@ -64,8 +64,7 @@ func (t *TemplateBuild) uploadMemfile(ctx context.Context, memfilePath string) e return err } - err = object.StoreFile(ctx, memfilePath) - if err != nil { + if _, _, err = object.StoreFile(ctx, memfilePath, nil); err != nil { return fmt.Errorf("error when uploading memfile: %w", err) } @@ -78,14 +77,14 @@ func (t *TemplateBuild) uploadRootfsHeader(ctx context.Context, h *headers.Heade return err } - serialized, err := headers.Serialize(h.Metadata, h.Mapping) + serialized, err := headers.SerializeHeader(h) if err != nil { - return fmt.Errorf("error when serializing memfile header: %w", err) + return fmt.Errorf("error when serializing rootfs header: %w", err) } err = object.Put(ctx, serialized) if err != nil { - return fmt.Errorf("error when uploading memfile header: %w", err) + return fmt.Errorf("error when uploading rootfs header: %w", err) } return nil @@ -97,8 +96,7 @@ func (t *TemplateBuild) uploadRootfs(ctx context.Context, rootfsPath string) err return err } - err = object.StoreFile(ctx, rootfsPath) - if err != nil { + if _, _, err = object.StoreFile(ctx, rootfsPath, nil); err != nil { return fmt.Errorf("error when uploading rootfs: %w", err) } diff --git a/packages/orchestrator/pkg/server/chunks.go b/packages/orchestrator/pkg/server/chunks.go index 91488dbebc..387532e590 100644 --- a/packages/orchestrator/pkg/server/chunks.go +++ b/packages/orchestrator/pkg/server/chunks.go @@ -13,10 +13,7 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" ) -var ( - peerNotAvailable = &orchestrator.PeerAvailability{NotAvailable: true} - peerUseStorage = &orchestrator.PeerAvailability{UseStorage: true} -) +var peerNotAvailable = &orchestrator.PeerAvailability{NotAvailable: true} // seekableStreamSender implements peerserver.Sender over a gRPC server stream (for seekable files). type seekableStreamSender struct { @@ -47,13 +44,28 @@ func toGRPCError(err error) error { } } +func (s *Server) buildUploadedResponse(buildID string) *orchestrator.PeerAvailability { + cacheItem := s.uploadedBuilds.Get(buildID) + if cacheItem == nil { + return nil + } + + hdrs := cacheItem.Value() + + return &orchestrator.PeerAvailability{ + UseStorage: true, + MemfileHeader: hdrs.memfileHeader, + RootfsHeader: hdrs.rootfsHeader, + } +} + func (s *Server) GetBuildFileSize(ctx context.Context, req *orchestrator.GetBuildFileSizeRequest) (*orchestrator.GetBuildFileSizeResponse, error) { telemetry.SetAttributes(ctx, telemetry.WithBuildID(req.GetBuildId()), attribute.String("file_name", req.GetFileName())) - if s.uploadedBuilds.Get(req.GetBuildId()) != nil { + if avail := s.buildUploadedResponse(req.GetBuildId()); avail != nil { telemetry.SetAttributes(ctx, attribute.Bool("uploaded", true)) - return &orchestrator.GetBuildFileSizeResponse{Availability: peerUseStorage}, nil + return &orchestrator.GetBuildFileSizeResponse{Availability: avail}, nil } src, err := peerserver.ResolveSeekable(s.templateCache, req.GetBuildId(), req.GetFileName()) @@ -78,10 +90,10 @@ func (s *Server) GetBuildFileSize(ctx context.Context, req *orchestrator.GetBuil func (s *Server) GetBuildFileExists(ctx context.Context, req *orchestrator.GetBuildFileExistsRequest) (*orchestrator.GetBuildFileExistsResponse, error) { telemetry.SetAttributes(ctx, telemetry.WithBuildID(req.GetBuildId()), attribute.String("file_name", req.GetFileName())) - if s.uploadedBuilds.Get(req.GetBuildId()) != nil { + if avail := s.buildUploadedResponse(req.GetBuildId()); avail != nil { telemetry.SetAttributes(ctx, attribute.Bool("uploaded", true)) - return &orchestrator.GetBuildFileExistsResponse{Availability: peerUseStorage}, nil + return &orchestrator.GetBuildFileExistsResponse{Availability: avail}, nil } src, err := peerserver.ResolveBlob(s.templateCache, req.GetBuildId(), req.GetFileName()) @@ -122,10 +134,10 @@ func (s *Server) ReadAtBuildSeekable(req *orchestrator.ReadAtBuildSeekableReques attribute.Int64("length", length), ) - if s.uploadedBuilds.Get(req.GetBuildId()) != nil { + if avail := s.buildUploadedResponse(req.GetBuildId()); avail != nil { telemetry.SetAttributes(ctx, attribute.Bool("uploaded", true)) - return stream.Send(&orchestrator.ReadAtBuildSeekableResponse{Availability: peerUseStorage}) + return stream.Send(&orchestrator.ReadAtBuildSeekableResponse{Availability: avail}) } src, err := peerserver.ResolveSeekable(s.templateCache, req.GetBuildId(), req.GetFileName()) @@ -157,10 +169,10 @@ func (s *Server) GetBuildBlob(req *orchestrator.GetBuildBlobRequest, stream orch attribute.String("file_name", req.GetFileName()), ) - if s.uploadedBuilds.Get(req.GetBuildId()) != nil { + if avail := s.buildUploadedResponse(req.GetBuildId()); avail != nil { telemetry.SetAttributes(ctx, attribute.Bool("uploaded", true)) - return stream.Send(&orchestrator.GetBuildBlobResponse{Availability: peerUseStorage}) + return stream.Send(&orchestrator.GetBuildBlobResponse{Availability: avail}) } src, err := peerserver.ResolveBlob(s.templateCache, req.GetBuildId(), req.GetFileName()) diff --git a/packages/orchestrator/pkg/server/main.go b/packages/orchestrator/pkg/server/main.go index b6e893ef82..08385b8d7a 100644 --- a/packages/orchestrator/pkg/server/main.go +++ b/packages/orchestrator/pkg/server/main.go @@ -29,6 +29,11 @@ import ( // templates they refer to and are cleaned up automatically. const uploadedBuildsTTL = 1 * time.Hour +type uploadedBuildHeaders struct { + memfileHeader []byte + rootfsHeader []byte +} + type Server struct { orchestrator.UnimplementedSandboxServiceServer orchestrator.UnimplementedChunkServiceServer @@ -46,7 +51,7 @@ type Server struct { sbxEventsService *events.EventsService startingSandboxes *semaphore.Weighted peerRegistry peerclient.Registry - uploadedBuilds *ttlcache.Cache[string, struct{}] + uploadedBuilds *ttlcache.Cache[string, *uploadedBuildHeaders] sandboxCreateDuration metric.Int64Histogram } @@ -66,8 +71,8 @@ type ServiceConfig struct { } func New(cfg ServiceConfig) (*Server, error) { - uploadedBuilds := ttlcache.New[string, struct{}]( - ttlcache.WithTTL[string, struct{}](uploadedBuildsTTL), + uploadedBuilds := ttlcache.New( + ttlcache.WithTTL[string, *uploadedBuildHeaders](uploadedBuildsTTL), ) go uploadedBuilds.Start() diff --git a/packages/orchestrator/pkg/server/sandboxes.go b/packages/orchestrator/pkg/server/sandboxes.go index 19d32c64c4..05db2a662e 100644 --- a/packages/orchestrator/pkg/server/sandboxes.go +++ b/packages/orchestrator/pkg/server/sandboxes.go @@ -639,9 +639,13 @@ func (s *Server) Checkpoint(ctx context.Context, in *orchestrator.SandboxCheckpo // be paused or resumed later. uploadCtx, cancel := context.WithTimeout(context.WithoutCancel(ctx), uploadTimeout) defer cancel() - defer res.completeUpload(uploadCtx) - if err := res.snapshot.Upload(uploadCtx, s.persistence, res.paths); err != nil { + memHdr, rootHdr, err := res.uploadSnapshot(uploadCtx, s.persistence, s.config.CompressConfig, s.featureFlags) + if completeErr := res.completeUpload(uploadCtx, memHdr, rootHdr); completeErr != nil { + telemetry.ReportCriticalError(uploadCtx, "error completing upload", completeErr, telemetry.WithSandboxID(in.GetSandboxId())) + } + + if err != nil { telemetry.ReportCriticalError(ctx, "error uploading snapshot for checkpoint", err, telemetry.WithSandboxID(in.GetSandboxId())) s.sandboxFactory.Sandboxes.MarkStopping(ctx, resumedSbx.Runtime.SandboxID, resumedSbx.LifecycleID) @@ -695,12 +699,25 @@ type snapshotResult struct { meta metadata.Template snapshot *sandbox.Snapshot paths storage.Paths - completeUpload func(ctx context.Context) + completeUpload func(ctx context.Context, memfileHdr, rootfsHdr []byte) error +} + +// uploadSnapshot uploads snapshot files to GCS and returns serialized V4 +// header bytes for peer transition (nil for uncompressed builds). +func (r *snapshotResult) uploadSnapshot(ctx context.Context, persistence storage.StorageProvider, baseCompressCfg storage.CompressConfig, flags *featureflags.Client) (memfileHdr, rootfsHdr []byte, err error) { + cfg := storage.ResolveCompressConfig(ctx, baseCompressCfg, flags, storage.FileTypeMemfile, storage.UseCasePause) + uploader := sandbox.NewBuildUploader(r.snapshot, persistence, r.paths, cfg, nil) + + if err := uploader.UploadData(ctx); err != nil { + return nil, nil, err + } + + return uploader.FinalizeHeaders(ctx) } // snapshotAndCacheSandbox creates a snapshot of a sandbox and adds it to the local // template cache. The caller is responsible for starting the GCS upload via -// startSnapshotUploadAsync or uploadSnapshotWithPrefetchAsync. +// uploadSnapshotAsync. func (s *Server) snapshotAndCacheSandbox( ctx context.Context, sbx *sandbox.Sandbox, @@ -746,14 +763,19 @@ func (s *Server) snapshotAndCacheSandbox( logger.L().Warn(ctx, "failed to register peer address for routing", zap.String("build_id", meta.Template.BuildID), zap.Error(err)) } - completeUpload := func(ctx context.Context) { + completeUpload := func(ctx context.Context, memfileHdr, rootfsHdr []byte) error { // Signal in-flight peer streams to switch to GCS. - s.uploadedBuilds.Set(meta.Template.BuildID, struct{}{}, ttlcache.DefaultTTL) + s.uploadedBuilds.Set(meta.Template.BuildID, &uploadedBuildHeaders{ + memfileHeader: memfileHdr, + rootfsHeader: rootfsHdr, + }, ttlcache.DefaultTTL) // Remove from Redis so new nodes go directly to GCS. if err := s.peerRegistry.Unregister(ctx, meta.Template.BuildID); err != nil { logger.L().Warn(ctx, "failed to unregister peer address from routing", zap.String("build_id", meta.Template.BuildID), zap.Error(err)) } + + return nil } return &snapshotResult{ @@ -768,7 +790,7 @@ func (s *Server) snapshotAndCacheSandbox( meta: meta, snapshot: snapshot, paths: paths, - completeUpload: func(context.Context) {}, + completeUpload: func(context.Context, []byte, []byte) error { return nil }, }, nil } @@ -780,16 +802,17 @@ func (s *Server) uploadSnapshotAsync(ctx context.Context, sbx *sandbox.Sandbox, go func() { defer cancel() - defer res.completeUpload(ctx) - err := res.snapshot.Upload(ctx, s.persistence, res.paths) + memHdr, rootHdr, err := res.uploadSnapshot(ctx, s.persistence, s.config.CompressConfig, s.featureFlags) if err != nil { sbxlogger.I(sbx).Error(ctx, "error uploading snapshot files", zap.Error(err)) - - return + } else { + sbxlogger.E(sbx).Info(ctx, "Snapshot files uploaded to GCS") } - sbxlogger.E(sbx).Info(ctx, "Snapshot files uploaded to GCS") + if completeErr := res.completeUpload(ctx, memHdr, rootHdr); completeErr != nil { + sbxlogger.I(sbx).Error(ctx, "error completing upload", zap.Error(completeErr)) + } }() } diff --git a/packages/orchestrator/pkg/template/build/builder.go b/packages/orchestrator/pkg/template/build/builder.go index 31c6b39ed9..b523050965 100644 --- a/packages/orchestrator/pkg/template/build/builder.go +++ b/packages/orchestrator/pkg/template/build/builder.go @@ -259,6 +259,8 @@ func runBuild( uploadTracker := layer.NewUploadTracker() + compressCfg := storage.ResolveCompressConfig(ctx, builder.config.CompressConfig, builder.featureFlags, storage.FileTypeMemfile, storage.UseCaseBuild) + layerExecutor := layer.NewLayerExecutor( bc, builder.logger, @@ -269,6 +271,7 @@ func runBuild( builder.buildStorage, index, uploadTracker, + compressCfg, ) baseBuilder := base.New( diff --git a/packages/orchestrator/pkg/template/build/layer/layer_executor.go b/packages/orchestrator/pkg/template/build/layer/layer_executor.go index 216f60a854..4bf98cdca4 100644 --- a/packages/orchestrator/pkg/template/build/layer/layer_executor.go +++ b/packages/orchestrator/pkg/template/build/layer/layer_executor.go @@ -34,6 +34,7 @@ type LayerExecutor struct { buildStorage storage.StorageProvider index cache.Index uploadTracker *UploadTracker + compressConfig *storage.CompressConfig // nil = no compression } func NewLayerExecutor( @@ -46,6 +47,7 @@ func NewLayerExecutor( buildStorage storage.StorageProvider, index cache.Index, uploadTracker *UploadTracker, + compressConfig *storage.CompressConfig, ) *LayerExecutor { return &LayerExecutor{ BuildContext: buildContext, @@ -59,6 +61,7 @@ func NewLayerExecutor( buildStorage: buildStorage, index: index, uploadTracker: uploadTracker, + compressConfig: compressConfig, } } @@ -274,10 +277,15 @@ func (lb *LayerExecutor) PauseAndUpload( } // Upload snapshot async, it's added to the template cache immediately - userLogger.Debug(ctx, fmt.Sprintf("Saving: %s", meta.Template.BuildID)) + if c := lb.compressConfig; c != nil { + userLogger.Debug(ctx, fmt.Sprintf("Saving: %s (compress=%s level=%d)", meta.Template.BuildID, c.Type, c.Level)) + } else { + userLogger.Debug(ctx, fmt.Sprintf("Saving: %s", meta.Template.BuildID)) + } // Register this upload and get functions to signal completion and wait for previous uploads completeUpload, waitForPreviousUploads := lb.uploadTracker.StartUpload() + uploader := sandbox.NewBuildUploader(snapshot, lb.templateStorage, storage.Paths{BuildID: meta.Template.BuildID}, lb.compressConfig, lb.uploadTracker.Pending()) lb.UploadErrGroup.Go(func() error { ctx := context.WithoutCancel(ctx) @@ -289,29 +297,28 @@ func (lb *LayerExecutor) PauseAndUpload( // still unblock and the errgroup can properly collect all errors. defer completeUpload() - err := snapshot.Upload( - ctx, - lb.templateStorage, - storage.Paths{BuildID: meta.Template.BuildID}, - ) - if err != nil { - return fmt.Errorf("error uploading snapshot: %w", err) + if err := uploader.UploadData(ctx); err != nil { + return fmt.Errorf("error uploading data files: %w", err) } // Wait for all previous layer uploads to complete before saving the cache entry. // This prevents race conditions where another build hits this cache entry // before its dependencies (previous layers) are available in storage. - err = waitForPreviousUploads(ctx) - if err != nil { + // For compressed builds, this also ensures all ancestor frame tables are + // available so headers can reference mappings from earlier layers. + if err := waitForPreviousUploads(ctx); err != nil { return fmt.Errorf("error waiting for previous uploads: %w", err) } - err = lb.index.SaveLayerMeta(ctx, hash, cache.LayerMetadata{ + if _, _, err := uploader.FinalizeHeaders(ctx); err != nil { + return fmt.Errorf("error finalizing headers: %w", err) + } + + if err := lb.index.SaveLayerMeta(ctx, hash, cache.LayerMetadata{ Template: cache.Template{ BuildID: meta.Template.BuildID, }, - }) - if err != nil { + }); err != nil { // Since the data should be basically identical, this is safe to skip. if !errors.Is(err, storage.ErrObjectRateLimited) { return fmt.Errorf("error saving UUID to hash mapping: %w", err) diff --git a/packages/orchestrator/pkg/template/build/layer/upload_tracker.go b/packages/orchestrator/pkg/template/build/layer/upload_tracker.go index 213938f147..72db831eea 100644 --- a/packages/orchestrator/pkg/template/build/layer/upload_tracker.go +++ b/packages/orchestrator/pkg/template/build/layer/upload_tracker.go @@ -3,22 +3,39 @@ package layer import ( "context" "sync" + + "github.com/e2b-dev/infra/packages/orchestrator/pkg/sandbox" ) // UploadTracker tracks in-flight uploads and allows waiting for all previous uploads to complete. // This prevents race conditions where a layer's cache entry is saved before its // dependencies (previous layers) are fully uploaded. +// +// It also owns a shared PendingBuildInfo that collects frame tables from compressed +// uploads across all layers. waitForPreviousUploads guarantees that by the time +// layer N finalizes its compressed headers, all upstream layers (0..N-1) have +// completed both their data and header uploads, so all upstream frame tables +// are available for cross-pollination. type UploadTracker struct { mu sync.Mutex waitChs []chan struct{} + + // pending collects frame tables from compressed uploads across all layers. + pending *sandbox.PendingBuildInfo } func NewUploadTracker() *UploadTracker { return &UploadTracker{ waitChs: make([]chan struct{}, 0), + pending: &sandbox.PendingBuildInfo{}, } } +// Pending returns the shared PendingBuildInfo for collecting frame tables. +func (t *UploadTracker) Pending() *sandbox.PendingBuildInfo { + return t.pending +} + // StartUpload registers that a new upload has started. // Returns a function that should be called when the upload completes. func (t *UploadTracker) StartUpload() (complete func(), waitForPrevious func(context.Context) error) { diff --git a/packages/shared/go.mod b/packages/shared/go.mod index 093aa0ae83..fe78879680 100644 --- a/packages/shared/go.mod +++ b/packages/shared/go.mod @@ -30,11 +30,13 @@ require ( github.com/hashicorp/go-retryablehttp v0.7.7 github.com/hashicorp/nomad/api v0.0.0-20251216171439-1dee0671280e github.com/jellydator/ttlcache/v3 v3.4.0 + github.com/klauspost/compress v1.18.2 github.com/launchdarkly/go-sdk-common/v3 v3.3.0 github.com/launchdarkly/go-server-sdk/v7 v7.13.0 github.com/ngrok/firewall_toolkit v0.0.18 github.com/oapi-codegen/runtime v1.1.1 github.com/orcaman/concurrent-map/v2 v2.0.1 + github.com/pierrec/lz4/v4 v4.1.22 github.com/redis/go-redis/extra/redisotel/v9 v9.17.3 github.com/redis/go-redis/v9 v9.17.3 github.com/stretchr/testify v1.11.1 @@ -229,7 +231,6 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/julienschmidt/httprouter v1.3.0 // indirect github.com/kamstrup/intmap v0.5.1 // indirect - github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/cpuid/v2 v2.2.11 // indirect github.com/knadh/koanf/maps v0.1.2 // indirect github.com/knadh/koanf/providers/confmap v1.0.0 // indirect @@ -281,7 +282,6 @@ require ( github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect - github.com/pierrec/lz4/v4 v4.1.22 // indirect github.com/pires/go-proxyproto v0.7.0 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pkg/errors v0.9.1 // indirect diff --git a/packages/shared/pkg/featureflags/context.go b/packages/shared/pkg/featureflags/context.go index 4f0e957ff0..79e52b1557 100644 --- a/packages/shared/pkg/featureflags/context.go +++ b/packages/shared/pkg/featureflags/context.go @@ -164,6 +164,14 @@ func VolumeContext(volumeName string) ldcontext.Context { return ldcontext.NewWithKind(VolumeKind, volumeName) } +func CompressFileTypeContext(fileType string) ldcontext.Context { + return ldcontext.NewWithKind(CompressFileTypeKind, fileType) +} + +func CompressUseCaseContext(useCase string) ldcontext.Context { + return ldcontext.NewWithKind(CompressUseCaseKind, useCase) +} + func VersionContext(orchestratorID, commit string) ldcontext.Context { return ldcontext.NewBuilder(orchestratorID). Kind(OrchestratorKind). diff --git a/packages/shared/pkg/featureflags/flags.go b/packages/shared/pkg/featureflags/flags.go index 4f9c31eec8..6e985d4304 100644 --- a/packages/shared/pkg/featureflags/flags.go +++ b/packages/shared/pkg/featureflags/flags.go @@ -18,14 +18,16 @@ const ( SandboxKernelVersionAttribute string = "kernel-version" SandboxFirecrackerVersionAttribute string = "firecracker-version" - TeamKind ldcontext.Kind = "team" - UserKind ldcontext.Kind = "user" - ClusterKind ldcontext.Kind = "cluster" - deploymentKind ldcontext.Kind = "deployment" - TierKind ldcontext.Kind = "tier" - ServiceKind ldcontext.Kind = "service" - TemplateKind ldcontext.Kind = "template" - VolumeKind ldcontext.Kind = "volume" + TeamKind ldcontext.Kind = "team" + UserKind ldcontext.Kind = "user" + ClusterKind ldcontext.Kind = "cluster" + deploymentKind ldcontext.Kind = "deployment" + TierKind ldcontext.Kind = "tier" + ServiceKind ldcontext.Kind = "service" + TemplateKind ldcontext.Kind = "template" + VolumeKind ldcontext.Kind = "volume" + CompressFileTypeKind ldcontext.Kind = "compress-file-type" + CompressUseCaseKind ldcontext.Kind = "compress-use-case" OrchestratorKind ldcontext.Kind = "orchestrator" OrchestratorCommitAttribute string = "commit" @@ -203,6 +205,8 @@ var ( // MaxConcurrentSnapshotBuildQueries limits concurrent GetSnapshotBuilds calls (e.g. sandbox delete). // 0 or negative disables throttling (unlimited concurrency). MaxConcurrentSnapshotBuildQueries = newIntFlag("max-concurrent-snapshot-build-queries", 0) + + MinChunkerReadSizeKB = newIntFlag("min-chunker-read-size-kb", 0) // 0 = default (16 KB) ) type StringFlag struct { @@ -312,17 +316,17 @@ func GetTrackedTemplatesSet(ctx context.Context, ff *Client) map[string]struct{} return result } -// ChunkerConfigFlag is a JSON flag controlling the chunker implementation and tuning. -// -// NOTE: Changing useStreaming has no effect on chunkers already created for -// cached templates. A service restart (redeploy) is required for that change -// to take effect. minReadBatchSizeKB is checked just-in-time on each fetch, -// so it takes effect immediately. -// -// JSON format: {"useStreaming": false, "minReadBatchSizeKB": 16} -var ChunkerConfigFlag = newJSONFlag("chunker-config", ldvalue.FromJSONMarshal(map[string]any{ - "useStreaming": false, - "minReadBatchSizeKB": 16, +// CompressConfigFlag controls compression during template builds. +// When compressBuilds is true, builds upload exclusively compressed data +// (no uncompressed fallback). When false, exclusively uncompressed with V3 headers. +var CompressConfigFlag = newJSONFlag("compress-config", ldvalue.FromJSONMarshal(map[string]any{ + "compressBuilds": false, + "compressionType": "zstd", + "compressionLevel": 2, + "frameSizeKB": 2048, + "targetPartSizeMB": 50, + "frameEncodeWorkers": 4, + "encoderConcurrency": 1, })) // TCPFirewallEgressThrottleConfig controls per-sandbox egress throttling via Firecracker's diff --git a/packages/shared/pkg/grpc/orchestrator/chunks.pb.go b/packages/shared/pkg/grpc/orchestrator/chunks.pb.go index 388c9bd808..e02396c301 100644 --- a/packages/shared/pkg/grpc/orchestrator/chunks.pb.go +++ b/packages/shared/pkg/grpc/orchestrator/chunks.pb.go @@ -33,6 +33,12 @@ type PeerAvailability struct { // use_storage is true when the GCS upload has completed and the caller // should switch to reading from GCS/NFS directly instead of this peer. UseStorage bool `protobuf:"varint,2,opt,name=use_storage,json=useStorage,proto3" json:"use_storage,omitempty"` + // memfile_header contains the serialized V4 header (with FrameTables) + // for the memfile, included when use_storage is true and the upload was compressed. + MemfileHeader []byte `protobuf:"bytes,3,opt,name=memfile_header,json=memfileHeader,proto3" json:"memfile_header,omitempty"` + // rootfs_header contains the serialized V4 header (with FrameTables) + // for the rootfs, included when use_storage is true and the upload was compressed. + RootfsHeader []byte `protobuf:"bytes,4,opt,name=rootfs_header,json=rootfsHeader,proto3" json:"rootfs_header,omitempty"` } func (x *PeerAvailability) Reset() { @@ -81,6 +87,20 @@ func (x *PeerAvailability) GetUseStorage() bool { return false } +func (x *PeerAvailability) GetMemfileHeader() []byte { + if x != nil { + return x.MemfileHeader + } + return nil +} + +func (x *PeerAvailability) GetRootfsHeader() []byte { + if x != nil { + return x.RootfsHeader + } + return nil +} + type GetBuildFileSizeRequest struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache @@ -537,85 +557,90 @@ func (x *GetBuildBlobResponse) GetAvailability() *PeerAvailability { var File_chunks_proto protoreflect.FileDescriptor var file_chunks_proto_rawDesc = []byte{ - 0x0a, 0x0c, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0x58, - 0x0a, 0x10, 0x50, 0x65, 0x65, 0x72, 0x41, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x69, 0x6c, 0x69, - 0x74, 0x79, 0x12, 0x23, 0x0a, 0x0d, 0x6e, 0x6f, 0x74, 0x5f, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, - 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c, 0x6e, 0x6f, 0x74, 0x41, 0x76, - 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x75, 0x73, 0x65, 0x5f, 0x73, - 0x74, 0x6f, 0x72, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x75, 0x73, - 0x65, 0x53, 0x74, 0x6f, 0x72, 0x61, 0x67, 0x65, 0x22, 0x51, 0x0a, 0x17, 0x47, 0x65, 0x74, 0x42, - 0x75, 0x69, 0x6c, 0x64, 0x46, 0x69, 0x6c, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x52, 0x65, 0x71, 0x75, + 0x0a, 0x0c, 0x63, 0x68, 0x75, 0x6e, 0x6b, 0x73, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x22, 0xa4, + 0x01, 0x0a, 0x10, 0x50, 0x65, 0x65, 0x72, 0x41, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x69, 0x6c, + 0x69, 0x74, 0x79, 0x12, 0x23, 0x0a, 0x0d, 0x6e, 0x6f, 0x74, 0x5f, 0x61, 0x76, 0x61, 0x69, 0x6c, + 0x61, 0x62, 0x6c, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0c, 0x6e, 0x6f, 0x74, 0x41, + 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x1f, 0x0a, 0x0b, 0x75, 0x73, 0x65, 0x5f, + 0x73, 0x74, 0x6f, 0x72, 0x61, 0x67, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x08, 0x52, 0x0a, 0x75, + 0x73, 0x65, 0x53, 0x74, 0x6f, 0x72, 0x61, 0x67, 0x65, 0x12, 0x25, 0x0a, 0x0e, 0x6d, 0x65, 0x6d, + 0x66, 0x69, 0x6c, 0x65, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, 0x72, 0x18, 0x03, 0x20, 0x01, 0x28, + 0x0c, 0x52, 0x0d, 0x6d, 0x65, 0x6d, 0x66, 0x69, 0x6c, 0x65, 0x48, 0x65, 0x61, 0x64, 0x65, 0x72, + 0x12, 0x23, 0x0a, 0x0d, 0x72, 0x6f, 0x6f, 0x74, 0x66, 0x73, 0x5f, 0x68, 0x65, 0x61, 0x64, 0x65, + 0x72, 0x18, 0x04, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x0c, 0x72, 0x6f, 0x6f, 0x74, 0x66, 0x73, 0x48, + 0x65, 0x61, 0x64, 0x65, 0x72, 0x22, 0x51, 0x0a, 0x17, 0x47, 0x65, 0x74, 0x42, 0x75, 0x69, 0x6c, + 0x64, 0x46, 0x69, 0x6c, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, + 0x12, 0x19, 0x0a, 0x08, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, + 0x28, 0x09, 0x52, 0x07, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x49, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x66, + 0x69, 0x6c, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, + 0x66, 0x69, 0x6c, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x22, 0x70, 0x0a, 0x18, 0x47, 0x65, 0x74, 0x42, + 0x75, 0x69, 0x6c, 0x64, 0x46, 0x69, 0x6c, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x52, 0x65, 0x73, 0x70, + 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x5f, 0x73, 0x69, + 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x74, 0x6f, 0x74, 0x61, 0x6c, 0x53, + 0x69, 0x7a, 0x65, 0x12, 0x35, 0x0a, 0x0c, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x69, 0x6c, + 0x69, 0x74, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x50, 0x65, 0x65, 0x72, + 0x41, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x79, 0x52, 0x0c, 0x61, 0x76, + 0x61, 0x69, 0x6c, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x79, 0x22, 0x53, 0x0a, 0x19, 0x47, 0x65, + 0x74, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x46, 0x69, 0x6c, 0x65, 0x45, 0x78, 0x69, 0x73, 0x74, 0x73, + 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x62, 0x75, 0x69, 0x6c, 0x64, + 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x62, 0x75, 0x69, 0x6c, 0x64, + 0x49, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x66, 0x69, 0x6c, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, + 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x66, 0x69, 0x6c, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x22, + 0x53, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x46, 0x69, 0x6c, 0x65, 0x45, + 0x78, 0x69, 0x73, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x35, 0x0a, + 0x0c, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x79, 0x18, 0x01, 0x20, + 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x41, 0x76, 0x61, 0x69, 0x6c, 0x61, + 0x62, 0x69, 0x6c, 0x69, 0x74, 0x79, 0x52, 0x0c, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x69, + 0x6c, 0x69, 0x74, 0x79, 0x22, 0x84, 0x01, 0x0a, 0x1a, 0x52, 0x65, 0x61, 0x64, 0x41, 0x74, 0x42, + 0x75, 0x69, 0x6c, 0x64, 0x53, 0x65, 0x65, 0x6b, 0x61, 0x62, 0x6c, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x49, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x66, 0x69, 0x6c, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x08, 0x66, 0x69, 0x6c, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x22, 0x70, 0x0a, 0x18, 0x47, - 0x65, 0x74, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x46, 0x69, 0x6c, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x52, - 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x1d, 0x0a, 0x0a, 0x74, 0x6f, 0x74, 0x61, 0x6c, - 0x5f, 0x73, 0x69, 0x7a, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x03, 0x52, 0x09, 0x74, 0x6f, 0x74, - 0x61, 0x6c, 0x53, 0x69, 0x7a, 0x65, 0x12, 0x35, 0x0a, 0x0c, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, - 0x62, 0x69, 0x6c, 0x69, 0x74, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x50, - 0x65, 0x65, 0x72, 0x41, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x79, 0x52, - 0x0c, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x79, 0x22, 0x53, 0x0a, - 0x19, 0x47, 0x65, 0x74, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x46, 0x69, 0x6c, 0x65, 0x45, 0x78, 0x69, - 0x73, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x62, 0x75, - 0x69, 0x6c, 0x64, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x62, 0x75, - 0x69, 0x6c, 0x64, 0x49, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x66, 0x69, 0x6c, 0x65, 0x5f, 0x6e, 0x61, - 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x66, 0x69, 0x6c, 0x65, 0x4e, 0x61, - 0x6d, 0x65, 0x22, 0x53, 0x0a, 0x1a, 0x47, 0x65, 0x74, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x46, 0x69, - 0x6c, 0x65, 0x45, 0x78, 0x69, 0x73, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x09, 0x52, 0x08, 0x66, 0x69, 0x6c, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x16, 0x0a, 0x06, 0x6f, + 0x66, 0x66, 0x73, 0x65, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x6f, 0x66, 0x66, + 0x73, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x18, 0x04, 0x20, + 0x01, 0x28, 0x03, 0x52, 0x06, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x22, 0x68, 0x0a, 0x1b, 0x52, + 0x65, 0x61, 0x64, 0x41, 0x74, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x53, 0x65, 0x65, 0x6b, 0x61, 0x62, + 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, 0x64, 0x61, + 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x12, 0x35, + 0x0a, 0x0c, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x79, 0x18, 0x02, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x41, 0x76, 0x61, 0x69, 0x6c, + 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x79, 0x52, 0x0c, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, + 0x69, 0x6c, 0x69, 0x74, 0x79, 0x22, 0x4d, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x42, 0x75, 0x69, 0x6c, + 0x64, 0x42, 0x6c, 0x6f, 0x62, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, + 0x62, 0x75, 0x69, 0x6c, 0x64, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, + 0x62, 0x75, 0x69, 0x6c, 0x64, 0x49, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x66, 0x69, 0x6c, 0x65, 0x5f, + 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x66, 0x69, 0x6c, 0x65, + 0x4e, 0x61, 0x6d, 0x65, 0x22, 0x61, 0x0a, 0x14, 0x47, 0x65, 0x74, 0x42, 0x75, 0x69, 0x6c, 0x64, + 0x42, 0x6c, 0x6f, 0x62, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, 0x04, + 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, 0x61, 0x12, 0x35, 0x0a, 0x0c, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x79, - 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x41, 0x76, 0x61, + 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x41, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x79, 0x52, 0x0c, 0x61, 0x76, 0x61, 0x69, 0x6c, - 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x79, 0x22, 0x84, 0x01, 0x0a, 0x1a, 0x52, 0x65, 0x61, 0x64, - 0x41, 0x74, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x53, 0x65, 0x65, 0x6b, 0x61, 0x62, 0x6c, 0x65, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x19, 0x0a, 0x08, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x5f, - 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x49, - 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x66, 0x69, 0x6c, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, - 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x66, 0x69, 0x6c, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x12, 0x16, - 0x0a, 0x06, 0x6f, 0x66, 0x66, 0x73, 0x65, 0x74, 0x18, 0x03, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, - 0x6f, 0x66, 0x66, 0x73, 0x65, 0x74, 0x12, 0x16, 0x0a, 0x06, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, - 0x18, 0x04, 0x20, 0x01, 0x28, 0x03, 0x52, 0x06, 0x6c, 0x65, 0x6e, 0x67, 0x74, 0x68, 0x22, 0x68, - 0x0a, 0x1b, 0x52, 0x65, 0x61, 0x64, 0x41, 0x74, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x53, 0x65, 0x65, - 0x6b, 0x61, 0x62, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x12, 0x0a, - 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, 0x61, 0x74, - 0x61, 0x12, 0x35, 0x0a, 0x0c, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, - 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x50, 0x65, 0x65, 0x72, 0x41, 0x76, - 0x61, 0x69, 0x6c, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x79, 0x52, 0x0c, 0x61, 0x76, 0x61, 0x69, - 0x6c, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x79, 0x22, 0x4d, 0x0a, 0x13, 0x47, 0x65, 0x74, 0x42, - 0x75, 0x69, 0x6c, 0x64, 0x42, 0x6c, 0x6f, 0x62, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, - 0x19, 0x0a, 0x08, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x5f, 0x69, 0x64, 0x18, 0x01, 0x20, 0x01, 0x28, - 0x09, 0x52, 0x07, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x49, 0x64, 0x12, 0x1b, 0x0a, 0x09, 0x66, 0x69, - 0x6c, 0x65, 0x5f, 0x6e, 0x61, 0x6d, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x08, 0x66, - 0x69, 0x6c, 0x65, 0x4e, 0x61, 0x6d, 0x65, 0x22, 0x61, 0x0a, 0x14, 0x47, 0x65, 0x74, 0x42, 0x75, - 0x69, 0x6c, 0x64, 0x42, 0x6c, 0x6f, 0x62, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12, - 0x12, 0x0a, 0x04, 0x64, 0x61, 0x74, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0c, 0x52, 0x04, 0x64, - 0x61, 0x74, 0x61, 0x12, 0x35, 0x0a, 0x0c, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x69, 0x6c, - 0x69, 0x74, 0x79, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x11, 0x2e, 0x50, 0x65, 0x65, 0x72, - 0x41, 0x76, 0x61, 0x69, 0x6c, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x79, 0x52, 0x0c, 0x61, 0x76, - 0x61, 0x69, 0x6c, 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x79, 0x32, 0xb9, 0x02, 0x0a, 0x0c, 0x43, - 0x68, 0x75, 0x6e, 0x6b, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x47, - 0x65, 0x74, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x46, 0x69, 0x6c, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x12, - 0x18, 0x2e, 0x47, 0x65, 0x74, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x46, 0x69, 0x6c, 0x65, 0x53, 0x69, - 0x7a, 0x65, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x47, 0x65, 0x74, 0x42, - 0x75, 0x69, 0x6c, 0x64, 0x46, 0x69, 0x6c, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x12, 0x4d, 0x0a, 0x12, 0x47, 0x65, 0x74, 0x42, 0x75, 0x69, 0x6c, 0x64, - 0x46, 0x69, 0x6c, 0x65, 0x45, 0x78, 0x69, 0x73, 0x74, 0x73, 0x12, 0x1a, 0x2e, 0x47, 0x65, 0x74, - 0x42, 0x75, 0x69, 0x6c, 0x64, 0x46, 0x69, 0x6c, 0x65, 0x45, 0x78, 0x69, 0x73, 0x74, 0x73, 0x52, - 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x47, 0x65, 0x74, 0x42, 0x75, 0x69, 0x6c, - 0x64, 0x46, 0x69, 0x6c, 0x65, 0x45, 0x78, 0x69, 0x73, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, - 0x6e, 0x73, 0x65, 0x12, 0x52, 0x0a, 0x13, 0x52, 0x65, 0x61, 0x64, 0x41, 0x74, 0x42, 0x75, 0x69, - 0x6c, 0x64, 0x53, 0x65, 0x65, 0x6b, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x1b, 0x2e, 0x52, 0x65, 0x61, - 0x64, 0x41, 0x74, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x53, 0x65, 0x65, 0x6b, 0x61, 0x62, 0x6c, 0x65, - 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x52, 0x65, 0x61, 0x64, 0x41, 0x74, - 0x42, 0x75, 0x69, 0x6c, 0x64, 0x53, 0x65, 0x65, 0x6b, 0x61, 0x62, 0x6c, 0x65, 0x52, 0x65, 0x73, - 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x30, 0x01, 0x12, 0x3d, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x42, 0x75, - 0x69, 0x6c, 0x64, 0x42, 0x6c, 0x6f, 0x62, 0x12, 0x14, 0x2e, 0x47, 0x65, 0x74, 0x42, 0x75, 0x69, - 0x6c, 0x64, 0x42, 0x6c, 0x6f, 0x62, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, - 0x47, 0x65, 0x74, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x42, 0x6c, 0x6f, 0x62, 0x52, 0x65, 0x73, 0x70, - 0x6f, 0x6e, 0x73, 0x65, 0x30, 0x01, 0x42, 0x2f, 0x5a, 0x2d, 0x68, 0x74, 0x74, 0x70, 0x73, 0x3a, - 0x2f, 0x2f, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x32, 0x62, - 0x2d, 0x64, 0x65, 0x76, 0x2f, 0x69, 0x6e, 0x66, 0x72, 0x61, 0x2f, 0x6f, 0x72, 0x63, 0x68, 0x65, - 0x73, 0x74, 0x72, 0x61, 0x74, 0x6f, 0x72, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x61, 0x62, 0x69, 0x6c, 0x69, 0x74, 0x79, 0x32, 0xb9, 0x02, 0x0a, 0x0c, 0x43, 0x68, 0x75, 0x6e, + 0x6b, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x12, 0x47, 0x0a, 0x10, 0x47, 0x65, 0x74, 0x42, + 0x75, 0x69, 0x6c, 0x64, 0x46, 0x69, 0x6c, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x12, 0x18, 0x2e, 0x47, + 0x65, 0x74, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x46, 0x69, 0x6c, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x52, + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x19, 0x2e, 0x47, 0x65, 0x74, 0x42, 0x75, 0x69, 0x6c, + 0x64, 0x46, 0x69, 0x6c, 0x65, 0x53, 0x69, 0x7a, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x12, 0x4d, 0x0a, 0x12, 0x47, 0x65, 0x74, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x46, 0x69, 0x6c, + 0x65, 0x45, 0x78, 0x69, 0x73, 0x74, 0x73, 0x12, 0x1a, 0x2e, 0x47, 0x65, 0x74, 0x42, 0x75, 0x69, + 0x6c, 0x64, 0x46, 0x69, 0x6c, 0x65, 0x45, 0x78, 0x69, 0x73, 0x74, 0x73, 0x52, 0x65, 0x71, 0x75, + 0x65, 0x73, 0x74, 0x1a, 0x1b, 0x2e, 0x47, 0x65, 0x74, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x46, 0x69, + 0x6c, 0x65, 0x45, 0x78, 0x69, 0x73, 0x74, 0x73, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, + 0x12, 0x52, 0x0a, 0x13, 0x52, 0x65, 0x61, 0x64, 0x41, 0x74, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x53, + 0x65, 0x65, 0x6b, 0x61, 0x62, 0x6c, 0x65, 0x12, 0x1b, 0x2e, 0x52, 0x65, 0x61, 0x64, 0x41, 0x74, + 0x42, 0x75, 0x69, 0x6c, 0x64, 0x53, 0x65, 0x65, 0x6b, 0x61, 0x62, 0x6c, 0x65, 0x52, 0x65, 0x71, + 0x75, 0x65, 0x73, 0x74, 0x1a, 0x1c, 0x2e, 0x52, 0x65, 0x61, 0x64, 0x41, 0x74, 0x42, 0x75, 0x69, + 0x6c, 0x64, 0x53, 0x65, 0x65, 0x6b, 0x61, 0x62, 0x6c, 0x65, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, + 0x73, 0x65, 0x30, 0x01, 0x12, 0x3d, 0x0a, 0x0c, 0x47, 0x65, 0x74, 0x42, 0x75, 0x69, 0x6c, 0x64, + 0x42, 0x6c, 0x6f, 0x62, 0x12, 0x14, 0x2e, 0x47, 0x65, 0x74, 0x42, 0x75, 0x69, 0x6c, 0x64, 0x42, + 0x6c, 0x6f, 0x62, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x15, 0x2e, 0x47, 0x65, 0x74, + 0x42, 0x75, 0x69, 0x6c, 0x64, 0x42, 0x6c, 0x6f, 0x62, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, + 0x65, 0x30, 0x01, 0x42, 0x2f, 0x5a, 0x2d, 0x68, 0x74, 0x74, 0x70, 0x73, 0x3a, 0x2f, 0x2f, 0x67, + 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x65, 0x32, 0x62, 0x2d, 0x64, 0x65, + 0x76, 0x2f, 0x69, 0x6e, 0x66, 0x72, 0x61, 0x2f, 0x6f, 0x72, 0x63, 0x68, 0x65, 0x73, 0x74, 0x72, + 0x61, 0x74, 0x6f, 0x72, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/packages/shared/pkg/storage/compress_config.go b/packages/shared/pkg/storage/compress_config.go new file mode 100644 index 0000000000..dd67d3c575 --- /dev/null +++ b/packages/shared/pkg/storage/compress_config.go @@ -0,0 +1,115 @@ +package storage + +import ( + "context" + "fmt" + + "github.com/e2b-dev/infra/packages/shared/pkg/featureflags" +) + +// CompressConfig is the base compression configuration, loaded from environment +// variables at startup. Feature flags can override individual fields at runtime +// via ResolveCompressConfig. +type CompressConfig struct { + Enabled bool `env:"COMPRESS_ENABLED" envDefault:"false"` + Type string `env:"COMPRESS_TYPE" envDefault:"zstd"` + Level int `env:"COMPRESS_LEVEL" envDefault:"2"` + FrameSizeKB int `env:"COMPRESS_FRAME_SIZE_KB" envDefault:"2048"` + TargetPartSizeMB int `env:"COMPRESS_TARGET_PART_SIZE_MB" envDefault:"50"` + FrameEncodeWorkers int `env:"COMPRESS_FRAME_ENCODE_WORKERS" envDefault:"4"` + EncoderConcurrency int `env:"COMPRESS_ENCODER_CONCURRENCY" envDefault:"1"` +} + +// CompressionType returns the parsed CompressionType. +func (c *CompressConfig) CompressionType() CompressionType { + if c == nil { + return CompressionNone + } + + return parseCompressionType(c.Type) +} + +// FrameSize returns the frame size in bytes. +func (c *CompressConfig) FrameSize() int { + if c == nil || c.FrameSizeKB <= 0 { + return DefaultCompressFrameSize + } + + return c.FrameSizeKB * 1024 +} + +// TargetPartSize returns the target part size in bytes. +func (c *CompressConfig) TargetPartSize() int64 { + if c == nil || c.TargetPartSizeMB <= 0 { + return int64(gcpMultipartUploadChunkSize) + } + + return int64(c.TargetPartSizeMB) * (1 << 20) +} + +// IsEnabled reports whether compression is configured and active. +func (c *CompressConfig) IsEnabled() bool { + return c != nil && c.Enabled && c.CompressionType() != CompressionNone +} + +// Validate checks that the config is internally consistent. +func (c *CompressConfig) Validate() error { + if c == nil || !c.IsEnabled() { + return nil + } + + fs := c.FrameSize() + if fs <= 0 { + return fmt.Errorf("frame size must be positive, got %d KB", c.FrameSizeKB) + } + if MemoryChunkSize%fs != 0 && fs%MemoryChunkSize != 0 { + return fmt.Errorf("frame size (%d) must be a divisor or multiple of MemoryChunkSize (%d)", fs, MemoryChunkSize) + } + + return nil +} + +// Resolve returns a pointer to this config if compression is enabled, or nil. +// Callers use nil to mean "no compression". +func (c *CompressConfig) Resolve() *CompressConfig { + if c == nil || !c.IsEnabled() { + return nil + } + + return c +} + +// ResolveCompressConfig returns the effective compression config for a given +// file type and use case. Feature flags override the base config when active. +// Returns nil when compression is disabled. +// +// fileType and useCase are added to the LD evaluation context so that +// LaunchDarkly targeting rules can differentiate (e.g. compress memfile +// but not rootfs, or compress builds but not pauses). +func ResolveCompressConfig(ctx context.Context, base CompressConfig, ff *featureflags.Client, fileType, useCase string) *CompressConfig { + if ff != nil { + ctx = featureflags.AddToContext(ctx, + featureflags.CompressFileTypeContext(fileType), + featureflags.CompressUseCaseContext(useCase), + ) + + v := ff.JSONFlag(ctx, featureflags.CompressConfigFlag).AsValueMap() + + if v.Get("compressBuilds").BoolValue() { + ct := v.Get("compressionType").StringValue() + if parseCompressionType(ct) != CompressionNone { + return &CompressConfig{ + Enabled: true, + Type: ct, + Level: v.Get("compressionLevel").IntValue(), + FrameSizeKB: v.Get("frameSizeKB").IntValue(), + TargetPartSizeMB: v.Get("targetPartSizeMB").IntValue(), + FrameEncodeWorkers: v.Get("frameEncodeWorkers").IntValue(), + EncoderConcurrency: v.Get("encoderConcurrency").IntValue(), + } + } + } + } + + return base.Resolve() +} diff --git a/packages/shared/pkg/storage/compress_decode.go b/packages/shared/pkg/storage/compress_decode.go new file mode 100644 index 0000000000..01e40ed6a5 --- /dev/null +++ b/packages/shared/pkg/storage/compress_decode.go @@ -0,0 +1,128 @@ +package storage + +import ( + "fmt" + "io" + "sync" + + "github.com/klauspost/compress/zstd" + lz4 "github.com/pierrec/lz4/v4" +) + +var lz4DecoderPool sync.Pool + +func getLZ4Decoder(r io.Reader) *lz4.Reader { + if v := lz4DecoderPool.Get(); v != nil { + dec := v.(*lz4.Reader) + dec.Reset(r) + + return dec + } + + return lz4.NewReader(r) +} + +func putLZ4Decoder(dec *lz4.Reader) { + dec.Reset(nil) + lz4DecoderPool.Put(dec) +} + +// zstd concurrency is hardcoded to 1: benchmarks show higher values hurt +// throughput for single 2MiB frame decodes. +var zstdDecoderPool sync.Pool + +func getZstdDecoder(r io.Reader) (*zstd.Decoder, error) { + if v := zstdDecoderPool.Get(); v != nil { + dec := v.(*zstd.Decoder) + if err := dec.Reset(r); err != nil { + dec.Close() + + return nil, err + } + + return dec, nil + } + + return zstd.NewReader(r) +} + +func putZstdDecoder(dec *zstd.Decoder) { + dec.Reset(nil) + zstdDecoderPool.Put(dec) +} + +// NewDecompressingReader wraps a reader with the appropriate decompressor. +// Close releases the decompressor back to its pool but does NOT close the +// underlying reader — the caller is responsible for closing it. +func NewDecompressingReader(raw io.Reader, ct CompressionType) (io.ReadCloser, error) { + switch ct { + case CompressionLZ4: + dec := getLZ4Decoder(raw) + + return &pooledDecoder{ + Reader: dec, + close: func() { putLZ4Decoder(dec) }, + }, nil + + case CompressionZstd: + dec, err := getZstdDecoder(raw) + if err != nil { + return nil, fmt.Errorf("failed to create zstd decoder: %w", err) + } + + return &pooledDecoder{ + Reader: dec, + close: func() { putZstdDecoder(dec) }, + }, nil + + default: + return nil, fmt.Errorf("unsupported compression type: %s", ct) + } +} + +// pooledDecoder wraps a decompressor from a sync.Pool. +// Close returns the decompressor to the pool. +type pooledDecoder struct { + io.Reader + + close func() +} + +func (r *pooledDecoder) Close() error { + r.close() + + return nil +} + +// newDecompressingReadCloser wraps raw with the appropriate decompressor and +// takes ownership: Close releases the decompressor back to the pool AND closes raw. +func newDecompressingReadCloser(raw io.ReadCloser, ct CompressionType) (io.ReadCloser, error) { + dec, err := NewDecompressingReader(raw, ct) + if err != nil { + return nil, err + } + + return &decompressingReadCloser{dec: dec, raw: raw}, nil +} + +// decompressingReadCloser reads from the decompressor and closes both the +// decompressor (returning it to the pool) and the underlying raw stream. +type decompressingReadCloser struct { + dec io.ReadCloser // decompressor — reads from raw + raw io.Closer // underlying stream +} + +func (c *decompressingReadCloser) Read(p []byte) (int, error) { + return c.dec.Read(p) +} + +func (c *decompressingReadCloser) Close() error { + decErr := c.dec.Close() + rawErr := c.raw.Close() + + if decErr != nil { + return decErr + } + + return rawErr +} diff --git a/packages/shared/pkg/storage/compress_encode.go b/packages/shared/pkg/storage/compress_encode.go new file mode 100644 index 0000000000..cc2ef2e7d6 --- /dev/null +++ b/packages/shared/pkg/storage/compress_encode.go @@ -0,0 +1,121 @@ +package storage + +import ( + "bytes" + "context" + "fmt" + "sync" + + "github.com/klauspost/compress/zstd" + lz4 "github.com/pierrec/lz4/v4" +) + +// compressor compresses individual frames. Implementations are pooled and +// reused across frames within a single CompressStream call. +type compressor interface { + compress(src []byte) ([]byte, error) +} + +// lz4Compressor wraps a pooled lz4.Writer. The writer is reused via Reset +// between frames to avoid re-allocating internal hash tables (~64KB). +type lz4Compressor struct { + w *lz4.Writer +} + +func (c *lz4Compressor) compress(src []byte) ([]byte, error) { + var buf bytes.Buffer + buf.Grow(lz4.CompressBlockBound(len(src))) + c.w.Reset(&buf) + + if _, err := c.w.Write(src); err != nil { + return nil, fmt.Errorf("lz4 compress: %w", err) + } + + if err := c.w.Close(); err != nil { + return nil, fmt.Errorf("lz4 compress close: %w", err) + } + + return buf.Bytes(), nil +} + +// zstdCompressor wraps a pooled zstd.Encoder using EncodeAll. +type zstdCompressor struct { + enc *zstd.Encoder +} + +func (z *zstdCompressor) compress(src []byte) ([]byte, error) { //nolint:unparam // satisfies compressor interface + return z.enc.EncodeAll(src, make([]byte, 0, len(src))), nil +} + +// newCompressorPool returns a pool of compressors for the given config. +// Both LZ4 and zstd encoders are pooled and reused via Reset/EncodeAll. +// The config is validated eagerly — if zstd options are invalid, an error +// is returned immediately rather than deferred to pool.Get(). +func newCompressorPool(cfg *CompressConfig) (*sync.Pool, error) { + pool := &sync.Pool{} + + switch cfg.CompressionType() { + case CompressionZstd: + zstdOpts := []zstd.EOption{ + zstd.WithEncoderLevel(zstd.EncoderLevel(cfg.Level)), + zstd.WithEncoderCRC(true), + } + if cfg.FrameSize() > 0 { + zstdOpts = append(zstdOpts, zstd.WithWindowSize(cfg.FrameSize())) + } + if cfg.EncoderConcurrency > 0 { + zstdOpts = append(zstdOpts, zstd.WithEncoderConcurrency(cfg.EncoderConcurrency)) + } + + // Validate options by creating one encoder upfront. + first, err := zstd.NewWriter(nil, zstdOpts...) + if err != nil { + return nil, fmt.Errorf("zstd encoder: %w", err) + } + pool.Put(&zstdCompressor{enc: first}) + + pool.New = func() any { + // Options are already validated; NewWriter won't fail. + enc, _ := zstd.NewWriter(nil, zstdOpts...) + + return &zstdCompressor{enc: enc} + } + case CompressionLZ4: + lz4Opts := []lz4.Option{ + lz4.BlockSizeOption(lz4.Block4Mb), + lz4.BlockChecksumOption(true), + lz4.ChecksumOption(false), + lz4.ConcurrencyOption(1), + lz4.CompressionLevelOption(lz4.Fast), + } + + // Validate options by creating one encoder upfront. + first := lz4.NewWriter(nil) + if err := first.Apply(lz4Opts...); err != nil { + return nil, fmt.Errorf("lz4 encoder: %w", err) + } + pool.Put(&lz4Compressor{w: first}) + + pool.New = func() any { + w := lz4.NewWriter(nil) + _ = w.Apply(lz4Opts...) //nolint:errcheck // options validated above + + return &lz4Compressor{w: w} + } + default: + return nil, fmt.Errorf("unsupported compression type: %s", cfg.CompressionType()) + } + + return pool, nil +} + +func CompressBytes(ctx context.Context, data []byte, cfg *CompressConfig) (*FrameTable, []byte, [32]byte, error) { + up := &memPartUploader{} + + ft, checksum, err := compressStream(ctx, bytes.NewReader(data), cfg, up, 4) + if err != nil { + return nil, nil, [32]byte{}, err + } + + return ft, up.Assemble(), checksum, nil +} diff --git a/packages/shared/pkg/storage/compress_frame_table.go b/packages/shared/pkg/storage/compress_frame_table.go new file mode 100644 index 0000000000..6512485d11 --- /dev/null +++ b/packages/shared/pkg/storage/compress_frame_table.go @@ -0,0 +1,253 @@ +package storage + +import ( + "fmt" +) + +type CompressionType byte + +const ( + CompressionNone = CompressionType(iota) + CompressionZstd + CompressionLZ4 +) + +func (ct CompressionType) Suffix() string { + switch ct { + case CompressionZstd: + return ".zstd" + case CompressionLZ4: + return ".lz4" + default: + return "" + } +} + +func (ct CompressionType) String() string { + switch ct { + case CompressionZstd: + return "zstd" + case CompressionLZ4: + return "lz4" + default: + return "none" + } +} + +// parseCompressionType converts a string to CompressionType. +// Returns CompressionNone for unrecognised values. +func parseCompressionType(s string) CompressionType { + switch s { + case "lz4": + return CompressionLZ4 + case "zstd": + return CompressionZstd + default: + return CompressionNone + } +} + +type FrameOffset struct { + U int64 + C int64 +} + +func (o *FrameOffset) String() string { + return fmt.Sprintf("U:%d/C:%d", o.U, o.C) +} + +func (o *FrameOffset) Add(f FrameSize) { + o.U += int64(f.U) + o.C += int64(f.C) +} + +type FrameSize struct { + U int32 + C int32 +} + +func (s FrameSize) String() string { + return fmt.Sprintf("U:%d/C:%d", s.U, s.C) +} + +type Range struct { + Start int64 + Length int +} + +func (r Range) String() string { + return fmt.Sprintf("%d/%d", r.Start, r.Length) +} + +type FrameTable struct { + compressionType CompressionType + StartAt FrameOffset + Frames []FrameSize +} + +// NewFrameTable creates a FrameTable with the given compression type. +func NewFrameTable(ct CompressionType) *FrameTable { + return &FrameTable{compressionType: ct} +} + +// CompressionType returns the compression type. Nil-safe: returns CompressionNone for nil. +func (ft *FrameTable) CompressionType() CompressionType { + if ft == nil { + return CompressionNone + } + + return ft.compressionType +} + +// IsCompressed reports whether ft is non-nil and has a compression type set. +func (ft *FrameTable) IsCompressed() bool { + return ft != nil && ft.compressionType != CompressionNone +} + +// Range calls fn for each frame overlapping [start, start+length). +func (ft *FrameTable) Range(start, length int64, fn func(offset FrameOffset, frame FrameSize) error) error { + currentOffset := ft.StartAt + for _, frame := range ft.Frames { + frameEnd := currentOffset.U + int64(frame.U) + requestEnd := start + length + if frameEnd <= start { + currentOffset.U += int64(frame.U) + currentOffset.C += int64(frame.C) + + continue + } + if currentOffset.U >= requestEnd { + break + } + + if err := fn(currentOffset, frame); err != nil { + return err + } + currentOffset.U += int64(frame.U) + currentOffset.C += int64(frame.C) + } + + return nil +} + +func (ft *FrameTable) Size() (uncompressed, compressed int64) { + for _, frame := range ft.Frames { + uncompressed += int64(frame.U) + compressed += int64(frame.C) + } + + return uncompressed, compressed +} + +// Subset returns frames covering r. Whole frames only (can't split compressed). +func (ft *FrameTable) Subset(r Range) (*FrameTable, error) { + if ft == nil || r.Length == 0 { + return nil, nil + } + if r.Start < ft.StartAt.U { + return nil, fmt.Errorf("requested range starts before the beginning of the frame table") + } + + result, _ := ft.SubsetFrom(r, 0) + if result == nil { + return nil, fmt.Errorf("requested range is beyond the end of the frame table") + } + + return result, nil +} + +// SubsetFrom is like Subset but starts scanning from frame index `from`, +// returning the index of the first frame past the result. Use this to +// efficiently extract consecutive subsets from a sorted sequence of ranges +// without re-scanning from the beginning each time. +func (ft *FrameTable) SubsetFrom(r Range, from int) (*FrameTable, int) { + if ft == nil || r.Length == 0 { + return nil, from + } + + result := &FrameTable{ + compressionType: ft.compressionType, + } + + // Advance currentOffset to frame `from`. + currentOffset := ft.StartAt + for i := range from { + if i >= len(ft.Frames) { + break + } + currentOffset.Add(ft.Frames[i]) + } + + startSet := false + requestedEnd := r.Start + int64(r.Length) + nextFrom := from + + for i := from; i < len(ft.Frames); i++ { + frame := ft.Frames[i] + frameEnd := currentOffset.U + int64(frame.U) + + if frameEnd <= r.Start { + currentOffset.Add(frame) + nextFrom = i + 1 + + continue + } + if currentOffset.U >= requestedEnd { + break + } + + if !startSet { + result.StartAt = currentOffset + startSet = true + nextFrom = i + } + result.Frames = append(result.Frames, frame) + currentOffset.Add(frame) + } + + if !startSet { + return nil, nextFrom + } + + return result, nextFrom +} + +// FrameFor finds the frame containing the given offset and returns its start position and full size. +func (ft *FrameTable) FrameFor(offset int64) (starts FrameOffset, size FrameSize, err error) { + if ft == nil { + return FrameOffset{}, FrameSize{}, fmt.Errorf("FrameFor called with nil frame table - data is not compressed") + } + + currentOffset := ft.StartAt + for _, frame := range ft.Frames { + frameEnd := currentOffset.U + int64(frame.U) + if offset >= currentOffset.U && offset < frameEnd { + return currentOffset, frame, nil + } + currentOffset.Add(frame) + } + + return FrameOffset{}, FrameSize{}, fmt.Errorf("offset %d is beyond the end of the frame table", offset) +} + +// GetFetchRange translates a U-space range to C-space using the frame table. +func (ft *FrameTable) GetFetchRange(rangeU Range) (Range, error) { + fetchRange := rangeU + if ft.IsCompressed() { + start, size, err := ft.FrameFor(rangeU.Start) + if err != nil { + return Range{}, fmt.Errorf("getting frame for offset %d: %w", rangeU.Start, err) + } + endOffset := rangeU.Start + int64(rangeU.Length) + frameEnd := start.U + int64(size.U) + if endOffset > frameEnd { + return Range{}, fmt.Errorf("range %v spans beyond frame ending at %d", rangeU, frameEnd) + } + fetchRange = Range{ + Start: start.C, + Length: int(size.C), + } + } + + return fetchRange, nil +} diff --git a/packages/shared/pkg/storage/compress_frame_table_test.go b/packages/shared/pkg/storage/compress_frame_table_test.go new file mode 100644 index 0000000000..c06647eede --- /dev/null +++ b/packages/shared/pkg/storage/compress_frame_table_test.go @@ -0,0 +1,246 @@ +package storage + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +// threeFrameFT returns a FrameTable with three 1MB uncompressed frames +// and varying compressed sizes, starting at the given offset. +func threeFrameFT(startU, startC int64) *FrameTable { + ft := &FrameTable{ + compressionType: CompressionLZ4, + StartAt: FrameOffset{U: startU, C: startC}, + Frames: []FrameSize{ + {U: 1 << 20, C: 500_000}, // frame 0 + {U: 1 << 20, C: 600_000}, // frame 1 + {U: 1 << 20, C: 400_000}, // frame 2 + }, + } + + return ft +} + +// collectRange calls ft.Range and returns the offsets visited. +func collectRange(ft *FrameTable, start, length int64) ([]FrameOffset, error) { + var offsets []FrameOffset + err := ft.Range(start, length, func(offset FrameOffset, _ FrameSize) error { + offsets = append(offsets, offset) + + return nil + }) + + return offsets, err +} + +func TestRange(t *testing.T) { + t.Parallel() + ft := threeFrameFT(0, 0) + + t.Run("selects all frames", func(t *testing.T) { + t.Parallel() + offsets, err := collectRange(ft, 0, 3<<20) + require.NoError(t, err) + require.Len(t, offsets, 3) + }) + + t.Run("selects single middle frame", func(t *testing.T) { + t.Parallel() + offsets, err := collectRange(ft, 1<<20, 1<<20) + require.NoError(t, err) + require.Len(t, offsets, 1) + require.Equal(t, int64(1<<20), offsets[0].U) + require.Equal(t, int64(500_000), offsets[0].C) + }) + + t.Run("partial overlap selects touched frames", func(t *testing.T) { + t.Parallel() + // 1 byte spanning frames 0 and 1 boundary. + offsets, err := collectRange(ft, (1<<20)-1, 2) + require.NoError(t, err) + require.Len(t, offsets, 2) + }) + + t.Run("beyond end returns nothing", func(t *testing.T) { + t.Parallel() + offsets, err := collectRange(ft, 3<<20, 1) + require.NoError(t, err) + require.Empty(t, offsets) + }) + + t.Run("callback error propagates", func(t *testing.T) { + t.Parallel() + sentinel := fmt.Errorf("stop") + err := ft.Range(0, 3<<20, func(_ FrameOffset, _ FrameSize) error { + return sentinel + }) + require.ErrorIs(t, err, sentinel) + }) + + t.Run("respects StartAt on subset", func(t *testing.T) { + t.Parallel() + sub, err := ft.Subset(Range{Start: 1 << 20, Length: 2 << 20}) + require.NoError(t, err) + + // Query for offset 2MB — the second frame of the subset. + offsets, err := collectRange(sub, 2<<20, 1<<20) + require.NoError(t, err) + require.Len(t, offsets, 1) + require.Equal(t, int64(2<<20), offsets[0].U) + require.Equal(t, int64(1_100_000), offsets[0].C) // 500k + 600k + + // Query for offset 0 — before the subset, should find nothing. + offsets, err = collectRange(sub, 0, 1<<20) + require.NoError(t, err) + require.Empty(t, offsets, "Range should not find frames before StartAt") + }) +} + +func TestSubset(t *testing.T) { + t.Parallel() + ft := threeFrameFT(0, 0) + + t.Run("full range", func(t *testing.T) { + t.Parallel() + sub, err := ft.Subset(Range{Start: 0, Length: 3 << 20}) + require.NoError(t, err) + require.Len(t, sub.Frames, 3) + require.Equal(t, int64(0), sub.StartAt.U) + }) + + t.Run("last frame", func(t *testing.T) { + t.Parallel() + sub, err := ft.Subset(Range{Start: 2 << 20, Length: 1 << 20}) + require.NoError(t, err) + require.Len(t, sub.Frames, 1) + require.Equal(t, int64(2<<20), sub.StartAt.U) + require.Equal(t, int64(1_100_000), sub.StartAt.C) + require.Equal(t, int32(400_000), sub.Frames[0].C) + }) + + t.Run("preserves compression type", func(t *testing.T) { + t.Parallel() + sub, err := ft.Subset(Range{Start: 0, Length: 1 << 20}) + require.NoError(t, err) + require.Equal(t, CompressionLZ4, sub.CompressionType()) + }) + + t.Run("nil table returns nil", func(t *testing.T) { + t.Parallel() + sub, err := (*FrameTable)(nil).Subset(Range{Start: 0, Length: 100}) + require.NoError(t, err) + require.Nil(t, sub) + }) + + t.Run("zero length returns nil", func(t *testing.T) { + t.Parallel() + sub, err := ft.Subset(Range{Start: 0, Length: 0}) + require.NoError(t, err) + require.Nil(t, sub) + }) + + t.Run("before StartAt errors", func(t *testing.T) { + t.Parallel() + sub := threeFrameFT(1<<20, 500_000) + _, err := sub.Subset(Range{Start: 0, Length: 1 << 20}) + require.Error(t, err) + }) + + t.Run("beyond end errors", func(t *testing.T) { + t.Parallel() + _, err := ft.Subset(Range{Start: 4 << 20, Length: 1 << 20}) + require.Error(t, err) + }) +} + +func TestFrameFor(t *testing.T) { + t.Parallel() + ft := threeFrameFT(0, 0) + + t.Run("first byte of each frame", func(t *testing.T) { + t.Parallel() + for i, wantU := range []int64{0, 1 << 20, 2 << 20} { + start, size, err := ft.FrameFor(wantU) + require.NoError(t, err, "frame %d", i) + require.Equal(t, wantU, start.U) + require.Equal(t, int32(1<<20), size.U) + } + }) + + t.Run("last byte of frame", func(t *testing.T) { + t.Parallel() + start, _, err := ft.FrameFor((1 << 20) - 1) + require.NoError(t, err) + require.Equal(t, int64(0), start.U) + }) + + t.Run("returns correct C offset", func(t *testing.T) { + t.Parallel() + start, _, err := ft.FrameFor(2 << 20) + require.NoError(t, err) + require.Equal(t, int64(1_100_000), start.C) // 500k + 600k + }) + + t.Run("beyond end errors", func(t *testing.T) { + t.Parallel() + _, _, err := ft.FrameFor(3 << 20) + require.Error(t, err) + }) + + t.Run("nil table errors", func(t *testing.T) { + t.Parallel() + _, _, err := (*FrameTable)(nil).FrameFor(0) + require.Error(t, err) + }) + + t.Run("respects StartAt", func(t *testing.T) { + t.Parallel() + sub := threeFrameFT(1<<20, 500_000) + start, _, err := sub.FrameFor(1 << 20) + require.NoError(t, err) + require.Equal(t, int64(1<<20), start.U) + require.Equal(t, int64(500_000), start.C) + + // Before StartAt — no frame should contain offset 0. + _, _, err = sub.FrameFor(0) + require.Error(t, err) + }) +} + +func TestGetFetchRange(t *testing.T) { + t.Parallel() + ft := threeFrameFT(0, 0) + + t.Run("translates U-space to C-space", func(t *testing.T) { + t.Parallel() + r, err := ft.GetFetchRange(Range{Start: 1 << 20, Length: 1 << 20}) + require.NoError(t, err) + require.Equal(t, int64(500_000), r.Start) + require.Equal(t, 600_000, r.Length) + }) + + t.Run("range spanning multiple frames errors", func(t *testing.T) { + t.Parallel() + _, err := ft.GetFetchRange(Range{Start: 0, Length: 2 << 20}) + require.Error(t, err) + }) + + t.Run("nil table returns input unchanged", func(t *testing.T) { + t.Parallel() + input := Range{Start: 42, Length: 100} + r, err := (*FrameTable)(nil).GetFetchRange(input) + require.NoError(t, err) + require.Equal(t, input, r) + }) + + t.Run("uncompressed table returns input unchanged", func(t *testing.T) { + t.Parallel() + uncompressed := &FrameTable{compressionType: CompressionNone} + input := Range{Start: 42, Length: 100} + r, err := uncompressed.GetFetchRange(input) + require.NoError(t, err) + require.Equal(t, input, r) + }) +} diff --git a/packages/shared/pkg/storage/compress_upload.go b/packages/shared/pkg/storage/compress_upload.go new file mode 100644 index 0000000000..00b2a8f668 --- /dev/null +++ b/packages/shared/pkg/storage/compress_upload.go @@ -0,0 +1,264 @@ +package storage + +import ( + "bytes" + "context" + "crypto/sha256" + "errors" + "fmt" + "io" + "slices" + "sync" + "sync/atomic" + + "golang.org/x/sync/errgroup" +) + +const ( + // DefaultCompressFrameSize is the default uncompressed size of each compression + // frame (2 MiB). Overridable via CompressConfig.FrameSizeKB. + // The last frame in a file may be shorter. + // + // The chunker fetches one frame at a time from storage on a cache miss. + // Larger frame sizes mean more data cached per fetch (faster warm-up and + // fewer GCS round-trips), but higher memory and I/O cost per miss. + // + // This MUST be multiple of every block/page size: + // - header.HugepageSize (2 MiB) — UFFD huge-page size, also used by prefetch + // - header.RootfsBlockSize (4 KiB) — NBD / rootfs block size + DefaultCompressFrameSize = 2 * 1024 * 1024 + + // File type identifiers for per-file-type compression targeting. + FileTypeMemfile = "memfile" + FileTypeRootfs = "rootfs" + + // Use case identifiers for per-use-case compression targeting. + UseCaseBuild = "build" + UseCasePause = "pause" +) + +// partUploader is the interface for uploading data in parts. +// Implementations exist for GCS multipart uploads and local file writes. +type partUploader interface { + Start(ctx context.Context) error + UploadPart(ctx context.Context, partIndex int, data ...[]byte) error + Complete(ctx context.Context) error + Close() error +} + +// memPartUploader collects compressed parts in memory. Thread-safe. +// Useful for tests and benchmarks that need CompressStream output as bytes. +type memPartUploader struct { + mu sync.Mutex + parts map[int][]byte +} + +func (m *memPartUploader) Start(context.Context) error { + m.parts = make(map[int][]byte) + + return nil +} + +func (m *memPartUploader) UploadPart(_ context.Context, partIndex int, data ...[]byte) error { + var buf bytes.Buffer + for _, d := range data { + buf.Write(d) + } + m.mu.Lock() + m.parts[partIndex] = buf.Bytes() + m.mu.Unlock() + + return nil +} + +func (m *memPartUploader) Complete(context.Context) error { return nil } +func (m *memPartUploader) Close() error { return nil } + +// Assemble returns the concatenated parts in index order. +func (m *memPartUploader) Assemble() []byte { + keys := make([]int, 0, len(m.parts)) + for k := range m.parts { + keys = append(keys, k) + } + slices.Sort(keys) + + var buf bytes.Buffer + for _, k := range keys { + buf.Write(m.parts[k]) + } + + return buf.Bytes() +} + +type frame struct { + uncompressedSize int + compressed []byte +} + +type part struct { + index int + frames []*frame + compressedSize atomic.Int64 + eg *errgroup.Group + readyToUpload chan error +} + +func newPart(index int, parentCtx context.Context, workers int) (p *part, ctx context.Context) { + p = &part{index: index} + p.eg, ctx = errgroup.WithContext(parentCtx) + p.eg.SetLimit(workers) + + return p, ctx +} + +func (p *part) addFrame(ctx context.Context, uncompressedData []byte, pool *sync.Pool) { + if len(uncompressedData) == 0 { + return + } + + frameInPart := &frame{uncompressedSize: len(uncompressedData)} + p.frames = append(p.frames, frameInPart) + + p.eg.Go(func() error { + if err := ctx.Err(); err != nil { + return err + } + c := pool.Get().(compressor) + out, err := c.compress(uncompressedData) + pool.Put(c) + if err != nil { + return err + } + frameInPart.compressed = out + p.compressedSize.Add(int64(len(out))) + + return nil + }) +} + +func (p *part) submit(ctx context.Context, queue chan<- *part) { + p.readyToUpload = make(chan error, 1) + + go func() { + p.readyToUpload <- p.eg.Wait() + close(p.readyToUpload) + }() + + select { + case queue <- p: + case <-ctx.Done(): + } +} + +// compressStream: read → compress (parallel) → emit metadata (ordered) → upload (concurrent). +func compressStream(ctx context.Context, in io.Reader, cfg *CompressConfig, uploader partUploader, maxUploadConcurrency int) (ft *FrameTable, checksum [32]byte, err error) { //nolint:unparam // callers in later PRs pass different values + frameSize := cfg.FrameSize() + targetPartSize := cfg.TargetPartSize() + + if err := uploader.Start(ctx); err != nil { + return nil, [32]byte{}, fmt.Errorf("failed to start framed upload: %w", err) + } + defer uploader.Close() + + // for compression we create a pool per file since there are often enough + // frames to justify pooling. + compressors, err := newCompressorPool(cfg) + if err != nil { + return nil, [32]byte{}, err + } + hasher := sha256.New() + + ft = &FrameTable{compressionType: cfg.CompressionType()} + + ctx, cancel := context.WithCancel(ctx) // pipeline errors cancel the read loop + defer cancel() + + q := make(chan *part, maxUploadConcurrency) + var closeQ sync.Once + defer closeQ.Do(func() { close(q) }) + + uploadEG, uploadCtx := errgroup.WithContext(ctx) + uploadEG.SetLimit(maxUploadConcurrency) + + var emitEG errgroup.Group + emitEG.Go(func() error { + for p := range q { + select { + case compressErr := <-p.readyToUpload: + if compressErr != nil { + cancel() + + return compressErr + } + case <-ctx.Done(): + return ctx.Err() + } + + var compressed [][]byte + for _, f := range p.frames { + ft.Frames = append(ft.Frames, FrameSize{U: int32(f.uncompressedSize), C: int32(len(f.compressed))}) + compressed = append(compressed, f.compressed) + } + + pi := p.index + uploadEG.Go(func() error { + return uploader.UploadPart(uploadCtx, pi, compressed...) + }) + } + + return nil + }) + + part, compressCtx := newPart(1, ctx, cfg.FrameEncodeWorkers) + for { + if err := ctx.Err(); err != nil { + return nil, [32]byte{}, err + } + + buf := make([]byte, frameSize) + n, err := io.ReadFull(in, buf) + + switch { + case err == nil: + case errors.Is(err, io.EOF): + case errors.Is(err, io.ErrUnexpectedEOF): + // fall through + default: + return nil, [32]byte{}, fmt.Errorf("read frame: %w", err) + } + + if n > 0 { + hasher.Write(buf[:n]) + part.addFrame(compressCtx, buf[:n], compressors) + } + + if err != nil { + break + } + + if part.compressedSize.Load() >= targetPartSize { + part.submit(ctx, q) + part, compressCtx = newPart(part.index+1, ctx, cfg.FrameEncodeWorkers) + } + } + + if len(part.frames) > 0 { + part.submit(ctx, q) + } + + closeQ.Do(func() { close(q) }) + + emitErr := emitEG.Wait() + uploadErr := uploadEG.Wait() + if err := errors.Join(emitErr, uploadErr); err != nil { + return nil, [32]byte{}, err + } + + if err := uploader.Complete(ctx); err != nil { + return nil, [32]byte{}, fmt.Errorf("failed to finish uploading frames: %w", err) + } + + copy(checksum[:], hasher.Sum(nil)) + + return ft, checksum, nil +} diff --git a/packages/shared/pkg/storage/compress_upload_test.go b/packages/shared/pkg/storage/compress_upload_test.go new file mode 100644 index 0000000000..80188acafe --- /dev/null +++ b/packages/shared/pkg/storage/compress_upload_test.go @@ -0,0 +1,458 @@ +package storage + +import ( + "bytes" + "context" + crand "crypto/rand" + "crypto/sha256" + "fmt" + "io" + "math/rand/v2" + "os" + "path/filepath" + "slices" + "sync/atomic" + "testing" + "time" + + "github.com/klauspost/compress/zstd" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +// generateSemiRandomData produces deterministic, compressible data. +// Random byte repeated 1-16 times — gives ~0.5-0.7 compression ratio. +func generateSemiRandomData(size int) []byte { + data := make([]byte, size) + rng := rand.New(rand.NewPCG(1, 2)) //nolint:gosec // deterministic + i := 0 + for i < size { + runLen := rng.IntN(16) + 1 + if i+runLen > size { + runLen = size - i + } + b := byte(rng.IntN(256)) + for j := range runLen { + data[i+j] = b + } + i += runLen + } + + return data +} + +// ThrottledPartUploader wraps memPartUploader with simulated upload bandwidth. +type ThrottledPartUploader struct { + memPartUploader + + bandwidth int64 // bytes/sec; 0 = unlimited +} + +func (t *ThrottledPartUploader) UploadPart(ctx context.Context, partIndex int, data ...[]byte) error { + if t.bandwidth > 0 { + total := 0 + for _, d := range data { + total += len(d) + } + time.Sleep(time.Duration(float64(total) / float64(t.bandwidth) * float64(time.Second))) + } + + return t.memPartUploader.UploadPart(ctx, partIndex, data...) +} + +// decompressAll walks the FrameTable and decompresses each frame from the +// concatenated compressed blob, returning the original uncompressed data. +func decompressAll(ft *FrameTable, compressed []byte) ([]byte, error) { + var result []byte + var cOff int64 + + for i, fs := range ft.Frames { + if cOff+int64(fs.C) > int64(len(compressed)) { + return nil, fmt.Errorf("frame %d: compressed data truncated (need %d, have %d)", i, cOff+int64(fs.C), len(compressed)) + } + + frameData := compressed[cOff : cOff+int64(fs.C)] + + var frame []byte + var err error + + switch ft.CompressionType() { + case CompressionLZ4: + dec := getLZ4Decoder(bytes.NewReader(frameData)) + frame, err = io.ReadAll(dec) + putLZ4Decoder(dec) + case CompressionZstd: + var dec *zstd.Decoder + dec, err = getZstdDecoder(bytes.NewReader(frameData)) + if err == nil { + frame, err = io.ReadAll(dec) + putZstdDecoder(dec) + } + } + if err != nil { + return nil, fmt.Errorf("frame %d: %w", i, err) + } + result = append(result, frame...) + cOff += int64(fs.C) + } + + return result, nil +} + +// defaultCfg returns a CompressConfig with the given overrides applied. +func defaultCfg(ct CompressionType, workers, frameSize int) *CompressConfig { + level := 2 // zstd default + if ct == CompressionLZ4 { + level = 0 + } + + return &CompressConfig{ + Enabled: true, + Type: ct.String(), + Level: level, + EncoderConcurrency: 1, + FrameEncodeWorkers: workers, + FrameSizeKB: frameSize / 1024, + TargetPartSizeMB: 50, + } +} + +func TestCompressStreamRoundTrip(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + dataSize int + frameSize int + workers int + codec CompressionType + incompressible bool // use crypto/rand data that cannot be compressed + }{ + {"basic", 10 * megabyte, 2 * megabyte, 4, CompressionZstd, false}, + {"workers_1", 10 * megabyte, 2 * megabyte, 1, CompressionZstd, false}, + {"workers_2", 10 * megabyte, 2 * megabyte, 2, CompressionZstd, false}, + {"not_frame_aligned", 10*megabyte + 1, 2 * megabyte, 4, CompressionZstd, false}, + {"smaller_than_frame", 100 * 1024, 2 * megabyte, 4, CompressionZstd, false}, + {"smaller_than_part", 5 * megabyte, 2 * megabyte, 4, CompressionZstd, false}, + {"empty", 0, 2 * megabyte, 4, CompressionZstd, false}, + {"single_byte", 1, 2 * megabyte, 1, CompressionZstd, false}, + {"lz4", 10 * megabyte, 2 * megabyte, 4, CompressionLZ4, false}, + {"lz4_incompressible", 10 * megabyte, 2 * megabyte, 4, CompressionLZ4, true}, + {"zstd_incompressible", 10 * megabyte, 2 * megabyte, 4, CompressionZstd, true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var original []byte + if tc.dataSize > 0 { + if tc.incompressible { + original = make([]byte, tc.dataSize) + _, err := crand.Read(original) + require.NoError(t, err) + } else { + original = generateSemiRandomData(tc.dataSize) + } + } + + up := &memPartUploader{} + cfg := defaultCfg(tc.codec, tc.workers, tc.frameSize) + + ft, checksum, err := compressStream( + context.Background(), + bytes.NewReader(original), + cfg, + up, + 4, + ) + require.NoError(t, err) + + if tc.dataSize == 0 { + require.Empty(t, ft.Frames) + require.Equal(t, sha256.Sum256(nil), checksum) + + return + } + + // Verify frame count. + expectedFrames := (tc.dataSize + tc.frameSize - 1) / tc.frameSize + require.Len(t, ft.Frames, expectedFrames) + + // Verify checksum. + require.Equal(t, sha256.Sum256(original), checksum) + + // Round-trip: decompress and compare. + compressed := up.Assemble() + decompressed, err := decompressAll(ft, compressed) + require.NoError(t, err) + require.Equal(t, original, decompressed) + }) + } +} + +func TestCompressStreamContextCancel(t *testing.T) { + t.Parallel() + + data := generateSemiRandomData(10 * megabyte) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + + up := &memPartUploader{} + cfg := defaultCfg(CompressionZstd, 4, 2*megabyte) + + _, _, err := compressStream(ctx, bytes.NewReader(data), cfg, up, 4) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) +} + +func TestCompressStreamPartSizeMinimum(t *testing.T) { + t.Parallel() + + // Generate once; subtests slice to their needed size. + sharedData := generateSemiRandomData(100 * megabyte) + + tests := []struct { + name string + dataSize int + frameSize int + targetPartSizeMB int + }{ + {"large_file", 100 * megabyte, 2 * megabyte, 50}, + {"small_file_one_part", 5 * megabyte, 2 * megabyte, 50}, + {"small_target", 100 * megabyte, 2 * megabyte, 10}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + data := sharedData[:tc.dataSize] + up := &memPartUploader{} + cfg := defaultCfg(CompressionZstd, 4, tc.frameSize) + cfg.TargetPartSizeMB = tc.targetPartSizeMB + + _, _, err := compressStream(context.Background(), bytes.NewReader(data), cfg, up, 4) + require.NoError(t, err) + + // Verify: no non-final part is under 5 MiB. + keys := make([]int, 0, len(up.parts)) + for k := range up.parts { + keys = append(keys, k) + } + slices.Sort(keys) + + for i, k := range keys { + isFinal := i == len(keys)-1 + if !isFinal { + require.GreaterOrEqual(t, len(up.parts[k]), 5*1024*1024, + "non-final part %d is under 5 MiB (%d bytes)", k, len(up.parts[k])) + } + } + + require.NotEmpty(t, up.parts, "should have at least one part") + }) + } +} + +// TestCompressStreamRace runs many concurrent CompressStream calls with high +// worker counts to shake out data races in the compressor pool, memPartUploader, +// and errgroup coordination. Run with -race. +func TestCompressStreamRace(t *testing.T) { + t.Parallel() + + const ( + streams = 8 // concurrent CompressStream calls + dataSize = 4 * megabyte // small enough to be fast, big enough to exercise batching + frameSize = 128 * 1024 // 128 KB — many frames per part + workers = 8 // high worker count to maximise contention + targetPartSizeMB = 1 // small parts → many parts per stream + ) + + data := generateSemiRandomData(dataSize) + wantChecksum := sha256.Sum256(data) + + // Use an errgroup to run all streams concurrently. + eg, ctx := errgroup.WithContext(context.Background()) + for i := range streams { + codec := CompressionZstd + if i%2 == 1 { + codec = CompressionLZ4 // mix codecs for more coverage + } + + eg.Go(func() error { + up := &memPartUploader{} + cfg := defaultCfg(codec, workers, frameSize) + cfg.TargetPartSizeMB = targetPartSizeMB + if codec == CompressionZstd { + cfg.EncoderConcurrency = 4 // multi-threaded zstd encoders for more contention + } + + ft, checksum, err := compressStream(ctx, bytes.NewReader(data), cfg, up, 4) + if err != nil { + return fmt.Errorf("stream %d: compress: %w", i, err) + } + + if checksum != wantChecksum { + return fmt.Errorf("stream %d: checksum mismatch", i) + } + + decompressed, err := decompressAll(ft, up.Assemble()) + if err != nil { + return fmt.Errorf("stream %d: decompress: %w", i, err) + } + + if !bytes.Equal(data, decompressed) { + return fmt.Errorf("stream %d: round-trip data mismatch", i) + } + + return nil + }) + } + + require.NoError(t, eg.Wait()) +} + +func BenchmarkCompress(b *testing.B) { + const dataSize = 256 * megabyte + data := generateSemiRandomData(dataSize) + + configs := []struct { + name string + workers int + bandwidth int64 // bytes/sec; 0 = unlimited + }{ + {"w1_unlimited", 1, 0}, + {"w2_unlimited", 2, 0}, + {"w4_unlimited", 4, 0}, + {"w1_200MBs", 1, 200 * megabyte}, + {"w4_200MBs", 4, 200 * megabyte}, + {"w4_100MBs", 4, 100 * megabyte}, + } + + for _, bcfg := range configs { + b.Run(bcfg.name, func(b *testing.B) { + compCfg := &CompressConfig{ + Enabled: true, + Type: "zstd", + Level: 2, + EncoderConcurrency: 1, + FrameEncodeWorkers: bcfg.workers, + FrameSizeKB: 2 * 1024, + TargetPartSizeMB: 50, + } + + var lastParts atomic.Int32 + + b.ResetTimer() + b.SetBytes(int64(dataSize)) + + for range b.N { + up := &ThrottledPartUploader{bandwidth: bcfg.bandwidth} + + _, _, err := compressStream( + context.Background(), + bytes.NewReader(data), + compCfg, + up, 4, + ) + if err != nil { + b.Fatal(err) + } + + lastParts.Store(int32(len(up.parts))) + } + + // Report after all iterations using last run's values. + // b.SetBytes already reports MB/s (uncompressed throughput). + b.ReportMetric(float64(lastParts.Load()), "parts") + }) + } +} + +func BenchmarkStoreFile(b *testing.B) { + const dataSize = 1024 * megabyte // 1 GB + + data := generateSemiRandomData(dataSize) + inputDir := b.TempDir() + inputPath := filepath.Join(inputDir, "input.bin") + require.NoError(b, os.WriteFile(inputPath, data, 0o644)) + data = nil //nolint:ineffassign,wastedassign // hint GC to free 1GB before benchmark loop + + codecs := []struct { + name string + codec CompressionType + level int + }{ + {"zstd1", CompressionZstd, 1}, + {"zstd2", CompressionZstd, 2}, + {"zstd3", CompressionZstd, 3}, + {"lz4", CompressionLZ4, 0}, + } + workerCounts := []int{1, 2, 4, 8} + + for _, codec := range codecs { + for _, workers := range workerCounts { + name := fmt.Sprintf("%s/w%d", codec.name, workers) + b.Run(name, func(b *testing.B) { + compCfg := &CompressConfig{ + Enabled: true, + Type: codec.codec.String(), + Level: codec.level, + EncoderConcurrency: 1, + FrameEncodeWorkers: workers, + FrameSizeKB: 2 * 1024, + TargetPartSizeMB: 50, + } + + b.SetBytes(int64(dataSize)) + b.ResetTimer() + + for range b.N { + outDir := b.TempDir() + outPath := filepath.Join(outDir, "output.dat") + obj := &fsObject{path: outPath} + + ft, _, err := obj.StoreFile(b.Context(), inputPath, compCfg) + if err != nil { + b.Fatal(err) + } + + uSize, cSize := ft.Size() + b.ReportMetric(float64(cSize)/float64(uSize), "ratio") + } + }) + } + } + + b.Run("uncompressed", func(b *testing.B) { + b.SetBytes(int64(dataSize)) + b.ResetTimer() + + for range b.N { + outDir := b.TempDir() + outPath := filepath.Join(outDir, "output.dat") + + in, err := os.Open(inputPath) + if err != nil { + b.Fatal(err) + } + out, err := os.Create(outPath) + if err != nil { + in.Close() + b.Fatal(err) + } + if _, err := io.Copy(out, in); err != nil { + in.Close() + out.Close() + b.Fatal(err) + } + in.Close() + out.Close() + } + }) +} diff --git a/packages/shared/pkg/storage/gcp_multipart.go b/packages/shared/pkg/storage/gcp_multipart.go index 75324c16c1..ee568df86f 100644 --- a/packages/shared/pkg/storage/gcp_multipart.go +++ b/packages/shared/pkg/storage/gcp_multipart.go @@ -139,9 +139,61 @@ type MultipartUploader struct { client *retryablehttp.Client retryConfig RetryConfig baseURL string // Allow overriding for testing + metadata map[string]string + + // Fields for partUploader interface + uploadID string + mu sync.Mutex + parts []Part +} + +var _ partUploader = (*MultipartUploader)(nil) + +// Start initiates the GCS multipart upload. +func (m *MultipartUploader) Start(ctx context.Context) error { + uploadID, err := m.initiateUpload(ctx) + if err != nil { + return fmt.Errorf("failed to initiate multipart upload: %w", err) + } + + m.uploadID = uploadID + + return nil +} + +// UploadPart uploads a single part to GCS. Multiple data slices are hashed +// and uploaded without copying into a single contiguous buffer. +func (m *MultipartUploader) UploadPart(ctx context.Context, partIndex int, data ...[]byte) error { + etag, err := m.uploadPartSlices(ctx, m.uploadID, partIndex, data) + if err != nil { + return fmt.Errorf("failed to upload part %d: %w", partIndex, err) + } + + m.mu.Lock() + m.parts = append(m.parts, Part{ + PartNumber: partIndex, + ETag: etag, + }) + m.mu.Unlock() + + return nil +} + +// Complete finalizes the GCS multipart upload with all collected parts. +func (m *MultipartUploader) Complete(ctx context.Context) error { + m.mu.Lock() + parts := make([]Part, len(m.parts)) + copy(parts, m.parts) + m.mu.Unlock() + + return m.completeUpload(ctx, m.uploadID, parts) +} + +func (m *MultipartUploader) Close() error { + return nil } -func NewMultipartUploaderWithRetryConfig(ctx context.Context, bucketName, objectName string, retryConfig RetryConfig) (*MultipartUploader, error) { +func NewMultipartUploaderWithRetryConfig(ctx context.Context, bucketName, objectName string, retryConfig RetryConfig, metadata map[string]string) (*MultipartUploader, error) { creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform") if err != nil { return nil, fmt.Errorf("failed to get credentials: %w", err) @@ -159,6 +211,7 @@ func NewMultipartUploaderWithRetryConfig(ctx context.Context, bucketName, object client: createRetryableClient(ctx, retryConfig), retryConfig: retryConfig, baseURL: fmt.Sprintf("https://%s.storage.googleapis.com", bucketName), + metadata: metadata, }, nil } @@ -174,6 +227,10 @@ func (m *MultipartUploader) initiateUpload(ctx context.Context) (string, error) req.Header.Set("Content-Length", "0") req.Header.Set("Content-Type", "application/octet-stream") + for k, v := range m.metadata { + req.Header.Set("x-goog-meta-"+k, v) + } + resp, err := m.client.Do(req) if err != nil { return "", err @@ -232,6 +289,60 @@ func (m *MultipartUploader) uploadPart(ctx context.Context, uploadID string, par return etag, nil } +// uploadPartSlices uploads a part from multiple byte slices without concatenating them. +// It computes MD5 by hashing each slice and uses a ReaderFunc for retryable reads. +func (m *MultipartUploader) uploadPartSlices(ctx context.Context, uploadID string, partNumber int, slices [][]byte) (string, error) { + // Compute MD5 and total length without copying + hasher := md5.New() + totalLen := 0 + for _, s := range slices { + hasher.Write(s) + totalLen += len(s) + } + md5Sum := base64.StdEncoding.EncodeToString(hasher.Sum(nil)) + + url := fmt.Sprintf("%s/%s?partNumber=%d&uploadId=%s", + m.baseURL, m.objectName, partNumber, uploadID) + + // Use a ReaderFunc so the retryable client can replay the body on retries + bodyFn := func() (io.Reader, error) { + readers := make([]io.Reader, len(slices)) + for i, s := range slices { + readers[i] = bytes.NewReader(s) + } + + return io.MultiReader(readers...), nil + } + + req, err := retryablehttp.NewRequestWithContext(ctx, "PUT", url, retryablehttp.ReaderFunc(bodyFn)) + if err != nil { + return "", err + } + + req.Header.Set("Authorization", "Bearer "+m.token) + req.Header.Set("Content-Length", fmt.Sprintf("%d", totalLen)) + req.Header.Set("Content-MD5", md5Sum) + + resp, err := m.client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + + return "", fmt.Errorf("failed to upload part %d (status %d): %s", partNumber, resp.StatusCode, string(body)) + } + + etag := resp.Header.Get("ETag") + if etag == "" { + return "", fmt.Errorf("no ETag returned for part %d", partNumber) + } + + return etag, nil +} + func (m *MultipartUploader) completeUpload(ctx context.Context, uploadID string, parts []Part) error { // Sort parts by part number sort.Slice(parts, func(i, j int) bool { diff --git a/packages/shared/pkg/storage/gcp_multipart_test.go b/packages/shared/pkg/storage/gcp_multipart_test.go index 7fe4d397ce..49ed8bbc51 100644 --- a/packages/shared/pkg/storage/gcp_multipart_test.go +++ b/packages/shared/pkg/storage/gcp_multipart_test.go @@ -1,6 +1,8 @@ package storage import ( + "crypto/md5" + "encoding/base64" "encoding/xml" "fmt" "io" @@ -115,6 +117,42 @@ func TestMultipartUploader_UploadPart_Success(t *testing.T) { require.Equal(t, expectedETag, etag) } +func TestMultipartUploader_UploadPartSlices_Success(t *testing.T) { + t.Parallel() + expectedETag := `"slice-etag"` + slices := [][]byte{[]byte("hello "), []byte("world"), []byte("!")} + + // Compute expected MD5 over all slices. + h := md5.New() + for _, s := range slices { + h.Write(s) + } + expectedMD5 := base64.StdEncoding.EncodeToString(h.Sum(nil)) + + handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "PUT", r.Method) + assert.Contains(t, r.URL.RawQuery, "partNumber=3") + assert.Contains(t, r.URL.RawQuery, "uploadId=test-upload-id") + + // Verify MD5 matches the expected hash of all slices. + assert.Equal(t, expectedMD5, r.Header.Get("Content-MD5")) + + // Verify body is the concatenation of all slices. + body, err := io.ReadAll(r.Body) + assert.NoError(t, err) + assert.Equal(t, []byte("hello world!"), body) + + w.Header().Set("ETag", expectedETag) + w.WriteHeader(http.StatusOK) + }) + + uploader := createTestMultipartUploader(t, handler) + etag, err := uploader.uploadPartSlices(t.Context(), "test-upload-id", 3, slices) + + require.NoError(t, err) + require.Equal(t, expectedETag, etag) +} + func TestMultipartUploader_UploadPart_MissingETag(t *testing.T) { t.Parallel() handler := http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { @@ -170,7 +208,6 @@ func TestMultipartUploader_UploadFileInParallel_Success(t *testing.T) { err := os.WriteFile(testFile, []byte(testContent), 0o644) require.NoError(t, err) - var uploadID string var initiateCount, uploadPartCount, completeCount int32 receivedParts := sync.Map{} @@ -179,11 +216,10 @@ func TestMultipartUploader_UploadFileInParallel_Success(t *testing.T) { case r.URL.RawQuery == uploadsPath: // Initiate upload atomic.AddInt32(&initiateCount, 1) - uploadID = "test-upload-id-123" response := InitiateMultipartUploadResult{ Bucket: testBucketName, Key: testObjectName, - UploadID: uploadID, + UploadID: "test-upload-id-123", } xmlData, _ := xml.Marshal(response) w.Header().Set("Content-Type", "application/xml") @@ -524,7 +560,7 @@ func TestMultipartUploader_EdgeCases_VerySmallFile(t *testing.T) { err := os.WriteFile(smallFile, []byte(smallContent), 0o644) require.NoError(t, err) - var receivedData string + var receivedParts sync.Map handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { switch { @@ -540,7 +576,8 @@ func TestMultipartUploader_EdgeCases_VerySmallFile(t *testing.T) { case strings.Contains(r.URL.RawQuery, "partNumber"): body, _ := io.ReadAll(r.Body) - receivedData = string(body) + partNum := r.URL.Query().Get("partNumber") + receivedParts.Store(partNum, string(body)) w.Header().Set("ETag", `"small-etag"`) w.WriteHeader(http.StatusOK) @@ -553,7 +590,18 @@ func TestMultipartUploader_EdgeCases_VerySmallFile(t *testing.T) { uploader := createTestMultipartUploader(t, handler) _, err = uploader.UploadFileInParallel(t.Context(), smallFile, 10) // High concurrency for small file require.NoError(t, err) - require.Equal(t, smallContent, receivedData) + + // Small file should produce exactly one part + var partCount int + receivedParts.Range(func(_, _ any) bool { + partCount++ + + return true + }) + require.Equal(t, 1, partCount) + data, ok := receivedParts.Load("1") + require.True(t, ok) + require.Equal(t, smallContent, data.(string)) } type repeatReader struct { @@ -692,8 +740,9 @@ func TestMultipartUploader_BoundaryConditions_ExactChunkSize(t *testing.T) { // Should have exactly 2 parts, each of ChunkSize require.Len(t, partSizes, 2) - require.Equal(t, gcpMultipartUploadChunkSize, partSizes[0]) - require.Equal(t, gcpMultipartUploadChunkSize, partSizes[1]) + for _, size := range partSizes { + require.Equal(t, gcpMultipartUploadChunkSize, size) + } } func TestMultipartUploader_FileNotFound_Error(t *testing.T) { diff --git a/packages/shared/pkg/storage/header/header.go b/packages/shared/pkg/storage/header/header.go index 9a1f3008f5..6541ac0d96 100644 --- a/packages/shared/pkg/storage/header/header.go +++ b/packages/shared/pkg/storage/header/header.go @@ -1,26 +1,68 @@ package header import ( + "cmp" "context" "fmt" + "maps" + "slices" "github.com/bits-and-blooms/bitset" "github.com/google/uuid" "go.uber.org/zap" "github.com/e2b-dev/infra/packages/shared/pkg/logger" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) +// BuildFileInfo holds metadata about a build's data file, stored in the header +// so the read path can avoid network round-trips (e.g. Size() calls to GCS). +type BuildFileInfo struct { + Size int64 // uncompressed file size + Checksum [32]byte // SHA-256 of uncompressed data; zero value means unknown +} + const NormalizeFixVersion = 3 type Header struct { - Metadata *Metadata + Metadata *Metadata + // BuildFiles maps build IDs to their file metadata (size + checksum). + // Each layer's upload adds its own entry via applyToHeader, and inherits + // all parent entries via ToDiffHeader (which copies originalHeader.BuildFiles). + // This means every V4 header has a complete map of all builds referenced + // in its Mapping. V3 headers have no BuildFiles; the read path falls back + // to a Size() RPC for those. + BuildFiles map[uuid.UUID]BuildFileInfo blockStarts *bitset.BitSet startMap map[int64]*BuildMap Mapping []*BuildMap } +// CloneForUpload returns a clone with copied Mapping and BuildFiles, safe to +// mutate for serialization without racing with concurrent readers of the +// original. Only serialization-relevant fields are populated (Metadata, +// Mapping, BuildFiles); lookup indices (blockStarts, startMap) are left nil. +func (t *Header) CloneForUpload() *Header { + mappings := make([]*BuildMap, len(t.Mapping)) + for i, m := range t.Mapping { + mappings[i] = m.Copy() + } + + metaCopy := *t.Metadata + clone := &Header{ + Metadata: &metaCopy, + Mapping: mappings, + } + + if t.BuildFiles != nil { + clone.BuildFiles = make(map[uuid.UUID]BuildFileInfo, len(t.BuildFiles)) + maps.Copy(clone.BuildFiles, t.BuildFiles) + } + + return clone +} + func NewHeader(metadata *Metadata, mapping []*BuildMap) (*Header, error) { if metadata.BlockSize == 0 { return nil, fmt.Errorf("block size cannot be zero") @@ -40,11 +82,11 @@ func NewHeader(metadata *Metadata, mapping []*BuildMap) (*Header, error) { intervals := bitset.New(uint(blocks)) startMap := make(map[int64]*BuildMap, len(mapping)) - for _, mapping := range mapping { - block := BlockIdx(int64(mapping.Offset), int64(metadata.BlockSize)) + for _, m := range mapping { + block := BlockIdx(int64(m.Offset), int64(metadata.BlockSize)) intervals.Set(uint(block)) - startMap[block] = mapping + startMap[block] = m } return &Header{ @@ -55,27 +97,81 @@ func NewHeader(metadata *Metadata, mapping []*BuildMap) (*Header, error) { }, nil } +func (t *Header) String() string { + if t == nil { + return "[nil Header]" + } + + return fmt.Sprintf("[Header: version=%d, size=%d, blockSize=%d, generation=%d, buildId=%s, mappings=%d]", + t.Metadata.Version, + t.Metadata.Size, + t.Metadata.BlockSize, + t.Metadata.Generation, + t.Metadata.BuildId.String(), + len(t.Mapping), + ) +} + +func (t *Header) Mappings(all bool) string { + if t == nil { + return "[nil Header, no mappings]" + } + n := 0 + for _, m := range t.Mapping { + if all || m.BuildId == t.Metadata.BuildId { + n++ + } + } + result := fmt.Sprintf("All mappings: %d\n", n) + if !all { + result = fmt.Sprintf("Mappings for build %s: %d\n", t.Metadata.BuildId.String(), n) + } + for _, m := range t.Mapping { + if !all && m.BuildId != t.Metadata.BuildId { + continue + } + frames := 0 + if m.FrameTable != nil { + frames = len(m.FrameTable.Frames) + } + result += fmt.Sprintf(" - Offset: %d, Length: %d, BuildId: %s, BuildStorageOffset: %d, numFrames: %d\n", + m.Offset, + m.Length, + m.BuildId.String(), + m.BuildStorageOffset, + frames, + ) + } + + return result +} + // IsNormalizeFixApplied is a helper method to soft fail for older versions of the header where fix for normalization was not applied. // This should be removed in the future. func (t *Header) IsNormalizeFixApplied() bool { return t.Metadata.Version >= NormalizeFixVersion } -func (t *Header) GetShiftedMapping(ctx context.Context, offset int64) (mappedOffset int64, mappedLength int64, buildID *uuid.UUID, err error) { +func (t *Header) GetShiftedMapping(ctx context.Context, offset int64) (BuildMap, error) { mapping, shift, err := t.getMapping(ctx, offset) if err != nil { - return 0, 0, nil, err + return BuildMap{}, err } + mappedLength := int64(mapping.Length) - shift - mappedOffset = int64(mapping.BuildStorageOffset) + shift - mappedLength = int64(mapping.Length) - shift - buildID = &mapping.BuildId + b := BuildMap{ + Offset: mapping.BuildStorageOffset + uint64(shift), + Length: uint64(mappedLength), + BuildId: mapping.BuildId, + FrameTable: mapping.FrameTable, + } if mappedLength < 0 { if t.IsNormalizeFixApplied() { - return 0, 0, nil, fmt.Errorf("mapped length for offset %d is negative: %d", offset, mappedLength) + return BuildMap{}, fmt.Errorf("mapped length for offset %d is negative: %d", offset, mappedLength) } + b.Length = 0 logger.L().Warn(ctx, "mapped length is negative, but normalize fix is not applied", zap.Int64("offset", offset), zap.Int64("mappedLength", mappedLength), @@ -83,7 +179,7 @@ func (t *Header) GetShiftedMapping(ctx context.Context, offset int64) (mappedOff ) } - return mappedOffset, mappedLength, buildID, nil + return b, nil } // TODO: Maybe we can optimize mapping by automatically assuming the mapping is uuid.Nil if we don't find it + stopping storing the nil mapping. @@ -143,3 +239,101 @@ func (t *Header) getMapping(ctx context.Context, offset int64) (*BuildMap, int64 return mapping, shift, nil } + +// ValidateHeader checks header integrity and returns an error if corruption is detected. +// This verifies: +// 1. Header and metadata are valid +// 2. Mappings cover the entire file [0, Size) with no gaps +// 3. Mappings don't extend beyond file size (with block alignment tolerance) +func ValidateHeader(h *Header) error { + if h == nil { + return fmt.Errorf("header is nil") + } + if h.Metadata == nil { + return fmt.Errorf("header metadata is nil") + } + if h.Metadata.BlockSize == 0 { + return fmt.Errorf("header has zero block size") + } + if h.Metadata.Size == 0 { + return fmt.Errorf("header has zero size") + } + if len(h.Mapping) == 0 { + return fmt.Errorf("header has no mappings") + } + + // Sort mappings by offset to check for gaps/overlaps + sortedMappings := make([]*BuildMap, len(h.Mapping)) + copy(sortedMappings, h.Mapping) + slices.SortFunc(sortedMappings, func(a, b *BuildMap) int { + return cmp.Compare(a.Offset, b.Offset) + }) + + // Check that first mapping starts at 0 + if sortedMappings[0].Offset != 0 { + return fmt.Errorf("mappings don't start at 0: first mapping starts at %d for buildId %s", + sortedMappings[0].Offset, h.Metadata.BuildId.String()) + } + + // Check for gaps and overlaps between consecutive mappings + for i := range len(sortedMappings) - 1 { + currentEnd := sortedMappings[i].Offset + sortedMappings[i].Length + nextStart := sortedMappings[i+1].Offset + + if currentEnd < nextStart { + return fmt.Errorf("gap in mappings: mapping[%d] ends at %d but mapping[%d] starts at %d (gap=%d bytes) for buildId %s", + i, currentEnd, i+1, nextStart, nextStart-currentEnd, h.Metadata.BuildId.String()) + } + if currentEnd > nextStart { + return fmt.Errorf("overlap in mappings: mapping[%d] ends at %d but mapping[%d] starts at %d (overlap=%d bytes) for buildId %s", + i, currentEnd, i+1, nextStart, currentEnd-nextStart, h.Metadata.BuildId.String()) + } + } + + // Check that last mapping covers up to (at least) Size + lastMapping := sortedMappings[len(sortedMappings)-1] + lastEnd := lastMapping.Offset + lastMapping.Length + if lastEnd < h.Metadata.Size { + return fmt.Errorf("mappings don't cover entire file: last mapping ends at %d but file size is %d (missing %d bytes) for buildId %s", + lastEnd, h.Metadata.Size, h.Metadata.Size-lastEnd, h.Metadata.BuildId.String()) + } + + // Allow last mapping to extend up to one block past size (for alignment) + if lastEnd > h.Metadata.Size+h.Metadata.BlockSize { + return fmt.Errorf("last mapping extends too far: ends at %d but file size is %d (overhang=%d bytes, max allowed=%d) for buildId %s", + lastEnd, h.Metadata.Size, lastEnd-h.Metadata.Size, h.Metadata.BlockSize, h.Metadata.BuildId.String()) + } + + // Validate individual mapping bounds + for i, m := range h.Mapping { + if m.Offset > h.Metadata.Size { + return fmt.Errorf("mapping[%d] has Offset %d beyond header size %d for buildId %s", + i, m.Offset, h.Metadata.Size, m.BuildId.String()) + } + if m.Length == 0 { + return fmt.Errorf("mapping[%d] has zero length at offset %d for buildId %s", + i, m.Offset, m.BuildId.String()) + } + } + + return nil +} + +// SetFrames associates compression frame information with this header's mappings. +// +// Only mappings matching this header's BuildId will be updated. Returns nil if frameTable is nil. +func (t *Header) SetFrames(frameTable *storage.FrameTable) error { + if frameTable == nil { + return nil + } + + for _, mapping := range t.Mapping { + if mapping.BuildId == t.Metadata.BuildId { + if err := mapping.SetFrames(frameTable); err != nil { + return err + } + } + } + + return nil +} diff --git a/packages/shared/pkg/storage/header/mapping.go b/packages/shared/pkg/storage/header/mapping.go index 0802bb1fe8..241ca0133d 100644 --- a/packages/shared/pkg/storage/header/mapping.go +++ b/packages/shared/pkg/storage/header/mapping.go @@ -6,6 +6,8 @@ import ( "github.com/bits-and-blooms/bitset" "github.com/google/uuid" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) // Start, Length and SourceStart are in bytes of the data file @@ -17,6 +19,7 @@ type BuildMap struct { Length uint64 BuildId uuid.UUID BuildStorageOffset uint64 + FrameTable *storage.FrameTable } func (mapping *BuildMap) Copy() *BuildMap { @@ -25,7 +28,47 @@ func (mapping *BuildMap) Copy() *BuildMap { Length: mapping.Length, BuildId: mapping.BuildId, BuildStorageOffset: mapping.BuildStorageOffset, + FrameTable: mapping.FrameTable, + } +} + +// SetFrames associates compression frame information with this mapping. +// +// When a file is uploaded with compression, the compressor produces a FrameTable +// that describes how the compressed data is organized into frames. This method +// computes which compressed frames cover this mapping's data within the build's +// storage file based on BuildStorageOffset and Length. +// +// Returns nil if frameTable is nil. Returns an error if the mapping's range +// cannot be found in the frame table. +func (mapping *BuildMap) SetFrames(frameTable *storage.FrameTable) error { + _, err := mapping.SetFramesFrom(frameTable, 0) + + return err +} + +// SetFramesFrom is like SetFrames but starts scanning from frame index `from`, +// returning the next cursor position. Use this when applying frames to a +// sorted sequence of mappings to avoid O(N²) rescanning. +func (mapping *BuildMap) SetFramesFrom(frameTable *storage.FrameTable, from int) (int, error) { + if frameTable == nil { + return from, nil + } + + mappedRange := storage.Range{ + Start: int64(mapping.BuildStorageOffset), + Length: int(mapping.Length), } + + subset, next := frameTable.SubsetFrom(mappedRange, from) + if subset == nil && mapping.Length > 0 { + return next, fmt.Errorf("mapping at virtual offset %d (storage offset %d, length %d): no frames found from index %d", + mapping.Offset, mapping.BuildStorageOffset, mapping.Length, from) + } + + mapping.FrameTable = subset + + return next, nil } func CreateMapping( @@ -84,9 +127,9 @@ func CreateMapping( func MergeMappings( baseMapping []*BuildMap, diffMapping []*BuildMap, -) []*BuildMap { +) ([]*BuildMap, error) { if len(diffMapping) == 0 { - return baseMapping + return baseMapping, nil } baseMappingCopy := make([]*BuildMap, len(baseMapping)) @@ -160,6 +203,9 @@ func MergeMappings( // the build storage offset is the same as the base mapping BuildStorageOffset: base.BuildStorageOffset, } + if err := leftBase.SetFrames(base.FrameTable); err != nil { + return nil, fmt.Errorf("set frames for left split at offset %d: %w", leftBase.Offset, err) + } mappings = append(mappings, leftBase) } @@ -178,6 +224,9 @@ func MergeMappings( BuildId: base.BuildId, BuildStorageOffset: base.BuildStorageOffset + uint64(rightBaseShift), } + if err := rightBase.SetFrames(base.FrameTable); err != nil { + return nil, fmt.Errorf("set frames for right split at offset %d: %w", rightBase.Offset, err) + } baseMapping[baseIdx] = rightBase } else { @@ -205,6 +254,9 @@ func MergeMappings( BuildId: base.BuildId, BuildStorageOffset: base.BuildStorageOffset + uint64(rightBaseShift), } + if err := rightBase.SetFrames(base.FrameTable); err != nil { + return nil, fmt.Errorf("set frames for right split at offset %d: %w", rightBase.Offset, err) + } baseMapping[baseIdx] = rightBase } else { @@ -226,6 +278,9 @@ func MergeMappings( BuildId: base.BuildId, BuildStorageOffset: base.BuildStorageOffset, } + if err := leftBase.SetFrames(base.FrameTable); err != nil { + return nil, fmt.Errorf("set frames for left split at offset %d: %w", leftBase.Offset, err) + } mappings = append(mappings, leftBase) } @@ -241,10 +296,12 @@ func MergeMappings( mappings = append(mappings, baseMapping[baseIdx:]...) mappings = append(mappings, diffMapping[diffIdx:]...) - return mappings + return mappings, nil } // NormalizeMappings joins adjacent mappings that have the same buildId. +// When merging mappings, FrameTables are also merged by extending the first +// mapping's FrameTable with frames from subsequent mappings. func NormalizeMappings(mappings []*BuildMap) []*BuildMap { if len(mappings) == 0 { return nil @@ -252,7 +309,7 @@ func NormalizeMappings(mappings []*BuildMap) []*BuildMap { result := make([]*BuildMap, 0, len(mappings)) - // Start with a copy of the first mapping + // Start with a copy of the first mapping (Copy() now includes FrameTable) current := mappings[0].Copy() for i := 1; i < len(mappings); i++ { @@ -260,10 +317,22 @@ func NormalizeMappings(mappings []*BuildMap) []*BuildMap { if mp.BuildId != current.BuildId { // BuildId changed, add the current map to results and start a new one result = append(result, current) - current = mp.Copy() // New copy + current = mp.Copy() // New copy (includes FrameTable) } else { - // Same BuildId, just add the length + // Same BuildId, merge: add the length and extend FrameTable current.Length += mp.Length + + // Extend FrameTable if the mapping being merged has one + if mp.FrameTable != nil { + if current.FrameTable == nil { + // Current has no FrameTable but merged one does - take it + current.FrameTable = mp.FrameTable + } else { + // Both have FrameTables - extend current's with mp's frames + // The frames are contiguous subsets, so we append non-overlapping frames + current.FrameTable = mergeFrameTables(current.FrameTable, mp.FrameTable) + } + } } } @@ -272,3 +341,63 @@ func NormalizeMappings(mappings []*BuildMap) []*BuildMap { return result } + +// mergeFrameTables extends ft1 with frames from ft2. The FrameTables are +// assumed to be contiguous subsets from the same original, so ft2's frames +// follow ft1's frames (with possible overlap at the boundary). this function +// returns either an reference to one of the input tables, unchanged, or a new +// FrameTable with frames from both tables. +func mergeFrameTables(ft1, ft2 *storage.FrameTable) *storage.FrameTable { + if ft1 == nil { + return ft2 + } + if ft2 == nil { + return ft1 + } + + // Calculate where ft1 ends (uncompressed offset) + ft1EndU := ft1.StartAt.U + for _, frame := range ft1.Frames { + ft1EndU += int64(frame.U) + } + + // Find where to start appending from ft2 (skip frames already covered by ft1) + ft2CurrentU := ft2.StartAt.U + startIdx := 0 + for i, frame := range ft2.Frames { + frameEndU := ft2CurrentU + int64(frame.U) + if frameEndU <= ft1EndU { + // This frame is already covered by ft1 + ft2CurrentU = frameEndU + startIdx = i + 1 + + continue + } + if ft2CurrentU < ft1EndU { + // This frame overlaps with ft1's last frame - it's the same frame, skip it + ft2CurrentU = frameEndU + startIdx = i + 1 + + continue + } + // This frame is beyond ft1's coverage + break + } + + // Append remaining frames from ft2 + if startIdx < len(ft2.Frames) { + // Create a new FrameTable with extended frames + newFrames := make([]storage.FrameSize, len(ft1.Frames), len(ft1.Frames)+len(ft2.Frames)-startIdx) + copy(newFrames, ft1.Frames) + newFrames = append(newFrames, ft2.Frames[startIdx:]...) + + result := storage.NewFrameTable(ft1.CompressionType()) + result.StartAt = ft1.StartAt + result.Frames = newFrames + + return result + } + + // All of ft2's frames were already covered by ft1 + return ft1 +} diff --git a/packages/shared/pkg/storage/header/mapping_test.go b/packages/shared/pkg/storage/header/mapping_test.go index d20f070a3c..28728c2df5 100644 --- a/packages/shared/pkg/storage/header/mapping_test.go +++ b/packages/shared/pkg/storage/header/mapping_test.go @@ -46,11 +46,12 @@ func TestMergeMappingsRemoveEmpty(t *testing.T) { }, } - m := MergeMappings(simpleBase, diff) + m, err := MergeMappings(simpleBase, diff) + require.NoError(t, err) require.True(t, Equal(m, simpleBase)) - err := ValidateMappings(m, size, blockSize) + err = ValidateMappings(m, size, blockSize) require.NoError(t, err) } @@ -65,7 +66,8 @@ func TestMergeMappingsBaseBeforeDiffNoOverlap(t *testing.T) { }, } - m := MergeMappings(simpleBase, diff) + m, err := MergeMappings(simpleBase, diff) + require.NoError(t, err) require.True(t, Equal(m, []*BuildMap{ { @@ -90,7 +92,7 @@ func TestMergeMappingsBaseBeforeDiffNoOverlap(t *testing.T) { }, })) - err := ValidateMappings(m, size, blockSize) + err = ValidateMappings(m, size, blockSize) require.NoError(t, err) } @@ -105,7 +107,8 @@ func TestMergeMappingsDiffBeforeBaseNoOverlap(t *testing.T) { }, } - m := MergeMappings(simpleBase, diff) + m, err := MergeMappings(simpleBase, diff) + require.NoError(t, err) require.True(t, Equal(m, []*BuildMap{ { @@ -130,7 +133,7 @@ func TestMergeMappingsDiffBeforeBaseNoOverlap(t *testing.T) { }, })) - err := ValidateMappings(m, size, blockSize) + err = ValidateMappings(m, size, blockSize) require.NoError(t, err) } @@ -145,7 +148,8 @@ func TestMergeMappingsBaseInsideDiff(t *testing.T) { }, } - m := MergeMappings(simpleBase, diff) + m, err := MergeMappings(simpleBase, diff) + require.NoError(t, err) require.True(t, Equal(m, []*BuildMap{ { @@ -165,7 +169,7 @@ func TestMergeMappingsBaseInsideDiff(t *testing.T) { }, })) - err := ValidateMappings(m, size, blockSize) + err = ValidateMappings(m, size, blockSize) require.NoError(t, err) } @@ -180,7 +184,8 @@ func TestMergeMappingsDiffInsideBase(t *testing.T) { }, } - m := MergeMappings(simpleBase, diff) + m, err := MergeMappings(simpleBase, diff) + require.NoError(t, err) require.True(t, Equal(m, []*BuildMap{ { @@ -210,7 +215,7 @@ func TestMergeMappingsDiffInsideBase(t *testing.T) { }, })) - err := ValidateMappings(m, size, blockSize) + err = ValidateMappings(m, size, blockSize) require.NoError(t, err) } @@ -225,7 +230,8 @@ func TestMergeMappingsBaseAfterDiffWithOverlap(t *testing.T) { }, } - m := MergeMappings(simpleBase, diff) + m, err := MergeMappings(simpleBase, diff) + require.NoError(t, err) require.True(t, Equal(m, []*BuildMap{ { @@ -250,7 +256,7 @@ func TestMergeMappingsBaseAfterDiffWithOverlap(t *testing.T) { }, })) - err := ValidateMappings(m, size, blockSize) + err = ValidateMappings(m, size, blockSize) require.NoError(t, err) } @@ -265,7 +271,8 @@ func TestMergeMappingsDiffAfterBaseWithOverlap(t *testing.T) { }, } - m := MergeMappings(simpleBase, diff) + m, err := MergeMappings(simpleBase, diff) + require.NoError(t, err) require.True(t, Equal(m, []*BuildMap{ { @@ -290,7 +297,7 @@ func TestMergeMappingsDiffAfterBaseWithOverlap(t *testing.T) { }, })) - err := ValidateMappings(m, size, blockSize) + err = ValidateMappings(m, size, blockSize) require.NoError(t, err) } diff --git a/packages/shared/pkg/storage/header/metadata.go b/packages/shared/pkg/storage/header/metadata.go index 574dea78bf..7d4725a0cd 100644 --- a/packages/shared/pkg/storage/header/metadata.go +++ b/packages/shared/pkg/storage/header/metadata.go @@ -1,7 +1,9 @@ package header import ( + "bytes" "context" + "encoding/binary" "fmt" "io" @@ -15,6 +17,59 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" ) +const ( + // metadataVersion is used by template-manager for uncompressed builds (V3 headers). + metadataVersion = 3 + // MetadataVersionCompressed is used for compressed builds (V4 headers with FrameTables). + MetadataVersionCompressed = 4 +) + +type Metadata struct { + Version uint64 + BlockSize uint64 + Size uint64 + Generation uint64 + BuildId uuid.UUID + // TODO: Use the base build id when setting up the snapshot rootfs + BaseBuildId uuid.UUID +} + +func NewTemplateMetadata(buildId uuid.UUID, blockSize, size uint64) *Metadata { + return &Metadata{ + Version: metadataVersion, + Generation: 0, + BlockSize: blockSize, + Size: size, + BuildId: buildId, + BaseBuildId: buildId, + } +} + +func (m *Metadata) NextGeneration(buildID uuid.UUID) *Metadata { + return &Metadata{ + Version: m.Version, + Generation: m.Generation + 1, + BlockSize: m.BlockSize, + Size: m.Size, + BuildId: buildID, + BaseBuildId: m.BaseBuildId, + } +} + +// metadataSize is the binary size of the Metadata struct, computed from the struct layout. +var metadataSize = binary.Size(Metadata{}) + +func deserializeMetadata(data []byte) (*Metadata, error) { + var metadata Metadata + + err := binary.Read(bytes.NewReader(data), binary.LittleEndian, &metadata) + if err != nil { + return nil, fmt.Errorf("failed to read metadata: %w", err) + } + + return &metadata, nil +} + var ignoreBuildID = uuid.Nil type DiffMetadata struct { @@ -27,7 +82,7 @@ type DiffMetadata struct { func (d *DiffMetadata) toDiffMapping( ctx context.Context, buildID uuid.UUID, -) (mapping []*BuildMap) { +) ([]*BuildMap, error) { dirtyMappings := CreateMapping( &buildID, d.Dirty, @@ -43,10 +98,13 @@ func (d *DiffMetadata) toDiffMapping( ) telemetry.ReportEvent(ctx, "created empty mapping") - mappings := MergeMappings(dirtyMappings, emptyMappings) + mappings, err := MergeMappings(dirtyMappings, emptyMappings) + if err != nil { + return nil, fmt.Errorf("merge dirty+empty mappings: %w", err) + } telemetry.ReportEvent(ctx, "merge mappings") - return mappings + return mappings, nil } func (d *DiffMetadata) ToDiffHeader( @@ -63,12 +121,18 @@ func (d *DiffMetadata) ToDiffHeader( } }() - diffMapping := d.toDiffMapping(ctx, buildID) + diffMapping, err := d.toDiffMapping(ctx, buildID) + if err != nil { + return nil, fmt.Errorf("toDiffMapping: %w", err) + } - m := MergeMappings( + m, err := MergeMappings( originalHeader.Mapping, diffMapping, ) + if err != nil { + return nil, fmt.Errorf("merge base+diff mappings: %w", err) + } telemetry.ReportEvent(ctx, "merged mappings") // TODO: We can run normalization only when empty mappings are not empty for this snapshot @@ -93,6 +157,18 @@ func (d *DiffMetadata) ToDiffHeader( return nil, fmt.Errorf("failed to create header: %w", err) } + // Copy only BuildFiles referenced by the merged mappings. + referenced := make(map[uuid.UUID]struct{}, len(m)) + for _, mapping := range m { + referenced[mapping.BuildId] = struct{}{} + } + header.BuildFiles = make(map[uuid.UUID]BuildFileInfo, len(referenced)) + for id := range referenced { + if info, ok := originalHeader.BuildFiles[id]; ok { + header.BuildFiles[id] = info + } + } + err = ValidateMappings(header.Mapping, header.Metadata.Size, header.Metadata.BlockSize) if err != nil { if header.IsNormalizeFixApplied() { diff --git a/packages/shared/pkg/storage/header/serialization.go b/packages/shared/pkg/storage/header/serialization.go index 6af71f832b..1e1c28f516 100644 --- a/packages/shared/pkg/storage/header/serialization.go +++ b/packages/shared/pkg/storage/header/serialization.go @@ -1,102 +1,83 @@ package header import ( - "bytes" "context" - "encoding/binary" - "errors" "fmt" - "io" - - "github.com/google/uuid" "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) -const metadataVersion = 3 - -type Metadata struct { - Version uint64 - BlockSize uint64 - Size uint64 - Generation uint64 - BuildId uuid.UUID - // TODO: Use the base build id when setting up the snapshot rootfs - BaseBuildId uuid.UUID -} - -func NewTemplateMetadata(buildId uuid.UUID, blockSize, size uint64) *Metadata { - return &Metadata{ - Version: metadataVersion, - Generation: 0, - BlockSize: blockSize, - Size: size, - BuildId: buildId, - BaseBuildId: buildId, +// SerializeHeader serializes a header, dispatching to the version-specific format. +// +// V3 (Version <= 3): [Metadata] [v3 mappings…] +// V4 (Version >= 4): [Metadata] [uint32 uncompressedSize] [LZ4( BuildFiles + v4 mappings + FrameTables )] +func SerializeHeader(h *Header) ([]byte, error) { + if h.Metadata.Version <= 3 { + return serializeV3(h.Metadata, h.Mapping) } -} -func (m *Metadata) NextGeneration(buildID uuid.UUID) *Metadata { - return &Metadata{ - Version: m.Version, - Generation: m.Generation + 1, - BlockSize: m.BlockSize, - Size: m.Size, - BuildId: buildID, - BaseBuildId: m.BaseBuildId, - } + return serializeV4(h.Metadata, h.BuildFiles, h.Mapping) } -func Serialize(metadata *Metadata, mappings []*BuildMap) ([]byte, error) { - var buf bytes.Buffer +// DeserializeBytes auto-detects the header version and deserializes accordingly. +// See SerializeHeader for the binary layout. +func DeserializeBytes(data []byte) (*Header, error) { + if len(data) < metadataSize { + return nil, fmt.Errorf("header too short: %d bytes", len(data)) + } - err := binary.Write(&buf, binary.LittleEndian, metadata) + metadata, err := deserializeMetadata(data[:metadataSize]) if err != nil { - return nil, fmt.Errorf("failed to write metadata: %w", err) + return nil, err } - for _, mapping := range mappings { - err := binary.Write(&buf, binary.LittleEndian, mapping) - if err != nil { - return nil, fmt.Errorf("failed to write block mapping: %w", err) - } + blockData := data[metadataSize:] + + if metadata.Version >= 4 { + return deserializeV4(metadata, blockData) } - return buf.Bytes(), nil + return deserializeV3(metadata, blockData) } -func Deserialize(ctx context.Context, in storage.Blob) (*Header, error) { - data, err := storage.GetBlob(ctx, in) +// LoadHeader fetches a serialized header from storage and deserializes it. +// Errors (including storage.ErrObjectNotExist) are returned as-is. +func LoadHeader(ctx context.Context, s storage.StorageProvider, path string) (*Header, error) { + blob, err := s.OpenBlob(ctx, path, storage.MetadataObjectType) if err != nil { - return nil, fmt.Errorf("failed to write to buffer: %w", err) + return nil, fmt.Errorf("open blob %s: %w", path, err) + } + + data, err := storage.GetBlob(ctx, blob) + if err != nil { + return nil, err } return DeserializeBytes(data) } -func DeserializeBytes(data []byte) (*Header, error) { - reader := bytes.NewReader(data) - var metadata Metadata - err := binary.Read(reader, binary.LittleEndian, &metadata) +// StoreHeader serializes a header and uploads it to storage. +// Inverse of LoadHeader. +func StoreHeader(ctx context.Context, s storage.StorageProvider, path string, h *Header) ([]byte, error) { + data, err := SerializeHeader(h) if err != nil { - return nil, fmt.Errorf("failed to read metadata: %w", err) + return nil, fmt.Errorf("serialize header: %w", err) } - mappings := make([]*BuildMap, 0) - - for { - var m BuildMap - err := binary.Read(reader, binary.LittleEndian, &m) - if errors.Is(err, io.EOF) { - break - } + blob, err := s.OpenBlob(ctx, path, storage.MetadataObjectType) + if err != nil { + return nil, fmt.Errorf("open blob %s: %w", path, err) + } - if err != nil { - return nil, fmt.Errorf("failed to read block mapping: %w", err) - } + return data, blob.Put(ctx, data) +} - mappings = append(mappings, &m) +// Deserialize reads a header from a storage Blob (legacy API). +func Deserialize(ctx context.Context, in storage.Blob) (*Header, error) { + data, err := storage.GetBlob(ctx, in) + if err != nil { + return nil, fmt.Errorf("failed to write to buffer: %w", err) } - return NewHeader(&metadata, mappings) + return DeserializeBytes(data) } diff --git a/packages/shared/pkg/storage/header/serialization_test.go b/packages/shared/pkg/storage/header/serialization_test.go new file mode 100644 index 0000000000..cfc01d9c1a --- /dev/null +++ b/packages/shared/pkg/storage/header/serialization_test.go @@ -0,0 +1,397 @@ +package header + +import ( + "crypto/sha256" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage" +) + +// newFT creates a FrameTable for test fixtures. +func newFT(ct storage.CompressionType, startAt storage.FrameOffset, frames []storage.FrameSize) *storage.FrameTable { + ft := storage.NewFrameTable(ct) + ft.StartAt = startAt + ft.Frames = frames + + return ft +} + +func TestSerializeDeserialize_V3_RoundTrip(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + baseID := uuid.New() + metadata := &Metadata{ + Version: 3, + BlockSize: 4096, + Size: 8192, + Generation: 7, + BuildId: buildID, + BaseBuildId: baseID, + } + + mappings := []*BuildMap{ + { + Offset: 0, + Length: 4096, + BuildId: buildID, + BuildStorageOffset: 0, + }, + { + Offset: 4096, + Length: 4096, + BuildId: baseID, + BuildStorageOffset: 123, + }, + } + + data, err := serializeV3(metadata, mappings) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + require.Equal(t, metadata, got.Metadata) + require.Len(t, got.Mapping, 2) + require.Equal(t, uint64(0), got.Mapping[0].Offset) + require.Equal(t, uint64(4096), got.Mapping[0].Length) + require.Equal(t, buildID, got.Mapping[0].BuildId) + require.Equal(t, uint64(0), got.Mapping[0].BuildStorageOffset) + + require.Equal(t, uint64(4096), got.Mapping[1].Offset) + require.Equal(t, uint64(4096), got.Mapping[1].Length) + require.Equal(t, baseID, got.Mapping[1].BuildId) + require.Equal(t, uint64(123), got.Mapping[1].BuildStorageOffset) + + // V3 headers have no BuildFiles + require.Nil(t, got.BuildFiles) +} + +func TestDeserialize_TruncatedMetadata(t *testing.T) { + t.Parallel() + + _, err := DeserializeBytes([]byte{0x01, 0x02, 0x03}) + require.Error(t, err) + require.Contains(t, err.Error(), "header too short") +} + +func TestSerializeDeserialize_EmptyMappings_Defaults(t *testing.T) { + t.Parallel() + + metadata := &Metadata{ + Version: 3, + BlockSize: 4096, + Size: 8192, + Generation: 0, + BuildId: uuid.New(), + BaseBuildId: uuid.New(), + } + + data, err := serializeV3(metadata, nil) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + // NewHeader creates a default mapping when none provided + require.Len(t, got.Mapping, 1) + require.Equal(t, uint64(0), got.Mapping[0].Offset) + require.Equal(t, metadata.Size, got.Mapping[0].Length) + require.Equal(t, metadata.BuildId, got.Mapping[0].BuildId) +} + +func TestDeserialize_BlockSizeZero(t *testing.T) { + t.Parallel() + + metadata := &Metadata{ + Version: 3, + BlockSize: 0, + Size: 4096, + Generation: 0, + BuildId: uuid.New(), + BaseBuildId: uuid.New(), + } + + data, err := serializeV3(metadata, nil) + require.NoError(t, err) + + _, err = DeserializeBytes(data) + require.Error(t, err) + require.Contains(t, err.Error(), "block size cannot be zero") +} + +func TestSerializeDeserialize_V4_WithFrameTable(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + baseID := uuid.New() + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 8192, + Generation: 1, + BuildId: buildID, + BaseBuildId: baseID, + } + + mappings := []*BuildMap{ + { + Offset: 0, + Length: 4096, + BuildId: buildID, + BuildStorageOffset: 0, + FrameTable: newFT(storage.CompressionLZ4, storage.FrameOffset{U: 0, C: 0}, []storage.FrameSize{ + {U: 2048, C: 1024}, + {U: 2048, C: 900}, + }), + }, + { + Offset: 4096, + Length: 4096, + BuildId: baseID, + BuildStorageOffset: 0, + }, + } + + checksum := sha256.Sum256([]byte("test-data")) + buildFiles := map[uuid.UUID]BuildFileInfo{ + buildID: {Size: 12345, Checksum: checksum}, + baseID: {Size: 67890}, + } + + h, err := NewHeader(metadata, mappings) + require.NoError(t, err) + h.BuildFiles = buildFiles + + // Test with Serialize + Deserialize (unified path) + data, err := SerializeHeader(h) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + require.Equal(t, uint64(4), got.Metadata.Version) + require.Len(t, got.Mapping, 2) + + // First mapping has FrameTable + m0 := got.Mapping[0] + require.Equal(t, uint64(0), m0.Offset) + require.Equal(t, uint64(4096), m0.Length) + require.Equal(t, buildID, m0.BuildId) + require.NotNil(t, m0.FrameTable) + require.Equal(t, storage.CompressionLZ4, m0.FrameTable.CompressionType()) + require.Equal(t, int64(0), m0.FrameTable.StartAt.U) + require.Equal(t, int64(0), m0.FrameTable.StartAt.C) + require.Len(t, m0.FrameTable.Frames, 2) + require.Equal(t, int32(2048), m0.FrameTable.Frames[0].U) + require.Equal(t, int32(1024), m0.FrameTable.Frames[0].C) + require.Equal(t, int32(2048), m0.FrameTable.Frames[1].U) + require.Equal(t, int32(900), m0.FrameTable.Frames[1].C) + + // Second mapping has no FrameTable + m1 := got.Mapping[1] + require.Equal(t, uint64(4096), m1.Offset) + require.Equal(t, uint64(4096), m1.Length) + require.Equal(t, baseID, m1.BuildId) + require.Nil(t, m1.FrameTable) + + // BuildFiles round-trip + require.Len(t, got.BuildFiles, 2) + require.Equal(t, int64(12345), got.BuildFiles[buildID].Size) + require.Equal(t, checksum, got.BuildFiles[buildID].Checksum) + require.Equal(t, int64(67890), got.BuildFiles[baseID].Size) + require.Equal(t, [32]byte{}, got.BuildFiles[baseID].Checksum) +} + +func TestSerializeDeserialize_V4_Zstd_NonZeroStartAt(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 4096, + Generation: 0, + BuildId: buildID, + BaseBuildId: buildID, + } + + mappings := []*BuildMap{ + { + Offset: 0, + Length: 4096, + BuildId: buildID, + BuildStorageOffset: 8192, + FrameTable: newFT(storage.CompressionZstd, storage.FrameOffset{U: 8192, C: 4000}, []storage.FrameSize{ + {U: 4096, C: 3500}, + }), + }, + } + + h, err := NewHeader(metadata, mappings) + require.NoError(t, err) + + // Test with Serialize + Deserialize (unified path) + data, err := SerializeHeader(h) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + require.Len(t, got.Mapping, 1) + m := got.Mapping[0] + require.NotNil(t, m.FrameTable) + require.Equal(t, storage.CompressionZstd, m.FrameTable.CompressionType()) + require.Equal(t, int64(8192), m.FrameTable.StartAt.U) + require.Equal(t, int64(4000), m.FrameTable.StartAt.C) + require.Len(t, m.FrameTable.Frames, 1) + require.Equal(t, int32(4096), m.FrameTable.Frames[0].U) + require.Equal(t, int32(3500), m.FrameTable.Frames[0].C) + + // No BuildFiles set + require.Nil(t, got.BuildFiles) +} + +// TestSerializeDeserialize_V4_CompressionNone_EmptyFrames verifies that a +// FrameTable with CompressionNone and zero frames does not corrupt the stream. +// Before the fix, the serializer wrote a StartAt offset (16 bytes) but the +// deserializer skipped it because the packed value was 0. +func TestSerializeDeserialize_V4_CompressionNone_EmptyFrames(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + baseID := uuid.New() + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 8192, + Generation: 0, + BuildId: buildID, + BaseBuildId: buildID, + } + + mappings := []*BuildMap{ + { + Offset: 0, + Length: 4096, + BuildId: buildID, + BuildStorageOffset: 0, + // FrameTable with CompressionNone and no frames — packed value is 0. + FrameTable: newFT(storage.CompressionNone, storage.FrameOffset{U: 100, C: 50}, nil), + }, + { + Offset: 4096, + Length: 4096, + BuildId: baseID, + BuildStorageOffset: 0, + }, + } + + h, err := NewHeader(metadata, mappings) + require.NoError(t, err) + + // Test with Serialize + Deserialize (unified path) + data, err := SerializeHeader(h) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + require.Len(t, got.Mapping, 2) + + // First mapping: FrameTable was effectively empty, deserializer should treat as nil. + require.Nil(t, got.Mapping[0].FrameTable) + + // Second mapping must not be corrupted by stray StartAt bytes. + require.Equal(t, uint64(4096), got.Mapping[1].Offset) + require.Equal(t, uint64(4096), got.Mapping[1].Length) + require.Equal(t, baseID, got.Mapping[1].BuildId) +} + +func TestSerializeDeserialize_V4_ManyFrames(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + const numFrames = 1000 + frames := make([]storage.FrameSize, numFrames) + for i := range frames { + frames[i] = storage.FrameSize{U: 4096, C: int32(2000 + i)} + } + + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 4096 * numFrames, + Generation: 0, + BuildId: buildID, + BaseBuildId: buildID, + } + + mappings := []*BuildMap{ + { + Offset: 0, + Length: 4096 * numFrames, + BuildId: buildID, + BuildStorageOffset: 0, + FrameTable: newFT(storage.CompressionLZ4, storage.FrameOffset{U: 0, C: 0}, frames), + }, + } + + h, err := NewHeader(metadata, mappings) + require.NoError(t, err) + + // Test with Serialize + Deserialize (unified path) + data, err := SerializeHeader(h) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + require.Len(t, got.Mapping, 1) + require.NotNil(t, got.Mapping[0].FrameTable) + require.Len(t, got.Mapping[0].FrameTable.Frames, numFrames) + + // Spot-check first and last frame + require.Equal(t, int32(4096), got.Mapping[0].FrameTable.Frames[0].U) + require.Equal(t, int32(2000), got.Mapping[0].FrameTable.Frames[0].C) + require.Equal(t, int32(4096), got.Mapping[0].FrameTable.Frames[numFrames-1].U) + require.Equal(t, int32(2000+numFrames-1), got.Mapping[0].FrameTable.Frames[numFrames-1].C) +} + +func TestSerializeDeserialize_V4_EmptyBuildFiles(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 4096, + Generation: 0, + BuildId: buildID, + BaseBuildId: buildID, + } + + mappings := []*BuildMap{ + { + Offset: 0, + Length: 4096, + BuildId: buildID, + }, + } + + h, err := NewHeader(metadata, mappings) + require.NoError(t, err) + // No BuildFiles set (nil map) + + data, err := SerializeHeader(h) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + require.Len(t, got.Mapping, 1) + require.Nil(t, got.BuildFiles) // numBuilds=0 → nil +} diff --git a/packages/shared/pkg/storage/header/serialization_v3.go b/packages/shared/pkg/storage/header/serialization_v3.go new file mode 100644 index 0000000000..2f150dbb86 --- /dev/null +++ b/packages/shared/pkg/storage/header/serialization_v3.go @@ -0,0 +1,65 @@ +package header + +import ( + "bytes" + "encoding/binary" + "errors" + "fmt" + "io" +) + +type v3SerializableBuildMap struct { + Offset uint64 + Length uint64 + BuildId [16]byte // uuid.UUID + BuildStorageOffset uint64 +} + +// serializeV3 writes [Metadata] [v3 mappings…] with no length prefix. +func serializeV3(metadata *Metadata, mappings []*BuildMap) ([]byte, error) { + var buf bytes.Buffer + + if err := binary.Write(&buf, binary.LittleEndian, metadata); err != nil { + return nil, fmt.Errorf("failed to write metadata: %w", err) + } + + for _, mapping := range mappings { + v3 := &v3SerializableBuildMap{ + Offset: mapping.Offset, + Length: mapping.Length, + BuildId: mapping.BuildId, + BuildStorageOffset: mapping.BuildStorageOffset, + } + if err := binary.Write(&buf, binary.LittleEndian, v3); err != nil { + return nil, fmt.Errorf("failed to write block mapping: %w", err) + } + } + + return buf.Bytes(), nil +} + +// deserializeV3 reads V3 mappings (read until EOF, no count prefix). +func deserializeV3(metadata *Metadata, blockData []byte) (*Header, error) { + reader := bytes.NewReader(blockData) + var mappings []*BuildMap + + for { + var v3 v3SerializableBuildMap + err := binary.Read(reader, binary.LittleEndian, &v3) + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, fmt.Errorf("failed to read block mapping: %w", err) + } + + mappings = append(mappings, &BuildMap{ + Offset: v3.Offset, + Length: v3.Length, + BuildId: v3.BuildId, + BuildStorageOffset: v3.BuildStorageOffset, + }) + } + + return NewHeader(metadata, mappings) +} diff --git a/packages/shared/pkg/storage/header/serialization_v4.go b/packages/shared/pkg/storage/header/serialization_v4.go new file mode 100644 index 0000000000..21fdcd45fc --- /dev/null +++ b/packages/shared/pkg/storage/header/serialization_v4.go @@ -0,0 +1,248 @@ +package header + +import ( + "bytes" + "cmp" + "encoding/binary" + "fmt" + "io" + "slices" + + "github.com/google/uuid" + lz4 "github.com/pierrec/lz4/v4" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage" +) + +type v4SerializableBuildMap struct { + Offset uint64 + Length uint64 + BuildId [16]byte // uuid.UUID + BuildStorageOffset uint64 + CompressionType uint32 + NumFrames uint32 + + // if CompressionType != CompressionNone and NumFrames > 0: + // - followed by FrameOffset (16 bytes) + // - followed by FrameSize × NumFrames (8 bytes each) +} + +// v4SerializableBuildFileInfo is the on-disk format for a BuildFileInfo entry. +type v4SerializableBuildFileInfo struct { + BuildId uuid.UUID + Size int64 + Checksum [32]byte +} + +// serializeV4 writes [Metadata] [uint32 uncompressedSize] [LZ4( BuildFiles + counted mappings + FrameTables )]. +func serializeV4(metadata *Metadata, buildFiles map[uuid.UUID]BuildFileInfo, mappings []*BuildMap) ([]byte, error) { + // --- raw metadata prefix (not compressed) --- + var metaBuf bytes.Buffer + if err := binary.Write(&metaBuf, binary.LittleEndian, metadata); err != nil { + return nil, fmt.Errorf("failed to write metadata: %w", err) + } + + // --- compressed block: build-info + mappings + frame tables --- + var block bytes.Buffer + + // Build-info section. + if err := binary.Write(&block, binary.LittleEndian, uint32(len(buildFiles))); err != nil { + return nil, fmt.Errorf("failed to write build files count: %w", err) + } + + // Sort by UUID for deterministic serialization. + buildIDs := make([]uuid.UUID, 0, len(buildFiles)) + for id := range buildFiles { + buildIDs = append(buildIDs, id) + } + slices.SortFunc(buildIDs, func(a, b uuid.UUID) int { + return cmp.Compare(a.String(), b.String()) + }) + + for _, id := range buildIDs { + info := buildFiles[id] + entry := v4SerializableBuildFileInfo{ + BuildId: id, + Size: info.Size, + Checksum: info.Checksum, + } + if err := binary.Write(&block, binary.LittleEndian, &entry); err != nil { + return nil, fmt.Errorf("failed to write build file info: %w", err) + } + } + + // Counted mappings with inline FrameTables. + if err := binary.Write(&block, binary.LittleEndian, uint32(len(mappings))); err != nil { + return nil, fmt.Errorf("failed to write mappings count: %w", err) + } + + for _, mapping := range mappings { + v4 := &v4SerializableBuildMap{ + Offset: mapping.Offset, + Length: mapping.Length, + BuildId: mapping.BuildId, + BuildStorageOffset: mapping.BuildStorageOffset, + } + + var offset *storage.FrameOffset + var frames []storage.FrameSize + if mapping.FrameTable != nil { + v4.CompressionType = uint32(mapping.FrameTable.CompressionType()) + v4.NumFrames = uint32(len(mapping.FrameTable.Frames)) + if v4.CompressionType != 0 && v4.NumFrames > 0 { + offset = &mapping.FrameTable.StartAt + frames = mapping.FrameTable.Frames + } + } + + if err := binary.Write(&block, binary.LittleEndian, v4); err != nil { + return nil, fmt.Errorf("failed to write block mapping: %w", err) + } + if offset != nil { + if err := binary.Write(&block, binary.LittleEndian, offset); err != nil { + return nil, fmt.Errorf("failed to write compression frames starting offset: %w", err) + } + } + for _, frame := range frames { + if err := binary.Write(&block, binary.LittleEndian, frame); err != nil { + return nil, fmt.Errorf("failed to write compression frame: %w", err) + } + } + } + + // LZ4-compress the block and assemble: [metadata] [uint32 size] [compressed block]. + blockBytes := block.Bytes() + compressed, err := compressLZ4(blockBytes) + if err != nil { + return nil, fmt.Errorf("failed to LZ4-compress v4 header block: %w", err) + } + + result := make([]byte, metadataSize+4+len(compressed)) + copy(result, metaBuf.Bytes()) + binary.LittleEndian.PutUint32(result[metadataSize:], uint32(len(blockBytes))) + copy(result[metadataSize+4:], compressed) + + return result, nil +} + +// deserializeV4 decompresses and reads the V4 block: build-info + counted mappings + FrameTables. +func deserializeV4(metadata *Metadata, blockData []byte) (*Header, error) { + if len(blockData) < 4 { + return nil, fmt.Errorf("v4 header block too short for size prefix: %d bytes", len(blockData)) + } + + decompressed, err := decompressLZ4(blockData[4:]) + if err != nil { + return nil, fmt.Errorf("failed to LZ4-decompress v4 header block: %w", err) + } + + reader := bytes.NewReader(decompressed) + + // Build-info section. + var numBuilds uint32 + if err := binary.Read(reader, binary.LittleEndian, &numBuilds); err != nil { + return nil, fmt.Errorf("failed to read build files count: %w", err) + } + + var buildFiles map[uuid.UUID]BuildFileInfo + if numBuilds > 0 { + buildFiles = make(map[uuid.UUID]BuildFileInfo, numBuilds) + for range numBuilds { + var entry v4SerializableBuildFileInfo + if err := binary.Read(reader, binary.LittleEndian, &entry); err != nil { + return nil, fmt.Errorf("failed to read build file info: %w", err) + } + buildFiles[entry.BuildId] = BuildFileInfo{ + Size: entry.Size, + Checksum: entry.Checksum, + } + } + } + + // Counted mappings with inline FrameTables. + var numMappings uint32 + if err := binary.Read(reader, binary.LittleEndian, &numMappings); err != nil { + return nil, fmt.Errorf("failed to read mappings count: %w", err) + } + + mappings := make([]*BuildMap, 0, numMappings) + for range numMappings { + var v4 v4SerializableBuildMap + if err := binary.Read(reader, binary.LittleEndian, &v4); err != nil { + return nil, fmt.Errorf("failed to read block mapping: %w", err) + } + + m := &BuildMap{ + Offset: v4.Offset, + Length: v4.Length, + BuildId: v4.BuildId, + BuildStorageOffset: v4.BuildStorageOffset, + } + + if v4.CompressionType != 0 && v4.NumFrames > 0 { + m.FrameTable = storage.NewFrameTable(storage.CompressionType(v4.CompressionType)) + numFrames := v4.NumFrames + + var startAt storage.FrameOffset + if err := binary.Read(reader, binary.LittleEndian, &startAt); err != nil { + return nil, fmt.Errorf("failed to read compression frames starting offset: %w", err) + } + m.FrameTable.StartAt = startAt + + for range numFrames { + var frame storage.FrameSize + if err := binary.Read(reader, binary.LittleEndian, &frame); err != nil { + return nil, fmt.Errorf("failed to read the expected compression frame: %w", err) + } + m.FrameTable.Frames = append(m.FrameTable.Frames, frame) + } + } + + mappings = append(mappings, m) + } + + h, err := NewHeader(metadata, mappings) + if err != nil { + return nil, err + } + h.BuildFiles = buildFiles + + return h, nil +} + +// compressLZ4 compresses data for V4 header serialization using the LZ4 +// streaming API. Settings are fixed for the V4 wire format. +func compressLZ4(data []byte) ([]byte, error) { + var buf bytes.Buffer + buf.Grow(len(data)) + + w := lz4.NewWriter(&buf) + w.Apply( + lz4.BlockSizeOption(lz4.Block4Mb), + lz4.BlockChecksumOption(true), + lz4.ChecksumOption(true), + lz4.CompressionLevelOption(lz4.Fast), + ) + + if _, err := w.Write(data); err != nil { + return nil, fmt.Errorf("lz4 compress: %w", err) + } + + if err := w.Close(); err != nil { + return nil, fmt.Errorf("lz4 compress close: %w", err) + } + + return buf.Bytes(), nil +} + +// decompressLZ4 decompresses an LZ4 frame from V4 header data. +func decompressLZ4(src []byte) ([]byte, error) { + r := lz4.NewReader(bytes.NewReader(src)) + + data, err := io.ReadAll(r) + if err != nil { + return nil, fmt.Errorf("lz4 decompress: %w", err) + } + + return data, nil +} diff --git a/packages/shared/pkg/storage/mocks/mockobjectprovider.go b/packages/shared/pkg/storage/mock_blob.go similarity index 99% rename from packages/shared/pkg/storage/mocks/mockobjectprovider.go rename to packages/shared/pkg/storage/mock_blob.go index 6955ab4312..d65768339f 100644 --- a/packages/shared/pkg/storage/mocks/mockobjectprovider.go +++ b/packages/shared/pkg/storage/mock_blob.go @@ -2,7 +2,7 @@ // github.com/vektra/mockery // template: testify -package storagemocks +package storage import ( "context" diff --git a/packages/shared/pkg/storage/mocks/mockfeatureflagsclient.go b/packages/shared/pkg/storage/mock_featureflagsclient.go similarity index 99% rename from packages/shared/pkg/storage/mocks/mockfeatureflagsclient.go rename to packages/shared/pkg/storage/mock_featureflagsclient.go index d83936eddd..53dd4c5b29 100644 --- a/packages/shared/pkg/storage/mocks/mockfeatureflagsclient.go +++ b/packages/shared/pkg/storage/mock_featureflagsclient.go @@ -2,7 +2,7 @@ // github.com/vektra/mockery // template: testify -package storagemocks +package storage import ( "context" diff --git a/packages/shared/pkg/storage/mocks/mockioreader.go b/packages/shared/pkg/storage/mock_ioreader.go similarity index 99% rename from packages/shared/pkg/storage/mocks/mockioreader.go rename to packages/shared/pkg/storage/mock_ioreader.go index 5497bc53c5..9adb02421e 100644 --- a/packages/shared/pkg/storage/mocks/mockioreader.go +++ b/packages/shared/pkg/storage/mock_ioreader.go @@ -2,7 +2,7 @@ // github.com/vektra/mockery // template: testify -package storagemocks +package storage import ( mock "github.com/stretchr/testify/mock" diff --git a/packages/shared/pkg/storage/mocks/mockseekableobjectprovider.go b/packages/shared/pkg/storage/mock_seekable.go similarity index 63% rename from packages/shared/pkg/storage/mocks/mockseekableobjectprovider.go rename to packages/shared/pkg/storage/mock_seekable.go index 3931f6b349..77a199c456 100644 --- a/packages/shared/pkg/storage/mocks/mockseekableobjectprovider.go +++ b/packages/shared/pkg/storage/mock_seekable.go @@ -2,7 +2,7 @@ // github.com/vektra/mockery // template: testify -package storagemocks +package storage import ( "context" @@ -39,8 +39,8 @@ func (_m *MockSeekable) EXPECT() *MockSeekable_Expecter { } // OpenRangeReader provides a mock function for the type MockSeekable -func (_mock *MockSeekable) OpenRangeReader(ctx context.Context, off int64, length int64) (io.ReadCloser, error) { - ret := _mock.Called(ctx, off, length) +func (_mock *MockSeekable) OpenRangeReader(ctx context.Context, offsetU int64, length int64, frameTable *FrameTable) (io.ReadCloser, error) { + ret := _mock.Called(ctx, offsetU, length, frameTable) if len(ret) == 0 { panic("no return value specified for OpenRangeReader") @@ -48,18 +48,18 @@ func (_mock *MockSeekable) OpenRangeReader(ctx context.Context, off int64, lengt var r0 io.ReadCloser var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, int64, int64) (io.ReadCloser, error)); ok { - return returnFunc(ctx, off, length) + if returnFunc, ok := ret.Get(0).(func(context.Context, int64, int64, *FrameTable) (io.ReadCloser, error)); ok { + return returnFunc(ctx, offsetU, length, frameTable) } - if returnFunc, ok := ret.Get(0).(func(context.Context, int64, int64) io.ReadCloser); ok { - r0 = returnFunc(ctx, off, length) + if returnFunc, ok := ret.Get(0).(func(context.Context, int64, int64, *FrameTable) io.ReadCloser); ok { + r0 = returnFunc(ctx, offsetU, length, frameTable) } else { if ret.Get(0) != nil { r0 = ret.Get(0).(io.ReadCloser) } } - if returnFunc, ok := ret.Get(1).(func(context.Context, int64, int64) error); ok { - r1 = returnFunc(ctx, off, length) + if returnFunc, ok := ret.Get(1).(func(context.Context, int64, int64, *FrameTable) error); ok { + r1 = returnFunc(ctx, offsetU, length, frameTable) } else { r1 = ret.Error(1) } @@ -73,13 +73,14 @@ type MockSeekable_OpenRangeReader_Call struct { // OpenRangeReader is a helper method to define mock.On call // - ctx context.Context -// - off int64 +// - offsetU int64 // - length int64 -func (_e *MockSeekable_Expecter) OpenRangeReader(ctx interface{}, off interface{}, length interface{}) *MockSeekable_OpenRangeReader_Call { - return &MockSeekable_OpenRangeReader_Call{Call: _e.mock.On("OpenRangeReader", ctx, off, length)} +// - frameTable *FrameTable +func (_e *MockSeekable_Expecter) OpenRangeReader(ctx interface{}, offsetU interface{}, length interface{}, frameTable interface{}) *MockSeekable_OpenRangeReader_Call { + return &MockSeekable_OpenRangeReader_Call{Call: _e.mock.On("OpenRangeReader", ctx, offsetU, length, frameTable)} } -func (_c *MockSeekable_OpenRangeReader_Call) Run(run func(ctx context.Context, off int64, length int64)) *MockSeekable_OpenRangeReader_Call { +func (_c *MockSeekable_OpenRangeReader_Call) Run(run func(ctx context.Context, offsetU int64, length int64, frameTable *FrameTable)) *MockSeekable_OpenRangeReader_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -93,10 +94,15 @@ func (_c *MockSeekable_OpenRangeReader_Call) Run(run func(ctx context.Context, o if args[2] != nil { arg2 = args[2].(int64) } + var arg3 *FrameTable + if args[3] != nil { + arg3 = args[3].(*FrameTable) + } run( arg0, arg1, arg2, + arg3, ) }) return _c @@ -107,79 +113,7 @@ func (_c *MockSeekable_OpenRangeReader_Call) Return(readCloser io.ReadCloser, er return _c } -func (_c *MockSeekable_OpenRangeReader_Call) RunAndReturn(run func(ctx context.Context, off int64, length int64) (io.ReadCloser, error)) *MockSeekable_OpenRangeReader_Call { - _c.Call.Return(run) - return _c -} - -// ReadAt provides a mock function for the type MockSeekable -func (_mock *MockSeekable) ReadAt(ctx context.Context, buffer []byte, off int64) (int, error) { - ret := _mock.Called(ctx, buffer, off) - - if len(ret) == 0 { - panic("no return value specified for ReadAt") - } - - var r0 int - var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, []byte, int64) (int, error)); ok { - return returnFunc(ctx, buffer, off) - } - if returnFunc, ok := ret.Get(0).(func(context.Context, []byte, int64) int); ok { - r0 = returnFunc(ctx, buffer, off) - } else { - r0 = ret.Get(0).(int) - } - if returnFunc, ok := ret.Get(1).(func(context.Context, []byte, int64) error); ok { - r1 = returnFunc(ctx, buffer, off) - } else { - r1 = ret.Error(1) - } - return r0, r1 -} - -// MockSeekable_ReadAt_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'ReadAt' -type MockSeekable_ReadAt_Call struct { - *mock.Call -} - -// ReadAt is a helper method to define mock.On call -// - ctx context.Context -// - buffer []byte -// - off int64 -func (_e *MockSeekable_Expecter) ReadAt(ctx interface{}, buffer interface{}, off interface{}) *MockSeekable_ReadAt_Call { - return &MockSeekable_ReadAt_Call{Call: _e.mock.On("ReadAt", ctx, buffer, off)} -} - -func (_c *MockSeekable_ReadAt_Call) Run(run func(ctx context.Context, buffer []byte, off int64)) *MockSeekable_ReadAt_Call { - _c.Call.Run(func(args mock.Arguments) { - var arg0 context.Context - if args[0] != nil { - arg0 = args[0].(context.Context) - } - var arg1 []byte - if args[1] != nil { - arg1 = args[1].([]byte) - } - var arg2 int64 - if args[2] != nil { - arg2 = args[2].(int64) - } - run( - arg0, - arg1, - arg2, - ) - }) - return _c -} - -func (_c *MockSeekable_ReadAt_Call) Return(n int, err error) *MockSeekable_ReadAt_Call { - _c.Call.Return(n, err) - return _c -} - -func (_c *MockSeekable_ReadAt_Call) RunAndReturn(run func(ctx context.Context, buffer []byte, off int64) (int, error)) *MockSeekable_ReadAt_Call { +func (_c *MockSeekable_OpenRangeReader_Call) RunAndReturn(run func(ctx context.Context, offsetU int64, length int64, frameTable *FrameTable) (io.ReadCloser, error)) *MockSeekable_OpenRangeReader_Call { _c.Call.Return(run) return _c } @@ -245,20 +179,39 @@ func (_c *MockSeekable_Size_Call) RunAndReturn(run func(ctx context.Context) (in } // StoreFile provides a mock function for the type MockSeekable -func (_mock *MockSeekable) StoreFile(ctx context.Context, path string) error { - ret := _mock.Called(ctx, path) +func (_mock *MockSeekable) StoreFile(ctx context.Context, path string, cfg *CompressConfig) (*FrameTable, [32]byte, error) { + ret := _mock.Called(ctx, path, cfg) if len(ret) == 0 { panic("no return value specified for StoreFile") } - var r0 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string) error); ok { - r0 = returnFunc(ctx, path) + var r0 *FrameTable + var r1 [32]byte + var r2 error + if returnFunc, ok := ret.Get(0).(func(context.Context, string, *CompressConfig) (*FrameTable, [32]byte, error)); ok { + return returnFunc(ctx, path, cfg) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, string, *CompressConfig) *FrameTable); ok { + r0 = returnFunc(ctx, path, cfg) } else { - r0 = ret.Error(0) + if ret.Get(0) != nil { + r0 = ret.Get(0).(*FrameTable) + } } - return r0 + if returnFunc, ok := ret.Get(1).(func(context.Context, string, *CompressConfig) [32]byte); ok { + r1 = returnFunc(ctx, path, cfg) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).([32]byte) + } + } + if returnFunc, ok := ret.Get(2).(func(context.Context, string, *CompressConfig) error); ok { + r2 = returnFunc(ctx, path, cfg) + } else { + r2 = ret.Error(2) + } + return r0, r1, r2 } // MockSeekable_StoreFile_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'StoreFile' @@ -269,11 +222,12 @@ type MockSeekable_StoreFile_Call struct { // StoreFile is a helper method to define mock.On call // - ctx context.Context // - path string -func (_e *MockSeekable_Expecter) StoreFile(ctx interface{}, path interface{}) *MockSeekable_StoreFile_Call { - return &MockSeekable_StoreFile_Call{Call: _e.mock.On("StoreFile", ctx, path)} +// - cfg *CompressConfig +func (_e *MockSeekable_Expecter) StoreFile(ctx interface{}, path interface{}, cfg interface{}) *MockSeekable_StoreFile_Call { + return &MockSeekable_StoreFile_Call{Call: _e.mock.On("StoreFile", ctx, path, cfg)} } -func (_c *MockSeekable_StoreFile_Call) Run(run func(ctx context.Context, path string)) *MockSeekable_StoreFile_Call { +func (_c *MockSeekable_StoreFile_Call) Run(run func(ctx context.Context, path string, cfg *CompressConfig)) *MockSeekable_StoreFile_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -283,20 +237,25 @@ func (_c *MockSeekable_StoreFile_Call) Run(run func(ctx context.Context, path st if args[1] != nil { arg1 = args[1].(string) } + var arg2 *CompressConfig + if args[2] != nil { + arg2 = args[2].(*CompressConfig) + } run( arg0, arg1, + arg2, ) }) return _c } -func (_c *MockSeekable_StoreFile_Call) Return(err error) *MockSeekable_StoreFile_Call { - _c.Call.Return(err) +func (_c *MockSeekable_StoreFile_Call) Return(frameTable *FrameTable, bytes [32]byte, err error) *MockSeekable_StoreFile_Call { + _c.Call.Return(frameTable, bytes, err) return _c } -func (_c *MockSeekable_StoreFile_Call) RunAndReturn(run func(ctx context.Context, path string) error) *MockSeekable_StoreFile_Call { +func (_c *MockSeekable_StoreFile_Call) RunAndReturn(run func(ctx context.Context, path string, cfg *CompressConfig) (*FrameTable, [32]byte, error)) *MockSeekable_StoreFile_Call { _c.Call.Return(run) return _c } diff --git a/packages/shared/pkg/storage/mocks/provider/mockstorageprovider.go b/packages/shared/pkg/storage/mock_storageprovider.go similarity index 86% rename from packages/shared/pkg/storage/mocks/provider/mockstorageprovider.go rename to packages/shared/pkg/storage/mock_storageprovider.go index b505eb617f..4657bf0754 100644 --- a/packages/shared/pkg/storage/mocks/provider/mockstorageprovider.go +++ b/packages/shared/pkg/storage/mock_storageprovider.go @@ -2,13 +2,12 @@ // github.com/vektra/mockery // template: testify -package providermocks +package storage import ( "context" "time" - "github.com/e2b-dev/infra/packages/shared/pkg/storage" mock "github.com/stretchr/testify/mock" ) @@ -141,26 +140,26 @@ func (_c *MockStorageProvider_GetDetails_Call) RunAndReturn(run func() string) * } // OpenBlob provides a mock function for the type MockStorageProvider -func (_mock *MockStorageProvider) OpenBlob(ctx context.Context, path string, objectType storage.ObjectType) (storage.Blob, error) { +func (_mock *MockStorageProvider) OpenBlob(ctx context.Context, path string, objectType ObjectType) (Blob, error) { ret := _mock.Called(ctx, path, objectType) if len(ret) == 0 { panic("no return value specified for OpenBlob") } - var r0 storage.Blob + var r0 Blob var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, storage.ObjectType) (storage.Blob, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, ObjectType) (Blob, error)); ok { return returnFunc(ctx, path, objectType) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, storage.ObjectType) storage.Blob); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, ObjectType) Blob); ok { r0 = returnFunc(ctx, path, objectType) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(storage.Blob) + r0 = ret.Get(0).(Blob) } } - if returnFunc, ok := ret.Get(1).(func(context.Context, string, storage.ObjectType) error); ok { + if returnFunc, ok := ret.Get(1).(func(context.Context, string, ObjectType) error); ok { r1 = returnFunc(ctx, path, objectType) } else { r1 = ret.Error(1) @@ -176,12 +175,12 @@ type MockStorageProvider_OpenBlob_Call struct { // OpenBlob is a helper method to define mock.On call // - ctx context.Context // - path string -// - objectType storage.ObjectType +// - objectType ObjectType func (_e *MockStorageProvider_Expecter) OpenBlob(ctx interface{}, path interface{}, objectType interface{}) *MockStorageProvider_OpenBlob_Call { return &MockStorageProvider_OpenBlob_Call{Call: _e.mock.On("OpenBlob", ctx, path, objectType)} } -func (_c *MockStorageProvider_OpenBlob_Call) Run(run func(ctx context.Context, path string, objectType storage.ObjectType)) *MockStorageProvider_OpenBlob_Call { +func (_c *MockStorageProvider_OpenBlob_Call) Run(run func(ctx context.Context, path string, objectType ObjectType)) *MockStorageProvider_OpenBlob_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -191,9 +190,9 @@ func (_c *MockStorageProvider_OpenBlob_Call) Run(run func(ctx context.Context, p if args[1] != nil { arg1 = args[1].(string) } - var arg2 storage.ObjectType + var arg2 ObjectType if args[2] != nil { - arg2 = args[2].(storage.ObjectType) + arg2 = args[2].(ObjectType) } run( arg0, @@ -204,37 +203,37 @@ func (_c *MockStorageProvider_OpenBlob_Call) Run(run func(ctx context.Context, p return _c } -func (_c *MockStorageProvider_OpenBlob_Call) Return(blob storage.Blob, err error) *MockStorageProvider_OpenBlob_Call { +func (_c *MockStorageProvider_OpenBlob_Call) Return(blob Blob, err error) *MockStorageProvider_OpenBlob_Call { _c.Call.Return(blob, err) return _c } -func (_c *MockStorageProvider_OpenBlob_Call) RunAndReturn(run func(ctx context.Context, path string, objectType storage.ObjectType) (storage.Blob, error)) *MockStorageProvider_OpenBlob_Call { +func (_c *MockStorageProvider_OpenBlob_Call) RunAndReturn(run func(ctx context.Context, path string, objectType ObjectType) (Blob, error)) *MockStorageProvider_OpenBlob_Call { _c.Call.Return(run) return _c } // OpenSeekable provides a mock function for the type MockStorageProvider -func (_mock *MockStorageProvider) OpenSeekable(ctx context.Context, path string, seekableObjectType storage.SeekableObjectType) (storage.Seekable, error) { +func (_mock *MockStorageProvider) OpenSeekable(ctx context.Context, path string, seekableObjectType SeekableObjectType) (Seekable, error) { ret := _mock.Called(ctx, path, seekableObjectType) if len(ret) == 0 { panic("no return value specified for OpenSeekable") } - var r0 storage.Seekable + var r0 Seekable var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, string, storage.SeekableObjectType) (storage.Seekable, error)); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, SeekableObjectType) (Seekable, error)); ok { return returnFunc(ctx, path, seekableObjectType) } - if returnFunc, ok := ret.Get(0).(func(context.Context, string, storage.SeekableObjectType) storage.Seekable); ok { + if returnFunc, ok := ret.Get(0).(func(context.Context, string, SeekableObjectType) Seekable); ok { r0 = returnFunc(ctx, path, seekableObjectType) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(storage.Seekable) + r0 = ret.Get(0).(Seekable) } } - if returnFunc, ok := ret.Get(1).(func(context.Context, string, storage.SeekableObjectType) error); ok { + if returnFunc, ok := ret.Get(1).(func(context.Context, string, SeekableObjectType) error); ok { r1 = returnFunc(ctx, path, seekableObjectType) } else { r1 = ret.Error(1) @@ -250,12 +249,12 @@ type MockStorageProvider_OpenSeekable_Call struct { // OpenSeekable is a helper method to define mock.On call // - ctx context.Context // - path string -// - seekableObjectType storage.SeekableObjectType +// - seekableObjectType SeekableObjectType func (_e *MockStorageProvider_Expecter) OpenSeekable(ctx interface{}, path interface{}, seekableObjectType interface{}) *MockStorageProvider_OpenSeekable_Call { return &MockStorageProvider_OpenSeekable_Call{Call: _e.mock.On("OpenSeekable", ctx, path, seekableObjectType)} } -func (_c *MockStorageProvider_OpenSeekable_Call) Run(run func(ctx context.Context, path string, seekableObjectType storage.SeekableObjectType)) *MockStorageProvider_OpenSeekable_Call { +func (_c *MockStorageProvider_OpenSeekable_Call) Run(run func(ctx context.Context, path string, seekableObjectType SeekableObjectType)) *MockStorageProvider_OpenSeekable_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -265,9 +264,9 @@ func (_c *MockStorageProvider_OpenSeekable_Call) Run(run func(ctx context.Contex if args[1] != nil { arg1 = args[1].(string) } - var arg2 storage.SeekableObjectType + var arg2 SeekableObjectType if args[2] != nil { - arg2 = args[2].(storage.SeekableObjectType) + arg2 = args[2].(SeekableObjectType) } run( arg0, @@ -278,12 +277,12 @@ func (_c *MockStorageProvider_OpenSeekable_Call) Run(run func(ctx context.Contex return _c } -func (_c *MockStorageProvider_OpenSeekable_Call) Return(seekable storage.Seekable, err error) *MockStorageProvider_OpenSeekable_Call { +func (_c *MockStorageProvider_OpenSeekable_Call) Return(seekable Seekable, err error) *MockStorageProvider_OpenSeekable_Call { _c.Call.Return(seekable, err) return _c } -func (_c *MockStorageProvider_OpenSeekable_Call) RunAndReturn(run func(ctx context.Context, path string, seekableObjectType storage.SeekableObjectType) (storage.Seekable, error)) *MockStorageProvider_OpenSeekable_Call { +func (_c *MockStorageProvider_OpenSeekable_Call) RunAndReturn(run func(ctx context.Context, path string, seekableObjectType SeekableObjectType) (Seekable, error)) *MockStorageProvider_OpenSeekable_Call { _c.Call.Return(run) return _c } diff --git a/packages/shared/pkg/storage/paths.go b/packages/shared/pkg/storage/paths.go index 0164d73b69..f9df0ee9fb 100644 --- a/packages/shared/pkg/storage/paths.go +++ b/packages/shared/pkg/storage/paths.go @@ -53,6 +53,24 @@ func (p Paths) Metadata() string { return fmt.Sprintf("%s/%s", p.BuildID, MetadataName) } +func (p Paths) MemfileCompressed(ct CompressionType) string { + return fmt.Sprintf("%s/%s%s", p.BuildID, MemfileName, ct.Suffix()) +} + +func (p Paths) RootfsCompressed(ct CompressionType) string { + return fmt.Sprintf("%s/%s%s", p.BuildID, RootfsName, ct.Suffix()) +} + +// DataFile returns the storage path for a data file (e.g. "memfile", "rootfs.ext4"), +// with compression suffix appended if ct is not CompressionNone. +func (p Paths) DataFile(name string, ct CompressionType) string { + if ct == CompressionNone { + return fmt.Sprintf("%s/%s", p.BuildID, name) + } + + return fmt.Sprintf("%s/%s%s", p.BuildID, name, ct.Suffix()) +} + // SplitPath splits a storage path of the form "{buildID}/{fileName}" // back into its components. This is the inverse of the path methods. func SplitPath(path string) (buildID, fileName string) { @@ -60,3 +78,24 @@ func SplitPath(path string) (buildID, fileName string) { return buildID, fileName } + +// StripCompression removes a known compression suffix from a file name. +// For example: "memfile.zstd" → "memfile". +// If no known suffix is present, the name is returned unchanged. +func StripCompression(name string) string { + for _, suffix := range knownCompressionSuffixes { + if before, ok := strings.CutSuffix(name, suffix); ok { + return before + } + } + + return name +} + +// AppendCompression adds a compression suffix to a path. +// For example: "buildId/memfile" → "buildId/memfile.zstd". +func AppendCompression(path string, ct CompressionType) string { + return path + ct.Suffix() +} + +var knownCompressionSuffixes = []string{".lz4", ".zstd"} diff --git a/packages/shared/pkg/storage/storage.go b/packages/shared/pkg/storage/storage.go index 3ba75e84d1..b01c6629d0 100644 --- a/packages/shared/pkg/storage/storage.go +++ b/packages/shared/pkg/storage/storage.go @@ -14,6 +14,7 @@ import ( "github.com/e2b-dev/infra/packages/shared/pkg/env" "github.com/e2b-dev/infra/packages/shared/pkg/limit" + "github.com/e2b-dev/infra/packages/shared/pkg/telemetry" "github.com/e2b-dev/infra/packages/shared/pkg/utils" ) @@ -41,6 +42,10 @@ const ( // MemoryChunkSize must always be bigger or equal to the block size. MemoryChunkSize = 4 * 1024 * 1024 // 4 MB + + // MetadataKeyUncompressedSize stores the original size so that Size() + // returns the uncompressed size for compressed objects. + MetadataKeyUncompressedSize = "uncompressed-size" ) // GetProviderType returns the configured storage provider type from the @@ -91,24 +96,35 @@ type Blob interface { type SeekableReader interface { // Random slice access, off and buffer length must be aligned to block size - ReadAt(ctx context.Context, buffer []byte, off int64) (int, error) + ReadAt(ctx context.Context, buffer []byte, off int64, ft *FrameTable) (int, error) Size(ctx context.Context) (int64, error) } // StreamingReader supports progressive reads via a streaming range reader. type StreamingReader interface { - OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) + OpenRangeReader(ctx context.Context, offsetU int64, length int64, frameTable *FrameTable) (io.ReadCloser, error) } type SeekableWriter interface { // Store entire file - StoreFile(ctx context.Context, path string) error + StoreFile(ctx context.Context, path string, cfg *CompressConfig) (*FrameTable, [32]byte, error) } type Seekable interface { - SeekableReader - SeekableWriter StreamingReader + SeekableWriter + Size(ctx context.Context) (int64, error) +} + +// PeerTransitionedError is returned by the peer Seekable when the GCS upload +// has completed and serialized V4 headers are available. +type PeerTransitionedError struct { + MemfileHeader []byte + RootfsHeader []byte +} + +func (e *PeerTransitionedError) Error() string { + return "peer upload completed, headers available" } // StorageConfig holds the configuration for creating a storage provider. @@ -197,3 +213,46 @@ func GetBlob(ctx context.Context, b Blob) ([]byte, error) { return buf.Bytes(), nil } + +// LoadBlob opens a blob by path and reads its contents. +func LoadBlob(ctx context.Context, s StorageProvider, path string, objectType ObjectType) ([]byte, error) { + blob, err := s.OpenBlob(ctx, path, objectType) + if err != nil { + return nil, fmt.Errorf("failed to open blob %s: %w", path, err) + } + + return GetBlob(ctx, blob) +} + +// timedReadCloser wraps a reader with OTEL timer metrics. +// Close records success (with total bytes read) or failure on the timer. +type timedReadCloser struct { + inner io.ReadCloser + timer *telemetry.Stopwatch + ctx context.Context //nolint:containedctx // needed for timer recording in Close + bytesRead int64 + closeErr error +} + +func (r *timedReadCloser) Read(p []byte) (int, error) { + n, err := r.inner.Read(p) + r.bytesRead += int64(n) + + if err != nil && err != io.EOF { + r.closeErr = err + } + + return n, err +} + +func (r *timedReadCloser) Close() error { + err := r.inner.Close() + + if r.closeErr != nil || err != nil { + r.timer.Failure(r.ctx, r.bytesRead) + } else { + r.timer.Success(r.ctx, r.bytesRead) + } + + return err +} diff --git a/packages/shared/pkg/storage/storage_aws.go b/packages/shared/pkg/storage/storage_aws.go index 189e1cd501..ca252e9dfe 100644 --- a/packages/shared/pkg/storage/storage_aws.go +++ b/packages/shared/pkg/storage/storage_aws.go @@ -162,13 +162,17 @@ func (o *awsObject) WriteTo(ctx context.Context, dst io.Writer) (int64, error) { return io.Copy(dst, resp.Body) } -func (o *awsObject) StoreFile(ctx context.Context, path string) error { +func (o *awsObject) StoreFile(ctx context.Context, path string, cfg *CompressConfig) (*FrameTable, [32]byte, error) { + if cfg.IsEnabled() { + return nil, [32]byte{}, fmt.Errorf("compressed uploads are not supported on AWS (builds target GCP only)") + } + ctx, cancel := context.WithTimeout(ctx, awsWriteTimeout) defer cancel() f, err := os.Open(path) if err != nil { - return fmt.Errorf("failed to open file %s: %w", path, err) + return nil, [32]byte{}, fmt.Errorf("failed to open file %s: %w", path, err) } defer f.Close() @@ -189,7 +193,7 @@ func (o *awsObject) StoreFile(ctx context.Context, path string) error { }, ) - return err + return nil, [32]byte{}, err } func (o *awsObject) Put(ctx context.Context, data []byte) error { @@ -211,7 +215,11 @@ func (o *awsObject) Put(ctx context.Context, data []byte) error { return nil } -func (o *awsObject) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { +func (o *awsObject) OpenRangeReader(ctx context.Context, off, length int64, frameTable *FrameTable) (io.ReadCloser, error) { + if frameTable.IsCompressed() { + return nil, fmt.Errorf("compressed reads are not supported on AWS") + } + readRange := aws.String(fmt.Sprintf("bytes=%d-%d", off, off+length-1)) resp, err := o.client.GetObject(ctx, &s3.GetObjectInput{ Bucket: aws.String(o.bucketName), @@ -230,37 +238,6 @@ func (o *awsObject) OpenRangeReader(ctx context.Context, off, length int64) (io. return resp.Body, nil } -func (o *awsObject) ReadAt(ctx context.Context, buff []byte, off int64) (n int, err error) { - ctx, cancel := context.WithTimeout(ctx, awsReadTimeout) - defer cancel() - - readRange := aws.String(fmt.Sprintf("bytes=%d-%d", off, off+int64(len(buff))-1)) - resp, err := o.client.GetObject(ctx, &s3.GetObjectInput{ - Bucket: aws.String(o.bucketName), - Key: aws.String(o.path), - Range: readRange, - }) - if err != nil { - var nsk *types.NoSuchKey - if errors.As(err, &nsk) { - return 0, ErrObjectNotExist - } - - return 0, err - } - - defer resp.Body.Close() - - // When the object is smaller than requested range there will be unexpected EOF, - // but backend expects to return EOF in this case. - n, err = io.ReadFull(resp.Body, buff) - if errors.Is(err, io.ErrUnexpectedEOF) { - err = io.EOF - } - - return n, err -} - func (o *awsObject) Size(ctx context.Context) (int64, error) { ctx, cancel := context.WithTimeout(ctx, awsOperationTimeout) defer cancel() diff --git a/packages/shared/pkg/storage/storage_cache_blob.go b/packages/shared/pkg/storage/storage_cache_blob.go index 489b19836f..32ebd40daa 100644 --- a/packages/shared/pkg/storage/storage_cache_blob.go +++ b/packages/shared/pkg/storage/storage_cache_blob.go @@ -45,12 +45,12 @@ func (b *cachedBlob) WriteTo(ctx context.Context, dst io.Writer) (n int64, e err bytesRead, err := b.copyFullFileFromCache(ctx, dst) if err == nil { - recordCacheRead(ctx, true, bytesRead, cacheTypeObject, cacheOpWriteTo) + recordCacheRead(ctx, true, bytesRead, cacheTypeBlob, cacheOpWriteTo) return bytesRead, nil } - recordCacheReadError(ctx, cacheTypeObject, cacheOpWriteTo, err) + recordCacheReadError(ctx, cacheTypeBlob, cacheOpWriteTo, err) // This is semi-arbitrary. this code path is called for files that tend to be less than 1 MB (headers, metadata, etc), // so 2 MB allows us to read the file without needing to allocate more memory, with some room for growth. If the @@ -73,13 +73,13 @@ func (b *cachedBlob) WriteTo(ctx context.Context, dst io.Writer) (n int64, e err count, err := b.writeFileToCache(ctx, buffer) if err != nil { - recordCacheWriteError(ctx, cacheTypeObject, cacheOpWriteTo, err) + recordCacheWriteError(ctx, cacheTypeBlob, cacheOpWriteTo, err) recordError(span, err) return } - recordCacheWrite(ctx, count, cacheTypeObject, cacheOpWriteTo) + recordCacheWrite(ctx, count, cacheTypeBlob, cacheOpWriteTo) }) } @@ -88,7 +88,7 @@ func (b *cachedBlob) WriteTo(ctx context.Context, dst io.Writer) (n int64, e err return int64(written), fmt.Errorf("failed to write object: %w", err) } - recordCacheRead(ctx, false, int64(written), cacheTypeObject, cacheOpWriteTo) + recordCacheRead(ctx, false, int64(written), cacheTypeBlob, cacheOpWriteTo) return int64(written), err // in case err == EOF } @@ -110,9 +110,9 @@ func (b *cachedBlob) Put(ctx context.Context, data []byte) (e error) { count, err := b.writeFileToCache(ctx, bytes.NewReader(data)) if err != nil { recordError(span, err) - recordCacheWriteError(ctx, cacheTypeObject, cacheOpWrite, err) + recordCacheWriteError(ctx, cacheTypeBlob, cacheOpPut, err) } else { - recordCacheWrite(ctx, count, cacheTypeObject, cacheOpWrite) + recordCacheWrite(ctx, count, cacheTypeBlob, cacheOpPut) } }) } diff --git a/packages/shared/pkg/storage/storage_cache_blob_test.go b/packages/shared/pkg/storage/storage_cache_blob_test.go index 27c4afe88e..fbeecac4e0 100644 --- a/packages/shared/pkg/storage/storage_cache_blob_test.go +++ b/packages/shared/pkg/storage/storage_cache_blob_test.go @@ -13,8 +13,6 @@ import ( "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" "go.opentelemetry.io/otel/trace/noop" - - storagemocks "github.com/e2b-dev/infra/packages/shared/pkg/storage/mocks" ) var noopTracer = noop.TracerProvider{}.Tracer("github.com/e2b-dev/infra/packages/shared/pkg/storage") @@ -32,12 +30,12 @@ func TestCachedObjectProvider_Put(t *testing.T) { err := os.MkdirAll(cacheDir, os.ModePerm) require.NoError(t, err) - inner := storagemocks.NewMockBlob(t) + inner := NewMockBlob(t) inner.EXPECT(). Put(mock.Anything, mock.Anything). Return(nil) - featureFlags := storagemocks.NewMockFeatureFlagsClient(t) + featureFlags := NewMockFeatureFlagsClient(t) featureFlags.EXPECT().BoolFlag(mock.Anything, mock.Anything).Return(true) c := cachedBlob{path: cacheDir, inner: inner, chunkSize: 1024, flags: featureFlags, tracer: noopTracer} @@ -68,7 +66,7 @@ func TestCachedObjectProvider_Put(t *testing.T) { const dataSize = 10 * megabyte actualData := generateData(t, dataSize) - inner := storagemocks.NewMockBlob(t) + inner := NewMockBlob(t) inner.EXPECT(). WriteTo(mock.Anything, mock.Anything). RunAndReturn(func(_ context.Context, dst io.Writer) (int64, error) { @@ -101,7 +99,7 @@ func TestCachedObjectProvider_WriteFileToCache(t *testing.T) { tracer: noopTracer, } errTarget := errors.New("find me") - reader := storagemocks.NewMockReader(t) + reader := NewMockReader(t) reader.EXPECT().Read(mock.Anything).Return(4, nil).Once() reader.EXPECT().Read(mock.Anything).Return(0, errTarget).Once() diff --git a/packages/shared/pkg/storage/storage_cache_compressed.go b/packages/shared/pkg/storage/storage_cache_compressed.go new file mode 100644 index 0000000000..45712ab3b3 --- /dev/null +++ b/packages/shared/pkg/storage/storage_cache_compressed.go @@ -0,0 +1,151 @@ +package storage + +import ( + "bytes" + "context" + "fmt" + "io" + "os" + + "go.opentelemetry.io/otel/attribute" +) + +// openReaderCompressed handles the compressed cache path for OpenRangeReader. +// NFS stores compressed frames (.frm); on hit we decompress, on miss we fetch +// raw compressed bytes and tee them to NFS on Close. +func (c *cachedSeekable) openReaderCompressed(ctx context.Context, offsetU int64, frameTable *FrameTable) (io.ReadCloser, error) { + frameStart, frameSize, err := frameTable.FrameFor(offsetU) + if err != nil { + return nil, fmt.Errorf("cache OpenRangeReader: frame lookup for offset %d: %w", offsetU, err) + } + + framePath := makeFrameFilename(c.path, frameStart, frameSize) + + timer := cacheSlabReadTimerFactory.Begin( + attribute.String(nfsCacheOperationAttr, nfsCacheOperationAttrRead), + attribute.Bool("compressed", true), + attribute.String("compression_type", string(frameTable.CompressionType())), + ) + + // Cache hit: open compressed frame from NFS and wrap with decompressor. + f, err := os.Open(framePath) + + switch { + case err == nil: + recordCacheRead(ctx, true, int64(frameSize.C), cacheTypeSeekable, cacheOpOpenRangeReader) + timer.Success(ctx, int64(frameSize.C)) + + decompressed, err := newDecompressingReadCloser(f, frameTable.CompressionType()) + if err != nil { + f.Close() + + return nil, fmt.Errorf("cache OpenRangeReader: decompress cached frame: %w", err) + } + + return decompressed, nil + case !os.IsNotExist(err): + recordCacheReadError(ctx, cacheTypeSeekable, cacheOpOpenRangeReader, err) + } + + timer.Failure(ctx, 0) + + // Cache miss: fetch raw compressed bytes via OpenRangeReader(nil frameTable). + raw, err := c.inner.OpenRangeReader(ctx, frameStart.C, int64(frameSize.C), nil) + if err != nil { + return nil, fmt.Errorf("cache OpenRangeReader: raw fetch at C=%d: %w", frameStart.C, err) + } + + recordCacheRead(ctx, false, int64(frameSize.C), cacheTypeSeekable, cacheOpOpenRangeReader) + + rc, err := newDecompressingCacheReader(raw, frameTable.CompressionType(), int(frameSize.C), c, ctx, framePath, offsetU) + if err != nil { + raw.Close() + + return nil, fmt.Errorf("cache OpenRangeReader: create decompressor: %w", err) + } + + return rc, nil +} + +// newDecompressingCacheReader creates a reader that decompresses on Read and +// writes the accumulated compressed bytes to the NFS cache on Close. +func newDecompressingCacheReader( + raw io.ReadCloser, + ct CompressionType, + expectedSize int, + cache *cachedSeekable, + ctx context.Context, //nolint:revive // ctx after other params for readability at call site + framePath string, + offset int64, +) (io.ReadCloser, error) { + var compressedBuf bytes.Buffer + compressedBuf.Grow(expectedSize) + + tee := io.TeeReader(raw, &compressedBuf) + + dec, err := NewDecompressingReader(tee, ct) + if err != nil { + return nil, err + } + + return &decompressingCacheReader{ + decompressor: dec, + raw: raw, + compressedBuf: &compressedBuf, + expectedSize: expectedSize, + cache: cache, + ctx: ctx, + framePath: framePath, + offset: offset, + }, nil +} + +type decompressingCacheReader struct { + decompressor io.ReadCloser // decompresses on Read + raw io.ReadCloser // underlying compressed stream (must be closed) + compressedBuf *bytes.Buffer + expectedSize int + cache *cachedSeekable + ctx context.Context //nolint:containedctx // needed for async cache write-back in Close + framePath string + offset int64 +} + +func (r *decompressingCacheReader) Read(p []byte) (int, error) { + return r.decompressor.Read(p) +} + +func (r *decompressingCacheReader) Close() error { + if err := r.decompressor.Close(); err != nil { + r.raw.Close() + + return err + } + + if err := r.raw.Close(); err != nil { + return err + } + + if !skipCacheWriteback(r.ctx) && isCompleteRead(r.compressedBuf.Len(), r.expectedSize, nil) { + data := make([]byte, r.compressedBuf.Len()) + copy(data, r.compressedBuf.Bytes()) + + r.cache.goCtx(r.ctx, func(ctx context.Context) { + ctx, span := r.cache.tracer.Start(ctx, "write compressed frame back to cache") + defer span.End() + + if err := r.cache.writeToCache(ctx, r.offset, r.framePath, data); err != nil { + recordError(span, err) + recordCacheWriteError(ctx, cacheTypeSeekable, cacheOpOpenRangeReader, err) + } + }) + } + + return nil +} + +// makeFrameFilename returns the NFS cache path for a compressed frame. +// Format: {cacheBasePath}/{016xC}-{xC}.frm +func makeFrameFilename(cacheBasePath string, offset FrameOffset, size FrameSize) string { + return fmt.Sprintf("%s/%016x-%x.frm", cacheBasePath, offset.C, size.C) +} diff --git a/packages/shared/pkg/storage/storage_cache_metrics.go b/packages/shared/pkg/storage/storage_cache_metrics.go index 037bc7ed06..24c8a4016d 100644 --- a/packages/shared/pkg/storage/storage_cache_metrics.go +++ b/packages/shared/pkg/storage/storage_cache_metrics.go @@ -28,20 +28,18 @@ var ( type cacheOp string const ( - cacheOpWriteTo cacheOp = "write_to" - cacheOpReadAt cacheOp = "read_at" - cacheOpSize cacheOp = "size" - - cacheOpOpenRangeReader cacheOp = "open_range_reader" - - cacheOpWrite cacheOp = "write" + cacheOpWriteTo cacheOp = "write_to" + cacheOpOpenRangeReader cacheOp = "open_range_reader" + cacheOpSize cacheOp = "size" cacheOpWriteFromFileSystem cacheOp = "write_from_filesystem" + + cacheOpPut cacheOp = "put" ) type cacheType string const ( - cacheTypeObject cacheType = "object" + cacheTypeBlob cacheType = "blob" cacheTypeSeekable cacheType = "seekable" ) diff --git a/packages/shared/pkg/storage/storage_cache_seekable.go b/packages/shared/pkg/storage/storage_cache_seekable.go index 7341107392..7e8db4ce20 100644 --- a/packages/shared/pkg/storage/storage_cache_seekable.go +++ b/packages/shared/pkg/storage/storage_cache_seekable.go @@ -32,9 +32,9 @@ var ( ) const ( - nfsCacheOperationAttr = "operation" - nfsCacheOperationAttrReadAt = "ReadAt" - nfsCacheOperationAttrSize = "Size" + nfsCacheOperationAttr = "operation" + nfsCacheOperationAttrRead = "Read" + nfsCacheOperationAttrSize = "Size" ) var ( @@ -72,114 +72,116 @@ var ( _ StreamingReader = (*cachedSeekable)(nil) ) -func (c *cachedSeekable) ReadAt(ctx context.Context, buff []byte, offset int64) (n int, err error) { - ctx, span := c.tracer.Start(ctx, "read object at offset", trace.WithAttributes( - attribute.Int64("offset", offset), - attribute.Int("buff_len", len(buff)), - )) - defer func() { - recordError(span, err) - span.End() - }() - - if err := c.validateReadAtParams(int64(len(buff)), offset); err != nil { - return 0, err - } - - // try to read from cache first - chunkPath := c.makeChunkFilename(offset) +func (c *cachedSeekable) OpenRangeReader(ctx context.Context, off int64, length int64, frameTable *FrameTable) (io.ReadCloser, error) { + compressed := frameTable.IsCompressed() - readTimer := cacheSlabReadTimerFactory.Begin(attribute.String(nfsCacheOperationAttr, nfsCacheOperationAttrReadAt)) - count, err := c.readAtFromCache(ctx, chunkPath, buff) - if ignoreEOF(err) == nil { - recordCacheRead(ctx, true, int64(count), cacheTypeSeekable, cacheOpReadAt) - readTimer.Success(ctx, int64(count)) + ctx, span := c.tracer.Start(ctx, "read", trace.WithAttributes( + attribute.Int64("offset", off), + attribute.Int64("length", length), + attribute.Bool("compressed", compressed), + )) - return count, err // return `err` in case it's io.EOF - } - readTimer.Failure(ctx, int64(count)) + if compressed { + rc, err := c.openReaderCompressed(ctx, off, frameTable) + if err != nil { + recordError(span, err) + span.End() - if !os.IsNotExist(err) { - recordCacheReadError(ctx, cacheTypeSeekable, cacheOpReadAt, err) - } + return nil, err + } - logger.L().Debug(ctx, "failed to read cached chunk, falling back to remote read", - zap.String("chunk_path", chunkPath), - zap.Int64("offset", offset), - zap.Error(err)) + rc = withSpan(rc, span) - // read remote file - readCount, err := c.inner.ReadAt(ctx, buff, offset) - if ignoreEOF(err) != nil { - return readCount, fmt.Errorf("failed to perform uncached read: %w", err) + return rc, nil } - if !skipCacheWriteback(ctx) && isCompleteRead(readCount, len(buff), err) { - shadowBuff := make([]byte, readCount) - copy(shadowBuff, buff[:readCount]) - - c.goCtx(ctx, func(ctx context.Context) { - ctx, span := c.tracer.Start(ctx, "write chunk at offset back to cache") - defer span.End() + if err := c.validateReadParams(length, off); err != nil { + recordError(span, err) + span.End() - if err := c.writeChunkToCache(ctx, offset, chunkPath, shadowBuff); err != nil { - recordError(span, err) - recordCacheWriteError(ctx, cacheTypeSeekable, cacheOpReadAt, err) - } - }) + return nil, err } - recordCacheRead(ctx, false, int64(readCount), cacheTypeSeekable, cacheOpReadAt) - - return readCount, err -} + timer := cacheSlabReadTimerFactory.Begin( + attribute.String(nfsCacheOperationAttr, nfsCacheOperationAttrRead), + attribute.Bool("compressed", false), + ) -func (c *cachedSeekable) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { - // Try NFS cache file first chunkPath := c.makeChunkFilename(off) fp, err := os.Open(chunkPath) if err == nil { recordCacheRead(ctx, true, length, cacheTypeSeekable, cacheOpOpenRangeReader) + timer.Success(ctx, length) + + rc := io.ReadCloser(&fsRangeReadCloser{Reader: io.NewSectionReader(fp, 0, length), file: fp}) + rc = withSpan(rc, span) - return &fsRangeReadCloser{ - Reader: io.NewSectionReader(fp, 0, length), - file: fp, - }, nil + return rc, nil } if !os.IsNotExist(err) { recordCacheReadError(ctx, cacheTypeSeekable, cacheOpOpenRangeReader, err) } - // Cache miss: delegate to the inner backend (Seekable embeds StreamingReader). - inner, err := c.inner.OpenRangeReader(ctx, off, length) + timer.Failure(ctx, 0) + + rc, err := c.inner.OpenRangeReader(ctx, off, length, nil) if err != nil { + recordError(span, err) + span.End() + return nil, fmt.Errorf("failed to open inner range reader: %w", err) } recordCacheRead(ctx, false, length, cacheTypeSeekable, cacheOpOpenRangeReader) - // Skip write-through when the caller has opted out of cache writeback. - if skipCacheWriteback(ctx) { - return inner, nil + if !skipCacheWriteback(ctx) { + rc = newCacheWriteThroughReader(rc, c, ctx, off, length, chunkPath) } - // Wrap in a write-through reader that caches data on Close + rc = withSpan(rc, span) + + return rc, nil +} + +// withSpan wraps a reader with an OTEL span that ends on Close. +func withSpan(rc io.ReadCloser, span trace.Span) io.ReadCloser { + return &spanReadCloser{inner: rc, span: span} +} + +type spanReadCloser struct { + inner io.ReadCloser + span trace.Span +} + +func (r *spanReadCloser) Read(p []byte) (int, error) { + return r.inner.Read(p) +} + +func (r *spanReadCloser) Close() error { + err := r.inner.Close() + recordError(r.span, err) + r.span.End() + + return err +} + +// newCacheWriteThroughReader wraps a reader, buffering all data read through it. +// On Close, it asynchronously writes the buffered data to the NFS cache only +// if the total bytes read match the expected length (to avoid caching truncated data). +func newCacheWriteThroughReader(inner io.ReadCloser, cache *cachedSeekable, ctx context.Context, off, expectedLen int64, chunkPath string) io.ReadCloser { return &cacheWriteThroughReader{ inner: inner, - buf: bytes.NewBuffer(make([]byte, 0, length)), - cache: c, + buf: bytes.NewBuffer(make([]byte, 0, expectedLen)), + cache: cache, ctx: ctx, off: off, - expectedLen: length, + expectedLen: expectedLen, chunkPath: chunkPath, - }, nil + } } -// cacheWriteThroughReader wraps an inner reader, buffering all data read through it. -// On Close, it asynchronously writes the buffered data to the NFS cache only -// if the total bytes read match the expected length (to avoid caching truncated data). type cacheWriteThroughReader struct { inner io.ReadCloser buf *bytes.Buffer @@ -206,7 +208,7 @@ func (r *cacheWriteThroughReader) Close() error { // Unlike ReadAt where io.EOF can justify a short read (last chunk), // a streaming reader always ends with EOF regardless of whether the // data was truncated, so the byte count is the only reliable check. - if r.buf.Len() > 0 && int64(r.buf.Len()) == r.expectedLen { + if isCompleteRead(r.buf.Len(), int(r.expectedLen), nil) { data := make([]byte, r.buf.Len()) copy(data, r.buf.Bytes()) @@ -214,7 +216,7 @@ func (r *cacheWriteThroughReader) Close() error { ctx, span := r.cache.tracer.Start(ctx, "write range reader chunk back to cache") defer span.End() - if err := r.cache.writeChunkToCache(ctx, r.off, r.chunkPath, data); err != nil { + if err := r.cache.writeToCache(ctx, r.off, r.chunkPath, data); err != nil { recordError(span, err) recordCacheWriteError(ctx, cacheTypeSeekable, cacheOpOpenRangeReader, err) } @@ -266,7 +268,7 @@ func (c *cachedSeekable) Size(ctx context.Context) (n int64, e error) { return size, nil } -func (c *cachedSeekable) StoreFile(ctx context.Context, path string) (e error) { +func (c *cachedSeekable) StoreFile(ctx context.Context, path string, cfg *CompressConfig) (_ *FrameTable, _ [32]byte, e error) { ctx, span := c.tracer.Start(ctx, "write object from file system", trace.WithAttributes(attribute.String("path", path)), ) @@ -278,7 +280,7 @@ func (c *cachedSeekable) StoreFile(ctx context.Context, path string) (e error) { // write the file to the disk and the remote system at the same time. // this opens the file twice, but the API makes it difficult to use a MultiWriter - if c.flags.BoolFlag(ctx, featureflags.EnableWriteThroughCacheFlag) { + if !cfg.IsEnabled() && c.flags.BoolFlag(ctx, featureflags.EnableWriteThroughCacheFlag) { c.goCtx(ctx, func(ctx context.Context) { ctx, span := c.tracer.Start(ctx, "write cache object from file system", trace.WithAttributes(attribute.String("path", path))) @@ -301,7 +303,7 @@ func (c *cachedSeekable) StoreFile(ctx context.Context, path string) (e error) { }) } - return c.inner.StoreFile(ctx, path) + return c.inner.StoreFile(ctx, path, cfg) } func (c *cachedSeekable) goCtx(ctx context.Context, fn func(context.Context)) { @@ -314,36 +316,8 @@ func (c *cachedSeekable) makeChunkFilename(offset int64) string { return fmt.Sprintf("%s/%012d-%d.bin", c.path, offset/c.chunkSize, c.chunkSize) } -func (c *cachedSeekable) makeTempChunkFilename(offset int64) string { - tempFilename := uuid.NewString() - - return fmt.Sprintf("%s/.temp.%012d-%d.bin.%s", c.path, offset/c.chunkSize, c.chunkSize, tempFilename) -} - -func (c *cachedSeekable) readAtFromCache(ctx context.Context, chunkPath string, buff []byte) (n int, e error) { - ctx, span := c.tracer.Start(ctx, "read chunk at offset from cache") - defer func() { - recordError(span, e) - span.End() - }() - - fp, err := os.Open(chunkPath) - if err != nil { - return 0, fmt.Errorf("failed to open file: %w", err) - } - - defer utils.Cleanup(ctx, "failed to close chunk", fp.Close) - - // ReadAt (pread) is used instead of Read so that short reads from cache - // files (e.g. last chunk) return io.EOF per the io.ReaderAt contract. - // Plain Read on Linux returns (n, nil) for short reads and only - // signals EOF on a subsequent call, which would hide truncation. - count, err := fp.ReadAt(buff, 0) - if ignoreEOF(err) != nil { - return 0, fmt.Errorf("failed to read from chunk: %w", err) - } - - return count, err // return `err` in case it's io.EOF +func (c *cachedSeekable) makeTempFilename(path string) string { + return path + ".tmp." + uuid.NewString() } func (c *cachedSeekable) sizeFilename() string { @@ -365,7 +339,7 @@ func (c *cachedSeekable) readLocalSize(context.Context) (int64, error) { return size, nil } -func (c *cachedSeekable) validateReadAtParams(buffSize, offset int64) error { +func (c *cachedSeekable) validateReadParams(buffSize, offset int64) error { if buffSize == 0 { return ErrBufferTooSmall } @@ -382,14 +356,14 @@ func (c *cachedSeekable) validateReadAtParams(buffSize, offset int64) error { return nil } -func (c *cachedSeekable) writeChunkToCache(ctx context.Context, offset int64, chunkPath string, bytes []byte) error { +func (c *cachedSeekable) writeToCache(ctx context.Context, offset int64, finalPath string, bytes []byte) error { writeTimer := cacheSlabWriteTimerFactory.Begin() // Try to acquire lock for this chunk write to NFS cache - lockFile, err := lock.TryAcquireLock(ctx, chunkPath) + lockFile, err := lock.TryAcquireLock(ctx, finalPath) if err != nil { // failed to acquire lock, which is a different category of failure than "write failed" - recordCacheWriteError(ctx, cacheTypeSeekable, cacheOpReadAt, err) + recordCacheWriteError(ctx, cacheTypeSeekable, cacheOpOpenRangeReader, err) writeTimer.Failure(ctx, 0) @@ -400,14 +374,14 @@ func (c *cachedSeekable) writeChunkToCache(ctx context.Context, offset int64, ch defer func() { err := lock.ReleaseLock(ctx, lockFile) if err != nil { - logger.L().Warn(ctx, "failed to release lock after writing chunk to cache", + logger.L().Warn(ctx, "failed to release lock after writing to cache", zap.Int64("offset", offset), - zap.String("path", chunkPath), + zap.String("path", finalPath), zap.Error(err)) } }() - tempPath := c.makeTempChunkFilename(offset) + tempPath := c.makeTempFilename(finalPath) if err := os.WriteFile(tempPath, bytes, cacheFilePermissions); err != nil { go safelyRemoveFile(ctx, tempPath) @@ -417,7 +391,7 @@ func (c *cachedSeekable) writeChunkToCache(ctx context.Context, offset int64, ch return fmt.Errorf("failed to write temp cache file: %w", err) } - if err := utils.RenameOrDeleteFile(ctx, tempPath, chunkPath); err != nil { + if err := utils.RenameOrDeleteFile(ctx, tempPath, finalPath); err != nil { writeTimer.Failure(ctx, int64(len(bytes))) return fmt.Errorf("failed to rename temp file: %w", err) diff --git a/packages/shared/pkg/storage/storage_cache_seekable_test.go b/packages/shared/pkg/storage/storage_cache_seekable_test.go index 40b9ea03d7..b69eeaf271 100644 --- a/packages/shared/pkg/storage/storage_cache_seekable_test.go +++ b/packages/shared/pkg/storage/storage_cache_seekable_test.go @@ -12,10 +12,30 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" - - storagemocks "github.com/e2b-dev/infra/packages/shared/pkg/storage/mocks" ) +// testReadAt emulates the removed cachedSeekable.ReadAt via OpenRangeReader. +// This preserves the base test structure after ReadAt was removed from the Seekable interface. +func testReadAt(ctx context.Context, c *cachedSeekable, buff []byte, off int64) (int, error) { + rc, err := c.OpenRangeReader(ctx, off, int64(len(buff)), nil) + if err != nil { + return 0, err + } + + n, err := io.ReadFull(rc, buff) + + closeErr := rc.Close() + if errors.Is(err, io.ErrUnexpectedEOF) { + err = io.EOF + } + + if err == nil { + err = closeErr + } + + return n, err +} + func TestCachedFileObjectProvider_MakeChunkFilename(t *testing.T) { t.Parallel() @@ -32,7 +52,7 @@ func TestCachedFileObjectProvider_Size(t *testing.T) { const expectedSize int64 = 1024 - inner := storagemocks.NewMockSeekable(t) + inner := NewMockSeekable(t) inner.EXPECT().Size(mock.Anything).Return(expectedSize, nil) c := cachedSeekable{path: t.TempDir(), inner: inner, tracer: noopTracer} @@ -71,19 +91,19 @@ func TestCachedFileObjectProvider_WriteFromFileSystem(t *testing.T) { err = os.WriteFile(tempFilename, data, 0o644) require.NoError(t, err) - inner := storagemocks.NewMockSeekable(t) + inner := NewMockSeekable(t) inner.EXPECT(). - StoreFile(mock.Anything, mock.Anything). - Return(nil) + StoreFile(mock.Anything, mock.Anything, (*CompressConfig)(nil)). + Return(nil, [32]byte{}, nil) - featureFlags := storagemocks.NewMockFeatureFlagsClient(t) + featureFlags := NewMockFeatureFlagsClient(t) featureFlags.EXPECT().BoolFlag(mock.Anything, mock.Anything).Return(true) featureFlags.EXPECT().IntFlag(mock.Anything, mock.Anything).Return(10) c := cachedSeekable{path: cacheDir, inner: inner, chunkSize: 1024, flags: featureFlags, tracer: noopTracer} // write temp file - err = c.StoreFile(t.Context(), tempFilename) + _, _, err = c.StoreFile(t.Context(), tempFilename, nil) require.NoError(t, err) // file is written asynchronously, wait for it to finish @@ -98,7 +118,7 @@ func TestCachedFileObjectProvider_WriteFromFileSystem(t *testing.T) { // verify that the size has been cached buff := make([]byte, len(data)) - bytesRead, err := c.ReadAt(t.Context(), buff, 0) + bytesRead, err := testReadAt(t.Context(), &c, buff, 0) require.NoError(t, err) assert.Equal(t, data, buff) assert.Equal(t, len(data), bytesRead) @@ -125,7 +145,7 @@ func TestCachedFileObjectProvider_WriteTo(t *testing.T) { require.NoError(t, err) buffer := make([]byte, 3) - read, err := c.ReadAt(t.Context(), buffer, 0) + read, err := testReadAt(t.Context(), &c, buffer, 0) require.NoError(t, err) assert.Equal(t, []byte{1, 2, 3}, buffer) assert.Equal(t, 3, read) @@ -147,7 +167,7 @@ func TestCachedFileObjectProvider_WriteTo(t *testing.T) { // per the io.ReaderAt contract. This is a cache hit — the caller // sees the data with EOF indicating end of file. buffer := make([]byte, 10) - read, err := c.ReadAt(t.Context(), buffer, 0) + read, err := testReadAt(t.Context(), &c, buffer, 0) require.ErrorIs(t, err, io.EOF) assert.Equal(t, 3, read) assert.Equal(t, []byte{1, 2, 3}, buffer[:read]) @@ -157,30 +177,27 @@ func TestCachedFileObjectProvider_WriteTo(t *testing.T) { t.Parallel() fakeData := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - fakeStorageObjectProvider := storagemocks.NewMockSeekable(t) + inner := NewMockSeekable(t) - fakeStorageObjectProvider.EXPECT(). - ReadAt(mock.Anything, mock.Anything, mock.Anything). - RunAndReturn(func(_ context.Context, buff []byte, off int64) (int, error) { - start := off - end := off + int64(len(buff)) - end = min(end, int64(len(fakeData))) - copy(buff, fakeData[start:end]) - - return int(end - start), nil + inner.EXPECT(). + OpenRangeReader(mock.Anything, mock.Anything, mock.Anything, (*FrameTable)(nil)). + RunAndReturn(func(_ context.Context, off int64, length int64, _ *FrameTable) (io.ReadCloser, error) { + end := min(int(off)+int(length), len(fakeData)) + + return io.NopCloser(bytes.NewReader(fakeData[off:end])), nil }) tempDir := t.TempDir() c := cachedSeekable{ path: tempDir, chunkSize: 3, - inner: fakeStorageObjectProvider, + inner: inner, tracer: noopTracer, } // first read goes to source buffer := make([]byte, 3) - read, err := c.ReadAt(t.Context(), buffer, 3) + read, err := testReadAt(t.Context(), &c, buffer, 3) require.NoError(t, err) assert.Equal(t, []byte{4, 5, 6}, buffer) assert.Equal(t, 3, read) @@ -191,7 +208,7 @@ func TestCachedFileObjectProvider_WriteTo(t *testing.T) { // second read pulls from cache c.inner = nil // prevent remote reads, force cache read buffer = make([]byte, 3) - read, err = c.ReadAt(t.Context(), buffer, 3) + read, err = testReadAt(t.Context(), &c, buffer, 3) require.NoError(t, err) assert.Equal(t, []byte{4, 5, 6}, buffer) assert.Equal(t, 3, read) @@ -202,7 +219,7 @@ func TestCachedFileObjectProvider_WriteTo(t *testing.T) { fakeData := []byte{1, 2, 3} - fakeStorageObjectProvider := storagemocks.NewMockBlob(t) + fakeStorageObjectProvider := NewMockBlob(t) fakeStorageObjectProvider.EXPECT(). WriteTo(mock.Anything, mock.Anything). RunAndReturn(func(_ context.Context, dst io.Writer) (int64, error) { @@ -279,7 +296,7 @@ func TestCachedFileObjectProvider_validateReadAtParams(t *testing.T) { chunkSize: tc.chunkSize, tracer: noopTracer, } - err := c.validateReadAtParams(tc.bufferSize, tc.offset) + err := c.validateReadParams(tc.bufferSize, tc.offset) if tc.expected == nil { require.NoError(t, err) } else { @@ -292,39 +309,24 @@ func TestCachedFileObjectProvider_validateReadAtParams(t *testing.T) { func TestCachedSeekableObjectProvider_ReadAt(t *testing.T) { t.Parallel() - t.Run("failed but returns count on short read", func(t *testing.T) { - t.Parallel() - - c := cachedSeekable{chunkSize: 10, tracer: noopTracer} - errTarget := errors.New("find me") - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT().ReadAt(mock.Anything, mock.Anything, mock.Anything).Return(5, errTarget) - c.inner = mockSeeker - - buff := make([]byte, 10) - count, err := c.ReadAt(t.Context(), buff, 0) - require.ErrorIs(t, err, errTarget) - assert.Equal(t, 5, count) - }) - t.Run("zero byte read with EOF is not cached", func(t *testing.T) { t.Parallel() tempDir := t.TempDir() - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - ReadAt(mock.Anything, mock.Anything, mock.Anything). - Return(0, io.EOF) + inner := NewMockSeekable(t) + inner.EXPECT(). + OpenRangeReader(mock.Anything, mock.Anything, mock.Anything, (*FrameTable)(nil)). + Return(io.NopCloser(bytes.NewReader(nil)), nil) c := cachedSeekable{ path: tempDir, chunkSize: 10, - inner: mockSeeker, + inner: inner, tracer: noopTracer, } buff := make([]byte, 10) - count, err := c.ReadAt(t.Context(), buff, 0) + count, err := testReadAt(t.Context(), &c, buff, 0) require.ErrorIs(t, err, io.EOF) assert.Equal(t, 0, count) @@ -335,127 +337,25 @@ func TestCachedSeekableObjectProvider_ReadAt(t *testing.T) { assert.True(t, os.IsNotExist(err), "zero-byte read should not be cached") }) - t.Run("zero byte read without EOF is not cached", func(t *testing.T) { - t.Parallel() - - tempDir := t.TempDir() - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - ReadAt(mock.Anything, mock.Anything, mock.Anything). - Return(0, nil) - - c := cachedSeekable{ - path: tempDir, - chunkSize: 10, - inner: mockSeeker, - tracer: noopTracer, - } - - buff := make([]byte, 10) - count, err := c.ReadAt(t.Context(), buff, 0) - require.NoError(t, err) - assert.Equal(t, 0, count) - - c.wg.Wait() - - chunkPath := c.makeChunkFilename(0) - _, err = os.Stat(chunkPath) - assert.True(t, os.IsNotExist(err), "zero-byte read should not be cached") - }) - - t.Run("short read without EOF is not cached", func(t *testing.T) { - t.Parallel() - - tempDir := t.TempDir() - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - ReadAt(mock.Anything, mock.Anything, mock.Anything). - RunAndReturn(func(_ context.Context, buff []byte, _ int64) (int, error) { - // Simulate a truncated upstream response: return fewer - // bytes than requested with no error and no EOF. - copy(buff[:2], []byte{0xAA, 0xBB}) - - return 2, nil - }) - - c := cachedSeekable{ - path: tempDir, - chunkSize: 10, - inner: mockSeeker, - tracer: noopTracer, - } - - buff := make([]byte, 10) - count, err := c.ReadAt(t.Context(), buff, 0) - require.NoError(t, err) - assert.Equal(t, 2, count) - - c.wg.Wait() - - // Verify no cache file was written. - chunkPath := c.makeChunkFilename(0) - _, err = os.Stat(chunkPath) - assert.True(t, os.IsNotExist(err), "truncated data should not be cached") - }) - - t.Run("short read with EOF is cached", func(t *testing.T) { - t.Parallel() - - tempDir := t.TempDir() - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - ReadAt(mock.Anything, mock.Anything, mock.Anything). - RunAndReturn(func(_ context.Context, buff []byte, _ int64) (int, error) { - // Last chunk: fewer bytes than the buffer with EOF. - copy(buff[:3], []byte{1, 2, 3}) - - return 3, io.EOF - }) - - c := cachedSeekable{ - path: tempDir, - chunkSize: 10, - inner: mockSeeker, - tracer: noopTracer, - } - - buff := make([]byte, 10) - count, err := c.ReadAt(t.Context(), buff, 0) - require.ErrorIs(t, err, io.EOF) - assert.Equal(t, 3, count) - - c.wg.Wait() - - // Verify the data was cached. - chunkPath := c.makeChunkFilename(0) - cached, err := os.ReadFile(chunkPath) - require.NoError(t, err) - assert.Equal(t, []byte{1, 2, 3}, cached) - }) - t.Run("full read without EOF is cached", func(t *testing.T) { t.Parallel() tempDir := t.TempDir() data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - ReadAt(mock.Anything, mock.Anything, mock.Anything). - RunAndReturn(func(_ context.Context, buff []byte, _ int64) (int, error) { - copy(buff, data) - - return len(data), nil - }) + inner := NewMockSeekable(t) + inner.EXPECT(). + OpenRangeReader(mock.Anything, mock.Anything, mock.Anything, (*FrameTable)(nil)). + Return(io.NopCloser(bytes.NewReader(data)), nil) c := cachedSeekable{ path: tempDir, chunkSize: 10, - inner: mockSeeker, + inner: inner, tracer: noopTracer, } buff := make([]byte, 10) - count, err := c.ReadAt(t.Context(), buff, 0) + count, err := testReadAt(t.Context(), &c, buff, 0) require.NoError(t, err) assert.Equal(t, 10, count) @@ -504,24 +404,20 @@ func TestCachedSeekable_ReadAt_PreservesEOF(t *testing.T) { t.Parallel() tempDir := t.TempDir() - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - ReadAt(mock.Anything, mock.Anything, mock.Anything). - RunAndReturn(func(_ context.Context, buff []byte, _ int64) (int, error) { - copy(buff[:3], []byte{1, 2, 3}) - - return 3, io.EOF - }) + inner := NewMockSeekable(t) + inner.EXPECT(). + OpenRangeReader(mock.Anything, mock.Anything, mock.Anything, (*FrameTable)(nil)). + Return(io.NopCloser(bytes.NewReader([]byte{1, 2, 3})), nil) c := cachedSeekable{ path: tempDir, chunkSize: 10, - inner: mockSeeker, + inner: inner, tracer: noopTracer, } buff := make([]byte, 10) - n, err := c.ReadAt(t.Context(), buff, 0) + n, err := testReadAt(t.Context(), &c, buff, 0) assert.Equal(t, 3, n) require.ErrorIs(t, err, io.EOF, "cachedSeekable must not swallow io.EOF") @@ -532,24 +428,20 @@ func TestCachedSeekable_ReadAt_PreservesEOF(t *testing.T) { t.Parallel() tempDir := t.TempDir() - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - ReadAt(mock.Anything, mock.Anything, mock.Anything). - RunAndReturn(func(_ context.Context, buff []byte, _ int64) (int, error) { - copy(buff, []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}) - - return 10, nil - }) + inner := NewMockSeekable(t) + inner.EXPECT(). + OpenRangeReader(mock.Anything, mock.Anything, mock.Anything, (*FrameTable)(nil)). + Return(io.NopCloser(bytes.NewReader([]byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})), nil) c := cachedSeekable{ path: tempDir, chunkSize: 10, - inner: mockSeeker, + inner: inner, tracer: noopTracer, } buff := make([]byte, 10) - n, err := c.ReadAt(t.Context(), buff, 0) + n, err := testReadAt(t.Context(), &c, buff, 0) assert.Equal(t, 10, n) require.NoError(t, err, "cachedSeekable must not inject errors on full read") @@ -562,25 +454,23 @@ func TestCachedSeekable_ReadAt_SkipCacheWriteback(t *testing.T) { tempDir := t.TempDir() data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - ReadAt(mock.Anything, mock.Anything, mock.Anything). - RunAndReturn(func(_ context.Context, buff []byte, _ int64) (int, error) { - copy(buff, data) - - return len(data), nil + inner := NewMockSeekable(t) + inner.EXPECT(). + OpenRangeReader(mock.Anything, mock.Anything, mock.Anything, (*FrameTable)(nil)). + RunAndReturn(func(_ context.Context, _ int64, _ int64, _ *FrameTable) (io.ReadCloser, error) { + return io.NopCloser(bytes.NewReader(data)), nil }) c := cachedSeekable{ path: tempDir, chunkSize: 10, - inner: mockSeeker, + inner: inner, tracer: noopTracer, } ctx := WithSkipCacheWriteback(t.Context()) buff := make([]byte, 10) - n, err := c.ReadAt(ctx, buff, 0) + n, err := testReadAt(ctx, &c, buff, 0) require.NoError(t, err) assert.Equal(t, 10, n) @@ -600,21 +490,21 @@ func TestCachedSeekable_OpenRangeReader(t *testing.T) { tempDir := t.TempDir() data := []byte("hello") - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - OpenRangeReader(mock.Anything, int64(0), int64(len(data))). + inner := NewMockSeekable(t) + inner.EXPECT(). + OpenRangeReader(mock.Anything, int64(0), int64(len(data)), (*FrameTable)(nil)). Return(io.NopCloser(bytes.NewReader(data)), nil). Once() c := cachedSeekable{ path: tempDir, chunkSize: 10, - inner: mockSeeker, + inner: inner, tracer: noopTracer, } // First call: cache miss, reads from inner. - rc, err := c.OpenRangeReader(t.Context(), 0, int64(len(data))) + rc, err := c.OpenRangeReader(t.Context(), 0, int64(len(data)), nil) require.NoError(t, err) got, err := io.ReadAll(rc) @@ -626,7 +516,7 @@ func TestCachedSeekable_OpenRangeReader(t *testing.T) { // Second call: should serve from NFS cache, inner not called again. c.inner = nil - rc2, err := c.OpenRangeReader(t.Context(), 0, int64(len(data))) + rc2, err := c.OpenRangeReader(t.Context(), 0, int64(len(data)), nil) require.NoError(t, err) got2, err := io.ReadAll(rc2) @@ -641,10 +531,10 @@ func TestCachedSeekable_OpenRangeReader(t *testing.T) { tempDir := t.TempDir() data := []byte("hello") - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - OpenRangeReader(mock.Anything, int64(0), int64(len(data))). - RunAndReturn(func(_ context.Context, _ int64, _ int64) (io.ReadCloser, error) { + inner := NewMockSeekable(t) + inner.EXPECT(). + OpenRangeReader(mock.Anything, int64(0), int64(len(data)), (*FrameTable)(nil)). + RunAndReturn(func(_ context.Context, _ int64, _ int64, _ *FrameTable) (io.ReadCloser, error) { return io.NopCloser(bytes.NewReader(data)), nil }). Times(2) @@ -652,13 +542,13 @@ func TestCachedSeekable_OpenRangeReader(t *testing.T) { c := cachedSeekable{ path: tempDir, chunkSize: 10, - inner: mockSeeker, + inner: inner, tracer: noopTracer, } ctx := WithSkipCacheWriteback(t.Context()) - rc, err := c.OpenRangeReader(ctx, 0, int64(len(data))) + rc, err := c.OpenRangeReader(ctx, 0, int64(len(data)), nil) require.NoError(t, err) got, err := io.ReadAll(rc) @@ -673,7 +563,7 @@ func TestCachedSeekable_OpenRangeReader(t *testing.T) { _, err = os.Stat(chunkPath) assert.True(t, os.IsNotExist(err), "skip writeback should not populate cache") - rc2, err := c.OpenRangeReader(ctx, 0, int64(len(data))) + rc2, err := c.OpenRangeReader(ctx, 0, int64(len(data)), nil) require.NoError(t, err) got2, err := io.ReadAll(rc2) @@ -687,19 +577,19 @@ func TestCachedSeekable_OpenRangeReader(t *testing.T) { tempDir := t.TempDir() - mockSeeker := storagemocks.NewMockSeekable(t) - mockSeeker.EXPECT(). - OpenRangeReader(mock.Anything, int64(0), int64(5)). + inner := NewMockSeekable(t) + inner.EXPECT(). + OpenRangeReader(mock.Anything, int64(0), int64(5), (*FrameTable)(nil)). Return(io.NopCloser(bytes.NewReader([]byte{0xAA, 0xBB})), nil) c := cachedSeekable{ path: tempDir, chunkSize: 10, - inner: mockSeeker, + inner: inner, tracer: noopTracer, } - rc, err := c.OpenRangeReader(t.Context(), 0, 5) + rc, err := c.OpenRangeReader(t.Context(), 0, 5, nil) require.NoError(t, err) got, err := io.ReadAll(rc) diff --git a/packages/shared/pkg/storage/storage_fs.go b/packages/shared/pkg/storage/storage_fs.go index 8eb2d3cc13..a6d9baf582 100644 --- a/packages/shared/pkg/storage/storage_fs.go +++ b/packages/shared/pkg/storage/storage_fs.go @@ -12,6 +12,7 @@ import ( "os" "path/filepath" "strconv" + "strings" "time" ) @@ -124,28 +125,52 @@ func (o *fsObject) Put(_ context.Context, data []byte) error { return err } -func (o *fsObject) StoreFile(_ context.Context, path string) error { +func (o *fsObject) StoreFile(ctx context.Context, path string, cfg *CompressConfig) (*FrameTable, [32]byte, error) { + if cfg.IsEnabled() { + return o.storeFileCompressed(ctx, path, cfg) + } + r, err := os.Open(path) if err != nil { - return fmt.Errorf("failed to open file %s: %w", path, err) + return nil, [32]byte{}, fmt.Errorf("failed to open file %s: %w", path, err) } defer r.Close() handle, err := o.getHandle(false) if err != nil { - return err + return nil, [32]byte{}, err } defer handle.Close() _, err = io.Copy(handle, r) + + return nil, [32]byte{}, err +} + +func (o *fsObject) storeFileCompressed(ctx context.Context, localPath string, cfg *CompressConfig) (*FrameTable, [32]byte, error) { + file, err := os.Open(localPath) if err != nil { - return err + return nil, [32]byte{}, fmt.Errorf("failed to open local file %s: %w", localPath, err) } + defer file.Close() - return nil + fi, err := file.Stat() + if err != nil { + return nil, [32]byte{}, fmt.Errorf("failed to stat local file %s: %w", localPath, err) + } + + // Write .uncompressed-size sidecar so Size() returns the correct value. + sidecarPath := o.path + "." + MetadataKeyUncompressedSize + if writeErr := os.WriteFile(sidecarPath, []byte(strconv.FormatInt(fi.Size(), 10)), 0o644); writeErr != nil { + return nil, [32]byte{}, fmt.Errorf("failed to write uncompressed-size sidecar for %s: %w", o.path, writeErr) + } + + uploader := &fsPartUploader{fullPath: o.path} + + return compressStream(ctx, file, cfg, uploader, 4) } -func (o *fsObject) OpenRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { +func (o *fsObject) openRangeReader(_ context.Context, off, length int64) (io.ReadCloser, error) { f, err := o.getHandle(true) if err != nil { return nil, err @@ -157,16 +182,6 @@ func (o *fsObject) OpenRangeReader(_ context.Context, off, length int64) (io.Rea }, nil } -func (o *fsObject) ReadAt(_ context.Context, buff []byte, off int64) (n int, err error) { - handle, err := o.getHandle(true) - if err != nil { - return 0, err - } - defer handle.Close() - - return handle.ReadAt(buff, off) -} - func (o *fsObject) Exists(_ context.Context) (bool, error) { _, err := os.Stat(o.path) if os.IsNotExist(err) { @@ -188,6 +203,14 @@ func (o *fsObject) Size(_ context.Context) (int64, error) { return 0, err } + // Check for .uncompressed-size sidecar file + sidecarPath := o.path + "." + MetadataKeyUncompressedSize + if sidecarData, sidecarErr := os.ReadFile(sidecarPath); sidecarErr == nil { + if parsed, parseErr := strconv.ParseInt(strings.TrimSpace(string(sidecarData)), 10, 64); parseErr == nil { + return parsed, nil + } + } + return fileInfo.Size(), nil } @@ -239,3 +262,38 @@ func (o *fsObject) getHandle(checkExistence bool) (*os.File, error) { return handle, nil } + +// fsPartUploader implements partUploader for local filesystem. +// Embeds memPartUploader for concurrent-safe part collection, +// then writes atomically on Complete. +type fsPartUploader struct { + memPartUploader + + fullPath string +} + +func (u *fsPartUploader) Complete(_ context.Context) error { + if err := os.MkdirAll(filepath.Dir(u.fullPath), 0o755); err != nil { + return fmt.Errorf("failed to create directory: %w", err) + } + + return os.WriteFile(u.fullPath, u.Assemble(), 0o644) +} + +func (o *fsObject) OpenRangeReader(ctx context.Context, offsetU int64, length int64, frameTable *FrameTable) (io.ReadCloser, error) { + if frameTable.IsCompressed() { + frameStart, frameSize, err := frameTable.FrameFor(offsetU) + if err != nil { + return nil, fmt.Errorf("get frame for offset %d, FS:%s: %w", offsetU, o.path, err) + } + + raw, err := o.openRangeReader(ctx, frameStart.C, int64(frameSize.C)) + if err != nil { + return nil, err + } + + return newDecompressingReadCloser(raw, frameTable.CompressionType()) + } + + return o.openRangeReader(ctx, offsetU, length) +} diff --git a/packages/shared/pkg/storage/storage_google.go b/packages/shared/pkg/storage/storage_google.go index 9434963c44..0b0477dcfe 100644 --- a/packages/shared/pkg/storage/storage_google.go +++ b/packages/shared/pkg/storage/storage_google.go @@ -10,6 +10,7 @@ import ( "io" "net/http" "os" + "strconv" "time" "cloud.google.com/go/storage" @@ -45,12 +46,13 @@ const ( gcloudDefaultUploadConcurrency = 16 gcsOperationAttr = "operation" - gcsOperationAttrReadAt = "ReadAt" gcsOperationAttrWrite = "Write" gcsOperationAttrWriteFromFileSystem = "WriteFromFileSystem" gcsOperationAttrWriteFromFileSystemOneShot = "WriteFromFileSystemOneShot" gcsOperationAttrWriteTo = "WriteTo" gcsOperationAttrSize = "Size" + gcsOperationAttrReadAt = "ReadAt" + gcsOperationAttrOpenReader = "OpenRangeReader" ) var ( @@ -245,10 +247,17 @@ func (o *gcpObject) Size(ctx context.Context) (int64, error) { timer.Success(ctx, 0) + if v, ok := attrs.Metadata[MetadataKeyUncompressedSize]; ok { + parsed, parseErr := strconv.ParseInt(v, 10, 64) + if parseErr == nil { + return parsed, nil + } + } + return attrs.Size, nil } -func (o *gcpObject) OpenRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { +func (o *gcpObject) openRangeReader(ctx context.Context, off, length int64) (io.ReadCloser, error) { ctx, cancel := context.WithTimeout(ctx, googleReadTimeout) reader, err := o.handle.NewRangeReader(ctx, off, length) @@ -384,7 +393,7 @@ func (o *gcpObject) WriteTo(ctx context.Context, dst io.Writer) (int64, error) { return n, nil } -func (o *gcpObject) StoreFile(ctx context.Context, path string) (e error) { +func (o *gcpObject) StoreFile(ctx context.Context, path string, cfg *CompressConfig) (_ *FrameTable, _ [32]byte, e error) { ctx, span := tracer.Start(ctx, "write to gcp from file system") defer func() { recordError(span, e) @@ -396,7 +405,33 @@ func (o *gcpObject) StoreFile(ctx context.Context, path string) (e error) { fileInfo, err := os.Stat(path) if err != nil { - return fmt.Errorf("failed to get file size: %w", err) + return nil, [32]byte{}, fmt.Errorf("failed to get file size: %w", err) + } + + timer := googleWriteTimerFactory.Begin( + attribute.String(gcsOperationAttr, gcsOperationAttrWriteFromFileSystem), + ) + + maxConcurrency := gcloudDefaultUploadConcurrency + if o.limiter != nil { + uploadLimiter := o.limiter.GCloudUploadLimiter() + if uploadLimiter != nil { + semaphoreErr := uploadLimiter.Acquire(ctx, 1) + if semaphoreErr != nil { + timer.Failure(ctx, 0) + + return nil, [32]byte{}, fmt.Errorf("failed to acquire semaphore: %w", semaphoreErr) + } + defer uploadLimiter.Release(1) + } + + maxConcurrency = o.limiter.GCloudMaxTasks(ctx) + } + + // Compressed uploads always go through the multipart compressed path, + // regardless of file size. + if cfg.IsEnabled() { + return o.storeFileCompressed(ctx, path, cfg, maxConcurrency) } // If the file is too small, the overhead of writing in parallel isn't worth the effort. @@ -410,39 +445,19 @@ func (o *gcpObject) StoreFile(ctx context.Context, path string) (e error) { if err != nil { timer.Failure(ctx, 0) - return fmt.Errorf("failed to read file: %w", err) + return nil, [32]byte{}, fmt.Errorf("failed to read file: %w", err) } err = o.Put(ctx, data) if err != nil { timer.Failure(ctx, int64(len(data))) - return fmt.Errorf("failed to write file (%d bytes): %w", len(data), err) + return nil, [32]byte{}, fmt.Errorf("failed to write file (%d bytes): %w", len(data), err) } timer.Success(ctx, int64(len(data))) - return nil - } - - timer := googleWriteTimerFactory.Begin( - attribute.String(gcsOperationAttr, gcsOperationAttrWriteFromFileSystem), - ) - - maxConcurrency := gcloudDefaultUploadConcurrency - if o.limiter != nil { - uploadLimiter := o.limiter.GCloudUploadLimiter() - if uploadLimiter != nil { - semaphoreErr := uploadLimiter.Acquire(ctx, 1) - if semaphoreErr != nil { - timer.Failure(ctx, 0) - - return fmt.Errorf("failed to acquire semaphore: %w", semaphoreErr) - } - defer uploadLimiter.Release(1) - } - - maxConcurrency = o.limiter.GCloudMaxTasks(ctx) + return nil, [32]byte{}, e } uploader, err := NewMultipartUploaderWithRetryConfig( @@ -450,11 +465,12 @@ func (o *gcpObject) StoreFile(ctx context.Context, path string) (e error) { bucketName, objectName, DefaultRetryConfig(), + nil, ) if err != nil { timer.Failure(ctx, 0) - return fmt.Errorf("failed to create multipart uploader: %w", err) + return nil, [32]byte{}, fmt.Errorf("failed to create multipart uploader: %w", err) } start := time.Now() @@ -462,7 +478,7 @@ func (o *gcpObject) StoreFile(ctx context.Context, path string) (e error) { if err != nil { timer.Failure(ctx, count) - return fmt.Errorf("failed to upload file in parallel: %w", err) + return nil, [32]byte{}, fmt.Errorf("failed to upload file in parallel: %w", err) } logger.L().Debug(ctx, "Uploaded file in parallel", @@ -476,7 +492,35 @@ func (o *gcpObject) StoreFile(ctx context.Context, path string) (e error) { timer.Success(ctx, count) - return nil + return nil, [32]byte{}, e +} + +func (o *gcpObject) storeFileCompressed(ctx context.Context, localPath string, cfg *CompressConfig, maxConcurrency int) (*FrameTable, [32]byte, error) { + file, err := os.Open(localPath) + if err != nil { + return nil, [32]byte{}, fmt.Errorf("failed to open local file %s: %w", localPath, err) + } + defer file.Close() + + fi, err := file.Stat() + if err != nil { + return nil, [32]byte{}, fmt.Errorf("failed to stat local file %s: %w", localPath, err) + } + + uploader, err := NewMultipartUploaderWithRetryConfig( + ctx, + o.storage.bucket.BucketName(), + o.path, + DefaultRetryConfig(), + map[string]string{ + MetadataKeyUncompressedSize: strconv.FormatInt(fi.Size(), 10), + }, + ) + if err != nil { + return nil, [32]byte{}, fmt.Errorf("failed to create multipart uploader: %w", err) + } + + return compressStream(ctx, file, cfg, uploader, maxConcurrency) } type gcpServiceToken struct { @@ -498,6 +542,45 @@ func parseServiceAccountBase64(serviceAccount string) (*gcpServiceToken, error) return &sa, nil } +func (o *gcpObject) OpenRangeReader(ctx context.Context, offsetU int64, length int64, frameTable *FrameTable) (io.ReadCloser, error) { + timer := googleReadTimerFactory.Begin(attribute.String(gcsOperationAttr, gcsOperationAttrOpenReader)) + + if !frameTable.IsCompressed() { + rc, err := o.openRangeReader(ctx, offsetU, length) + if err != nil { + timer.Failure(ctx, 0) + + return nil, err + } + + return &timedReadCloser{inner: rc, timer: timer, ctx: ctx}, nil + } + + frameStart, frameSize, err := frameTable.FrameFor(offsetU) + if err != nil { + timer.Failure(ctx, 0) + + return nil, fmt.Errorf("get frame for offset %d, GCS:%s: %w", offsetU, o.path, err) + } + + raw, err := o.openRangeReader(ctx, frameStart.C, int64(frameSize.C)) + if err != nil { + timer.Failure(ctx, 0) + + return nil, err + } + + decompressed, err := newDecompressingReadCloser(raw, frameTable.CompressionType()) + if err != nil { + raw.Close() + timer.Failure(ctx, 0) + + return nil, err + } + + return &timedReadCloser{inner: decompressed, timer: timer, ctx: ctx}, nil +} + func isResourceExhausted(err error) bool { type grpcStatusProvider interface { GRPCStatus() *status.Status diff --git a/tests/integration/Makefile b/tests/integration/Makefile index 00349fcfd4..13b52698be 100644 --- a/tests/integration/Makefile +++ b/tests/integration/Makefile @@ -40,9 +40,9 @@ test/%: *.go:*) \ BASE=$${TEST_PATH%%:*}; \ TEST_FN=$${TEST_PATH#*:}; \ - go tool gotestsum --rerun-fails=1 --packages="$$BASE" --format standard-verbose --junitfile=test-results.xml -- -count=1 -parallel=4 -run "$${TEST_FN}" ;; \ - *.go) go tool gotestsum --rerun-fails=1 --packages="$$TEST_PATH" --format standard-verbose --junitfile=test-results.xml -- -count=1 -parallel=4 ;; \ - *) go tool gotestsum --rerun-fails=1 --packages="$$TEST_PATH/..." --format standard-verbose --junitfile=test-results.xml -- -count=1 -parallel=4 ;; \ + go tool gotestsum --rerun-fails=1 --packages="$$BASE" --format standard-verbose --junitfile=test-results.xml -- -count=1 -parallel=4 -timeout=20m -run "$${TEST_FN}" ;; \ + *.go) go tool gotestsum --rerun-fails=1 --packages="$$TEST_PATH" --format standard-verbose --junitfile=test-results.xml -- -count=1 -parallel=4 -timeout=20m ;; \ + *) go tool gotestsum --rerun-fails=1 --packages="$$TEST_PATH/..." --format standard-verbose --junitfile=test-results.xml -- -count=1 -parallel=4 -timeout=20m ;; \ esac .PHONY: connect-orchestrator