diff --git a/packages/orchestrator/cmd/show-build-diff/main.go b/packages/orchestrator/cmd/show-build-diff/main.go index defa10be4a..edb59684ad 100644 --- a/packages/orchestrator/cmd/show-build-diff/main.go +++ b/packages/orchestrator/cmd/show-build-diff/main.go @@ -142,7 +142,10 @@ func main() { ) } - mergedHeader := header.MergeMappings(baseHeader.Mapping, onlyDiffMappings) + mergedHeader, err := header.MergeMappings(baseHeader.Mapping, onlyDiffMappings) + if err != nil { + log.Fatalf("merge mappings: %v", err) + } fmt.Printf("\n\nMERGED METADATA\n") fmt.Printf("========\n") diff --git a/packages/shared/go.mod b/packages/shared/go.mod index 4a08ce0826..0b27112271 100644 --- a/packages/shared/go.mod +++ b/packages/shared/go.mod @@ -30,11 +30,13 @@ require ( github.com/hashicorp/go-retryablehttp v0.7.7 github.com/hashicorp/nomad/api v0.0.0-20251216171439-1dee0671280e github.com/jellydator/ttlcache/v3 v3.4.0 + github.com/klauspost/compress v1.18.2 github.com/launchdarkly/go-sdk-common/v3 v3.3.0 github.com/launchdarkly/go-server-sdk/v7 v7.13.0 github.com/ngrok/firewall_toolkit v0.0.18 github.com/oapi-codegen/runtime v1.1.1 github.com/orcaman/concurrent-map/v2 v2.0.1 + github.com/pierrec/lz4/v4 v4.1.22 github.com/redis/go-redis/extra/redisotel/v9 v9.17.3 github.com/redis/go-redis/v9 v9.17.3 github.com/stretchr/testify v1.11.1 @@ -229,7 +231,6 @@ require ( github.com/json-iterator/go v1.1.12 // indirect github.com/julienschmidt/httprouter v1.3.0 // indirect github.com/kamstrup/intmap v0.5.1 // indirect - github.com/klauspost/compress v1.18.2 // indirect github.com/klauspost/cpuid/v2 v2.2.11 // indirect github.com/knadh/koanf/maps v0.1.2 // indirect github.com/knadh/koanf/providers/confmap v1.0.0 // indirect @@ -281,7 +282,6 @@ require ( github.com/opentracing/opentracing-go v1.2.1-0.20220228012449-10b1cf09e00b // indirect github.com/patrickmn/go-cache v2.1.0+incompatible // indirect github.com/pelletier/go-toml/v2 v2.2.4 // indirect - github.com/pierrec/lz4/v4 v4.1.22 // indirect github.com/pires/go-proxyproto v0.7.0 // indirect github.com/pkg/browser v0.0.0-20240102092130-5ac0b6a4141c // indirect github.com/pkg/errors v0.9.1 // indirect diff --git a/packages/shared/pkg/featureflags/context.go b/packages/shared/pkg/featureflags/context.go index 4f0e957ff0..79e52b1557 100644 --- a/packages/shared/pkg/featureflags/context.go +++ b/packages/shared/pkg/featureflags/context.go @@ -164,6 +164,14 @@ func VolumeContext(volumeName string) ldcontext.Context { return ldcontext.NewWithKind(VolumeKind, volumeName) } +func CompressFileTypeContext(fileType string) ldcontext.Context { + return ldcontext.NewWithKind(CompressFileTypeKind, fileType) +} + +func CompressUseCaseContext(useCase string) ldcontext.Context { + return ldcontext.NewWithKind(CompressUseCaseKind, useCase) +} + func VersionContext(orchestratorID, commit string) ldcontext.Context { return ldcontext.NewBuilder(orchestratorID). Kind(OrchestratorKind). diff --git a/packages/shared/pkg/featureflags/flags.go b/packages/shared/pkg/featureflags/flags.go index 675d6c811f..fadad31bf4 100644 --- a/packages/shared/pkg/featureflags/flags.go +++ b/packages/shared/pkg/featureflags/flags.go @@ -18,14 +18,16 @@ const ( SandboxKernelVersionAttribute string = "kernel-version" SandboxFirecrackerVersionAttribute string = "firecracker-version" - TeamKind ldcontext.Kind = "team" - UserKind ldcontext.Kind = "user" - ClusterKind ldcontext.Kind = "cluster" - deploymentKind ldcontext.Kind = "deployment" - TierKind ldcontext.Kind = "tier" - ServiceKind ldcontext.Kind = "service" - TemplateKind ldcontext.Kind = "template" - VolumeKind ldcontext.Kind = "volume" + TeamKind ldcontext.Kind = "team" + UserKind ldcontext.Kind = "user" + ClusterKind ldcontext.Kind = "cluster" + deploymentKind ldcontext.Kind = "deployment" + TierKind ldcontext.Kind = "tier" + ServiceKind ldcontext.Kind = "service" + TemplateKind ldcontext.Kind = "template" + VolumeKind ldcontext.Kind = "volume" + CompressFileTypeKind ldcontext.Kind = "compress-file-type" + CompressUseCaseKind ldcontext.Kind = "compress-use-case" OrchestratorKind ldcontext.Kind = "orchestrator" OrchestratorCommitAttribute string = "commit" @@ -326,6 +328,19 @@ var ChunkerConfigFlag = newJSONFlag("chunker-config", ldvalue.FromJSONMarshal(ma "minReadBatchSizeKB": 16, })) +// CompressConfigFlag controls compression during template builds. +// When compressBuilds is true, builds upload exclusively compressed data +// (no uncompressed fallback). When false, exclusively uncompressed with V3 headers. +var CompressConfigFlag = newJSONFlag("compress-config", ldvalue.FromJSONMarshal(map[string]any{ + "compressBuilds": false, + "compressionType": "zstd", + "compressionLevel": 2, + "frameSizeKB": 2048, + "targetPartSizeMB": 50, + "frameEncodeWorkers": 4, + "encoderConcurrency": 1, +})) + // TCPFirewallEgressThrottleConfig controls per-sandbox egress throttling via Firecracker's // VMM-level token bucket rate limiters on the network interface. // Structure mirrors the Firecracker RateLimiter API: two independent token buckets. diff --git a/packages/shared/pkg/storage/compress_config.go b/packages/shared/pkg/storage/compress_config.go new file mode 100644 index 0000000000..50e92d16e1 --- /dev/null +++ b/packages/shared/pkg/storage/compress_config.go @@ -0,0 +1,131 @@ +package storage + +import ( + "context" + "fmt" + + "github.com/e2b-dev/infra/packages/shared/pkg/featureflags" +) + +// CompressConfig is the base compression configuration, loaded from environment +// variables at startup. Feature flags can override individual fields at runtime +// via ResolveCompressConfig. +type CompressConfig struct { + Enabled bool `env:"COMPRESS_ENABLED" envDefault:"false"` + Type string `env:"COMPRESS_TYPE" envDefault:"zstd"` + Level int `env:"COMPRESS_LEVEL" envDefault:"2"` + FrameSizeKB int `env:"COMPRESS_FRAME_SIZE_KB" envDefault:"2048"` + TargetPartSizeMB int `env:"COMPRESS_TARGET_PART_SIZE_MB" envDefault:"50"` + FrameEncodeWorkers int `env:"COMPRESS_FRAME_ENCODE_WORKERS" envDefault:"4"` + EncoderConcurrency int `env:"COMPRESS_ENCODER_CONCURRENCY" envDefault:"1"` +} + +// CompressionType returns the parsed CompressionType. +func (c *CompressConfig) CompressionType() CompressionType { + if c == nil { + return CompressionNone + } + + return ParseCompressionType(c.Type) +} + +// FrameSize returns the frame size in bytes. +func (c *CompressConfig) FrameSize() int { + if c == nil || c.FrameSizeKB <= 0 { + return DefaultCompressFrameSize + } + + return c.FrameSizeKB * 1024 +} + +// TargetPartSize returns the target part size in bytes. +func (c *CompressConfig) TargetPartSize() int64 { + if c == nil || c.TargetPartSizeMB <= 0 { + return int64(gcpMultipartUploadChunkSize) + } + + return int64(c.TargetPartSizeMB) * (1 << 20) +} + +// IsEnabled reports whether compression is configured and active. +func (c *CompressConfig) IsEnabled() bool { + return c != nil && c.Enabled && c.CompressionType() != CompressionNone +} + +// Validate checks that the config is internally consistent. +func (c *CompressConfig) Validate() error { + if c == nil || !c.IsEnabled() { + return nil + } + + fs := c.FrameSize() + if fs <= 0 { + return fmt.Errorf("frame size must be positive, got %d KB", c.FrameSizeKB) + } + if MemoryChunkSize%fs != 0 && fs%MemoryChunkSize != 0 { + return fmt.Errorf("frame size (%d) must be a divisor or multiple of MemoryChunkSize (%d)", fs, MemoryChunkSize) + } + + return nil +} + +// Resolve returns a pointer to this config if compression is enabled, or nil. +// Callers use nil to mean "no compression". +func (c *CompressConfig) Resolve() *CompressConfig { + if c == nil || !c.IsEnabled() { + return nil + } + + return c +} + +// CompressConfigFromLDValue parses the LaunchDarkly CompressConfigFlag JSON +// into a CompressConfig. Returns nil if the flag disables compression. +func CompressConfigFromLDValue(ctx context.Context, ff *featureflags.Client) *CompressConfig { + if ff == nil { + return nil + } + + v := ff.JSONFlag(ctx, featureflags.CompressConfigFlag).AsValueMap() + + if !v.Get("compressBuilds").BoolValue() { + return nil + } + + ct := v.Get("compressionType").StringValue() + if ParseCompressionType(ct) == CompressionNone { + return nil + } + + return &CompressConfig{ + Enabled: true, + Type: ct, + Level: v.Get("compressionLevel").IntValue(), + FrameSizeKB: v.Get("frameSizeKB").IntValue(), + TargetPartSizeMB: v.Get("targetPartSizeMB").IntValue(), + FrameEncodeWorkers: v.Get("frameEncodeWorkers").IntValue(), + EncoderConcurrency: v.Get("encoderConcurrency").IntValue(), + } +} + +// ResolveCompressConfig returns the effective compression config for a given +// file type and use case. Feature flags override the base config when active. +// Returns nil when compression is disabled. +// +// fileType and useCase are added to the LD evaluation context so that +// LaunchDarkly targeting rules can differentiate (e.g. compress memfile +// but not rootfs, or compress builds but not pauses). +func ResolveCompressConfig(ctx context.Context, base CompressConfig, ff *featureflags.Client, fileType, useCase string) *CompressConfig { + if ff != nil { + ctx = featureflags.AddToContext(ctx, + featureflags.CompressFileTypeContext(fileType), + featureflags.CompressUseCaseContext(useCase), + ) + + if override := CompressConfigFromLDValue(ctx, ff); override != nil { + return override + } + } + + return base.Resolve() +} diff --git a/packages/shared/pkg/storage/compress_frame_table.go b/packages/shared/pkg/storage/compress_frame_table.go new file mode 100644 index 0000000000..74dfaa2637 --- /dev/null +++ b/packages/shared/pkg/storage/compress_frame_table.go @@ -0,0 +1,222 @@ +package storage + +import ( + "fmt" +) + +type CompressionType byte + +const ( + CompressionNone = CompressionType(iota) + CompressionZstd + CompressionLZ4 +) + +func (ct CompressionType) Suffix() string { + switch ct { + case CompressionZstd: + return ".zstd" + case CompressionLZ4: + return ".lz4" + default: + return "" + } +} + +func (ct CompressionType) String() string { + switch ct { + case CompressionZstd: + return "zstd" + case CompressionLZ4: + return "lz4" + default: + return "none" + } +} + +// ParseCompressionType converts a string to CompressionType. +// Returns CompressionNone for unrecognised values. +func ParseCompressionType(s string) CompressionType { + switch s { + case "lz4": + return CompressionLZ4 + case "zstd": + return CompressionZstd + default: + return CompressionNone + } +} + +type FrameOffset struct { + U int64 + C int64 +} + +func (o *FrameOffset) String() string { + return fmt.Sprintf("U:%#x/C:%#x", o.U, o.C) +} + +func (o *FrameOffset) Add(f FrameSize) { + o.U += int64(f.U) + o.C += int64(f.C) +} + +type FrameSize struct { + U int32 + C int32 +} + +func (s FrameSize) String() string { + return fmt.Sprintf("U:%#x/C:%#x", s.U, s.C) +} + +type Range struct { + Start int64 + Length int +} + +func (r Range) String() string { + return fmt.Sprintf("%#x/%#x", r.Start, r.Length) +} + +type FrameTable struct { + compressionType CompressionType + StartAt FrameOffset + Frames []FrameSize +} + +// NewFrameTable creates a FrameTable with the given compression type. +func NewFrameTable(ct CompressionType) *FrameTable { + return &FrameTable{compressionType: ct} +} + +// CompressionType returns the compression type. Nil-safe: returns CompressionNone for nil. +func (ft *FrameTable) CompressionType() CompressionType { + if ft == nil { + return CompressionNone + } + + return ft.compressionType +} + +// IsCompressed reports whether ft is non-nil and has a compression type set. +func (ft *FrameTable) IsCompressed() bool { + return ft != nil && ft.compressionType != CompressionNone +} + +// Range calls fn for each frame overlapping [start, start+length). +func (ft *FrameTable) Range(start, length int64, fn func(offset FrameOffset, frame FrameSize) error) error { + currentOffset := ft.StartAt + for _, frame := range ft.Frames { + frameEnd := currentOffset.U + int64(frame.U) + requestEnd := start + length + if frameEnd <= start { + currentOffset.U += int64(frame.U) + currentOffset.C += int64(frame.C) + + continue + } + if currentOffset.U >= requestEnd { + break + } + + if err := fn(currentOffset, frame); err != nil { + return err + } + currentOffset.U += int64(frame.U) + currentOffset.C += int64(frame.C) + } + + return nil +} + +func (ft *FrameTable) Size() (uncompressed, compressed int64) { + for _, frame := range ft.Frames { + uncompressed += int64(frame.U) + compressed += int64(frame.C) + } + + return uncompressed, compressed +} + +// Subset returns frames covering r. Whole frames only (can't split compressed). +// Stops silently at the end of the frameset if r extends beyond. +func (ft *FrameTable) Subset(r Range) (*FrameTable, error) { + if ft == nil || r.Length == 0 { + return nil, nil + } + if r.Start < ft.StartAt.U { + return nil, fmt.Errorf("requested range starts before the beginning of the frame table") + } + newFrameTable := &FrameTable{ + compressionType: ft.compressionType, + } + + startSet := false + currentOffset := ft.StartAt + requestedEnd := r.Start + int64(r.Length) + for _, frame := range ft.Frames { + frameEnd := currentOffset.U + int64(frame.U) + if frameEnd <= r.Start { + currentOffset.Add(frame) + + continue + } + if currentOffset.U >= requestedEnd { + break + } + + if !startSet { + newFrameTable.StartAt = currentOffset + startSet = true + } + newFrameTable.Frames = append(newFrameTable.Frames, frame) + currentOffset.Add(frame) + } + + if !startSet { + return nil, fmt.Errorf("requested range is beyond the end of the frame table") + } + + return newFrameTable, nil +} + +// FrameFor finds the frame containing the given offset and returns its start position and full size. +func (ft *FrameTable) FrameFor(offset int64) (starts FrameOffset, size FrameSize, err error) { + if ft == nil { + return FrameOffset{}, FrameSize{}, fmt.Errorf("FrameFor called with nil frame table - data is not compressed") + } + + currentOffset := ft.StartAt + for _, frame := range ft.Frames { + frameEnd := currentOffset.U + int64(frame.U) + if offset >= currentOffset.U && offset < frameEnd { + return currentOffset, frame, nil + } + currentOffset.Add(frame) + } + + return FrameOffset{}, FrameSize{}, fmt.Errorf("offset %#x is beyond the end of the frame table", offset) +} + +// GetFetchRange translates a U-space range to C-space using the frame table. +func (ft *FrameTable) GetFetchRange(rangeU Range) (Range, error) { + fetchRange := rangeU + if ft.IsCompressed() { + start, size, err := ft.FrameFor(rangeU.Start) + if err != nil { + return Range{}, fmt.Errorf("getting frame for offset %#x: %w", rangeU.Start, err) + } + endOffset := rangeU.Start + int64(rangeU.Length) + frameEnd := start.U + int64(size.U) + if endOffset > frameEnd { + return Range{}, fmt.Errorf("range %v spans beyond frame ending at %#x", rangeU, frameEnd) + } + fetchRange = Range{ + Start: start.C, + Length: int(size.C), + } + } + + return fetchRange, nil +} diff --git a/packages/shared/pkg/storage/compress_frame_table_test.go b/packages/shared/pkg/storage/compress_frame_table_test.go new file mode 100644 index 0000000000..c06647eede --- /dev/null +++ b/packages/shared/pkg/storage/compress_frame_table_test.go @@ -0,0 +1,246 @@ +package storage + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/require" +) + +// threeFrameFT returns a FrameTable with three 1MB uncompressed frames +// and varying compressed sizes, starting at the given offset. +func threeFrameFT(startU, startC int64) *FrameTable { + ft := &FrameTable{ + compressionType: CompressionLZ4, + StartAt: FrameOffset{U: startU, C: startC}, + Frames: []FrameSize{ + {U: 1 << 20, C: 500_000}, // frame 0 + {U: 1 << 20, C: 600_000}, // frame 1 + {U: 1 << 20, C: 400_000}, // frame 2 + }, + } + + return ft +} + +// collectRange calls ft.Range and returns the offsets visited. +func collectRange(ft *FrameTable, start, length int64) ([]FrameOffset, error) { + var offsets []FrameOffset + err := ft.Range(start, length, func(offset FrameOffset, _ FrameSize) error { + offsets = append(offsets, offset) + + return nil + }) + + return offsets, err +} + +func TestRange(t *testing.T) { + t.Parallel() + ft := threeFrameFT(0, 0) + + t.Run("selects all frames", func(t *testing.T) { + t.Parallel() + offsets, err := collectRange(ft, 0, 3<<20) + require.NoError(t, err) + require.Len(t, offsets, 3) + }) + + t.Run("selects single middle frame", func(t *testing.T) { + t.Parallel() + offsets, err := collectRange(ft, 1<<20, 1<<20) + require.NoError(t, err) + require.Len(t, offsets, 1) + require.Equal(t, int64(1<<20), offsets[0].U) + require.Equal(t, int64(500_000), offsets[0].C) + }) + + t.Run("partial overlap selects touched frames", func(t *testing.T) { + t.Parallel() + // 1 byte spanning frames 0 and 1 boundary. + offsets, err := collectRange(ft, (1<<20)-1, 2) + require.NoError(t, err) + require.Len(t, offsets, 2) + }) + + t.Run("beyond end returns nothing", func(t *testing.T) { + t.Parallel() + offsets, err := collectRange(ft, 3<<20, 1) + require.NoError(t, err) + require.Empty(t, offsets) + }) + + t.Run("callback error propagates", func(t *testing.T) { + t.Parallel() + sentinel := fmt.Errorf("stop") + err := ft.Range(0, 3<<20, func(_ FrameOffset, _ FrameSize) error { + return sentinel + }) + require.ErrorIs(t, err, sentinel) + }) + + t.Run("respects StartAt on subset", func(t *testing.T) { + t.Parallel() + sub, err := ft.Subset(Range{Start: 1 << 20, Length: 2 << 20}) + require.NoError(t, err) + + // Query for offset 2MB — the second frame of the subset. + offsets, err := collectRange(sub, 2<<20, 1<<20) + require.NoError(t, err) + require.Len(t, offsets, 1) + require.Equal(t, int64(2<<20), offsets[0].U) + require.Equal(t, int64(1_100_000), offsets[0].C) // 500k + 600k + + // Query for offset 0 — before the subset, should find nothing. + offsets, err = collectRange(sub, 0, 1<<20) + require.NoError(t, err) + require.Empty(t, offsets, "Range should not find frames before StartAt") + }) +} + +func TestSubset(t *testing.T) { + t.Parallel() + ft := threeFrameFT(0, 0) + + t.Run("full range", func(t *testing.T) { + t.Parallel() + sub, err := ft.Subset(Range{Start: 0, Length: 3 << 20}) + require.NoError(t, err) + require.Len(t, sub.Frames, 3) + require.Equal(t, int64(0), sub.StartAt.U) + }) + + t.Run("last frame", func(t *testing.T) { + t.Parallel() + sub, err := ft.Subset(Range{Start: 2 << 20, Length: 1 << 20}) + require.NoError(t, err) + require.Len(t, sub.Frames, 1) + require.Equal(t, int64(2<<20), sub.StartAt.U) + require.Equal(t, int64(1_100_000), sub.StartAt.C) + require.Equal(t, int32(400_000), sub.Frames[0].C) + }) + + t.Run("preserves compression type", func(t *testing.T) { + t.Parallel() + sub, err := ft.Subset(Range{Start: 0, Length: 1 << 20}) + require.NoError(t, err) + require.Equal(t, CompressionLZ4, sub.CompressionType()) + }) + + t.Run("nil table returns nil", func(t *testing.T) { + t.Parallel() + sub, err := (*FrameTable)(nil).Subset(Range{Start: 0, Length: 100}) + require.NoError(t, err) + require.Nil(t, sub) + }) + + t.Run("zero length returns nil", func(t *testing.T) { + t.Parallel() + sub, err := ft.Subset(Range{Start: 0, Length: 0}) + require.NoError(t, err) + require.Nil(t, sub) + }) + + t.Run("before StartAt errors", func(t *testing.T) { + t.Parallel() + sub := threeFrameFT(1<<20, 500_000) + _, err := sub.Subset(Range{Start: 0, Length: 1 << 20}) + require.Error(t, err) + }) + + t.Run("beyond end errors", func(t *testing.T) { + t.Parallel() + _, err := ft.Subset(Range{Start: 4 << 20, Length: 1 << 20}) + require.Error(t, err) + }) +} + +func TestFrameFor(t *testing.T) { + t.Parallel() + ft := threeFrameFT(0, 0) + + t.Run("first byte of each frame", func(t *testing.T) { + t.Parallel() + for i, wantU := range []int64{0, 1 << 20, 2 << 20} { + start, size, err := ft.FrameFor(wantU) + require.NoError(t, err, "frame %d", i) + require.Equal(t, wantU, start.U) + require.Equal(t, int32(1<<20), size.U) + } + }) + + t.Run("last byte of frame", func(t *testing.T) { + t.Parallel() + start, _, err := ft.FrameFor((1 << 20) - 1) + require.NoError(t, err) + require.Equal(t, int64(0), start.U) + }) + + t.Run("returns correct C offset", func(t *testing.T) { + t.Parallel() + start, _, err := ft.FrameFor(2 << 20) + require.NoError(t, err) + require.Equal(t, int64(1_100_000), start.C) // 500k + 600k + }) + + t.Run("beyond end errors", func(t *testing.T) { + t.Parallel() + _, _, err := ft.FrameFor(3 << 20) + require.Error(t, err) + }) + + t.Run("nil table errors", func(t *testing.T) { + t.Parallel() + _, _, err := (*FrameTable)(nil).FrameFor(0) + require.Error(t, err) + }) + + t.Run("respects StartAt", func(t *testing.T) { + t.Parallel() + sub := threeFrameFT(1<<20, 500_000) + start, _, err := sub.FrameFor(1 << 20) + require.NoError(t, err) + require.Equal(t, int64(1<<20), start.U) + require.Equal(t, int64(500_000), start.C) + + // Before StartAt — no frame should contain offset 0. + _, _, err = sub.FrameFor(0) + require.Error(t, err) + }) +} + +func TestGetFetchRange(t *testing.T) { + t.Parallel() + ft := threeFrameFT(0, 0) + + t.Run("translates U-space to C-space", func(t *testing.T) { + t.Parallel() + r, err := ft.GetFetchRange(Range{Start: 1 << 20, Length: 1 << 20}) + require.NoError(t, err) + require.Equal(t, int64(500_000), r.Start) + require.Equal(t, 600_000, r.Length) + }) + + t.Run("range spanning multiple frames errors", func(t *testing.T) { + t.Parallel() + _, err := ft.GetFetchRange(Range{Start: 0, Length: 2 << 20}) + require.Error(t, err) + }) + + t.Run("nil table returns input unchanged", func(t *testing.T) { + t.Parallel() + input := Range{Start: 42, Length: 100} + r, err := (*FrameTable)(nil).GetFetchRange(input) + require.NoError(t, err) + require.Equal(t, input, r) + }) + + t.Run("uncompressed table returns input unchanged", func(t *testing.T) { + t.Parallel() + uncompressed := &FrameTable{compressionType: CompressionNone} + input := Range{Start: 42, Length: 100} + r, err := uncompressed.GetFetchRange(input) + require.NoError(t, err) + require.Equal(t, input, r) + }) +} diff --git a/packages/shared/pkg/storage/compress_pool.go b/packages/shared/pkg/storage/compress_pool.go new file mode 100644 index 0000000000..dae91251cd --- /dev/null +++ b/packages/shared/pkg/storage/compress_pool.go @@ -0,0 +1,159 @@ +package storage + +import ( + "bytes" + "fmt" + "io" + "sync" + + "github.com/klauspost/compress/zstd" + lz4 "github.com/pierrec/lz4/v4" +) + +// compressor compresses individual frames. Implementations are pooled and +// reused across frames within a single CompressStream call. +type compressor interface { + compress(src []byte) ([]byte, error) +} + +// lz4Compressor wraps a pooled lz4.Writer. The writer is reused via Reset +// between frames to avoid re-allocating internal hash tables (~64KB). +type lz4Compressor struct { + w *lz4.Writer +} + +func (c *lz4Compressor) compress(src []byte) ([]byte, error) { + var buf bytes.Buffer + buf.Grow(lz4.CompressBlockBound(len(src))) + c.w.Reset(&buf) + + if _, err := c.w.Write(src); err != nil { + return nil, fmt.Errorf("lz4 compress: %w", err) + } + + if err := c.w.Close(); err != nil { + return nil, fmt.Errorf("lz4 compress close: %w", err) + } + + return buf.Bytes(), nil +} + +// zstdCompressor wraps a pooled zstd.Encoder using EncodeAll. +type zstdCompressor struct { + enc *zstd.Encoder +} + +func (z *zstdCompressor) compress(src []byte) ([]byte, error) { //nolint:unparam // satisfies compressor interface + return z.enc.EncodeAll(src, make([]byte, 0, len(src))), nil +} + +// newCompressorPool returns a pool of compressors for the given config. +// Both LZ4 and zstd encoders are pooled and reused via Reset/EncodeAll. +// The config is validated eagerly — if zstd options are invalid, an error +// is returned immediately rather than deferred to pool.Get(). +func newCompressorPool(cfg *CompressConfig) (*sync.Pool, error) { + pool := &sync.Pool{} + + switch cfg.CompressionType() { + case CompressionZstd: + zstdOpts := []zstd.EOption{ + zstd.WithEncoderLevel(zstd.EncoderLevel(cfg.Level)), + zstd.WithEncoderCRC(true), + } + if cfg.FrameSize() > 0 { + zstdOpts = append(zstdOpts, zstd.WithWindowSize(cfg.FrameSize())) + } + if cfg.EncoderConcurrency > 0 { + zstdOpts = append(zstdOpts, zstd.WithEncoderConcurrency(cfg.EncoderConcurrency)) + } + + // Validate options by creating one encoder upfront. + first, err := zstd.NewWriter(nil, zstdOpts...) + if err != nil { + return nil, fmt.Errorf("zstd encoder: %w", err) + } + pool.Put(&zstdCompressor{enc: first}) + + pool.New = func() any { + // Options are already validated; NewWriter won't fail. + enc, _ := zstd.NewWriter(nil, zstdOpts...) + + return &zstdCompressor{enc: enc} + } + case CompressionLZ4: + lz4Opts := []lz4.Option{ + lz4.BlockSizeOption(lz4.Block4Mb), + lz4.BlockChecksumOption(true), + lz4.ChecksumOption(false), + lz4.ConcurrencyOption(1), + lz4.CompressionLevelOption(lz4.Fast), + } + + // Validate options by creating one encoder upfront. + first := lz4.NewWriter(nil) + if err := first.Apply(lz4Opts...); err != nil { + return nil, fmt.Errorf("lz4 encoder: %w", err) + } + pool.Put(&lz4Compressor{w: first}) + + pool.New = func() any { + w := lz4.NewWriter(nil) + _ = w.Apply(lz4Opts...) //nolint:errcheck // options validated above + + return &lz4Compressor{w: w} + } + default: + return nil, fmt.Errorf("unsupported compression type: %s", cfg.CompressionType()) + } + + return pool, nil +} + +var lz4DecoderPool sync.Pool + +func getLZ4Decoder(r io.Reader) *lz4.Reader { + if v := lz4DecoderPool.Get(); v != nil { + dec := v.(*lz4.Reader) + dec.Reset(r) + + return dec + } + + dec := lz4.NewReader(r) + + return dec +} + +func putLZ4Decoder(dec *lz4.Reader) { + dec.Reset(nil) + lz4DecoderPool.Put(dec) +} + +// zstd concurrency is hardcoded to 1: benchmarks show higher values hurt +// throughput for single 2MiB frame decodes. +var zstdDecoderPool sync.Pool + +func getZstdDecoder(r io.Reader) (*zstd.Decoder, error) { + if v := zstdDecoderPool.Get(); v != nil { + dec := v.(*zstd.Decoder) + if err := dec.Reset(r); err != nil { + dec.Close() + + return nil, err + } + + return dec, nil + } + + dec, err := zstd.NewReader(r) + if err != nil { + return nil, err + } + + return dec, nil +} + +func putZstdDecoder(dec *zstd.Decoder) { + dec.Reset(nil) + zstdDecoderPool.Put(dec) +} diff --git a/packages/shared/pkg/storage/compress_upload.go b/packages/shared/pkg/storage/compress_upload.go new file mode 100644 index 0000000000..f2b0b0969b --- /dev/null +++ b/packages/shared/pkg/storage/compress_upload.go @@ -0,0 +1,280 @@ +package storage + +import ( + "bytes" + "context" + "crypto/sha256" + "errors" + "fmt" + "io" + "slices" + "sync" + "sync/atomic" + + "golang.org/x/sync/errgroup" +) + +// MaxCompressedHeaderSize is the maximum allowed decompressed header size (64 MiB). +// Headers are typically a few hundred KiB (e.g., 100 layers × 256 frames × 32 bytes/frame ≈ 800 KB). +// This is a safety bound to prevent unbounded allocation from corrupt data. +const MaxCompressedHeaderSize = 64 << 20 + +const ( + // DefaultCompressFrameSize is the default uncompressed size of each compression + // frame (2 MiB). Overridable via CompressConfig.FrameSizeKB. + // The last frame in a file may be shorter. + // + // The chunker fetches one frame at a time from storage on a cache miss. + // Larger frame sizes mean more data cached per fetch (faster warm-up and + // fewer GCS round-trips), but higher memory and I/O cost per miss. + // + // This MUST be multiple of every block/page size: + // - header.HugepageSize (2 MiB) — UFFD huge-page size, also used by prefetch + // - header.RootfsBlockSize (4 KiB) — NBD / rootfs block size + DefaultCompressFrameSize = 2 * 1024 * 1024 + + // File type identifiers for per-file-type compression targeting. + FileTypeMemfile = "memfile" + FileTypeRootfs = "rootfs" + + // Use case identifiers for per-use-case compression targeting. + UseCaseBuild = "build" + UseCasePause = "pause" +) + +// partUploader is the interface for uploading data in parts. +// Implementations exist for GCS multipart uploads and local file writes. +type partUploader interface { + Start(ctx context.Context) error + UploadPart(ctx context.Context, partIndex int, data ...[]byte) error + Complete(ctx context.Context) error + Close() error +} + +// memPartUploader collects compressed parts in memory. Thread-safe. +// Useful for tests and benchmarks that need CompressStream output as bytes. +type memPartUploader struct { + mu sync.Mutex + parts map[int][]byte +} + +func (m *memPartUploader) Start(context.Context) error { + m.parts = make(map[int][]byte) + + return nil +} + +func (m *memPartUploader) UploadPart(_ context.Context, partIndex int, data ...[]byte) error { + var buf bytes.Buffer + for _, d := range data { + buf.Write(d) + } + m.mu.Lock() + m.parts[partIndex] = buf.Bytes() + m.mu.Unlock() + + return nil +} + +func (m *memPartUploader) Complete(context.Context) error { return nil } +func (m *memPartUploader) Close() error { return nil } + +// Assemble returns the concatenated parts in index order. +func (m *memPartUploader) Assemble() []byte { + keys := make([]int, 0, len(m.parts)) + for k := range m.parts { + keys = append(keys, k) + } + slices.Sort(keys) + + var buf bytes.Buffer + for _, k := range keys { + buf.Write(m.parts[k]) + } + + return buf.Bytes() +} + +type frame struct { + uncompressedSize int + compressed []byte +} + +type part struct { + index int + frames []*frame + compressedSize atomic.Int64 + eg *errgroup.Group + readyToUpload chan error +} + +func newPart(index int, parentCtx context.Context, workers int) (p *part, ctx context.Context) { + p = &part{index: index} + p.eg, ctx = errgroup.WithContext(parentCtx) + p.eg.SetLimit(workers) + + return p, ctx +} + +func (p *part) addFrame(ctx context.Context, uncompressedData []byte, pool *sync.Pool) { + if len(uncompressedData) == 0 { + return + } + + frameInPart := &frame{uncompressedSize: len(uncompressedData)} + p.frames = append(p.frames, frameInPart) + + p.eg.Go(func() error { + if err := ctx.Err(); err != nil { + return err + } + c := pool.Get().(compressor) + out, err := c.compress(uncompressedData) + pool.Put(c) + if err != nil { + return err + } + frameInPart.compressed = out + p.compressedSize.Add(int64(len(out))) + + return nil + }) +} + +func (p *part) submit(ctx context.Context, queue chan<- *part) { + p.readyToUpload = make(chan error, 1) + + go func() { + p.readyToUpload <- p.eg.Wait() + close(p.readyToUpload) + }() + + select { + case queue <- p: + case <-ctx.Done(): + } +} + +// compressStream: read → compress (parallel) → emit metadata (ordered) → upload (concurrent). +func compressStream(ctx context.Context, in io.Reader, cfg *CompressConfig, uploader partUploader, maxUploadConcurrency int) (ft *FrameTable, checksum [32]byte, err error) { //nolint:unparam // callers in later PRs pass different values + frameSize := cfg.FrameSize() + targetPartSize := cfg.TargetPartSize() + + if err := uploader.Start(ctx); err != nil { + return nil, [32]byte{}, fmt.Errorf("failed to start framed upload: %w", err) + } + defer uploader.Close() + + // for compression we create a pool per file since there are often enough + // frames to justify pooling. + compressors, err := newCompressorPool(cfg) + if err != nil { + return nil, [32]byte{}, err + } + hasher := sha256.New() + + ft = &FrameTable{compressionType: cfg.CompressionType()} + + ctx, cancel := context.WithCancel(ctx) // pipeline errors cancel the read loop + defer cancel() + + q := make(chan *part, maxUploadConcurrency) + var closeQ sync.Once + defer closeQ.Do(func() { close(q) }) + + uploadEG, uploadCtx := errgroup.WithContext(ctx) + uploadEG.SetLimit(maxUploadConcurrency) + + var emitEG errgroup.Group + emitEG.Go(func() error { + for p := range q { + select { + case compressErr := <-p.readyToUpload: + if compressErr != nil { + cancel() + + return compressErr + } + case <-ctx.Done(): + return ctx.Err() + } + + var compressed [][]byte + for _, f := range p.frames { + ft.Frames = append(ft.Frames, FrameSize{U: int32(f.uncompressedSize), C: int32(len(f.compressed))}) + compressed = append(compressed, f.compressed) + } + + pi := p.index + uploadEG.Go(func() error { + return uploader.UploadPart(uploadCtx, pi, compressed...) + }) + } + + return nil + }) + + part, compressCtx := newPart(1, ctx, cfg.FrameEncodeWorkers) + for { + if err := ctx.Err(); err != nil { + return nil, [32]byte{}, err + } + + buf := make([]byte, frameSize) + n, err := io.ReadFull(in, buf) + + switch { + case err == nil: + case errors.Is(err, io.EOF): + case errors.Is(err, io.ErrUnexpectedEOF): + // fall through + default: + return nil, [32]byte{}, fmt.Errorf("read frame: %w", err) + } + + if n > 0 { + hasher.Write(buf[:n]) + part.addFrame(compressCtx, buf[:n], compressors) + } + + if err != nil { + break + } + + if part.compressedSize.Load() >= targetPartSize { + part.submit(ctx, q) + part, compressCtx = newPart(part.index+1, ctx, cfg.FrameEncodeWorkers) + } + } + + if len(part.frames) > 0 { + part.submit(ctx, q) + } + + closeQ.Do(func() { close(q) }) + + emitErr := emitEG.Wait() + uploadErr := uploadEG.Wait() + if err := errors.Join(emitErr, uploadErr); err != nil { + return nil, [32]byte{}, err + } + + if err := uploader.Complete(ctx); err != nil { + return nil, [32]byte{}, fmt.Errorf("failed to finish uploading frames: %w", err) + } + + copy(checksum[:], hasher.Sum(nil)) + + return ft, checksum, nil +} + +func CompressBytes(ctx context.Context, data []byte, cfg *CompressConfig) (*FrameTable, []byte, [32]byte, error) { + up := &memPartUploader{} + + ft, checksum, err := compressStream(ctx, bytes.NewReader(data), cfg, up, 4) + if err != nil { + return nil, nil, [32]byte{}, err + } + + return ft, up.Assemble(), checksum, nil +} diff --git a/packages/shared/pkg/storage/compress_upload_test.go b/packages/shared/pkg/storage/compress_upload_test.go new file mode 100644 index 0000000000..67fb6da4c2 --- /dev/null +++ b/packages/shared/pkg/storage/compress_upload_test.go @@ -0,0 +1,373 @@ +package storage + +import ( + "bytes" + "context" + crand "crypto/rand" + "crypto/sha256" + "fmt" + "io" + "math/rand/v2" + "slices" + "sync/atomic" + "testing" + "time" + + "github.com/klauspost/compress/zstd" + "github.com/stretchr/testify/require" + "golang.org/x/sync/errgroup" +) + +// generateSemiRandomData produces deterministic, compressible data. +// Random byte repeated 1-16 times — gives ~0.5-0.7 compression ratio. +func generateSemiRandomData(size int) []byte { + data := make([]byte, size) + rng := rand.New(rand.NewPCG(1, 2)) //nolint:gosec // deterministic + i := 0 + for i < size { + runLen := rng.IntN(16) + 1 + if i+runLen > size { + runLen = size - i + } + b := byte(rng.IntN(256)) + for j := range runLen { + data[i+j] = b + } + i += runLen + } + + return data +} + +// ThrottledPartUploader wraps memPartUploader with simulated upload bandwidth. +type ThrottledPartUploader struct { + memPartUploader + + bandwidth int64 // bytes/sec; 0 = unlimited +} + +func (t *ThrottledPartUploader) UploadPart(ctx context.Context, partIndex int, data ...[]byte) error { + if t.bandwidth > 0 { + total := 0 + for _, d := range data { + total += len(d) + } + time.Sleep(time.Duration(float64(total) / float64(t.bandwidth) * float64(time.Second))) + } + + return t.memPartUploader.UploadPart(ctx, partIndex, data...) +} + +// decompressAll walks the FrameTable and decompresses each frame from the +// concatenated compressed blob, returning the original uncompressed data. +func decompressAll(ft *FrameTable, compressed []byte) ([]byte, error) { + var result []byte + var cOff int64 + + for i, fs := range ft.Frames { + if cOff+int64(fs.C) > int64(len(compressed)) { + return nil, fmt.Errorf("frame %d: compressed data truncated (need %d, have %d)", i, cOff+int64(fs.C), len(compressed)) + } + + frameData := compressed[cOff : cOff+int64(fs.C)] + + var frame []byte + var err error + + switch ft.CompressionType() { + case CompressionLZ4: + dec := getLZ4Decoder(bytes.NewReader(frameData)) + frame, err = io.ReadAll(dec) + putLZ4Decoder(dec) + case CompressionZstd: + var dec *zstd.Decoder + dec, err = getZstdDecoder(bytes.NewReader(frameData)) + if err == nil { + frame, err = io.ReadAll(dec) + putZstdDecoder(dec) + } + } + if err != nil { + return nil, fmt.Errorf("frame %d: %w", i, err) + } + result = append(result, frame...) + cOff += int64(fs.C) + } + + return result, nil +} + +// defaultCfg returns a CompressConfig with the given overrides applied. +func defaultCfg(ct CompressionType, workers, frameSize int) *CompressConfig { + level := 2 // zstd default + if ct == CompressionLZ4 { + level = 0 + } + + return &CompressConfig{ + Enabled: true, + Type: ct.String(), + Level: level, + EncoderConcurrency: 1, + FrameEncodeWorkers: workers, + FrameSizeKB: frameSize / 1024, + TargetPartSizeMB: 50, + } +} + +func TestCompressStreamRoundTrip(t *testing.T) { + t.Parallel() + + tests := []struct { + name string + dataSize int + frameSize int + workers int + codec CompressionType + incompressible bool // use crypto/rand data that cannot be compressed + }{ + {"basic", 10 * megabyte, 2 * megabyte, 4, CompressionZstd, false}, + {"workers_1", 10 * megabyte, 2 * megabyte, 1, CompressionZstd, false}, + {"workers_2", 10 * megabyte, 2 * megabyte, 2, CompressionZstd, false}, + {"not_frame_aligned", 10*megabyte + 1, 2 * megabyte, 4, CompressionZstd, false}, + {"smaller_than_frame", 100 * 1024, 2 * megabyte, 4, CompressionZstd, false}, + {"smaller_than_part", 5 * megabyte, 2 * megabyte, 4, CompressionZstd, false}, + {"empty", 0, 2 * megabyte, 4, CompressionZstd, false}, + {"single_byte", 1, 2 * megabyte, 1, CompressionZstd, false}, + {"lz4", 10 * megabyte, 2 * megabyte, 4, CompressionLZ4, false}, + {"lz4_incompressible", 10 * megabyte, 2 * megabyte, 4, CompressionLZ4, true}, + {"zstd_incompressible", 10 * megabyte, 2 * megabyte, 4, CompressionZstd, true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + var original []byte + if tc.dataSize > 0 { + if tc.incompressible { + original = make([]byte, tc.dataSize) + _, err := crand.Read(original) + require.NoError(t, err) + } else { + original = generateSemiRandomData(tc.dataSize) + } + } + + up := &memPartUploader{} + cfg := defaultCfg(tc.codec, tc.workers, tc.frameSize) + + ft, checksum, err := compressStream( + context.Background(), + bytes.NewReader(original), + cfg, + up, + 4, + ) + require.NoError(t, err) + + if tc.dataSize == 0 { + require.Empty(t, ft.Frames) + require.Equal(t, sha256.Sum256(nil), checksum) + + return + } + + // Verify frame count. + expectedFrames := (tc.dataSize + tc.frameSize - 1) / tc.frameSize + require.Len(t, ft.Frames, expectedFrames) + + // Verify checksum. + require.Equal(t, sha256.Sum256(original), checksum) + + // Round-trip: decompress and compare. + compressed := up.Assemble() + decompressed, err := decompressAll(ft, compressed) + require.NoError(t, err) + require.Equal(t, original, decompressed) + }) + } +} + +func TestCompressStreamContextCancel(t *testing.T) { + t.Parallel() + + data := generateSemiRandomData(10 * megabyte) + + ctx, cancel := context.WithCancel(context.Background()) + go func() { + time.Sleep(10 * time.Millisecond) + cancel() + }() + + up := &memPartUploader{} + cfg := defaultCfg(CompressionZstd, 4, 2*megabyte) + + _, _, err := compressStream(ctx, bytes.NewReader(data), cfg, up, 4) + require.Error(t, err) + require.ErrorIs(t, err, context.Canceled) +} + +func TestCompressStreamPartSizeMinimum(t *testing.T) { + t.Parallel() + + // Generate once; subtests slice to their needed size. + sharedData := generateSemiRandomData(100 * megabyte) + + tests := []struct { + name string + dataSize int + frameSize int + targetPartSizeMB int + }{ + {"large_file", 100 * megabyte, 2 * megabyte, 50}, + {"small_file_one_part", 5 * megabyte, 2 * megabyte, 50}, + {"small_target", 100 * megabyte, 2 * megabyte, 10}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + t.Parallel() + + data := sharedData[:tc.dataSize] + up := &memPartUploader{} + cfg := defaultCfg(CompressionZstd, 4, tc.frameSize) + cfg.TargetPartSizeMB = tc.targetPartSizeMB + + _, _, err := compressStream(context.Background(), bytes.NewReader(data), cfg, up, 4) + require.NoError(t, err) + + // Verify: no non-final part is under 5 MiB. + keys := make([]int, 0, len(up.parts)) + for k := range up.parts { + keys = append(keys, k) + } + slices.Sort(keys) + + for i, k := range keys { + isFinal := i == len(keys)-1 + if !isFinal { + require.GreaterOrEqual(t, len(up.parts[k]), 5*1024*1024, + "non-final part %d is under 5 MiB (%d bytes)", k, len(up.parts[k])) + } + } + + require.NotEmpty(t, up.parts, "should have at least one part") + }) + } +} + +// TestCompressStreamRace runs many concurrent CompressStream calls with high +// worker counts to shake out data races in the compressor pool, memPartUploader, +// and errgroup coordination. Run with -race. +func TestCompressStreamRace(t *testing.T) { + t.Parallel() + + const ( + streams = 8 // concurrent CompressStream calls + dataSize = 4 * megabyte // small enough to be fast, big enough to exercise batching + frameSize = 128 * 1024 // 128 KB — many frames per part + workers = 8 // high worker count to maximise contention + targetPartSizeMB = 1 // small parts → many parts per stream + ) + + data := generateSemiRandomData(dataSize) + wantChecksum := sha256.Sum256(data) + + // Use an errgroup to run all streams concurrently. + eg, ctx := errgroup.WithContext(context.Background()) + for i := range streams { + codec := CompressionZstd + if i%2 == 1 { + codec = CompressionLZ4 // mix codecs for more coverage + } + + eg.Go(func() error { + up := &memPartUploader{} + cfg := defaultCfg(codec, workers, frameSize) + cfg.TargetPartSizeMB = targetPartSizeMB + if codec == CompressionZstd { + cfg.EncoderConcurrency = 4 // multi-threaded zstd encoders for more contention + } + + ft, checksum, err := compressStream(ctx, bytes.NewReader(data), cfg, up, 4) + if err != nil { + return fmt.Errorf("stream %d: compress: %w", i, err) + } + + if checksum != wantChecksum { + return fmt.Errorf("stream %d: checksum mismatch", i) + } + + decompressed, err := decompressAll(ft, up.Assemble()) + if err != nil { + return fmt.Errorf("stream %d: decompress: %w", i, err) + } + + if !bytes.Equal(data, decompressed) { + return fmt.Errorf("stream %d: round-trip data mismatch", i) + } + + return nil + }) + } + + require.NoError(t, eg.Wait()) +} + +func BenchmarkCompress(b *testing.B) { + const dataSize = 256 * megabyte + data := generateSemiRandomData(dataSize) + + configs := []struct { + name string + workers int + bandwidth int64 // bytes/sec; 0 = unlimited + }{ + {"w1_unlimited", 1, 0}, + {"w2_unlimited", 2, 0}, + {"w4_unlimited", 4, 0}, + {"w1_200MBs", 1, 200 * megabyte}, + {"w4_200MBs", 4, 200 * megabyte}, + {"w4_100MBs", 4, 100 * megabyte}, + } + + for _, bcfg := range configs { + b.Run(bcfg.name, func(b *testing.B) { + compCfg := &CompressConfig{ + Enabled: true, + Type: "zstd", + Level: 2, + EncoderConcurrency: 1, + FrameEncodeWorkers: bcfg.workers, + FrameSizeKB: 2 * 1024, + TargetPartSizeMB: 50, + } + + var lastParts atomic.Int32 + + b.ResetTimer() + b.SetBytes(int64(dataSize)) + + for range b.N { + up := &ThrottledPartUploader{bandwidth: bcfg.bandwidth} + + _, _, err := compressStream( + context.Background(), + bytes.NewReader(data), + compCfg, + up, 4, + ) + if err != nil { + b.Fatal(err) + } + + lastParts.Store(int32(len(up.parts))) + } + + // Report after all iterations using last run's values. + // b.SetBytes already reports MB/s (uncompressed throughput). + b.ReportMetric(float64(lastParts.Load()), "parts") + }) + } +} diff --git a/packages/shared/pkg/storage/gcp_multipart.go b/packages/shared/pkg/storage/gcp_multipart.go index 75324c16c1..ee568df86f 100644 --- a/packages/shared/pkg/storage/gcp_multipart.go +++ b/packages/shared/pkg/storage/gcp_multipart.go @@ -139,9 +139,61 @@ type MultipartUploader struct { client *retryablehttp.Client retryConfig RetryConfig baseURL string // Allow overriding for testing + metadata map[string]string + + // Fields for partUploader interface + uploadID string + mu sync.Mutex + parts []Part +} + +var _ partUploader = (*MultipartUploader)(nil) + +// Start initiates the GCS multipart upload. +func (m *MultipartUploader) Start(ctx context.Context) error { + uploadID, err := m.initiateUpload(ctx) + if err != nil { + return fmt.Errorf("failed to initiate multipart upload: %w", err) + } + + m.uploadID = uploadID + + return nil +} + +// UploadPart uploads a single part to GCS. Multiple data slices are hashed +// and uploaded without copying into a single contiguous buffer. +func (m *MultipartUploader) UploadPart(ctx context.Context, partIndex int, data ...[]byte) error { + etag, err := m.uploadPartSlices(ctx, m.uploadID, partIndex, data) + if err != nil { + return fmt.Errorf("failed to upload part %d: %w", partIndex, err) + } + + m.mu.Lock() + m.parts = append(m.parts, Part{ + PartNumber: partIndex, + ETag: etag, + }) + m.mu.Unlock() + + return nil +} + +// Complete finalizes the GCS multipart upload with all collected parts. +func (m *MultipartUploader) Complete(ctx context.Context) error { + m.mu.Lock() + parts := make([]Part, len(m.parts)) + copy(parts, m.parts) + m.mu.Unlock() + + return m.completeUpload(ctx, m.uploadID, parts) +} + +func (m *MultipartUploader) Close() error { + return nil } -func NewMultipartUploaderWithRetryConfig(ctx context.Context, bucketName, objectName string, retryConfig RetryConfig) (*MultipartUploader, error) { +func NewMultipartUploaderWithRetryConfig(ctx context.Context, bucketName, objectName string, retryConfig RetryConfig, metadata map[string]string) (*MultipartUploader, error) { creds, err := google.FindDefaultCredentials(ctx, "https://www.googleapis.com/auth/cloud-platform") if err != nil { return nil, fmt.Errorf("failed to get credentials: %w", err) @@ -159,6 +211,7 @@ func NewMultipartUploaderWithRetryConfig(ctx context.Context, bucketName, object client: createRetryableClient(ctx, retryConfig), retryConfig: retryConfig, baseURL: fmt.Sprintf("https://%s.storage.googleapis.com", bucketName), + metadata: metadata, }, nil } @@ -174,6 +227,10 @@ func (m *MultipartUploader) initiateUpload(ctx context.Context) (string, error) req.Header.Set("Content-Length", "0") req.Header.Set("Content-Type", "application/octet-stream") + for k, v := range m.metadata { + req.Header.Set("x-goog-meta-"+k, v) + } + resp, err := m.client.Do(req) if err != nil { return "", err @@ -232,6 +289,60 @@ func (m *MultipartUploader) uploadPart(ctx context.Context, uploadID string, par return etag, nil } +// uploadPartSlices uploads a part from multiple byte slices without concatenating them. +// It computes MD5 by hashing each slice and uses a ReaderFunc for retryable reads. +func (m *MultipartUploader) uploadPartSlices(ctx context.Context, uploadID string, partNumber int, slices [][]byte) (string, error) { + // Compute MD5 and total length without copying + hasher := md5.New() + totalLen := 0 + for _, s := range slices { + hasher.Write(s) + totalLen += len(s) + } + md5Sum := base64.StdEncoding.EncodeToString(hasher.Sum(nil)) + + url := fmt.Sprintf("%s/%s?partNumber=%d&uploadId=%s", + m.baseURL, m.objectName, partNumber, uploadID) + + // Use a ReaderFunc so the retryable client can replay the body on retries + bodyFn := func() (io.Reader, error) { + readers := make([]io.Reader, len(slices)) + for i, s := range slices { + readers[i] = bytes.NewReader(s) + } + + return io.MultiReader(readers...), nil + } + + req, err := retryablehttp.NewRequestWithContext(ctx, "PUT", url, retryablehttp.ReaderFunc(bodyFn)) + if err != nil { + return "", err + } + + req.Header.Set("Authorization", "Bearer "+m.token) + req.Header.Set("Content-Length", fmt.Sprintf("%d", totalLen)) + req.Header.Set("Content-MD5", md5Sum) + + resp, err := m.client.Do(req) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + body, _ := io.ReadAll(resp.Body) + + return "", fmt.Errorf("failed to upload part %d (status %d): %s", partNumber, resp.StatusCode, string(body)) + } + + etag := resp.Header.Get("ETag") + if etag == "" { + return "", fmt.Errorf("no ETag returned for part %d", partNumber) + } + + return etag, nil +} + func (m *MultipartUploader) completeUpload(ctx context.Context, uploadID string, parts []Part) error { // Sort parts by part number sort.Slice(parts, func(i, j int) bool { diff --git a/packages/shared/pkg/storage/header/header.go b/packages/shared/pkg/storage/header/header.go index 9a1f3008f5..d51452736b 100644 --- a/packages/shared/pkg/storage/header/header.go +++ b/packages/shared/pkg/storage/header/header.go @@ -1,26 +1,62 @@ package header import ( + "cmp" "context" "fmt" + "maps" + "slices" "github.com/bits-and-blooms/bitset" "github.com/google/uuid" "go.uber.org/zap" "github.com/e2b-dev/infra/packages/shared/pkg/logger" + "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) +// BuildFileInfo holds metadata about a build's data file, stored in the header +// so the read path can avoid network round-trips (e.g. Size() calls to GCS). +type BuildFileInfo struct { + Size int64 // uncompressed file size + Checksum [32]byte // SHA-256 of uncompressed data; zero value means unknown +} + const NormalizeFixVersion = 3 type Header struct { Metadata *Metadata + BuildFiles map[uuid.UUID]BuildFileInfo // V4 only: per-build file size + checksum blockStarts *bitset.BitSet startMap map[int64]*BuildMap Mapping []*BuildMap } +// CloneForUpload returns a clone with copied Mapping and BuildFiles, safe to +// mutate for serialization without racing with concurrent readers of the +// original. Only serialization-relevant fields are populated (Metadata, +// Mapping, BuildFiles); lookup indices (blockStarts, startMap) are left nil. +func (t *Header) CloneForUpload() *Header { + mappings := make([]*BuildMap, len(t.Mapping)) + for i, m := range t.Mapping { + mappings[i] = m.Copy() + } + + metaCopy := *t.Metadata + clone := &Header{ + Metadata: &metaCopy, + Mapping: mappings, + } + + if t.BuildFiles != nil { + clone.BuildFiles = make(map[uuid.UUID]BuildFileInfo, len(t.BuildFiles)) + maps.Copy(clone.BuildFiles, t.BuildFiles) + } + + return clone +} + func NewHeader(metadata *Metadata, mapping []*BuildMap) (*Header, error) { if metadata.BlockSize == 0 { return nil, fmt.Errorf("block size cannot be zero") @@ -40,19 +76,75 @@ func NewHeader(metadata *Metadata, mapping []*BuildMap) (*Header, error) { intervals := bitset.New(uint(blocks)) startMap := make(map[int64]*BuildMap, len(mapping)) - for _, mapping := range mapping { - block := BlockIdx(int64(mapping.Offset), int64(metadata.BlockSize)) + for _, m := range mapping { + block := BlockIdx(int64(m.Offset), int64(metadata.BlockSize)) intervals.Set(uint(block)) - startMap[block] = mapping + startMap[block] = m } - return &Header{ + h := &Header{ blockStarts: intervals, Metadata: metadata, Mapping: mapping, startMap: startMap, - }, nil + } + + // Validate header integrity at creation time + if err := ValidateHeader(h); err != nil { + return nil, fmt.Errorf("header validation failed: %w", err) + } + + return h, nil +} + +func (t *Header) String() string { + if t == nil { + return "[nil Header]" + } + + return fmt.Sprintf("[Header: version=%d, size=%d, blockSize=%d, generation=%d, buildId=%s, mappings=%d]", + t.Metadata.Version, + t.Metadata.Size, + t.Metadata.BlockSize, + t.Metadata.Generation, + t.Metadata.BuildId.String(), + len(t.Mapping), + ) +} + +func (t *Header) Mappings(all bool) string { + if t == nil { + return "[nil Header, no mappings]" + } + n := 0 + for _, m := range t.Mapping { + if all || m.BuildId == t.Metadata.BuildId { + n++ + } + } + result := fmt.Sprintf("All mappings: %d\n", n) + if !all { + result = fmt.Sprintf("Mappings for build %s: %d\n", t.Metadata.BuildId.String(), n) + } + for _, m := range t.Mapping { + if !all && m.BuildId != t.Metadata.BuildId { + continue + } + frames := 0 + if m.FrameTable != nil { + frames = len(m.FrameTable.Frames) + } + result += fmt.Sprintf(" - Offset: %#x, Length: %#x, BuildId: %s, BuildStorageOffset: %#x, numFrames: %d\n", + m.Offset, + m.Length, + m.BuildId.String(), + m.BuildStorageOffset, + frames, + ) + } + + return result } // IsNormalizeFixApplied is a helper method to soft fail for older versions of the header where fix for normalization was not applied. @@ -143,3 +235,101 @@ func (t *Header) getMapping(ctx context.Context, offset int64) (*BuildMap, int64 return mapping, shift, nil } + +// ValidateHeader checks header integrity and returns an error if corruption is detected. +// This verifies: +// 1. Header and metadata are valid +// 2. Mappings cover the entire file [0, Size) with no gaps +// 3. Mappings don't extend beyond file size (with block alignment tolerance) +func ValidateHeader(h *Header) error { + if h == nil { + return fmt.Errorf("header is nil") + } + if h.Metadata == nil { + return fmt.Errorf("header metadata is nil") + } + if h.Metadata.BlockSize == 0 { + return fmt.Errorf("header has zero block size") + } + if h.Metadata.Size == 0 { + return fmt.Errorf("header has zero size") + } + if len(h.Mapping) == 0 { + return fmt.Errorf("header has no mappings") + } + + // Sort mappings by offset to check for gaps/overlaps + sortedMappings := make([]*BuildMap, len(h.Mapping)) + copy(sortedMappings, h.Mapping) + slices.SortFunc(sortedMappings, func(a, b *BuildMap) int { + return cmp.Compare(a.Offset, b.Offset) + }) + + // Check that first mapping starts at 0 + if sortedMappings[0].Offset != 0 { + return fmt.Errorf("mappings don't start at 0: first mapping starts at %#x for buildId %s", + sortedMappings[0].Offset, h.Metadata.BuildId.String()) + } + + // Check for gaps and overlaps between consecutive mappings + for i := range len(sortedMappings) - 1 { + currentEnd := sortedMappings[i].Offset + sortedMappings[i].Length + nextStart := sortedMappings[i+1].Offset + + if currentEnd < nextStart { + return fmt.Errorf("gap in mappings: mapping[%d] ends at %#x but mapping[%d] starts at %#x (gap=%d bytes) for buildId %s", + i, currentEnd, i+1, nextStart, nextStart-currentEnd, h.Metadata.BuildId.String()) + } + if currentEnd > nextStart { + return fmt.Errorf("overlap in mappings: mapping[%d] ends at %#x but mapping[%d] starts at %#x (overlap=%d bytes) for buildId %s", + i, currentEnd, i+1, nextStart, currentEnd-nextStart, h.Metadata.BuildId.String()) + } + } + + // Check that last mapping covers up to (at least) Size + lastMapping := sortedMappings[len(sortedMappings)-1] + lastEnd := lastMapping.Offset + lastMapping.Length + if lastEnd < h.Metadata.Size { + return fmt.Errorf("mappings don't cover entire file: last mapping ends at %#x but file size is %#x (missing %d bytes) for buildId %s", + lastEnd, h.Metadata.Size, h.Metadata.Size-lastEnd, h.Metadata.BuildId.String()) + } + + // Allow last mapping to extend up to one block past size (for alignment) + if lastEnd > h.Metadata.Size+h.Metadata.BlockSize { + return fmt.Errorf("last mapping extends too far: ends at %#x but file size is %#x (overhang=%d bytes, max allowed=%d) for buildId %s", + lastEnd, h.Metadata.Size, lastEnd-h.Metadata.Size, h.Metadata.BlockSize, h.Metadata.BuildId.String()) + } + + // Validate individual mapping bounds + for i, m := range h.Mapping { + if m.Offset > h.Metadata.Size { + return fmt.Errorf("mapping[%d] has Offset %#x beyond header size %#x for buildId %s", + i, m.Offset, h.Metadata.Size, m.BuildId.String()) + } + if m.Length == 0 { + return fmt.Errorf("mapping[%d] has zero length at offset %#x for buildId %s", + i, m.Offset, m.BuildId.String()) + } + } + + return nil +} + +// AddFrames associates compression frame information with this header's mappings. +// +// Only mappings matching this header's BuildId will be updated. Returns nil if frameTable is nil. +func (t *Header) AddFrames(frameTable *storage.FrameTable) error { + if frameTable == nil { + return nil + } + + for _, mapping := range t.Mapping { + if mapping.BuildId == t.Metadata.BuildId { + if err := mapping.AddFrames(frameTable); err != nil { + return err + } + } + } + + return nil +} diff --git a/packages/shared/pkg/storage/header/mapping.go b/packages/shared/pkg/storage/header/mapping.go index 0802bb1fe8..512e5a2907 100644 --- a/packages/shared/pkg/storage/header/mapping.go +++ b/packages/shared/pkg/storage/header/mapping.go @@ -6,6 +6,8 @@ import ( "github.com/bits-and-blooms/bitset" "github.com/google/uuid" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) // Start, Length and SourceStart are in bytes of the data file @@ -17,6 +19,7 @@ type BuildMap struct { Length uint64 BuildId uuid.UUID BuildStorageOffset uint64 + FrameTable *storage.FrameTable } func (mapping *BuildMap) Copy() *BuildMap { @@ -25,9 +28,40 @@ func (mapping *BuildMap) Copy() *BuildMap { Length: mapping.Length, BuildId: mapping.BuildId, BuildStorageOffset: mapping.BuildStorageOffset, + FrameTable: mapping.FrameTable, } } +// AddFrames associates compression frame information with this mapping. +// +// When a file is uploaded with compression, the compressor produces a FrameTable +// that describes how the compressed data is organized into frames. This method +// computes which compressed frames cover this mapping's data within the build's +// storage file based on BuildStorageOffset and Length. +// +// Returns nil if frameTable is nil. Returns an error if the mapping's range +// cannot be found in the frame table. +func (mapping *BuildMap) AddFrames(frameTable *storage.FrameTable) error { + if frameTable == nil { + return nil + } + + mappedRange := storage.Range{ + Start: int64(mapping.BuildStorageOffset), + Length: int(mapping.Length), + } + + subset, err := frameTable.Subset(mappedRange) + if err != nil { + return fmt.Errorf("mapping at virtual offset %#x (storage offset %#x, length %#x): %w", + mapping.Offset, mapping.BuildStorageOffset, mapping.Length, err) + } + + mapping.FrameTable = subset + + return nil +} + func CreateMapping( buildId *uuid.UUID, dirty *bitset.BitSet, @@ -84,9 +118,9 @@ func CreateMapping( func MergeMappings( baseMapping []*BuildMap, diffMapping []*BuildMap, -) []*BuildMap { +) ([]*BuildMap, error) { if len(diffMapping) == 0 { - return baseMapping + return baseMapping, nil } baseMappingCopy := make([]*BuildMap, len(baseMapping)) @@ -97,6 +131,7 @@ func MergeMappings( mappings := make([]*BuildMap, 0) + var err error var baseIdx int var diffIdx int @@ -160,6 +195,10 @@ func MergeMappings( // the build storage offset is the same as the base mapping BuildStorageOffset: base.BuildStorageOffset, } + leftBase.FrameTable, err = base.FrameTable.Subset(storage.Range{Start: int64(leftBase.BuildStorageOffset), Length: int(leftBase.Length)}) + if err != nil { + return nil, fmt.Errorf("subset frame table for left split at offset %#x: %w", leftBase.Offset, err) + } mappings = append(mappings, leftBase) } @@ -178,6 +217,10 @@ func MergeMappings( BuildId: base.BuildId, BuildStorageOffset: base.BuildStorageOffset + uint64(rightBaseShift), } + rightBase.FrameTable, err = base.FrameTable.Subset(storage.Range{Start: int64(rightBase.BuildStorageOffset), Length: int(rightBase.Length)}) + if err != nil { + return nil, fmt.Errorf("subset frame table for right split at offset %#x: %w", rightBase.Offset, err) + } baseMapping[baseIdx] = rightBase } else { @@ -205,6 +248,10 @@ func MergeMappings( BuildId: base.BuildId, BuildStorageOffset: base.BuildStorageOffset + uint64(rightBaseShift), } + rightBase.FrameTable, err = base.FrameTable.Subset(storage.Range{Start: int64(rightBase.BuildStorageOffset), Length: int(rightBase.Length)}) + if err != nil { + return nil, fmt.Errorf("subset frame table for right split at offset %#x: %w", rightBase.Offset, err) + } baseMapping[baseIdx] = rightBase } else { @@ -226,6 +273,10 @@ func MergeMappings( BuildId: base.BuildId, BuildStorageOffset: base.BuildStorageOffset, } + leftBase.FrameTable, err = base.FrameTable.Subset(storage.Range{Start: int64(leftBase.BuildStorageOffset), Length: int(leftBase.Length)}) + if err != nil { + return nil, fmt.Errorf("subset frame table for left split at offset %#x: %w", leftBase.Offset, err) + } mappings = append(mappings, leftBase) } @@ -241,10 +292,12 @@ func MergeMappings( mappings = append(mappings, baseMapping[baseIdx:]...) mappings = append(mappings, diffMapping[diffIdx:]...) - return mappings + return mappings, nil } // NormalizeMappings joins adjacent mappings that have the same buildId. +// When merging mappings, FrameTables are also merged by extending the first +// mapping's FrameTable with frames from subsequent mappings. func NormalizeMappings(mappings []*BuildMap) []*BuildMap { if len(mappings) == 0 { return nil @@ -252,7 +305,7 @@ func NormalizeMappings(mappings []*BuildMap) []*BuildMap { result := make([]*BuildMap, 0, len(mappings)) - // Start with a copy of the first mapping + // Start with a copy of the first mapping (Copy() now includes FrameTable) current := mappings[0].Copy() for i := 1; i < len(mappings); i++ { @@ -260,10 +313,22 @@ func NormalizeMappings(mappings []*BuildMap) []*BuildMap { if mp.BuildId != current.BuildId { // BuildId changed, add the current map to results and start a new one result = append(result, current) - current = mp.Copy() // New copy + current = mp.Copy() // New copy (includes FrameTable) } else { - // Same BuildId, just add the length + // Same BuildId, merge: add the length and extend FrameTable current.Length += mp.Length + + // Extend FrameTable if the mapping being merged has one + if mp.FrameTable != nil { + if current.FrameTable == nil { + // Current has no FrameTable but merged one does - take it + current.FrameTable = mp.FrameTable + } else { + // Both have FrameTables - extend current's with mp's frames + // The frames are contiguous subsets, so we append non-overlapping frames + current.FrameTable = mergeFrameTables(current.FrameTable, mp.FrameTable) + } + } } } @@ -272,3 +337,63 @@ func NormalizeMappings(mappings []*BuildMap) []*BuildMap { return result } + +// mergeFrameTables extends ft1 with frames from ft2. The FrameTables are +// assumed to be contiguous subsets from the same original, so ft2's frames +// follow ft1's frames (with possible overlap at the boundary). this function +// returns either an reference to one of the input tables, unchanged, or a new +// FrameTable with frames from both tables. +func mergeFrameTables(ft1, ft2 *storage.FrameTable) *storage.FrameTable { + if ft1 == nil { + return ft2 + } + if ft2 == nil { + return ft1 + } + + // Calculate where ft1 ends (uncompressed offset) + ft1EndU := ft1.StartAt.U + for _, frame := range ft1.Frames { + ft1EndU += int64(frame.U) + } + + // Find where to start appending from ft2 (skip frames already covered by ft1) + ft2CurrentU := ft2.StartAt.U + startIdx := 0 + for i, frame := range ft2.Frames { + frameEndU := ft2CurrentU + int64(frame.U) + if frameEndU <= ft1EndU { + // This frame is already covered by ft1 + ft2CurrentU = frameEndU + startIdx = i + 1 + + continue + } + if ft2CurrentU < ft1EndU { + // This frame overlaps with ft1's last frame - it's the same frame, skip it + ft2CurrentU = frameEndU + startIdx = i + 1 + + continue + } + // This frame is beyond ft1's coverage + break + } + + // Append remaining frames from ft2 + if startIdx < len(ft2.Frames) { + // Create a new FrameTable with extended frames + newFrames := make([]storage.FrameSize, len(ft1.Frames), len(ft1.Frames)+len(ft2.Frames)-startIdx) + copy(newFrames, ft1.Frames) + newFrames = append(newFrames, ft2.Frames[startIdx:]...) + + result := storage.NewFrameTable(ft1.CompressionType()) + result.StartAt = ft1.StartAt + result.Frames = newFrames + + return result + } + + // All of ft2's frames were already covered by ft1 + return ft1 +} diff --git a/packages/shared/pkg/storage/header/mapping_test.go b/packages/shared/pkg/storage/header/mapping_test.go index d20f070a3c..28728c2df5 100644 --- a/packages/shared/pkg/storage/header/mapping_test.go +++ b/packages/shared/pkg/storage/header/mapping_test.go @@ -46,11 +46,12 @@ func TestMergeMappingsRemoveEmpty(t *testing.T) { }, } - m := MergeMappings(simpleBase, diff) + m, err := MergeMappings(simpleBase, diff) + require.NoError(t, err) require.True(t, Equal(m, simpleBase)) - err := ValidateMappings(m, size, blockSize) + err = ValidateMappings(m, size, blockSize) require.NoError(t, err) } @@ -65,7 +66,8 @@ func TestMergeMappingsBaseBeforeDiffNoOverlap(t *testing.T) { }, } - m := MergeMappings(simpleBase, diff) + m, err := MergeMappings(simpleBase, diff) + require.NoError(t, err) require.True(t, Equal(m, []*BuildMap{ { @@ -90,7 +92,7 @@ func TestMergeMappingsBaseBeforeDiffNoOverlap(t *testing.T) { }, })) - err := ValidateMappings(m, size, blockSize) + err = ValidateMappings(m, size, blockSize) require.NoError(t, err) } @@ -105,7 +107,8 @@ func TestMergeMappingsDiffBeforeBaseNoOverlap(t *testing.T) { }, } - m := MergeMappings(simpleBase, diff) + m, err := MergeMappings(simpleBase, diff) + require.NoError(t, err) require.True(t, Equal(m, []*BuildMap{ { @@ -130,7 +133,7 @@ func TestMergeMappingsDiffBeforeBaseNoOverlap(t *testing.T) { }, })) - err := ValidateMappings(m, size, blockSize) + err = ValidateMappings(m, size, blockSize) require.NoError(t, err) } @@ -145,7 +148,8 @@ func TestMergeMappingsBaseInsideDiff(t *testing.T) { }, } - m := MergeMappings(simpleBase, diff) + m, err := MergeMappings(simpleBase, diff) + require.NoError(t, err) require.True(t, Equal(m, []*BuildMap{ { @@ -165,7 +169,7 @@ func TestMergeMappingsBaseInsideDiff(t *testing.T) { }, })) - err := ValidateMappings(m, size, blockSize) + err = ValidateMappings(m, size, blockSize) require.NoError(t, err) } @@ -180,7 +184,8 @@ func TestMergeMappingsDiffInsideBase(t *testing.T) { }, } - m := MergeMappings(simpleBase, diff) + m, err := MergeMappings(simpleBase, diff) + require.NoError(t, err) require.True(t, Equal(m, []*BuildMap{ { @@ -210,7 +215,7 @@ func TestMergeMappingsDiffInsideBase(t *testing.T) { }, })) - err := ValidateMappings(m, size, blockSize) + err = ValidateMappings(m, size, blockSize) require.NoError(t, err) } @@ -225,7 +230,8 @@ func TestMergeMappingsBaseAfterDiffWithOverlap(t *testing.T) { }, } - m := MergeMappings(simpleBase, diff) + m, err := MergeMappings(simpleBase, diff) + require.NoError(t, err) require.True(t, Equal(m, []*BuildMap{ { @@ -250,7 +256,7 @@ func TestMergeMappingsBaseAfterDiffWithOverlap(t *testing.T) { }, })) - err := ValidateMappings(m, size, blockSize) + err = ValidateMappings(m, size, blockSize) require.NoError(t, err) } @@ -265,7 +271,8 @@ func TestMergeMappingsDiffAfterBaseWithOverlap(t *testing.T) { }, } - m := MergeMappings(simpleBase, diff) + m, err := MergeMappings(simpleBase, diff) + require.NoError(t, err) require.True(t, Equal(m, []*BuildMap{ { @@ -290,7 +297,7 @@ func TestMergeMappingsDiffAfterBaseWithOverlap(t *testing.T) { }, })) - err := ValidateMappings(m, size, blockSize) + err = ValidateMappings(m, size, blockSize) require.NoError(t, err) } diff --git a/packages/shared/pkg/storage/header/metadata.go b/packages/shared/pkg/storage/header/metadata.go index 32dac10d19..eab9e574af 100644 --- a/packages/shared/pkg/storage/header/metadata.go +++ b/packages/shared/pkg/storage/header/metadata.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "io" + "maps" "github.com/bits-and-blooms/bitset" "github.com/google/uuid" @@ -27,7 +28,7 @@ type DiffMetadata struct { func (d *DiffMetadata) toDiffMapping( ctx context.Context, buildID uuid.UUID, -) (mapping []*BuildMap) { +) ([]*BuildMap, error) { dirtyMappings := CreateMapping( &buildID, d.Dirty, @@ -43,10 +44,13 @@ func (d *DiffMetadata) toDiffMapping( ) telemetry.ReportEvent(ctx, "created empty mapping") - mappings := MergeMappings(dirtyMappings, emptyMappings) + mappings, err := MergeMappings(dirtyMappings, emptyMappings) + if err != nil { + return nil, fmt.Errorf("merge dirty+empty mappings: %w", err) + } telemetry.ReportEvent(ctx, "merge mappings") - return mappings + return mappings, nil } func (d *DiffMetadata) ToDiffHeader( @@ -63,12 +67,18 @@ func (d *DiffMetadata) ToDiffHeader( } }() - diffMapping := d.toDiffMapping(ctx, buildID) + diffMapping, err := d.toDiffMapping(ctx, buildID) + if err != nil { + return nil, fmt.Errorf("toDiffMapping: %w", err) + } - m := MergeMappings( + m, err := MergeMappings( originalHeader.Mapping, diffMapping, ) + if err != nil { + return nil, fmt.Errorf("merge base+diff mappings: %w", err) + } telemetry.ReportEvent(ctx, "merged mappings") // TODO: We can run normalization only when empty mappings are not empty for this snapshot @@ -93,6 +103,9 @@ func (d *DiffMetadata) ToDiffHeader( return nil, fmt.Errorf("failed to create header: %w", err) } + // Inherit upstream build file info (sizes + checksums). + header.BuildFiles = maps.Clone(originalHeader.BuildFiles) + err = ValidateMappings(header.Mapping, header.Metadata.Size, header.Metadata.BlockSize) if err != nil { if header.IsNormalizeFixApplied() { diff --git a/packages/shared/pkg/storage/header/serialization.go b/packages/shared/pkg/storage/header/serialization.go index 6af71f832b..a1f4de4521 100644 --- a/packages/shared/pkg/storage/header/serialization.go +++ b/packages/shared/pkg/storage/header/serialization.go @@ -2,18 +2,26 @@ package header import ( "bytes" + "cmp" "context" "encoding/binary" "errors" "fmt" "io" + "slices" "github.com/google/uuid" + lz4 "github.com/pierrec/lz4/v4" "github.com/e2b-dev/infra/packages/shared/pkg/storage" ) -const metadataVersion = 3 +const ( + // metadataVersion is used by template-manager for uncompressed builds (V3 headers). + metadataVersion = 3 + // MetadataVersionCompressed is used for compressed builds (V4 headers with FrameTables). + MetadataVersionCompressed = 4 +) type Metadata struct { Version uint64 @@ -25,6 +33,25 @@ type Metadata struct { BaseBuildId uuid.UUID } +type v3SerializableBuildMap struct { + Offset uint64 + Length uint64 + BuildId uuid.UUID + BuildStorageOffset uint64 +} + +type v4SerializableBuildMap struct { + Offset uint64 + Length uint64 + BuildId uuid.UUID + BuildStorageOffset uint64 + CompressionTypeNumFrames uint64 // CompressionType is stored as uint8 in the high byte, the low 24 bits are NumFrames + + // if CompressionType != CompressionNone and there are frames + // - followed by frames offset (16 bytes) + // - followed by frames... (16 bytes * NumFrames) +} + func NewTemplateMetadata(buildId uuid.UUID, blockSize, size uint64) *Metadata { return &Metadata{ Version: metadataVersion, @@ -47,7 +74,14 @@ func (m *Metadata) NextGeneration(buildID uuid.UUID) *Metadata { } } -func Serialize(metadata *Metadata, mappings []*BuildMap) ([]byte, error) { +// v4SerializableBuildFileInfo is the on-disk format for a BuildFileInfo entry. +type v4SerializableBuildFileInfo struct { + BuildId uuid.UUID + Size int64 + Checksum [32]byte +} + +func serialize(metadata *Metadata, buildFiles map[uuid.UUID]BuildFileInfo, mappings []*BuildMap) ([]byte, error) { var buf bytes.Buffer err := binary.Write(&buf, binary.LittleEndian, metadata) @@ -55,16 +89,268 @@ func Serialize(metadata *Metadata, mappings []*BuildMap) ([]byte, error) { return nil, fmt.Errorf("failed to write metadata: %w", err) } + if metadata.Version >= 4 { + // V4: write build-info section before mappings. + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(buildFiles))); err != nil { + return nil, fmt.Errorf("failed to write build files count: %w", err) + } + + // Sort by UUID for deterministic serialization. + buildIDs := make([]uuid.UUID, 0, len(buildFiles)) + for id := range buildFiles { + buildIDs = append(buildIDs, id) + } + slices.SortFunc(buildIDs, func(a, b uuid.UUID) int { + return cmp.Compare(a.String(), b.String()) + }) + + for _, id := range buildIDs { + info := buildFiles[id] + entry := v4SerializableBuildFileInfo{ + BuildId: id, + Size: info.Size, + Checksum: info.Checksum, + } + if err := binary.Write(&buf, binary.LittleEndian, &entry); err != nil { + return nil, fmt.Errorf("failed to write build file info: %w", err) + } + } + + // V4: write mapping count before mappings. + if err := binary.Write(&buf, binary.LittleEndian, uint32(len(mappings))); err != nil { + return nil, fmt.Errorf("failed to write mappings count: %w", err) + } + } + + var v any for _, mapping := range mappings { - err := binary.Write(&buf, binary.LittleEndian, mapping) + var offset *storage.FrameOffset + var frames []storage.FrameSize + if metadata.Version <= 3 { + v = &v3SerializableBuildMap{ + Offset: mapping.Offset, + Length: mapping.Length, + BuildId: mapping.BuildId, + BuildStorageOffset: mapping.BuildStorageOffset, + } + } else { + v4 := &v4SerializableBuildMap{ + Offset: mapping.Offset, + Length: mapping.Length, + BuildId: mapping.BuildId, + BuildStorageOffset: mapping.BuildStorageOffset, + } + if mapping.FrameTable != nil { + v4.CompressionTypeNumFrames = uint64(mapping.FrameTable.CompressionType())<<24 | uint64(len(mapping.FrameTable.Frames)&0xFFFFFF) + // Only write offset/frames when the packed value is non-zero, + // matching the deserializer's condition. A FrameTable with + // CompressionNone and zero frames produces a packed value of 0. + if v4.CompressionTypeNumFrames != 0 { + offset = &mapping.FrameTable.StartAt + frames = mapping.FrameTable.Frames + } + } + v = v4 + } + + err := binary.Write(&buf, binary.LittleEndian, v) if err != nil { return nil, fmt.Errorf("failed to write block mapping: %w", err) } + if offset != nil { + err := binary.Write(&buf, binary.LittleEndian, offset) + if err != nil { + return nil, fmt.Errorf("failed to write compression frames starting offset: %w", err) + } + } + for _, frame := range frames { + err := binary.Write(&buf, binary.LittleEndian, frame) + if err != nil { + return nil, fmt.Errorf("failed to write compression frame: %w", err) + } + } } return buf.Bytes(), nil } +// metadataSize is the binary size of the Metadata struct, computed from the struct layout. +var metadataSize = binary.Size(Metadata{}) + +func deserializeMetadata(data []byte) (*Metadata, error) { + var metadata Metadata + + err := binary.Read(bytes.NewReader(data), binary.LittleEndian, &metadata) + if err != nil { + return nil, fmt.Errorf("failed to read metadata: %w", err) + } + + return &metadata, nil +} + +// deserializeV3Mappings reads V3 mappings until EOF. +func deserializeV3Mappings(reader *bytes.Reader) ([]*BuildMap, error) { + var mappings []*BuildMap + + for { + var v3 v3SerializableBuildMap + err := binary.Read(reader, binary.LittleEndian, &v3) + if errors.Is(err, io.EOF) { + break + } + if err != nil { + return nil, fmt.Errorf("failed to read block mapping: %w", err) + } + + mappings = append(mappings, &BuildMap{ + Offset: v3.Offset, + Length: v3.Length, + BuildId: v3.BuildId, + BuildStorageOffset: v3.BuildStorageOffset, + }) + } + + return mappings, nil +} + +// deserializeV4Block reads the V4 block: build-info section, then counted mappings. +func deserializeV4Block(reader *bytes.Reader) (map[uuid.UUID]BuildFileInfo, []*BuildMap, error) { + // Read build-info section. + var numBuilds uint32 + if err := binary.Read(reader, binary.LittleEndian, &numBuilds); err != nil { + return nil, nil, fmt.Errorf("failed to read build files count: %w", err) + } + + var buildFiles map[uuid.UUID]BuildFileInfo + if numBuilds > 0 { + buildFiles = make(map[uuid.UUID]BuildFileInfo, numBuilds) + for range numBuilds { + var entry v4SerializableBuildFileInfo + if err := binary.Read(reader, binary.LittleEndian, &entry); err != nil { + return nil, nil, fmt.Errorf("failed to read build file info: %w", err) + } + buildFiles[entry.BuildId] = BuildFileInfo{ + Size: entry.Size, + Checksum: entry.Checksum, + } + } + } + + // Read counted mappings. + var numMappings uint32 + if err := binary.Read(reader, binary.LittleEndian, &numMappings); err != nil { + return nil, nil, fmt.Errorf("failed to read mappings count: %w", err) + } + + mappings := make([]*BuildMap, 0, numMappings) + for range numMappings { + var v4 v4SerializableBuildMap + if err := binary.Read(reader, binary.LittleEndian, &v4); err != nil { + return nil, nil, fmt.Errorf("failed to read block mapping: %w", err) + } + + m := &BuildMap{ + Offset: v4.Offset, + Length: v4.Length, + BuildId: v4.BuildId, + BuildStorageOffset: v4.BuildStorageOffset, + } + + if v4.CompressionTypeNumFrames != 0 { + m.FrameTable = storage.NewFrameTable(storage.CompressionType((v4.CompressionTypeNumFrames >> 24) & 0xFF)) + numFrames := v4.CompressionTypeNumFrames & 0xFFFFFF + + var startAt storage.FrameOffset + if err := binary.Read(reader, binary.LittleEndian, &startAt); err != nil { + return nil, nil, fmt.Errorf("failed to read compression frames starting offset: %w", err) + } + m.FrameTable.StartAt = startAt + + for range numFrames { + var frame storage.FrameSize + if err := binary.Read(reader, binary.LittleEndian, &frame); err != nil { + return nil, nil, fmt.Errorf("failed to read the expected compression frame: %w", err) + } + m.FrameTable.Frames = append(m.FrameTable.Frames, frame) + } + } + + mappings = append(mappings, m) + } + + return buildFiles, mappings, nil +} + +// Serialize serializes a V3 header from metadata and mappings (legacy API). +func Serialize(metadata *Metadata, mappings []*BuildMap) ([]byte, error) { + return serialize(metadata, nil, mappings) +} + +// SerializeHeader serializes a header with optional LZ4 compression for V4. +// +// V3 (Version <= 3): [Metadata (raw binary)] [v3 mappings (raw binary)] +// +// V4 (Version >= 4): [Metadata (raw binary)] [uint32 uncompressed block size] [LZ4-compressed block] +// +// where the LZ4 block contains: BuildFiles + v4 mappings with FrameTables. +func SerializeHeader(h *Header) ([]byte, error) { + raw, err := serialize(h.Metadata, h.BuildFiles, h.Mapping) + if err != nil { + return nil, err + } + + if h.Metadata.Version <= 3 { + return raw, nil + } + + // V4: keep Metadata prefix raw, then [uint32 uncompressed size] + [LZ4 frame]. + block := raw[metadataSize:] + compressed, err := compressLZ4(block) + if err != nil { + return nil, fmt.Errorf("failed to LZ4-compress v4 header mappings: %w", err) + } + + result := make([]byte, metadataSize+4+len(compressed)) + copy(result, raw[:metadataSize]) + binary.LittleEndian.PutUint32(result[metadataSize:], uint32(len(block))) + copy(result[metadataSize+4:], compressed) + + return result, nil +} + +// LoadHeader fetches a serialized header from storage and deserializes it. +// Errors (including storage.ErrObjectNotExist) are returned as-is. +func LoadHeader(ctx context.Context, s storage.StorageProvider, path string) (*Header, error) { + blob, err := s.OpenBlob(ctx, path, storage.MetadataObjectType) + if err != nil { + return nil, fmt.Errorf("open blob %s: %w", path, err) + } + + data, err := storage.GetBlob(ctx, blob) + if err != nil { + return nil, err + } + + return DeserializeBytes(data) +} + +// StoreHeader serializes a header and uploads it to storage. +// Inverse of LoadHeader. +func StoreHeader(ctx context.Context, s storage.StorageProvider, path string, h *Header) error { + data, err := SerializeHeader(h) + if err != nil { + return fmt.Errorf("serialize header: %w", err) + } + + blob, err := s.OpenBlob(ctx, path, storage.MetadataObjectType) + if err != nil { + return fmt.Errorf("open blob %s: %w", path, err) + } + + return blob.Put(ctx, data) +} + +// Deserialize reads a header from a storage Blob (legacy API). func Deserialize(ctx context.Context, in storage.Blob) (*Header, error) { data, err := storage.GetBlob(ctx, in) if err != nil { @@ -74,29 +360,92 @@ func Deserialize(ctx context.Context, in storage.Blob) (*Header, error) { return DeserializeBytes(data) } +// DeserializeBytes auto-detects the header version and deserializes accordingly. +// See SerializeHeader for the binary layout. +// The uint32 size prefix in V4 allows exact-size allocation for decompression +// instead of a fixed upper-bound buffer. func DeserializeBytes(data []byte) (*Header, error) { - reader := bytes.NewReader(data) - var metadata Metadata - err := binary.Read(reader, binary.LittleEndian, &metadata) + if len(data) < metadataSize { + return nil, fmt.Errorf("header too short: %d bytes", len(data)) + } + + metadata, err := deserializeMetadata(data[:metadataSize]) if err != nil { - return nil, fmt.Errorf("failed to read metadata: %w", err) + return nil, err } - mappings := make([]*BuildMap, 0) + blockData := data[metadataSize:] - for { - var m BuildMap - err := binary.Read(reader, binary.LittleEndian, &m) - if errors.Is(err, io.EOF) { - break + if metadata.Version >= 4 { + if len(blockData) < 4 { + return nil, fmt.Errorf("v4 header block too short for size prefix: %d bytes", len(blockData)) } + uncompressedSize := binary.LittleEndian.Uint32(blockData[:4]) + if uncompressedSize > storage.MaxCompressedHeaderSize { + return nil, fmt.Errorf("v4 header uncompressed size %d exceeds maximum %d", uncompressedSize, storage.MaxCompressedHeaderSize) + } + + blockData, err = decompressLZ4(blockData[4:]) if err != nil { - return nil, fmt.Errorf("failed to read block mapping: %w", err) + return nil, fmt.Errorf("failed to LZ4-decompress v4 header block: %w", err) + } + + buildFiles, mappings, err := deserializeV4Block(bytes.NewReader(blockData)) + if err != nil { + return nil, err + } + + h, err := NewHeader(metadata, mappings) + if err != nil { + return nil, err } + h.BuildFiles = buildFiles + + return h, nil + } + + mappings, err := deserializeV3Mappings(bytes.NewReader(blockData)) + if err != nil { + return nil, err + } + + return NewHeader(metadata, mappings) +} + +// compressLZ4 compresses data for V4 header serialization using the LZ4 +// streaming API. Settings are fixed for the V4 wire format. +func compressLZ4(data []byte) ([]byte, error) { + var buf bytes.Buffer + buf.Grow(len(data)) + + w := lz4.NewWriter(&buf) + w.Apply( + lz4.BlockSizeOption(lz4.Block4Mb), + lz4.BlockChecksumOption(true), + lz4.ChecksumOption(true), + lz4.CompressionLevelOption(lz4.Fast), + ) + + if _, err := w.Write(data); err != nil { + return nil, fmt.Errorf("lz4 compress: %w", err) + } + + if err := w.Close(); err != nil { + return nil, fmt.Errorf("lz4 compress close: %w", err) + } + + return buf.Bytes(), nil +} - mappings = append(mappings, &m) +// decompressLZ4 decompresses an LZ4 frame from V4 header data. +func decompressLZ4(src []byte) ([]byte, error) { + r := lz4.NewReader(bytes.NewReader(src)) + + data, err := io.ReadAll(r) + if err != nil { + return nil, fmt.Errorf("lz4 decompress: %w", err) } - return NewHeader(&metadata, mappings) + return data, nil } diff --git a/packages/shared/pkg/storage/header/serialization_test.go b/packages/shared/pkg/storage/header/serialization_test.go new file mode 100644 index 0000000000..93f8f5c96c --- /dev/null +++ b/packages/shared/pkg/storage/header/serialization_test.go @@ -0,0 +1,397 @@ +package header + +import ( + "crypto/sha256" + "testing" + + "github.com/google/uuid" + "github.com/stretchr/testify/require" + + "github.com/e2b-dev/infra/packages/shared/pkg/storage" +) + +// newFT creates a FrameTable for test fixtures. +func newFT(ct storage.CompressionType, startAt storage.FrameOffset, frames []storage.FrameSize) *storage.FrameTable { + ft := storage.NewFrameTable(ct) + ft.StartAt = startAt + ft.Frames = frames + + return ft +} + +func TestSerializeDeserialize_V3_RoundTrip(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + baseID := uuid.New() + metadata := &Metadata{ + Version: 3, + BlockSize: 4096, + Size: 8192, + Generation: 7, + BuildId: buildID, + BaseBuildId: baseID, + } + + mappings := []*BuildMap{ + { + Offset: 0, + Length: 4096, + BuildId: buildID, + BuildStorageOffset: 0, + }, + { + Offset: 4096, + Length: 4096, + BuildId: baseID, + BuildStorageOffset: 123, + }, + } + + data, err := serialize(metadata, nil, mappings) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + require.Equal(t, metadata, got.Metadata) + require.Len(t, got.Mapping, 2) + require.Equal(t, uint64(0), got.Mapping[0].Offset) + require.Equal(t, uint64(4096), got.Mapping[0].Length) + require.Equal(t, buildID, got.Mapping[0].BuildId) + require.Equal(t, uint64(0), got.Mapping[0].BuildStorageOffset) + + require.Equal(t, uint64(4096), got.Mapping[1].Offset) + require.Equal(t, uint64(4096), got.Mapping[1].Length) + require.Equal(t, baseID, got.Mapping[1].BuildId) + require.Equal(t, uint64(123), got.Mapping[1].BuildStorageOffset) + + // V3 headers have no BuildFiles + require.Nil(t, got.BuildFiles) +} + +func TestDeserialize_TruncatedMetadata(t *testing.T) { + t.Parallel() + + _, err := DeserializeBytes([]byte{0x01, 0x02, 0x03}) + require.Error(t, err) + require.Contains(t, err.Error(), "header too short") +} + +func TestSerializeDeserialize_EmptyMappings_Defaults(t *testing.T) { + t.Parallel() + + metadata := &Metadata{ + Version: 3, + BlockSize: 4096, + Size: 8192, + Generation: 0, + BuildId: uuid.New(), + BaseBuildId: uuid.New(), + } + + data, err := serialize(metadata, nil, nil) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + // NewHeader creates a default mapping when none provided + require.Len(t, got.Mapping, 1) + require.Equal(t, uint64(0), got.Mapping[0].Offset) + require.Equal(t, metadata.Size, got.Mapping[0].Length) + require.Equal(t, metadata.BuildId, got.Mapping[0].BuildId) +} + +func TestDeserialize_BlockSizeZero(t *testing.T) { + t.Parallel() + + metadata := &Metadata{ + Version: 3, + BlockSize: 0, + Size: 4096, + Generation: 0, + BuildId: uuid.New(), + BaseBuildId: uuid.New(), + } + + data, err := serialize(metadata, nil, nil) + require.NoError(t, err) + + _, err = DeserializeBytes(data) + require.Error(t, err) + require.Contains(t, err.Error(), "block size cannot be zero") +} + +func TestSerializeDeserialize_V4_WithFrameTable(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + baseID := uuid.New() + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 8192, + Generation: 1, + BuildId: buildID, + BaseBuildId: baseID, + } + + mappings := []*BuildMap{ + { + Offset: 0, + Length: 4096, + BuildId: buildID, + BuildStorageOffset: 0, + FrameTable: newFT(storage.CompressionLZ4, storage.FrameOffset{U: 0, C: 0}, []storage.FrameSize{ + {U: 2048, C: 1024}, + {U: 2048, C: 900}, + }), + }, + { + Offset: 4096, + Length: 4096, + BuildId: baseID, + BuildStorageOffset: 0, + }, + } + + checksum := sha256.Sum256([]byte("test-data")) + buildFiles := map[uuid.UUID]BuildFileInfo{ + buildID: {Size: 12345, Checksum: checksum}, + baseID: {Size: 67890}, + } + + h, err := NewHeader(metadata, mappings) + require.NoError(t, err) + h.BuildFiles = buildFiles + + // Test with Serialize + Deserialize (unified path) + data, err := SerializeHeader(h) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + require.Equal(t, uint64(4), got.Metadata.Version) + require.Len(t, got.Mapping, 2) + + // First mapping has FrameTable + m0 := got.Mapping[0] + require.Equal(t, uint64(0), m0.Offset) + require.Equal(t, uint64(4096), m0.Length) + require.Equal(t, buildID, m0.BuildId) + require.NotNil(t, m0.FrameTable) + require.Equal(t, storage.CompressionLZ4, m0.FrameTable.CompressionType()) + require.Equal(t, int64(0), m0.FrameTable.StartAt.U) + require.Equal(t, int64(0), m0.FrameTable.StartAt.C) + require.Len(t, m0.FrameTable.Frames, 2) + require.Equal(t, int32(2048), m0.FrameTable.Frames[0].U) + require.Equal(t, int32(1024), m0.FrameTable.Frames[0].C) + require.Equal(t, int32(2048), m0.FrameTable.Frames[1].U) + require.Equal(t, int32(900), m0.FrameTable.Frames[1].C) + + // Second mapping has no FrameTable + m1 := got.Mapping[1] + require.Equal(t, uint64(4096), m1.Offset) + require.Equal(t, uint64(4096), m1.Length) + require.Equal(t, baseID, m1.BuildId) + require.Nil(t, m1.FrameTable) + + // BuildFiles round-trip + require.Len(t, got.BuildFiles, 2) + require.Equal(t, int64(12345), got.BuildFiles[buildID].Size) + require.Equal(t, checksum, got.BuildFiles[buildID].Checksum) + require.Equal(t, int64(67890), got.BuildFiles[baseID].Size) + require.Equal(t, [32]byte{}, got.BuildFiles[baseID].Checksum) +} + +func TestSerializeDeserialize_V4_Zstd_NonZeroStartAt(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 4096, + Generation: 0, + BuildId: buildID, + BaseBuildId: buildID, + } + + mappings := []*BuildMap{ + { + Offset: 0, + Length: 4096, + BuildId: buildID, + BuildStorageOffset: 8192, + FrameTable: newFT(storage.CompressionZstd, storage.FrameOffset{U: 8192, C: 4000}, []storage.FrameSize{ + {U: 4096, C: 3500}, + }), + }, + } + + h, err := NewHeader(metadata, mappings) + require.NoError(t, err) + + // Test with Serialize + Deserialize (unified path) + data, err := SerializeHeader(h) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + require.Len(t, got.Mapping, 1) + m := got.Mapping[0] + require.NotNil(t, m.FrameTable) + require.Equal(t, storage.CompressionZstd, m.FrameTable.CompressionType()) + require.Equal(t, int64(8192), m.FrameTable.StartAt.U) + require.Equal(t, int64(4000), m.FrameTable.StartAt.C) + require.Len(t, m.FrameTable.Frames, 1) + require.Equal(t, int32(4096), m.FrameTable.Frames[0].U) + require.Equal(t, int32(3500), m.FrameTable.Frames[0].C) + + // No BuildFiles set + require.Nil(t, got.BuildFiles) +} + +// TestSerializeDeserialize_V4_CompressionNone_EmptyFrames verifies that a +// FrameTable with CompressionNone and zero frames does not corrupt the stream. +// Before the fix, the serializer wrote a StartAt offset (16 bytes) but the +// deserializer skipped it because the packed value was 0. +func TestSerializeDeserialize_V4_CompressionNone_EmptyFrames(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + baseID := uuid.New() + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 8192, + Generation: 0, + BuildId: buildID, + BaseBuildId: buildID, + } + + mappings := []*BuildMap{ + { + Offset: 0, + Length: 4096, + BuildId: buildID, + BuildStorageOffset: 0, + // FrameTable with CompressionNone and no frames — packed value is 0. + FrameTable: newFT(storage.CompressionNone, storage.FrameOffset{U: 100, C: 50}, nil), + }, + { + Offset: 4096, + Length: 4096, + BuildId: baseID, + BuildStorageOffset: 0, + }, + } + + h, err := NewHeader(metadata, mappings) + require.NoError(t, err) + + // Test with Serialize + Deserialize (unified path) + data, err := SerializeHeader(h) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + require.Len(t, got.Mapping, 2) + + // First mapping: FrameTable was effectively empty, deserializer should treat as nil. + require.Nil(t, got.Mapping[0].FrameTable) + + // Second mapping must not be corrupted by stray StartAt bytes. + require.Equal(t, uint64(4096), got.Mapping[1].Offset) + require.Equal(t, uint64(4096), got.Mapping[1].Length) + require.Equal(t, baseID, got.Mapping[1].BuildId) +} + +func TestSerializeDeserialize_V4_ManyFrames(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + const numFrames = 1000 + frames := make([]storage.FrameSize, numFrames) + for i := range frames { + frames[i] = storage.FrameSize{U: 4096, C: int32(2000 + i)} + } + + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 4096 * numFrames, + Generation: 0, + BuildId: buildID, + BaseBuildId: buildID, + } + + mappings := []*BuildMap{ + { + Offset: 0, + Length: 4096 * numFrames, + BuildId: buildID, + BuildStorageOffset: 0, + FrameTable: newFT(storage.CompressionLZ4, storage.FrameOffset{U: 0, C: 0}, frames), + }, + } + + h, err := NewHeader(metadata, mappings) + require.NoError(t, err) + + // Test with Serialize + Deserialize (unified path) + data, err := SerializeHeader(h) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + require.Len(t, got.Mapping, 1) + require.NotNil(t, got.Mapping[0].FrameTable) + require.Len(t, got.Mapping[0].FrameTable.Frames, numFrames) + + // Spot-check first and last frame + require.Equal(t, int32(4096), got.Mapping[0].FrameTable.Frames[0].U) + require.Equal(t, int32(2000), got.Mapping[0].FrameTable.Frames[0].C) + require.Equal(t, int32(4096), got.Mapping[0].FrameTable.Frames[numFrames-1].U) + require.Equal(t, int32(2000+numFrames-1), got.Mapping[0].FrameTable.Frames[numFrames-1].C) +} + +func TestSerializeDeserialize_V4_EmptyBuildFiles(t *testing.T) { + t.Parallel() + + buildID := uuid.New() + metadata := &Metadata{ + Version: 4, + BlockSize: 4096, + Size: 4096, + Generation: 0, + BuildId: buildID, + BaseBuildId: buildID, + } + + mappings := []*BuildMap{ + { + Offset: 0, + Length: 4096, + BuildId: buildID, + }, + } + + h, err := NewHeader(metadata, mappings) + require.NoError(t, err) + // No BuildFiles set (nil map) + + data, err := SerializeHeader(h) + require.NoError(t, err) + + got, err := DeserializeBytes(data) + require.NoError(t, err) + + require.Len(t, got.Mapping, 1) + require.Nil(t, got.BuildFiles) // numBuilds=0 → nil +} diff --git a/packages/shared/pkg/storage/storage_google.go b/packages/shared/pkg/storage/storage_google.go index f9b5da602b..5b96134bce 100644 --- a/packages/shared/pkg/storage/storage_google.go +++ b/packages/shared/pkg/storage/storage_google.go @@ -448,6 +448,7 @@ func (o *gcpObject) StoreFile(ctx context.Context, path string) (e error) { bucketName, objectName, DefaultRetryConfig(), + nil, ) if err != nil { timer.Failure(ctx, 0) diff --git a/packages/shared/pkg/storage/template.go b/packages/shared/pkg/storage/template.go index 3c501be7b8..677d2a3756 100644 --- a/packages/shared/pkg/storage/template.go +++ b/packages/shared/pkg/storage/template.go @@ -53,6 +53,33 @@ func (t TemplateFiles) StorageMetadataPath() string { return fmt.Sprintf("%s/%s", t.StorageDir(), MetadataName) } +// DataPath returns the data storage path for a given file name within this build. +func (t TemplateFiles) DataPath(fileName string) string { + return fmt.Sprintf("%s/%s", t.StorageDir(), fileName) +} + +// HeaderPath returns the header storage path for a given file name within this build. +func (t TemplateFiles) HeaderPath(fileName string) string { + return fmt.Sprintf("%s/%s%s", t.StorageDir(), fileName, HeaderSuffix) +} + +// CompressedDataName returns the compressed data filename: "memfile.zstd". +func CompressedDataName(fileName string, ct CompressionType) string { + return fileName + ct.Suffix() +} + +// CompressedDataPath returns the compressed data path for a given file name. +// Example: "{buildId}/memfile.zstd" +func (t TemplateFiles) CompressedDataPath(fileName string, ct CompressionType) string { + return fmt.Sprintf("%s/%s", t.StorageDir(), CompressedDataName(fileName, ct)) +} + +// CompressedPath transforms a base object path (e.g. "buildId/memfile") into +// the compressed data path (e.g. "buildId/memfile.zstd"). +func CompressedPath(basePath string, ct CompressionType) string { + return basePath + ct.Suffix() +} + // ParseStoragePath splits a storage path of the form "{buildID}/{fileName}" // back into its components. This is the inverse of the Storage*Path methods. func ParseStoragePath(path string) (buildID, fileName string) { @@ -60,3 +87,18 @@ func ParseStoragePath(path string) (buildID, fileName string) { return buildID, fileName } + +// BaseFileName strips known compression suffixes from a file name, +// returning the base name. For example: "memfile.zstd" → "memfile". +// If no known suffix is present, the name is returned unchanged. +func BaseFileName(name string) string { + for _, suffix := range knownCompressionSuffixes { + if before, ok := strings.CutSuffix(name, suffix); ok { + return before + } + } + + return name +} + +var knownCompressionSuffixes = []string{".lz4", ".zstd"}