diff --git a/cmd/cli/commands/compose.go b/cmd/cli/commands/compose.go index b96874f75..7e994f078 100644 --- a/cmd/cli/commands/compose.go +++ b/cmd/cli/commands/compose.go @@ -175,9 +175,10 @@ func downloadModelsOnlyIfNotFound(desktopClient *desktop.Client, models []string } return false }) { - _, _, err = desktopClient.Pull(model, false, func(s string) { + printer := desktop.NewSimplePrinter(func(s string) { _ = sendInfo(s) }) + _, _, err = desktopClient.Pull(model, false, printer) if err != nil { _ = sendErrorf("Failed to pull model: %v", err) return fmt.Errorf("Failed to pull model: %v\n", err) diff --git a/cmd/cli/commands/integration_test.go b/cmd/cli/commands/integration_test.go index 4cdca2087..73517c579 100644 --- a/cmd/cli/commands/integration_test.go +++ b/cmd/cli/commands/integration_test.go @@ -700,9 +700,9 @@ func TestIntegration_PushModel(t *testing.T) { // Push the tagged model t.Logf("Pushing model to custom registry with reference: %s", tc.ref) - _, _, err = env.client.Push(tc.ref, func(msg string) { + _, _, err = env.client.Push(tc.ref, desktop.NewSimplePrinter(func(msg string) { t.Logf("Progress: %s", msg) - }) + })) require.NoError(t, err, "Failed to push model to custom registry") t.Logf("✓ Successfully pushed model to custom registry: %s", tc.ref) }) @@ -742,9 +742,9 @@ func TestIntegration_PushModel(t *testing.T) { // Push the tagged model t.Logf("Pushing model to custom registry with reference: %s", tc.targetRef) - _, _, err = env.client.Push(tc.targetRef, func(msg string) { + _, _, err = env.client.Push(tc.targetRef, desktop.NewSimplePrinter(func(msg string) { t.Logf("Progress: %s", msg) - }) + })) require.NoError(t, err, "Failed to push model to custom registry") t.Logf("✓ Successfully pushed model to custom registry: %s", tc.targetRef) }) @@ -754,13 +754,13 @@ func TestIntegration_PushModel(t *testing.T) { // Test 3: Error cases t.Run("error cases", func(t *testing.T) { t.Run("push non-existent model", func(t *testing.T) { - _, _, err := env.client.Push("non-existent-model:v1", func(msg string) {}) + _, _, err := env.client.Push("non-existent-model:v1", desktop.NewSimplePrinter(func(msg string) {})) require.Error(t, err, "Should fail when pushing non-existent model") t.Logf("✓ Correctly failed to push non-existent model: %v", err) }) t.Run("push with invalid reference", func(t *testing.T) { - _, _, err := env.client.Push("", func(msg string) {}) + _, _, err := env.client.Push("", desktop.NewSimplePrinter(func(msg string) {})) require.Error(t, err, "Should fail with empty reference") t.Logf("✓ Correctly failed to push with invalid reference: %v", err) }) diff --git a/cmd/cli/commands/package.go b/cmd/cli/commands/package.go index 33b5d9c6a..6bec46dd4 100644 --- a/cmd/cli/commands/package.go +++ b/cmd/cli/commands/package.go @@ -374,7 +374,7 @@ func packageModel(cmd *cobra.Command, opts packageOptions) error { } // Print progress messages - TUIProgress(progressMsg.Message) + fmt.Print("\r\033[K", progressMsg.Message) } cmd.PrintErrln("") // newline after progress diff --git a/cmd/cli/commands/pull.go b/cmd/cli/commands/pull.go index db0408be5..09ce61a4a 100644 --- a/cmd/cli/commands/pull.go +++ b/cmd/cli/commands/pull.go @@ -2,12 +2,10 @@ package commands import ( "fmt" - "os" "github.com/docker/model-runner/cmd/cli/commands/completion" "github.com/docker/model-runner/cmd/cli/desktop" - "github.com/mattn/go-isatty" "github.com/spf13/cobra" ) @@ -33,18 +31,8 @@ func newPullCmd() *cobra.Command { } func pullModel(cmd *cobra.Command, desktopClient *desktop.Client, model string, ignoreRuntimeMemoryCheck bool) error { - var progress func(string) - if isatty.IsTerminal(os.Stdout.Fd()) { - progress = TUIProgress - } else { - progress = RawProgress - } - response, progressShown, err := desktopClient.Pull(model, ignoreRuntimeMemoryCheck, progress) - - // Add a newline before any output (success or error) if progress was shown. - if progressShown { - cmd.Println() - } + printer := asPrinter(cmd) + response, _, err := desktopClient.Pull(model, ignoreRuntimeMemoryCheck, printer) if err != nil { return handleClientError(err, "Failed to pull model") @@ -53,11 +41,3 @@ func pullModel(cmd *cobra.Command, desktopClient *desktop.Client, model string, cmd.Println(response) return nil } - -func TUIProgress(message string) { - fmt.Print("\r\033[K", message) -} - -func RawProgress(message string) { - fmt.Println(message) -} diff --git a/cmd/cli/commands/push.go b/cmd/cli/commands/push.go index 2116b0802..34aabe036 100644 --- a/cmd/cli/commands/push.go +++ b/cmd/cli/commands/push.go @@ -26,7 +26,8 @@ func newPushCmd() *cobra.Command { } func pushModel(cmd *cobra.Command, desktopClient *desktop.Client, model string) error { - response, progressShown, err := desktopClient.Push(model, TUIProgress) + printer := asPrinter(cmd) + response, progressShown, err := desktopClient.Push(model, printer) // Add a newline before any output (success or error) if progress was shown. if progressShown { diff --git a/cmd/cli/desktop/desktop.go b/cmd/cli/desktop/desktop.go index 6fe794e01..871c867d2 100644 --- a/cmd/cli/desktop/desktop.go +++ b/cmd/cli/desktop/desktop.go @@ -6,7 +6,6 @@ import ( "context" "encoding/json" "fmt" - "html" "io" "net/http" "net/url" @@ -14,7 +13,7 @@ import ( "strings" "time" - "github.com/docker/go-units" + "github.com/docker/model-runner/cmd/cli/pkg/standalone" "github.com/docker/model-runner/pkg/distribution/distribution" "github.com/docker/model-runner/pkg/inference" dmrm "github.com/docker/model-runner/pkg/inference/models" @@ -105,7 +104,7 @@ func (c *Client) Status() Status { } } -func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func(string)) (string, bool, error) { +func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, printer standalone.StatusPrinter) (string, bool, error) { model = normalizeHuggingFaceModelName(model) jsonData, err := json.Marshal(dmrm.ModelCreateRequest{From: model, IgnoreRuntimeMemoryCheck: ignoreRuntimeMemoryCheck}) if err != nil { @@ -128,52 +127,16 @@ func (c *Client) Pull(model string, ignoreRuntimeMemoryCheck bool, progress func return "", false, fmt.Errorf("pulling %s failed with status %s: %s", model, resp.Status, string(body)) } - progressShown := false - current := uint64(0) // Track cumulative progress across all layers - layerProgress := make(map[string]uint64) // Track progress per layer ID - - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - progressLine := scanner.Text() - if progressLine == "" { - continue - } - - // Parse the progress message - var progressMsg ProgressMessage - if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil { - return "", progressShown, fmt.Errorf("error parsing progress message: %w", err) - } - - // Handle different message types - switch progressMsg.Type { - case "progress": - // Update the current progress for this layer - layerID := progressMsg.Layer.ID - layerProgress[layerID] = progressMsg.Layer.Current - - // Sum all layer progress values - current = uint64(0) - for _, layerCurrent := range layerProgress { - current += layerCurrent - } - - progress(fmt.Sprintf("Downloaded %s of %s", units.CustomSize("%.2f%s", float64(current), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}), units.CustomSize("%.2f%s", float64(progressMsg.Total), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}))) - progressShown = true - case "error": - return "", progressShown, fmt.Errorf("error pulling model: %s", progressMsg.Message) - case "success": - return progressMsg.Message, progressShown, nil - default: - return "", progressShown, fmt.Errorf("unknown message type: %s", progressMsg.Type) - } + // Use Docker-style progress display + message, progressShown, err := DisplayProgress(resp.Body, printer) + if err != nil { + return "", progressShown, err } - // If we get here, something went wrong - return "", progressShown, fmt.Errorf("unexpected end of stream while pulling model %s", model) + return message, progressShown, nil } -func (c *Client) Push(model string, progress func(string)) (string, bool, error) { +func (c *Client) Push(model string, printer standalone.StatusPrinter) (string, bool, error) { model = normalizeHuggingFaceModelName(model) pushPath := inference.ModelsPrefix + "/" + model + "/push" resp, err := c.doRequest( @@ -191,37 +154,13 @@ func (c *Client) Push(model string, progress func(string)) (string, bool, error) return "", false, fmt.Errorf("pushing %s failed with status %s: %s", model, resp.Status, string(body)) } - progressShown := false - - scanner := bufio.NewScanner(resp.Body) - for scanner.Scan() { - progressLine := scanner.Text() - if progressLine == "" { - continue - } - - // Parse the progress message - var progressMsg ProgressMessage - if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil { - return "", progressShown, fmt.Errorf("error parsing progress message: %w", err) - } - - // Handle different message types - switch progressMsg.Type { - case "progress": - progress(progressMsg.Message) - progressShown = true - case "error": - return "", progressShown, fmt.Errorf("error pushing model: %s", progressMsg.Message) - case "success": - return progressMsg.Message, progressShown, nil - default: - return "", progressShown, fmt.Errorf("unknown message type: %s", progressMsg.Type) - } + // Use Docker-style progress display + message, progressShown, err := DisplayProgress(resp.Body, printer) + if err != nil { + return "", progressShown, err } - // If we get here, something went wrong - return "", progressShown, fmt.Errorf("unexpected end of stream while pushing model %s", model) + return message, progressShown, nil } func (c *Client) List() ([]dmrm.Model, error) { diff --git a/cmd/cli/desktop/desktop_test.go b/cmd/cli/desktop/desktop_test.go index 57dac0894..0b14ed8b3 100644 --- a/cmd/cli/desktop/desktop_test.go +++ b/cmd/cli/desktop/desktop_test.go @@ -36,7 +36,8 @@ func TestPullHuggingFaceModel(t *testing.T) { Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)), }, nil) - _, _, err := client.Pull(modelName, false, func(s string) {}) + printer := NewSimplePrinter(func(s string) {}) + _, _, err := client.Pull(modelName, false, printer) assert.NoError(t, err) } @@ -122,7 +123,8 @@ func TestNonHuggingFaceModel(t *testing.T) { Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pulled successfully"}`)), }, nil) - _, _, err := client.Pull(modelName, false, func(s string) {}) + printer := NewSimplePrinter(func(s string) {}) + _, _, err := client.Pull(modelName, false, printer) assert.NoError(t, err) } @@ -145,7 +147,8 @@ func TestPushHuggingFaceModel(t *testing.T) { Body: io.NopCloser(bytes.NewBufferString(`{"type":"success","message":"Model pushed successfully"}`)), }, nil) - _, _, err := client.Push(modelName, func(s string) {}) + printer := NewSimplePrinter(func(s string) {}) + _, _, err := client.Push(modelName, printer) assert.NoError(t, err) } diff --git a/cmd/cli/desktop/progress.go b/cmd/cli/desktop/progress.go new file mode 100644 index 000000000..ac87dc90c --- /dev/null +++ b/cmd/cli/desktop/progress.go @@ -0,0 +1,259 @@ +package desktop + +import ( + "bufio" + "encoding/json" + "fmt" + "html" + "io" + "strings" + + "github.com/docker/docker/pkg/jsonmessage" + "github.com/docker/go-units" + + "github.com/docker/model-runner/cmd/cli/pkg/standalone" +) + +// DisplayProgress displays progress messages from a model pull/push operation +// using Docker-style multi-line progress bars. +// Returns the final message, whether progress was actually shown, and any error. +func DisplayProgress(body io.Reader, printer standalone.StatusPrinter) (string, bool, error) { + fd, isTerminal := printer.GetFdInfo() + + // If not a terminal, fall back to simple line-by-line output + if !isTerminal { + return displayProgressSimple(body, printer) + } + + // Use a pipe to convert our progress messages to Docker's JSONMessage format + pr, pw := io.Pipe() + errCh := make(chan error, 1) + + // Start the display goroutine + go func() { + err := jsonmessage.DisplayJSONMessagesStream(pr, &writerAdapter{printer}, fd, isTerminal, nil) + if err != nil { + errCh <- err + } + close(errCh) + }() + + // Convert progress messages to JSONMessage format + scanner := bufio.NewScanner(body) + layerStatus := make(map[string]string) // Track status of each layer + var finalMessage string + progressShown := false // Track if we actually showed any progress bars + + for scanner.Scan() { + progressLine := scanner.Text() + if progressLine == "" { + continue + } + + var progressMsg ProgressMessage + if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil { + // If we can't parse, just skip + continue + } + + switch progressMsg.Type { + case "progress": + progressShown = true // We're showing actual progress + if err := writeDockerProgress(pw, &progressMsg, layerStatus); err != nil { + pw.Close() + return "", false, err + } + + case "success": + finalMessage = progressMsg.Message + // Don't write the success message here - let the caller print it + // to avoid duplicate output + + case "error": + pw.Close() + return "", false, fmt.Errorf("%s", progressMsg.Message) + } + } + + if err := scanner.Err(); err != nil { + pw.Close() + return "", false, err + } + + pw.Close() + + // Wait for display to finish + if err := <-errCh; err != nil && err != io.EOF { + return finalMessage, progressShown, err + } + + return finalMessage, progressShown, nil +} + +// displayProgressSimple displays progress messages in simple line-by-line format +func displayProgressSimple(body io.Reader, printer standalone.StatusPrinter) (string, bool, error) { + scanner := bufio.NewScanner(body) + current := uint64(0) + layerProgress := make(map[string]uint64) + var finalMessage string + progressShown := false // Track if we actually showed any progress + + for scanner.Scan() { + progressLine := scanner.Text() + if progressLine == "" { + continue + } + + var progressMsg ProgressMessage + if err := json.Unmarshal([]byte(html.UnescapeString(progressLine)), &progressMsg); err != nil { + continue + } + + switch progressMsg.Type { + case "progress": + progressShown = true // We're showing actual progress + layerID := progressMsg.Layer.ID + layerProgress[layerID] = progressMsg.Layer.Current + + // Sum all layer progress + current = uint64(0) + for _, layerCurrent := range layerProgress { + current += layerCurrent + } + + printer.Println(fmt.Sprintf("Downloaded %s of %s", + units.CustomSize("%.2f%s", float64(current), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}), + units.CustomSize("%.2f%s", float64(progressMsg.Total), 1000.0, []string{"B", "kB", "MB", "GB", "TB", "PB", "EB", "ZB", "YB"}))) + + case "success": + finalMessage = progressMsg.Message + + case "error": + return "", false, fmt.Errorf("%s", progressMsg.Message) + } + } + + if err := scanner.Err(); err != nil { + return "", false, err + } + + return finalMessage, progressShown, nil +} + +// writeDockerProgress writes a progress update in Docker's JSONMessage format +func writeDockerProgress(w io.Writer, msg *ProgressMessage, layerStatus map[string]string) error { + layerID := msg.Layer.ID + if layerID == "" { + return nil + } + + // Determine status based on progress + var status string + var progressDetail *jsonmessage.JSONProgress + + if msg.Layer.Current == 0 { + status = "Waiting" + } else if msg.Layer.Current < msg.Layer.Size { + status = "Downloading" + progressDetail = &jsonmessage.JSONProgress{ + Current: int64(msg.Layer.Current), + Total: int64(msg.Layer.Size), + } + } else if msg.Layer.Current >= msg.Layer.Size && msg.Layer.Size > 0 { + status = "Pull complete" + progressDetail = &jsonmessage.JSONProgress{ + Current: int64(msg.Layer.Current), + Total: int64(msg.Layer.Size), + } + } + + if status == "" { + return nil + } + + // Shorten layer ID for display (similar to Docker) + displayID := strings.TrimPrefix(layerID, "sha256:") + if len(displayID) > 12 { + displayID = displayID[:12] + } + + dockerMsg := jsonmessage.JSONMessage{ + ID: displayID, + Status: status, + Progress: progressDetail, + } + + data, err := json.Marshal(dockerMsg) + if err != nil { + return err + } + + _, err = fmt.Fprintf(w, "%s\n", data) + return err +} + +// writeDockerStatus writes a status message in Docker's JSONMessage format +func writeDockerStatus(w io.Writer, id, status, message string) error { + msg := jsonmessage.JSONMessage{ + ID: id, + Status: status, + } + + if message != "" { + msg.Status = message + } + + data, err := json.Marshal(msg) + if err != nil { + return err + } + + _, err = fmt.Fprintf(w, "%s\n", data) + return err +} + +// writerAdapter adapts StatusPrinter to io.Writer for jsonmessage +type writerAdapter struct { + printer standalone.StatusPrinter +} + +func (w *writerAdapter) Write(p []byte) (n int, err error) { + return w.printer.Write(p) +} + +// simplePrinter is a simple StatusPrinter that just writes to a function +type simplePrinter struct { + printFunc func(string) +} + +func (p *simplePrinter) Printf(format string, args ...any) { + s := fmt.Sprintf(format, args...) + p.printFunc(s) +} + +func (p *simplePrinter) Println(args ...any) { + s := fmt.Sprintln(args...) + p.printFunc(s) +} + +func (p *simplePrinter) PrintErrf(format string, args ...any) { + // For simple printer, just print to the same output + s := fmt.Sprintf(format, args...) + p.printFunc(s) +} + +func (p *simplePrinter) Write(p2 []byte) (n int, err error) { + p.printFunc(string(p2)) + return len(p2), nil +} + +func (p *simplePrinter) GetFdInfo() (uintptr, bool) { + return 0, false +} + +// NewSimplePrinter creates a StatusPrinter from a simple print function +func NewSimplePrinter(printFunc func(string)) standalone.StatusPrinter { + return &simplePrinter{ + printFunc: printFunc, + } +} diff --git a/cmd/cli/go.mod b/cmd/cli/go.mod index 666e70c17..6c324550a 100644 --- a/cmd/cli/go.mod +++ b/cmd/cli/go.mod @@ -14,7 +14,6 @@ require ( github.com/emirpasic/gods/v2 v2.0.0-alpha github.com/fatih/color v1.18.0 github.com/google/go-containerregistry v0.20.6 - github.com/mattn/go-isatty v0.0.20 github.com/mattn/go-runewidth v0.0.16 github.com/moby/term v0.5.2 github.com/muesli/termenv v0.16.0 @@ -87,6 +86,7 @@ require ( github.com/lufia/plan9stats v0.0.0-20211012122336-39d0f177ccd0 // indirect github.com/magiconair/properties v1.8.10 // indirect github.com/mattn/go-colorable v0.1.13 // indirect + github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-shellwords v1.0.12 // indirect github.com/microcosm-cc/bluemonday v1.0.27 // indirect github.com/mitchellh/go-homedir v1.1.0 // indirect diff --git a/go.mod b/go.mod index 36f0d813c..8c69cdb3f 100644 --- a/go.mod +++ b/go.mod @@ -17,7 +17,6 @@ require ( github.com/prometheus/client_model v0.6.2 github.com/prometheus/common v0.67.1 github.com/sirupsen/logrus v1.9.3 - github.com/spf13/cobra v1.10.1 github.com/stretchr/testify v1.11.1 golang.org/x/sync v0.17.0 ) @@ -39,7 +38,6 @@ require ( github.com/go-ole/go-ole v1.3.0 // indirect github.com/gogo/protobuf v1.3.2 // indirect github.com/henvic/httpretty v0.1.4 // indirect - github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/jaypipes/pcidb v1.1.1 // indirect github.com/json-iterator/go v1.1.12 // indirect github.com/klauspost/compress v1.18.0 // indirect @@ -52,7 +50,6 @@ require ( github.com/pmezard/go-difflib v1.0.1-0.20181226105442-5d4384ee4fb2 // indirect github.com/prometheus/procfs v0.15.1 // indirect github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d // indirect - github.com/spf13/pflag v1.0.9 // indirect github.com/vbatts/tar-split v0.12.1 // indirect github.com/yusufpapurcu/wmi v1.2.4 // indirect go.opentelemetry.io/auto/sdk v1.1.0 // indirect diff --git a/go.sum b/go.sum index 972b91e11..b22527591 100644 --- a/go.sum +++ b/go.sum @@ -24,7 +24,6 @@ github.com/containerd/stargz-snapshotter/estargz v0.16.3 h1:7evrXtoh1mSbGj/pfRcc github.com/containerd/stargz-snapshotter/estargz v0.16.3/go.mod h1:uyr4BfYfOj3G9WBVE8cOlQmXAbPN9VEQpBBeJIuOipU= github.com/containerd/typeurl/v2 v2.2.3 h1:yNA/94zxWdvYACdYO8zofhrTVuQY73fFU1y++dYSw40= github.com/containerd/typeurl/v2 v2.2.3/go.mod h1:95ljDnPfD3bAbDJRugOiShd/DlAAsxGtUBhJxIn7SCk= -github.com/cpuguy83/go-md2man/v2 v2.0.6/go.mod h1:oOW0eioCTA6cOiMLiUPZOpcVxMig6NIQQ7OS05n1F4g= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.2-0.20180830191138-d8f796af33cc h1:U9qPSI2PIWSS1VwoXQT9A3Wy9MM3WgvqSxFWenqJduM= @@ -70,8 +69,6 @@ github.com/gpustack/gguf-parser-go v0.22.1 h1:FRnEDWqT0Rcplr/R9ctCRSN2+3DhVsf6dn github.com/gpustack/gguf-parser-go v0.22.1/go.mod h1:y4TwTtDqFWTK+xvprOjRUh+dowgU2TKCX37vRKvGiZ0= github.com/henvic/httpretty v0.1.4 h1:Jo7uwIRWVFxkqOnErcoYfH90o3ddQyVrSANeS4cxYmU= github.com/henvic/httpretty v0.1.4/go.mod h1:Dn60sQTZfbt2dYsdUSNsCljyF4AfdqnuJFDLJA1I4AM= -github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= -github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= github.com/jaypipes/ghw v0.19.1 h1:Lhybk6aadgEJqIxeS0h07UOL/EgMGIdxbAy6V8J7RgY= github.com/jaypipes/ghw v0.19.1/go.mod h1:GPrvwbtPoxYUenr74+nAnWbardIZq600vJDD5HnPsPE= github.com/jaypipes/pcidb v1.1.1 h1:QmPhpsbmmnCwZmHeYAATxEaoRuiMAJusKYkUncMC0ro= @@ -121,15 +118,10 @@ github.com/prometheus/procfs v0.15.1 h1:YagwOFzUgYfKKHX6Dr+sHT7km/hxC76UB0leargg github.com/prometheus/procfs v0.15.1/go.mod h1:fB45yRUv8NstnjriLhBQLuOUt+WW4BsoGhij/e3PBqk= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= -github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d h1:3VwvTjiRPA7cqtgOWddEL+JrcijMlXUmj99c/6YyZoY= github.com/smallnest/ringbuffer v0.0.0-20241116012123-461381446e3d/go.mod h1:tAG61zBM1DYRaGIPloumExGvScf08oHuo0kFoOqdbT0= -github.com/spf13/cobra v1.10.1 h1:lJeBwCfmrnXthfAupyUTzJ/J4Nc1RsHC/mSRU2dll/s= -github.com/spf13/cobra v1.10.1/go.mod h1:7SmJGaTHFVBY0jW4NXGluQoLvhqFQM+6XSKD+P4XaB0= -github.com/spf13/pflag v1.0.9 h1:9exaQaMOCwffKiiiYk6/BndUBv+iRViNW+4lEMi0PvY= -github.com/spf13/pflag v1.0.9/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= diff --git a/main.go b/main.go index e5c30ea98..99e634938 100644 --- a/main.go +++ b/main.go @@ -11,7 +11,6 @@ import ( "strings" "syscall" - "github.com/docker/model-runner/pkg/distribution/transport/resumable" "github.com/docker/model-runner/pkg/gpuinfo" "github.com/docker/model-runner/pkg/inference" "github.com/docker/model-runner/pkg/inference/backends/llamacpp" @@ -108,7 +107,7 @@ func main() { models.ClientConfig{ StoreRootPath: modelPath, Logger: log.WithFields(logrus.Fields{"component": "model-manager"}), - Transport: resumable.New(baseTransport), + Transport: baseTransport, }, nil, memEstimator, diff --git a/pkg/distribution/transport/internal/bufferfile/fifo.go b/pkg/distribution/transport/internal/bufferfile/fifo.go deleted file mode 100644 index 1f5dccb19..000000000 --- a/pkg/distribution/transport/internal/bufferfile/fifo.go +++ /dev/null @@ -1,234 +0,0 @@ -// Package bufferfile provides a FIFO implementation backed by a temporary file -// that supports concurrent reads and writes. -package bufferfile - -import ( - "fmt" - "io" - "os" - "sync" -) - -// FIFO is an io.ReadWriteCloser implementation that supports concurrent -// reads and writes to a temporary file. Reads begin from the start of the file -// and writes always append to the end. The type maintains separate read and write -// positions internally. -type FIFO struct { - // file is the underlying temporary file used for storage. - file *os.File - // mu protects all fields and synchronizes access to the FIFO. - mu sync.Mutex - // cond is used to signal waiting readers when new data becomes available - // or when the write side is closed. - cond *sync.Cond - // readPos tracks the current read position within the file. - readPos int64 - // writePos tracks the current write position within the file - // (always at EOF). - writePos int64 - // closed indicates whether Close() has been called, making the FIFO - // unusable. - closed bool - // writeClosed indicates whether CloseWrite() has been called, meaning - // no more writes will occur but reads can continue until all data is - // consumed. - writeClosed bool - // writeErr holds any persistent write error that should be returned to - // future write operations. - writeErr error -} - -// NewFIFO creates a new FIFO backed by a temporary file. -// The caller is responsible for calling Close() to clean up the temporary -// file. -func NewFIFO() (*FIFO, error) { - return NewFIFOInDir("") -} - -// NewFIFOInDir creates a new FIFO backed by a temporary file in the provided -// directory. If dir is empty, the system temporary directory is used. -// The caller is responsible for calling Close() to clean up the temporary -// file. -func NewFIFOInDir(dir string) (*FIFO, error) { - file, err := os.CreateTemp(dir, "model-buffer-*.tmp") - if err != nil { - return nil, fmt.Errorf("failed to create temporary file in dir: %w", err) - } - - fifo := &FIFO{ - file: file, - readPos: 0, - writePos: 0, - closed: false, - } - fifo.cond = sync.NewCond(&fifo.mu) - - return fifo, nil -} - -// Write implements io.Writer. Writes always append to the end of the file. -// Write is safe for concurrent use with Read. -func (f *FIFO) Write(p []byte) (int, error) { - f.mu.Lock() - defer f.mu.Unlock() - - // Check if FIFO is closed for writing. - if f.closed || f.writeClosed { - return 0, fmt.Errorf("write to closed FIFO") - } - - // Return persistent write error if we have one. - if f.writeErr != nil { - return 0, f.writeErr - } - - // Handle empty writes. - if len(p) == 0 { - return 0, nil - } - - // Seek to current write position (end of file). - _, err := f.file.Seek(f.writePos, io.SeekStart) - if err != nil { - f.writeErr = fmt.Errorf("seek to write position failed: %w", err) - return 0, f.writeErr - } - - // Write the data to the file. - n, err := f.file.Write(p) - if n > 0 { - // Update our write position to track how much data we've written. - f.writePos += int64(n) - // Signal all waiting readers that new data is available. - f.cond.Broadcast() - } - if err != nil { - // Store the error for future write attempts. - f.writeErr = fmt.Errorf("write failed: %w", err) - return n, f.writeErr - } - - return n, nil -} - -// Read implements io.Reader. Reads from the current read position in the file. -// Read blocks until data is available or the FIFO is closed. -// Read is safe for concurrent use with Write. -func (f *FIFO) Read(p []byte) (int, error) { - if len(p) == 0 { - return 0, nil - } - - f.mu.Lock() - defer f.mu.Unlock() - - for { - if f.closed { - // FIFO has been fully closed - file is closed and cleaned up. - // Return EOF immediately since no more data can be read. - return 0, io.EOF - } - - // Calculate how much unread data is available - availableBytes := f.writePos - f.readPos - if availableBytes > 0 { - // Data is available - read it immediately. - return f.readFromFile(p) - } - - // No data currently available - check if writes are finished - if f.writeClosed { - // Write side is closed and no data available - return EOF. - return 0, io.EOF - } - - // No data available and writes are still possible - wait for more - // data. - // The condition variable will be signaled when: - // - New data is written (f.cond.Broadcast() in Write). - // - Write side is closed (f.cond.Broadcast() in CloseWrite). - // - FIFO is fully closed (f.cond.Broadcast() in Close). - f.cond.Wait() - } -} - -// readFromFile performs the actual file read operation. -// Must be called with mutex held. -func (f *FIFO) readFromFile(p []byte) (int, error) { - availableBytes := f.writePos - f.readPos - toRead := int64(len(p)) - if toRead > availableBytes { - toRead = availableBytes - } - - // Seek to current read position - _, err := f.file.Seek(f.readPos, io.SeekStart) - if err != nil { - return 0, fmt.Errorf("seek to read position failed: %w", err) - } - - // Read the data - n, err := f.file.Read(p[:toRead]) - if n > 0 { - f.readPos += int64(n) - } - if err != nil && err != io.EOF { - return n, fmt.Errorf("read failed: %w", err) - } - - return n, nil -} - -// Close closes the FIFO and removes the temporary file. -// Any blocked Read or Write operations will be interrupted. -// Close is safe to call multiple times. -func (f *FIFO) Close() error { - f.mu.Lock() - defer f.mu.Unlock() - - if f.closed { - return nil - } - - f.closed = true - - // Wake up all waiting readers. - f.cond.Broadcast() - - var err error - if f.file != nil { - // Get the file name before closing for cleanup. - fileName := f.file.Name() - - // Close the file (this will interrupt any blocked I/O operations). - if closeErr := f.file.Close(); closeErr != nil { - err = fmt.Errorf("failed to close file: %w", closeErr) - } - - // Remove the temporary file. - if removeErr := os.Remove(fileName); removeErr != nil { - if err != nil { - err = fmt.Errorf("%w; also failed to remove temp file: %v", err, removeErr) - } else { - err = fmt.Errorf("failed to remove temp file: %w", removeErr) - } - } - - f.file = nil - } - - return err -} - -// CloseWrite signals that no more writes will happen. -// Readers can still read remaining data, and will receive EOF when all data -// is consumed. Does not clean up resources - use Close() for that. -func (f *FIFO) CloseWrite() { - f.mu.Lock() - defer f.mu.Unlock() - - f.writeClosed = true - - // Wake up all waiting readers to check the new state. - f.cond.Broadcast() -} diff --git a/pkg/distribution/transport/internal/bufferfile/fifo_test.go b/pkg/distribution/transport/internal/bufferfile/fifo_test.go deleted file mode 100644 index 553503f16..000000000 --- a/pkg/distribution/transport/internal/bufferfile/fifo_test.go +++ /dev/null @@ -1,621 +0,0 @@ -package bufferfile - -import ( - "bytes" - "io" - "math/rand" - "sync" - "sync/atomic" - "testing" - "time" -) - -// stat returns information about the current state of the FIFO for testing -// purposes. -func (f *FIFO) stat() (readPos, writePos int64, closed bool) { - f.mu.Lock() - defer f.mu.Unlock() - return f.readPos, f.writePos, f.closed -} - -// TestFIFO_BasicReadWrite tests that data written to a FIFO can be read -// back exactly. This is the fundamental requirement for the FIFO to work -// correctly. -func TestFIFO_BasicReadWrite(t *testing.T) { - // Arrange: Create a new FIFO - fifo, err := NewFIFO() - if err != nil { - t.Fatalf("Failed to create FIFO: %v", err) - } - defer fifo.Close() - - data := []byte("hello world") - buf := make([]byte, len(data)) - - // Act: Write data to FIFO - n, err := fifo.Write(data) - if err != nil { - t.Fatalf("Write failed: %v", err) - } - if n != len(data) { - t.Fatalf("Expected to write %d bytes, wrote %d", len(data), n) - } - - // Act: Read data back from FIFO - n, err = fifo.Read(buf) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - // Assert: Verify read data matches written data - if n != len(data) { - t.Fatalf("Expected to read %d bytes, read %d", len(data), n) - } - if !bytes.Equal(buf, data) { - t.Fatalf("Read data doesn't match written data: got %q, want %q", buf, data) - } -} - -// TestFIFO_MultipleWrites tests that multiple separate writes are -// concatenated correctly when reading back from the FIFO, preserving the -// order and boundaries. -func TestFIFO_MultipleWrites(t *testing.T) { - // Arrange: Create FIFO and test data - fifo, err := NewFIFO() - if err != nil { - t.Fatalf("Failed to create FIFO: %v", err) - } - defer fifo.Close() - - chunks := [][]byte{ - []byte("chunk1"), - []byte("chunk2"), - []byte("chunk3"), - } - - // Act: Write multiple chunks sequentially - for i, chunk := range chunks { - n, err := fifo.Write(chunk) - if err != nil { - t.Fatalf("Write %d failed: %v", i, err) - } - if n != len(chunk) { - t.Fatalf("Write %d: expected %d bytes, wrote %d", i, len(chunk), n) - } - } - - // Act: Read all data back - expected := bytes.Join(chunks, nil) - buf := make([]byte, len(expected)) - totalRead := 0 - - for totalRead < len(expected) { - n, err := fifo.Read(buf[totalRead:]) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - totalRead += n - } - - // Assert: Verify concatenated data is correct - if !bytes.Equal(buf, expected) { - t.Fatalf("Read data doesn't match expected: got %q, want %q", buf, expected) - } -} - -// TestFIFO_PartialReads tests that data can be read in smaller chunks than -// it was written, ensuring proper read position tracking. -func TestFIFO_PartialReads(t *testing.T) { - fifo, err := NewFIFO() - if err != nil { - t.Fatalf("Failed to create FIFO: %v", err) - } - defer fifo.Close() - - // Write data - data := []byte("0123456789") - _, err = fifo.Write(data) - if err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Read in small chunks - buf := make([]byte, 3) // Smaller than data - var result []byte - - for len(result) < len(data) { - n, err := fifo.Read(buf) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - result = append(result, buf[:n]...) - } - - if !bytes.Equal(result, data) { - t.Fatalf("Partial read result doesn't match: got %q, want %q", result, data) - } -} - -// TestFIFO_ConcurrentReadWrite tests that multiple concurrent writers and -// readers can safely access the FIFO without data corruption or race -// conditions. -func TestFIFO_ConcurrentReadWrite(t *testing.T) { - fifo, err := NewFIFO() - if err != nil { - t.Fatalf("Failed to create FIFO: %v", err) - } - defer fifo.Close() - - const numWriters = 3 - const numChunksPerWriter = 100 - const chunkSize = 100 - - var wg sync.WaitGroup - var writeOrder []int - var writeOrderMu sync.Mutex - - // Start multiple writers - for writerID := 0; writerID < numWriters; writerID++ { - wg.Add(1) - go func(id int) { - defer wg.Done() - for i := 0; i < numChunksPerWriter; i++ { - // Create unique data for this writer and chunk - data := make([]byte, chunkSize) - for j := range data { - data[j] = byte((id*1000 + i) % 256) - } - - writeOrderMu.Lock() - writeOrder = append(writeOrder, id*1000+i) - writeOrderMu.Unlock() - - _, err := fifo.Write(data) - if err != nil { - t.Errorf("Writer %d chunk %d failed: %v", id, i, err) - return - } - } - }(writerID) - } - - // Read all data - var readData []byte - totalExpected := numWriters * numChunksPerWriter * chunkSize - buf := make([]byte, 1024) // Read buffer - - readDone := make(chan struct{}) - go func() { - defer close(readDone) - for len(readData) < totalExpected { - n, err := fifo.Read(buf) - if err != nil { - t.Errorf("Read failed: %v", err) - return - } - readData = append(readData, buf[:n]...) - } - }() - - // Wait for all writes to complete - wg.Wait() - - // Wait for all reads to complete - select { - case <-readDone: - // Success - case <-time.After(5 * time.Second): - t.Fatal("Read timed out") - } - - if len(readData) != totalExpected { - t.Fatalf("Expected to read %d bytes, got %d", totalExpected, len(readData)) - } - - t.Logf("Successfully handled %d concurrent writers writing %d total bytes", - numWriters, totalExpected) -} - -// TestFIFO_ReadBlocksUntilData tests that reads block when no data is -// available and unblock immediately when data is written, which is essential -// for the streaming behavior needed by the parallel transport. -func TestFIFO_ReadBlocksUntilData(t *testing.T) { - fifo, err := NewFIFO() - if err != nil { - t.Fatalf("Failed to create FIFO: %v", err) - } - defer fifo.Close() - - buf := make([]byte, 10) - readDone := make(chan struct{}) - var readErr error - - // Start a reader that should block - go func() { - defer close(readDone) - _, readErr = fifo.Read(buf) - }() - - // Ensure reader is blocked - select { - case <-readDone: - t.Fatal("Read should have blocked") - case <-time.After(100 * time.Millisecond): - // Good, read is blocked - } - - // Write data to unblock reader - data := []byte("test") - _, err = fifo.Write(data) - if err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Now read should complete - select { - case <-readDone: - if readErr != nil { - t.Fatalf("Read failed: %v", readErr) - } - case <-time.After(time.Second): - t.Fatal("Read did not complete after write") - } -} - -// TestFIFO_CloseInterruptsRead tests that Close() interrupts blocked -// readers and causes them to return EOF, which is needed for proper cleanup. -func TestFIFO_CloseInterruptsRead(t *testing.T) { - fifo, err := NewFIFO() - if err != nil { - t.Fatalf("Failed to create FIFO: %v", err) - } - - buf := make([]byte, 10) - readDone := make(chan struct{}) - var readN int - var readErr error - - // Start a reader that should block - go func() { - defer close(readDone) - readN, readErr = fifo.Read(buf) - }() - - // Ensure reader is blocked - select { - case <-readDone: - t.Fatal("Read should have blocked") - case <-time.After(100 * time.Millisecond): - // Good, read is blocked - } - - // Close FIFO to interrupt read - err = fifo.Close() - if err != nil { - t.Fatalf("Close failed: %v", err) - } - - // Read should complete with EOF - select { - case <-readDone: - if readErr != io.EOF { - t.Fatalf("Expected EOF after close, got: %v", readErr) - } - if readN != 0 { - t.Fatalf("Expected 0 bytes read after close, got %d", readN) - } - case <-time.After(time.Second): - t.Fatal("Read did not complete after close") - } -} - -// TestFIFO_CloseWithPendingData tests that Close() immediately makes all -// data unavailable, which implements the interruptible FIFO semantics. -func TestFIFO_CloseWithPendingData(t *testing.T) { - fifo, err := NewFIFO() - if err != nil { - t.Fatalf("Failed to create FIFO: %v", err) - } - - // Write some data - data := []byte("pending data") - _, err = fifo.Write(data) - if err != nil { - t.Fatalf("Write failed: %v", err) - } - - // Close FIFO - err = fifo.Close() - if err != nil { - t.Fatalf("Close failed: %v", err) - } - - // After close, reads should return EOF immediately (data is lost) - buf := make([]byte, len(data)) - n, err := fifo.Read(buf) - if err != io.EOF { - t.Fatalf("Expected EOF after close, got: %v", err) - } - if n != 0 { - t.Fatalf("Expected 0 bytes read after close, got %d", n) - } -} - -// TestFIFO_WriteAfterClose tests that writes fail after the FIFO is closed. -func TestFIFO_WriteAfterClose(t *testing.T) { - fifo, err := NewFIFO() - if err != nil { - t.Fatalf("Failed to create FIFO: %v", err) - } - - err = fifo.Close() - if err != nil { - t.Fatalf("Close failed: %v", err) - } - - // Write after close should fail - _, err = fifo.Write([]byte("test")) - if err == nil { - t.Fatal("Expected write after close to fail") - } - - // Even empty writes should fail after close - _, err = fifo.Write(nil) - if err == nil { - t.Fatal("Expected empty write after close to fail") - } -} - -// TestFIFO_WriteAfterCloseWrite tests that writes fail after CloseWrite -// is called. -func TestFIFO_WriteAfterCloseWrite(t *testing.T) { - fifo, err := NewFIFO() - if err != nil { - t.Fatalf("Failed to create FIFO: %v", err) - } - defer fifo.Close() - - fifo.CloseWrite() - - // Write after CloseWrite should fail - _, err = fifo.Write([]byte("test")) - if err == nil { - t.Fatal("Expected write after CloseWrite to fail") - } - - // Even empty writes should fail after CloseWrite - _, err = fifo.Write(nil) - if err == nil { - t.Fatal("Expected empty write after CloseWrite to fail") - } -} - -// TestFIFO_Stat tests the internal stat method used for debugging and -// testing position tracking. -func TestFIFO_Stat(t *testing.T) { - fifo, err := NewFIFO() - if err != nil { - t.Fatalf("Failed to create FIFO: %v", err) - } - defer fifo.Close() - - // Check initial state - readPos, writePos, closed := fifo.stat() - if readPos != 0 || writePos != 0 || closed { - t.Fatalf("Initial state wrong: readPos=%d, writePos=%d, closed=%v", - readPos, writePos, closed) - } - - // Write some data - data := []byte("test data") - _, err = fifo.Write(data) - if err != nil { - t.Fatalf("Write failed: %v", err) - } - - readPos, writePos, closed = fifo.stat() - if readPos != 0 || writePos != int64(len(data)) || closed { - t.Fatalf("After write state wrong: readPos=%d, writePos=%d, closed=%v", - readPos, writePos, closed) - } - - // Read some data - buf := make([]byte, 4) - n, err := fifo.Read(buf) - if err != nil { - t.Fatalf("Read failed: %v", err) - } - - readPos, writePos, closed = fifo.stat() - if readPos != int64(n) || writePos != int64(len(data)) || closed { - t.Fatalf("After read state wrong: readPos=%d, writePos=%d, closed=%v", - readPos, writePos, closed) - } - - // Close and check - fifo.Close() - readPos, writePos, closed = fifo.stat() - if !closed { - t.Fatal("FIFO should be marked as closed") - } -} - -// TestFIFO_StressTest performs concurrent read/write operations to test -// for race conditions and data corruption under heavy load. -func TestFIFO_StressTest(t *testing.T) { - if testing.Short() { - t.Skip("Skipping stress test in short mode") - } - - fifo, err := NewFIFO() - if err != nil { - t.Fatalf("Failed to create FIFO: %v", err) - } - defer fifo.Close() - - const duration = 2 * time.Second - const maxWriteSize = 1024 - const maxReadSize = 512 - - var totalWritten int64 - var totalRead int64 - var wg sync.WaitGroup - - // Start writer goroutine - wg.Add(1) - go func() { - defer wg.Done() - defer fifo.CloseWrite() - - // Signal to readers that no more bytes will arrive once the writer - // finishes so blocked reads can terminate. - start := time.Now() - for time.Since(start) < duration { - size := rand.Intn(maxWriteSize) + 1 - data := make([]byte, size) - rand.Read(data) - - n, err := fifo.Write(data) - if err != nil { - t.Errorf("Write failed: %v", err) - return - } - atomic.AddInt64(&totalWritten, int64(n)) - } - }() - - // Start reader goroutine - wg.Add(1) - go func() { - defer wg.Done() - buf := make([]byte, maxReadSize) - start := time.Now() - - for time.Since(start) < duration+time.Second { - // Give extra time to read. - n, err := fifo.Read(buf) - if err == io.EOF { - break - } - if err != nil { - t.Errorf("Read failed: %v", err) - return - } - atomic.AddInt64(&totalRead, int64(n)) - - // If we've read everything written and writer is done, we're - // done. - if atomic.LoadInt64(&totalRead) >= atomic.LoadInt64(&totalWritten) && - time.Since(start) > duration { - break - } - } - }() - - wg.Wait() - - finalWritten := atomic.LoadInt64(&totalWritten) - finalRead := atomic.LoadInt64(&totalRead) - t.Logf("Stress test completed: wrote %d bytes, read %d bytes", - finalWritten, finalRead) - - if finalRead > finalWritten { - t.Fatalf("Read more than written: read=%d, written=%d", - finalRead, finalWritten) - } -} - -// TestFIFO_EmptyOperations tests that empty reads and writes are handled -// correctly. -func TestFIFO_EmptyOperations(t *testing.T) { - fifo, err := NewFIFO() - if err != nil { - t.Fatalf("Failed to create FIFO: %v", err) - } - defer fifo.Close() - - // Test empty write - n, err := fifo.Write(nil) - if err != nil { - t.Fatalf("Empty write failed: %v", err) - } - if n != 0 { - t.Fatalf("Expected 0 bytes written for empty write, got %d", n) - } - - // Test empty read - n, err = fifo.Read(nil) - if err != nil { - t.Fatalf("Empty read failed: %v", err) - } - if n != 0 { - t.Fatalf("Expected 0 bytes read for empty read, got %d", n) - } -} - -// TestFIFO_MultipleClose tests that calling Close() multiple times is -// safe and doesn't cause errors or panics. -func TestFIFO_MultipleClose(t *testing.T) { - fifo, err := NewFIFO() - if err != nil { - t.Fatalf("Failed to create FIFO: %v", err) - } - - // First close should succeed - err = fifo.Close() - if err != nil { - t.Fatalf("First close failed: %v", err) - } - - // Second close should not panic and should not error - err = fifo.Close() - if err != nil { - t.Fatalf("Second close failed: %v", err) - } -} - -// Benchmark tests. -// BenchmarkFIFO_Write measures the performance of write operations. -func BenchmarkFIFO_Write(b *testing.B) { - fifo, err := NewFIFO() - if err != nil { - b.Fatalf("Failed to create FIFO: %v", err) - } - defer fifo.Close() - - data := make([]byte, 1024) - rand.Read(data) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, err := fifo.Write(data) - if err != nil { - b.Fatalf("Write failed: %v", err) - } - } -} - -// BenchmarkFIFO_Read measures the performance of read operations. -func BenchmarkFIFO_Read(b *testing.B) { - fifo, err := NewFIFO() - if err != nil { - b.Fatalf("Failed to create FIFO: %v", err) - } - defer fifo.Close() - - // Pre-fill with data - data := make([]byte, 1024) - rand.Read(data) - for i := 0; i < b.N; i++ { - fifo.Write(data) - } - - buf := make([]byte, 1024) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _, err := fifo.Read(buf) - if err != nil { - b.Fatalf("Read failed: %v", err) - } - } -} diff --git a/pkg/distribution/transport/internal/common/http_utils.go b/pkg/distribution/transport/internal/common/http_utils.go deleted file mode 100644 index a71186ddd..000000000 --- a/pkg/distribution/transport/internal/common/http_utils.go +++ /dev/null @@ -1,110 +0,0 @@ -// Package common provides shared utilities for HTTP transport implementations. -package common - -import ( - "net/http" - "strconv" - "strings" -) - -// SupportsRange determines whether an HTTP response indicates support for range requests. -func SupportsRange(h http.Header) bool { - ar := strings.ToLower(h.Get("Accept-Ranges")) - for _, part := range strings.Split(ar, ",") { - if strings.TrimSpace(part) == "bytes" { - return true - } - } - return false -} - -// ScrubConditionalHeaders removes conditional headers we do not want to forward -// on range requests, because they can alter semantics or conflict with If-Range logic. -func ScrubConditionalHeaders(h http.Header) { - h.Del("If-None-Match") - h.Del("If-Modified-Since") - h.Del("If-Match") - h.Del("If-Unmodified-Since") - // Range/If-Range headers are set explicitly by the caller. -} - -// IsWeakETag reports whether the ETag is a weak validator (W/"...") which must -// not be used with If-Range per RFC 7232 §2.1. -func IsWeakETag(etag string) bool { - etag = strings.TrimSpace(etag) - return strings.HasPrefix(etag, "W/") || strings.HasPrefix(etag, "w/") -} - -// ParseSingleRange parses a single "Range: bytes=start-end" header. -// It returns (start, end, ok). When end is omitted, end == -1. -// -// Notes: -// - Only absolute-start forms are supported (no suffix ranges "-N"). -// - Multi-range specifications (comma separated) return ok == false. -func ParseSingleRange(h string) (int64, int64, bool) { - if h == "" { - return 0, -1, false - } - h = strings.TrimSpace(h) - if !strings.HasPrefix(strings.ToLower(h), "bytes=") { - return 0, -1, false - } - spec := strings.TrimSpace(h[len("bytes="):]) - if strings.Contains(spec, ",") { - return 0, -1, false - } - parts := strings.SplitN(spec, "-", 2) - if len(parts) != 2 { - return 0, -1, false - } - if parts[0] == "" { - // Suffix form is not supported here. - return 0, -1, false - } - start, err := strconv.ParseInt(strings.TrimSpace(parts[0]), 10, 64) - if err != nil || start < 0 { - return 0, -1, false - } - end := int64(-1) - if strings.TrimSpace(parts[1]) != "" { - e, err := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 64) - if err != nil || e < start { - return 0, -1, false - } - end = e - } - return start, end, true -} - -// ParseContentRange parses "Content-Range: bytes start-end/total". It -// returns (start, end, total, ok). When total is unknown, total == -1. -func ParseContentRange(h string) (int64, int64, int64, bool) { - if h == "" { - return 0, -1, -1, false - } - h = strings.ToLower(strings.TrimSpace(h)) - if !strings.HasPrefix(h, "bytes ") { - return 0, -1, -1, false - } - body := strings.TrimSpace(h[len("bytes "):]) - seTotal := strings.SplitN(body, "/", 2) - if len(seTotal) != 2 { - return 0, -1, -1, false - } - se := strings.SplitN(strings.TrimSpace(seTotal[0]), "-", 2) - if len(se) != 2 { - return 0, -1, -1, false - } - start, err1 := strconv.ParseInt(strings.TrimSpace(se[0]), 10, 64) - end, err2 := strconv.ParseInt(strings.TrimSpace(se[1]), 10, 64) - totalStr := strings.TrimSpace(seTotal[1]) - var total int64 = -1 - var err3 error - if totalStr != "*" { - total, err3 = strconv.ParseInt(totalStr, 10, 64) - } - if err1 != nil || err2 != nil || (err3 != nil && totalStr != "*") { - return 0, -1, -1, false - } - return start, end, total, true -} diff --git a/pkg/distribution/transport/internal/common/http_utils_test.go b/pkg/distribution/transport/internal/common/http_utils_test.go deleted file mode 100644 index 1becd3b18..000000000 --- a/pkg/distribution/transport/internal/common/http_utils_test.go +++ /dev/null @@ -1,195 +0,0 @@ -package common - -import ( - "net/http" - "testing" -) - -// TestParseSingleRange exercises valid and invalid single-range specs. -func TestParseSingleRange(t *testing.T) { - cases := []struct { - in string - start, end int64 - ok bool - }{ - {"", 0, -1, false}, - {"bytes=0-99", 0, 99, true}, - {"bytes=0-", 0, -1, true}, - {"bytes=5-5", 5, 5, true}, - {"BYTES=7-9", 7, 9, true}, - // End before start. - {"bytes=10-5", 0, -1, false}, - // Suffix not supported. - {"bytes=-100", 0, -1, false}, - {"items=0-10", 0, -1, false}, - // Multi-range unsupported. - {"bytes=0-1,3-5", 0, -1, false}, - } - for _, tc := range cases { - start, end, ok := ParseSingleRange(tc.in) - if start != tc.start || end != tc.end || ok != tc.ok { - t.Errorf("ParseSingleRange(%q) = (%d,%d,%v), want (%d,%d,%v)", tc.in, start, end, ok, tc.start, tc.end, tc.ok) - } - } -} - -// TestParseContentRange exercises valid and invalid Content-Range headers. -func TestParseContentRange(t *testing.T) { - cases := []struct { - in string - start, end int64 - total int64 - ok bool - }{ - {"", 0, -1, -1, false}, - {"bytes 0-99/200", 0, 99, 200, true}, - {"BYTES 1-1/2", 1, 1, 2, true}, - {"bytes 0-0/*", 0, 0, -1, true}, - {"items 0-1/2", 0, -1, -1, false}, - {"bytes 0-99/abc", 0, -1, -1, false}, - // Parser accepts; semantic check happens elsewhere. - {"bytes 5-4/10", 5, 4, 10, true}, - } - for _, tc := range cases { - start, end, total, ok := ParseContentRange(tc.in) - if start != tc.start || end != tc.end || total != tc.total || ok != tc.ok { - t.Errorf("ParseContentRange(%q) = (%d,%d,%d,%v), want (%d,%d,%d,%v)", tc.in, start, end, total, ok, tc.start, tc.end, tc.total, tc.ok) - } - } -} - -// TestSupportsRange tests the Accept-Ranges header parsing. -func TestSupportsRange(t *testing.T) { - cases := []struct { - name string - header http.Header - expected bool - }{ - { - name: "no header", - header: http.Header{}, - expected: false, - }, - { - name: "bytes supported", - header: http.Header{"Accept-Ranges": []string{"bytes"}}, - expected: true, - }, - { - name: "bytes with mixed case", - header: http.Header{"Accept-Ranges": []string{"BYTES"}}, - expected: true, - }, - { - name: "bytes with other values", - header: http.Header{"Accept-Ranges": []string{"none, bytes"}}, - expected: true, - }, - { - name: "none only", - header: http.Header{"Accept-Ranges": []string{"none"}}, - expected: false, - }, - { - name: "other unit", - header: http.Header{"Accept-Ranges": []string{"items"}}, - expected: false, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - result := SupportsRange(tc.header) - if result != tc.expected { - t.Errorf("SupportsRange() = %v, want %v", result, tc.expected) - } - }) - } -} - -// TestIsWeakETag tests weak ETag detection. -func TestIsWeakETag(t *testing.T) { - cases := []struct { - name string - etag string - expected bool - }{ - { - name: "strong etag", - etag: `"abc123"`, - expected: false, - }, - { - name: "weak etag uppercase W", - etag: `W/"abc123"`, - expected: true, - }, - { - name: "weak etag lowercase w", - etag: `w/"abc123"`, - expected: true, - }, - { - name: "empty", - etag: "", - expected: false, - }, - { - name: "with spaces", - etag: ` W/"abc123" `, - expected: true, - }, - { - name: "malformed but starts with W", - etag: "W/malformed", - expected: true, - }, - } - - for _, tc := range cases { - t.Run(tc.name, func(t *testing.T) { - result := IsWeakETag(tc.etag) - if result != tc.expected { - t.Errorf("IsWeakETag(%q) = %v, want %v", tc.etag, result, tc.expected) - } - }) - } -} - -// TestScrubConditionalHeaders tests conditional header removal. -func TestScrubConditionalHeaders(t *testing.T) { - // Set up test headers with both conditional and non-conditional headers. - headers := http.Header{ - "If-None-Match": []string{`"etag1"`}, - "If-Modified-Since": []string{"Wed, 21 Oct 2015 07:28:00 GMT"}, - "If-Match": []string{`"etag2"`}, - "If-Unmodified-Since": []string{"Thu, 22 Oct 2015 07:28:00 GMT"}, - "Range": []string{"bytes=0-99"}, - "If-Range": []string{`"etag3"`}, - "Authorization": []string{"Bearer token"}, - } - - // Scrub the conditional headers. - ScrubConditionalHeaders(headers) - - // Verify conditional headers are removed. - conditionalHeaders := []string{ - "If-None-Match", - "If-Modified-Since", - "If-Match", - "If-Unmodified-Since", - } - for _, header := range conditionalHeaders { - if headers.Get(header) != "" { - t.Errorf("conditional header %s was not scrubbed", header) - } - } - - // Verify other headers are preserved. - preservedHeaders := []string{"Range", "If-Range", "Authorization"} - for _, header := range preservedHeaders { - if headers.Get(header) == "" { - t.Errorf("header %s was incorrectly removed", header) - } - } -} diff --git a/pkg/distribution/transport/internal/testing/fake_transport.go b/pkg/distribution/transport/internal/testing/fake_transport.go deleted file mode 100644 index 49c69fa9d..000000000 --- a/pkg/distribution/transport/internal/testing/fake_transport.go +++ /dev/null @@ -1,340 +0,0 @@ -// Package testing provides common test utilities for transport packages. -package testing - -import ( - "bytes" - "fmt" - "io" - "net/http" - "strconv" - "strings" - "sync" -) - -// FakeResource represents a resource that can be served by FakeTransport. -type FakeResource struct { - // Data provides random access to the resource content. - Data io.ReaderAt - // Length is the total number of bytes in the resource content. - Length int64 - // SupportsRange indicates if this resource supports byte ranges. - SupportsRange bool - // ETag is the ETag header value (optional). - ETag string - // LastModified is the Last-Modified header value (optional). - LastModified string - // ContentType is the Content-Type header value (optional). - ContentType string - // Headers are additional headers to include in responses. - Headers http.Header -} - -// FakeTransport is a test http.RoundTripper that serves fake resources. -type FakeTransport struct { - mu sync.Mutex - resources map[string]*FakeResource - requests []http.Request - // FailAfter causes the transport to fail after serving this many bytes - // on a request (for simulating connection failures). - failAfter map[string]int - // failCount tracks how many times we've failed for each URL. - failCount map[string]int - // RequestHook is called for each request if set. - RequestHook func(*http.Request) - // ResponseHook is called for each response if set. - ResponseHook func(*http.Response) -} - -// NewFakeTransport creates a new FakeTransport. -func NewFakeTransport() *FakeTransport { - return &FakeTransport{ - resources: make(map[string]*FakeResource), - failAfter: make(map[string]int), - failCount: make(map[string]int), - } -} - -// Add adds a resource to the fake transport. -func (ft *FakeTransport) Add(url string, resource *FakeResource) { - ft.mu.Lock() - defer ft.mu.Unlock() - ft.resources[url] = resource -} - -// AddSimple adds a simple resource with the provided reader and length. -func (ft *FakeTransport) AddSimple(url string, data io.ReaderAt, length int64, supportsRange bool) { - ft.Add(url, &FakeResource{ - Data: data, - Length: length, - SupportsRange: supportsRange, - }) -} - -// SetFailAfter configures the transport to fail after serving n bytes for -// the given URL. -func (ft *FakeTransport) SetFailAfter(url string, n int) { - ft.mu.Lock() - defer ft.mu.Unlock() - ft.failAfter[url] = n -} - -// GetRequests returns a copy of all requests made to this transport. -func (ft *FakeTransport) GetRequests() []http.Request { - ft.mu.Lock() - defer ft.mu.Unlock() - reqs := make([]http.Request, len(ft.requests)) - copy(reqs, ft.requests) - return reqs -} - -// GetRequestHeaders returns the headers from all requests for a given URL. -func (ft *FakeTransport) GetRequestHeaders(url string) []http.Header { - ft.mu.Lock() - defer ft.mu.Unlock() - - var headers []http.Header - for _, req := range ft.requests { - if req.URL.String() == url { - h := make(http.Header) - for k, v := range req.Header { - h[k] = append([]string(nil), v...) - } - headers = append(headers, h) - } - } - return headers -} - -// RoundTrip implements http.RoundTripper. -func (ft *FakeTransport) RoundTrip(req *http.Request) (*http.Response, error) { - ft.mu.Lock() - // Store request - reqCopy := *req - if req.Header != nil { - reqCopy.Header = req.Header.Clone() - } - ft.requests = append(ft.requests, reqCopy) - - // Get resource - resource, exists := ft.resources[req.URL.String()] - failAfter := ft.failAfter[req.URL.String()] - ft.mu.Unlock() - - if ft.RequestHook != nil { - ft.RequestHook(req) - } - - if !exists { - return &http.Response{ - StatusCode: http.StatusNotFound, - Status: "404 Not Found", - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - Body: io.NopCloser(bytes.NewReader(nil)), - Request: req, - }, nil - } - - // Handle HEAD request - if req.Method == http.MethodHead { - resp := ft.createResponse(req, resource, nil, http.StatusOK) - if ft.ResponseHook != nil { - ft.ResponseHook(resp) - } - return resp, nil - } - - // Handle Range request - if rangeHeader := req.Header.Get("Range"); rangeHeader != "" && resource.SupportsRange { - return ft.handleRangeRequest(req, resource, rangeHeader, failAfter) - } - - // Regular GET request - var body io.ReadCloser - if failAfter > 0 && ft.getFailCount(req.URL.String()) == 0 { - // First request - fail after specified bytes - body = NewFlakyReader(resource.Data, resource.Length, failAfter) - ft.incrementFailCount(req.URL.String()) - } else { - // Subsequent request or no failure configured - body = io.NopCloser(io.NewSectionReader(resource.Data, 0, resource.Length)) - } - - resp := ft.createResponse(req, resource, body, http.StatusOK) - if ft.ResponseHook != nil { - ft.ResponseHook(resp) - } - return resp, nil -} - -// handleRangeRequest serves a single byte range request for a resource. -// It validates the Range and If-Range headers and returns either 206 with the -// requested slice, or 200 with the full resource if validation fails. -// Multi-range specifications are not supported and result in 400. -func (ft *FakeTransport) handleRangeRequest(req *http.Request, resource *FakeResource, rangeHeader string, failAfter int) (*http.Response, error) { - // Parse range header (simplified - only handles single ranges) - if !strings.HasPrefix(rangeHeader, "bytes=") { - return ft.createErrorResponse(req, http.StatusBadRequest), nil - } - - rangeSpec := strings.TrimPrefix(rangeHeader, "bytes=") - parts := strings.Split(rangeSpec, "-") - if len(parts) != 2 { - return ft.createErrorResponse(req, http.StatusBadRequest), nil - } - - var start, end int64 - var err error - - if parts[0] != "" { - start, err = strconv.ParseInt(parts[0], 10, 64) - if err != nil { - return ft.createErrorResponse(req, http.StatusBadRequest), nil - } - } - - if parts[1] != "" { - end, err = strconv.ParseInt(parts[1], 10, 64) - if err != nil { - return ft.createErrorResponse(req, http.StatusBadRequest), nil - } - } else { - end = resource.Length - 1 - } - - // Validate range - if start < 0 || end >= resource.Length || start > end { - resp := ft.createErrorResponse(req, http.StatusRequestedRangeNotSatisfiable) - resp.Header.Set("Content-Range", fmt.Sprintf("bytes */%d", resource.Length)) - if ft.ResponseHook != nil { - ft.ResponseHook(resp) - } - return resp, nil - } - - // Check If-Range - if ifRange := req.Header.Get("If-Range"); ifRange != "" { - // Check if If-Range matches either ETag or Last-Modified - matches := false - - // Only match strong ETags for If-Range - if resource.ETag != "" && !strings.HasPrefix(resource.ETag, "W/") { - if ifRange == resource.ETag { - matches = true - } - } - - // Also check Last-Modified - if !matches && resource.LastModified != "" { - if ifRange == resource.LastModified { - matches = true - } - } - - if !matches { - // Validator doesn't match - return full content - body := NewFlakyReader(resource.Data, resource.Length, failAfter) - resp := ft.createResponse(req, resource, body, http.StatusOK) - if ft.ResponseHook != nil { - ft.ResponseHook(resp) - } - return resp, nil - } - } - - // Serve range - body := io.NopCloser(io.NewSectionReader(resource.Data, start, end-start+1)) - - resp := ft.createResponse(req, resource, body, http.StatusPartialContent) - resp.Header.Set("Content-Range", fmt.Sprintf("bytes %d-%d/%d", start, end, resource.Length)) - resp.ContentLength = end - start + 1 - - if ft.ResponseHook != nil { - ft.ResponseHook(resp) - } - return resp, nil -} - -// createResponse builds a basic http.Response for the given resource and -// status code, copying standard headers and any optional metadata. -func (ft *FakeTransport) createResponse(req *http.Request, resource *FakeResource, body io.ReadCloser, statusCode int) *http.Response { - if body == nil { - body = io.NopCloser(bytes.NewReader(nil)) - } - - resp := &http.Response{ - StatusCode: statusCode, - Status: http.StatusText(statusCode), - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - Body: body, - Request: req, - } - - // Set standard headers - if resource.SupportsRange { - resp.Header.Set("Accept-Ranges", "bytes") - } - - if resource.ETag != "" { - resp.Header.Set("ETag", resource.ETag) - } - - if resource.LastModified != "" { - resp.Header.Set("Last-Modified", resource.LastModified) - } - - if resource.ContentType != "" { - resp.Header.Set("Content-Type", resource.ContentType) - } - - // Copy additional headers - if resource.Headers != nil { - for k, v := range resource.Headers { - resp.Header[k] = v - } - } - - // Set Content-Length - if statusCode == http.StatusOK { - resp.ContentLength = resource.Length - resp.Header.Set("Content-Length", strconv.FormatInt(resource.Length, 10)) - } - - return resp -} - -// createErrorResponse constructs a minimal error response with the provided -// status code and an empty body. -func (ft *FakeTransport) createErrorResponse(req *http.Request, statusCode int) *http.Response { - return &http.Response{ - StatusCode: statusCode, - Status: http.StatusText(statusCode), - Proto: "HTTP/1.1", - ProtoMajor: 1, - ProtoMinor: 1, - Header: make(http.Header), - Body: io.NopCloser(bytes.NewReader(nil)), - Request: req, - } -} - -// getFailCount returns how many failures have been injected for the URL so -// far. It is safe for concurrent use. -func (ft *FakeTransport) getFailCount(url string) int { - ft.mu.Lock() - defer ft.mu.Unlock() - return ft.failCount[url] -} - -// incrementFailCount increments the injected failure counter for the URL. -// It is safe for concurrent use. -func (ft *FakeTransport) incrementFailCount(url string) { - ft.mu.Lock() - defer ft.mu.Unlock() - ft.failCount[url]++ -} diff --git a/pkg/distribution/transport/internal/testing/flaky_reader.go b/pkg/distribution/transport/internal/testing/flaky_reader.go deleted file mode 100644 index 707d1ae6e..000000000 --- a/pkg/distribution/transport/internal/testing/flaky_reader.go +++ /dev/null @@ -1,244 +0,0 @@ -package testing - -import ( - "errors" - "io" - "sync" -) - -// ErrFlakyFailure is returned when FlakyReader simulates a failure. -var ErrFlakyFailure = errors.New("simulated read failure") - -// FlakyReader simulates a reader that fails after a certain number of -// bytes. -type FlakyReader struct { - // data holds the content to be read through random access reads. - data io.ReaderAt - // length is the total number of readable bytes. - length int64 - // failAfter is the byte position after which reads should fail. - failAfter int64 - // pos is the current read position. - pos int64 - // failed indicates if the reader has already failed. - failed bool - // closed indicates if the reader has been closed. - closed bool - // mu protects all fields from concurrent access. - mu sync.Mutex -} - -// NewFlakyReader creates a FlakyReader that fails after reading failAfter -// bytes. If failAfter is 0 or negative, it never fails. -func NewFlakyReader(data io.ReaderAt, length int64, failAfter int) *FlakyReader { - return &FlakyReader{ - data: data, - length: length, - failAfter: int64(failAfter), - } -} - -// Read implements io.Reader. -func (fr *FlakyReader) Read(p []byte) (int, error) { - fr.mu.Lock() - defer fr.mu.Unlock() - - if fr.closed { - return 0, errors.New("read from closed reader") - } - - if fr.failed { - return 0, ErrFlakyFailure - } - - if fr.pos >= fr.length { - return 0, io.EOF - } - - // Calculate how much we can read. - remaining := fr.length - fr.pos - toRead := int64(len(p)) - if toRead > remaining { - toRead = remaining - } - - // Check if we should fail. - if fr.failAfter > 0 && fr.pos+toRead > fr.failAfter { - toRead = fr.failAfter - fr.pos - if toRead <= 0 { - fr.failed = true - return 0, ErrFlakyFailure - } - } - - if toRead == 0 { - return 0, nil - } - - buf := p[:toRead] - n, err := fr.data.ReadAt(buf, fr.pos) - fr.pos += int64(n) - - if err != nil && err != io.EOF { - return n, err - } - - if fr.failAfter > 0 && fr.pos >= fr.failAfter && fr.pos < fr.length { - fr.failed = true - if n == 0 { - return 0, ErrFlakyFailure - } - } - - if fr.pos >= fr.length { - return n, io.EOF - } - - if err == io.EOF { - return n, io.EOF - } - - return n, nil -} - -// Close implements io.Closer. -func (fr *FlakyReader) Close() error { - fr.mu.Lock() - defer fr.mu.Unlock() - fr.closed = true - return nil -} - -// Reset resets the reader to start from the beginning. -func (fr *FlakyReader) Reset() { - fr.mu.Lock() - defer fr.mu.Unlock() - fr.pos = 0 - fr.failed = false - fr.closed = false -} - -// Position returns the current read position. -func (fr *FlakyReader) Position() int { - fr.mu.Lock() - defer fr.mu.Unlock() - return int(fr.pos) -} - -// HasFailed returns true if the reader has simulated a failure. -func (fr *FlakyReader) HasFailed() bool { - fr.mu.Lock() - defer fr.mu.Unlock() - return fr.failed -} - -// MultiFailReader simulates multiple failures at different points. -type MultiFailReader struct { - // data holds the content to be read through random access reads. - data io.ReaderAt - // length is the total number of readable bytes. - length int64 - // failurePoints are the byte positions where failures should occur. - failurePoints []int - // failureCount tracks how many failures have been simulated. - failureCount int - // pos is the current read position. - pos int64 - // closed indicates if the reader has been closed. - closed bool - // mu protects all fields from concurrent access. - mu sync.Mutex -} - -// NewMultiFailReader creates a reader that fails at specified byte -// positions. -func NewMultiFailReader(data io.ReaderAt, length int64, failurePoints []int) *MultiFailReader { - return &MultiFailReader{ - data: data, - length: length, - failurePoints: failurePoints, - } -} - -// Read implements io.Reader. -func (mfr *MultiFailReader) Read(p []byte) (int, error) { - mfr.mu.Lock() - defer mfr.mu.Unlock() - - if mfr.closed { - return 0, errors.New("read from closed reader") - } - - if mfr.pos >= mfr.length { - return 0, io.EOF - } - - // Check if we're at a failure point. - for i, point := range mfr.failurePoints { - if i < mfr.failureCount { - continue // Already failed here. - } - if mfr.pos == int64(point) { - mfr.failureCount++ - return 0, ErrFlakyFailure - } - } - - // Calculate how much to read. - remaining := mfr.length - mfr.pos - toRead := int64(len(p)) - if toRead > remaining { - toRead = remaining - } - - // Check if we would cross a failure point. - for i, point := range mfr.failurePoints { - if i < mfr.failureCount { - continue // Skip already used failure points. - } - if mfr.pos < int64(point) && mfr.pos+toRead > int64(point) { - toRead = int64(point) - mfr.pos - break - } - } - - // Copy data. - if toRead == 0 { - return 0, nil - } - - buf := p[:toRead] - n, err := mfr.data.ReadAt(buf, mfr.pos) - mfr.pos += int64(n) - - if err != nil && err != io.EOF { - return n, err - } - - if mfr.pos >= mfr.length { - return n, io.EOF - } - - if err == io.EOF { - return n, io.EOF - } - - return n, nil -} - -// Close implements io.Closer. -func (mfr *MultiFailReader) Close() error { - mfr.mu.Lock() - defer mfr.mu.Unlock() - mfr.closed = true - return nil -} - -// Reset resets the reader to the beginning and clears failure state. -func (mfr *MultiFailReader) Reset() { - mfr.mu.Lock() - defer mfr.mu.Unlock() - mfr.pos = 0 - mfr.failureCount = 0 - mfr.closed = false -} diff --git a/pkg/distribution/transport/internal/testing/helpers.go b/pkg/distribution/transport/internal/testing/helpers.go deleted file mode 100644 index e3aa2dd87..000000000 --- a/pkg/distribution/transport/internal/testing/helpers.go +++ /dev/null @@ -1,198 +0,0 @@ -package testing - -import ( - "bytes" - "crypto/rand" - "fmt" - "io" - "testing" -) - -// GenerateTestData generates deterministic test data of the specified size. -func GenerateTestData(size int) []byte { - data := make([]byte, size) - for i := range data { - data[i] = byte(i % 256) - } - return data -} - -// GenerateRandomData generates random test data of the specified size. -func GenerateRandomData(size int) []byte { - data := make([]byte, size) - if _, err := rand.Read(data); err != nil { - panic(fmt.Sprintf("failed to generate random data: %v", err)) - } - return data -} - -// AssertDataEquals checks if two byte slices are equal. -func AssertDataEquals(t *testing.T, got, want []byte) { - t.Helper() - if !bytes.Equal(got, want) { - t.Errorf("data mismatch: got %d bytes, want %d bytes", len(got), len(want)) - if len(got) == len(want) { - // Find first difference. - for i := range got { - if got[i] != want[i] { - t.Errorf( - "first difference at byte %d: got %02x, want %02x", - i, got[i], want[i]) - break - } - } - } - } -} - -// ReadAll reads all data from a reader and returns it. -func ReadAll(t *testing.T, r io.Reader) []byte { - t.Helper() - data, err := io.ReadAll(r) - if err != nil { - t.Fatalf("failed to read all data: %v", err) - } - return data -} - -// ReadAllWithError reads all data from a reader and returns both data and -// error. -func ReadAllWithError(r io.Reader) ([]byte, error) { - return io.ReadAll(r) -} - -// MustRead reads exactly n bytes from a reader or fails the test. -func MustRead(t *testing.T, r io.Reader, n int) []byte { - t.Helper() - buf := make([]byte, n) - nn, err := io.ReadFull(r, buf) - if err != nil { - t.Fatalf( - "failed to read %d bytes: got %d, err: %v", n, nn, err) - } - return buf -} - -// AssertHeaderEquals checks if a header has the expected value. -func AssertHeaderEquals(t *testing.T, headers map[string][]string, key, want string) { - t.Helper() - values, ok := headers[key] - if !ok || len(values) == 0 { - if want != "" { - t.Errorf("header %q not found, want %q", key, want) - } - return - } - if values[0] != want { - t.Errorf("header %q = %q, want %q", key, values[0], want) - } -} - -// AssertHeaderPresent checks if a header is present. -func AssertHeaderPresent(t *testing.T, headers map[string][]string, key string) { - t.Helper() - if _, ok := headers[key]; !ok { - t.Errorf("header %q not found", key) - } -} - -// AssertHeaderAbsent checks if a header is absent. -func AssertHeaderAbsent(t *testing.T, headers map[string][]string, key string) { - t.Helper() - if _, ok := headers[key]; ok { - t.Errorf("header %q found, want absent", key) - } -} - -// ChunkData splits data into n chunks of approximately equal size. -func ChunkData(data []byte, n int) [][]byte { - if n <= 0 { - return nil - } - if n == 1 { - return [][]byte{data} - } - - chunkSize := len(data) / n - remainder := len(data) % n - - chunks := make([][]byte, n) - offset := 0 - - for i := 0; i < n; i++ { - size := chunkSize - if i == n-1 { - size += remainder - } - chunks[i] = data[offset : offset+size] - offset += size - } - - return chunks -} - -// ConcatChunks concatenates multiple byte slices into one. -func ConcatChunks(chunks [][]byte) []byte { - var total int - for _, chunk := range chunks { - total += len(chunk) - } - - result := make([]byte, 0, total) - for _, chunk := range chunks { - result = append(result, chunk...) - } - - return result -} - -// ByteRange represents a byte range. -type ByteRange struct { - // Start is the starting byte position (inclusive). - Start int64 - // End is the ending byte position (inclusive). - End int64 -} - -// CalculateByteRanges calculates byte ranges for splitting a file of given -// size into n parts. -func CalculateByteRanges(totalSize int64, n int) []ByteRange { - if n <= 0 || totalSize <= 0 { - return nil - } - - ranges := make([]ByteRange, n) - chunkSize := totalSize / int64(n) - remainder := totalSize % int64(n) - - var start int64 - for i := 0; i < n; i++ { - size := chunkSize - if i == n-1 { - size += remainder - } - ranges[i] = ByteRange{ - Start: start, - End: start + size - 1, - } - start += size - } - - return ranges -} - -// AssertNoError fails the test if err is not nil. -func AssertNoError(t *testing.T, err error, msg string) { - t.Helper() - if err != nil { - t.Fatalf("%s: %v", msg, err) - } -} - -// AssertError fails the test if err is nil. -func AssertError(t *testing.T, err error, msg string) { - t.Helper() - if err == nil { - t.Fatalf("%s: expected error, got nil", msg) - } -} diff --git a/pkg/distribution/transport/internal/testing/testing_test.go b/pkg/distribution/transport/internal/testing/testing_test.go deleted file mode 100644 index c68db4974..000000000 --- a/pkg/distribution/transport/internal/testing/testing_test.go +++ /dev/null @@ -1,137 +0,0 @@ -package testing - -import ( - "bytes" - "io" - "net/http" - "testing" -) - -// TestFakeTransport_Basic tests the basic functionality of FakeTransport. -func TestFakeTransport_Basic(t *testing.T) { - ft := NewFakeTransport() - - // Add a simple resource. - data := []byte("Hello, World!") - ft.AddSimple("http://example.com/test", bytes.NewReader(data), int64(len(data)), true) - - // Create a request. - req, err := http.NewRequest("GET", "http://example.com/test", nil) - if err != nil { - t.Fatalf("Failed to create request: %v", err) - } - - // Perform the request. - resp, err := ft.RoundTrip(req) - if err != nil { - t.Fatalf("RoundTrip failed: %v", err) - } - defer resp.Body.Close() - - // Read the response. - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("Failed to read response: %v", err) - } - - // Check the data. - if !bytes.Equal(got, data) { - t.Errorf("Response data mismatch: got %q, want %q", got, data) - } -} - -// TestFlakyReader_FailsAfterN tests that FlakyReader fails after reading -// a specified number of bytes. -func TestFlakyReader_FailsAfterN(t *testing.T) { - data := []byte("Hello, World!") - fr := NewFlakyReader(bytes.NewReader(data), int64(len(data)), 5) - - // Read first 5 bytes. - buf := make([]byte, 5) - n, err := fr.Read(buf) - if err != nil { - t.Fatalf("First read failed: %v", err) - } - if n != 5 { - t.Fatalf("Expected to read 5 bytes, got %d", n) - } - if string(buf) != "Hello" { - t.Errorf("Expected 'Hello', got %q", string(buf)) - } - - // Next read should fail. - _, err = fr.Read(buf) - if err != ErrFlakyFailure { - t.Errorf("Expected ErrFlakyFailure, got %v", err) - } -} - -// TestHelpers_GenerateTestData tests the deterministic test data generator. -func TestHelpers_GenerateTestData(t *testing.T) { - data := GenerateTestData(256) - - if len(data) != 256 { - t.Errorf("Expected 256 bytes, got %d", len(data)) - } - - // Check deterministic pattern. - for i := 0; i < 256; i++ { - if data[i] != byte(i%256) { - t.Errorf("Byte %d: expected %d, got %d", i, i%256, data[i]) - } - } -} - -// TestHelpers_ChunkData tests the data chunking functionality. -func TestHelpers_ChunkData(t *testing.T) { - data := GenerateTestData(100) - chunks := ChunkData(data, 4) - - if len(chunks) != 4 { - t.Fatalf("Expected 4 chunks, got %d", len(chunks)) - } - - // First 3 chunks should be 25 bytes each. - for i := 0; i < 3; i++ { - if len(chunks[i]) != 25 { - t.Errorf("Chunk %d: expected 25 bytes, got %d", i, len(chunks[i])) - } - } - - // Last chunk should be 25 + remainder. - if len(chunks[3]) != 25 { - t.Errorf("Last chunk: expected 25 bytes, got %d", len(chunks[3])) - } - - // Concatenate and verify. - combined := ConcatChunks(chunks) - if !bytes.Equal(combined, data) { - t.Error("Concatenated chunks don't match original data") - } -} - -// TestHelpers_ByteRanges tests byte range calculation for parallel -// downloads. -func TestHelpers_ByteRanges(t *testing.T) { - ranges := CalculateByteRanges(100, 4) - - if len(ranges) != 4 { - t.Fatalf("Expected 4 ranges, got %d", len(ranges)) - } - - expectedRanges := []ByteRange{ - {Start: 0, End: 24}, - {Start: 25, End: 49}, - {Start: 50, End: 74}, - {Start: 75, End: 99}, - } - - for i, r := range ranges { - if r.Start != expectedRanges[i].Start || - r.End != expectedRanges[i].End { - t.Errorf( - "Range %d: got %d-%d, want %d-%d", - i, r.Start, r.End, expectedRanges[i].Start, expectedRanges[i].End) - } - } -} diff --git a/pkg/distribution/transport/parallel/large_file_test.go b/pkg/distribution/transport/parallel/large_file_test.go deleted file mode 100644 index 56d6dc5ec..000000000 --- a/pkg/distribution/transport/parallel/large_file_test.go +++ /dev/null @@ -1,340 +0,0 @@ -package parallel - -import ( - "bytes" - "crypto/sha256" - "fmt" - "hash" - "io" - "net/http" - "os" - "strconv" - "testing" - - testutil "github.com/docker/model-runner/pkg/distribution/transport/internal/testing" -) - -// deterministicDataGenerator generates deterministic data based on position. -// This allows us to generate GB-sized data streams without storing them in -// memory. -type deterministicDataGenerator struct { - position int64 - size int64 -} - -// newDeterministicDataGenerator creates a new deterministic data generator -// with the specified size. -func newDeterministicDataGenerator(size int64) *deterministicDataGenerator { - return &deterministicDataGenerator{ - position: 0, - size: size, - } -} - -// Read implements io.Reader for deterministicDataGenerator. -func (g *deterministicDataGenerator) Read(p []byte) (int, error) { - if g.position >= g.size { - return 0, io.EOF - } - - // Calculate how much we can read. - remaining := g.size - g.position - toRead := int64(len(p)) - if toRead > remaining { - toRead = remaining - } - - // Generate deterministic data based on position. - for i := int64(0); i < toRead; i++ { - pos := g.position + i - // Use a simple but deterministic pattern: position mod 256. - // XOR with some constants to make it more interesting. - p[i] = byte((pos ^ (pos >> 8) ^ (pos >> 16)) % 256) - } - - g.position += toRead - return int(toRead), nil -} - -// ReadAt implements io.ReaderAt for deterministicDataGenerator. -func (g *deterministicDataGenerator) ReadAt(p []byte, off int64) (int, error) { - if off >= g.size { - return 0, io.EOF - } - - remaining := g.size - off - toRead := int64(len(p)) - if toRead > remaining { - toRead = remaining - } - - for i := int64(0); i < toRead; i++ { - pos := off + i - p[i] = byte((pos ^ (pos >> 8) ^ (pos >> 16)) % 256) - } - - if toRead < int64(len(p)) { - return int(toRead), io.EOF - } - - return int(toRead), nil -} - -// addLargeFileResource registers a deterministic large file with the fake -// transport. The resource shares behavior with the previous httptest server -// implementation, including range support and metadata headers. -func addLargeFileResource(ft *testutil.FakeTransport, url string, size int64) { - ft.Add(url, &testutil.FakeResource{ - Data: newDeterministicDataGenerator(size), - Length: size, - SupportsRange: true, - ETag: fmt.Sprintf(`"test-file-%d"`, size), - ContentType: "application/octet-stream", - }) -} - -// hashingReader wraps an io.Reader and computes SHA-256 while reading. -type hashingReader struct { - reader io.Reader - hasher hash.Hash - bytesRead int64 -} - -// newHashingReader creates a new hashing reader that computes SHA-256 -// hash while reading from the provided reader. -func newHashingReader(r io.Reader) *hashingReader { - return &hashingReader{ - reader: r, - hasher: sha256.New(), - bytesRead: 0, - } -} - -// Read implements io.Reader for hashingReader. -func (hr *hashingReader) Read(p []byte) (int, error) { - n, err := hr.reader.Read(p) - if n > 0 { - hr.hasher.Write(p[:n]) - hr.bytesRead += int64(n) - } - return n, err -} - -// Sum returns the SHA-256 hash of all data read so far. -func (hr *hashingReader) Sum() []byte { - return hr.hasher.Sum(nil) -} - -// BytesRead returns the total number of bytes read. -func (hr *hashingReader) BytesRead() int64 { - return hr.bytesRead -} - -// computeExpectedHash computes the expected SHA-256 hash for a file of -// given size. -func computeExpectedHash(size int64) []byte { - hasher := sha256.New() - gen := newDeterministicDataGenerator(size) - io.Copy(hasher, gen) - return hasher.Sum(nil) -} - -// getTestFileSize returns an appropriate file size for testing based on -// whether we're running under the race detector or other conditions. -// The returned size ensures parallel downloads will still occur (larger than -// typical minimum chunk sizes of 1-10MB). -func getTestFileSize(baseSize int64) int64 { - // Allow environment override for custom testing. - if sizeStr := os.Getenv("TEST_FILE_SIZE"); sizeStr != "" { - if size, err := strconv.ParseInt(sizeStr, 10, 64); err == nil { - return size - } - } - - // Check for race detector or coverage mode. - if testing.CoverMode() != "" || raceEnabled { - // Use ~200MB for "large" (1GB) and ~400MB for "very large" (4GB). - // This is large enough to trigger parallel downloads with typical - // chunk sizes of 4-8MB, but small enough to run quickly. - if baseSize >= 4*1024*1024*1024 { - return 400 * 1024 * 1024 // 400MB instead of 4GB. - } - return 200 * 1024 * 1024 // 200MB instead of 1GB. - } - - return baseSize -} - -// TestLargeFile_ParallelVsSequential tests parallel vs sequential -// download of a large file. The actual file size adapts based on whether -// the race detector is enabled (200MB in race mode, 1GB normally). -func TestLargeFile_ParallelVsSequential(t *testing.T) { - if testing.Short() { - t.Skip("Skipping large file test in short mode") - } - - // Test with large file (1GB normally, 200MB in race/coverage mode). - baseSize := int64(1024 * 1024 * 1024) // 1 GB base size. - size := getTestFileSize(baseSize) - - if size != baseSize { - t.Logf("Running with reduced file size: %d MB (race detector or coverage mode detected)", - size/(1024*1024)) - } - - url := fmt.Sprintf("https://parallel.example/data/%d", size) - - // Prepare fake transport resource metadata once for logging consistency. - resourceETag := fmt.Sprintf(`"test-file-%d"`, size) - - // Compute expected hash. - expectedHash := computeExpectedHash(size) - - t.Run("Sequential", func(t *testing.T) { - transport := testutil.NewFakeTransport() - addLargeFileResource(transport, url, size) - client := &http.Client{Transport: transport} - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("Failed to get %s: %v", url, err) - } - defer resp.Body.Close() - - if resp.Header.Get("ETag") != resourceETag { - t.Errorf("Expected ETag %s, got %s", resourceETag, resp.Header.Get("ETag")) - } - - if resp.ContentLength != size { - t.Errorf("Expected Content-Length %d, got %d", - size, resp.ContentLength) - } - - hashingReader := newHashingReader(resp.Body) - _, err = io.Copy(io.Discard, hashingReader) - if err != nil { - t.Fatalf("Failed to read response body: %v", err) - } - - if hashingReader.BytesRead() != size { - t.Errorf("Expected to read %d bytes, actually read %d bytes", - size, hashingReader.BytesRead()) - } - - actualHash := hashingReader.Sum() - if !bytes.Equal(expectedHash, actualHash) { - t.Errorf("Hash mismatch.\nExpected: %x\nActual: %x", - expectedHash, actualHash) - } - }) - - t.Run("Parallel", func(t *testing.T) { - baseTransport := testutil.NewFakeTransport() - addLargeFileResource(baseTransport, url, size) - transport := New( - baseTransport, - WithMaxConcurrentPerHost(map[string]uint{"": 0}), - WithMinChunkSize(4*1024*1024), // 4MB chunks. - WithMaxConcurrentPerRequest(8), - ) - client := &http.Client{Transport: transport} - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("Failed to get %s: %v", url, err) - } - defer resp.Body.Close() - - if resp.Header.Get("ETag") != resourceETag { - t.Errorf("Expected ETag %s, got %s", resourceETag, resp.Header.Get("ETag")) - } - - if resp.ContentLength != size { - t.Errorf("Expected Content-Length %d, got %d", - size, resp.ContentLength) - } - - hashingReader := newHashingReader(resp.Body) - _, err = io.Copy(io.Discard, hashingReader) - if err != nil { - t.Fatalf("Failed to read response body: %v", err) - } - - if hashingReader.BytesRead() != size { - t.Errorf("Expected to read %d bytes, actually read %d bytes", - size, hashingReader.BytesRead()) - } - - actualHash := hashingReader.Sum() - if !bytes.Equal(expectedHash, actualHash) { - t.Errorf("Hash mismatch.\nExpected: %x\nActual: %x", - expectedHash, actualHash) - } - }) -} - -// TestVeryLargeFile_ParallelDownload tests parallel download of a very large -// file. The actual file size adapts based on whether the race detector is -// enabled (400MB in race mode, 4GB normally). -func TestVeryLargeFile_ParallelDownload(t *testing.T) { - if testing.Short() { - t.Skip("Skipping very large file test in short mode") - } - - // Test with very large file (4GB normally, 400MB in race/coverage mode). - baseSize := int64(4 * 1024 * 1024 * 1024) // 4 GB base size. - size := getTestFileSize(baseSize) - - if size != baseSize { - t.Logf("Running with reduced file size: %d MB (race detector or coverage mode detected)", - size/(1024*1024)) - } - - url := fmt.Sprintf("https://parallel.example/very-large/%d", size) - - baseTransport := testutil.NewFakeTransport() - addLargeFileResource(baseTransport, url, size) - - // Only test parallel for very large files due to time constraints. - transport := New( - baseTransport, - WithMaxConcurrentPerHost(map[string]uint{"": 0}), - WithMinChunkSize(8*1024*1024), // 8MB chunks. - WithMaxConcurrentPerRequest(16), - ) - client := &http.Client{Transport: transport} - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("Failed to get %s: %v", url, err) - } - defer resp.Body.Close() - - if resp.ContentLength != size { - t.Errorf("Expected Content-Length %d, got %d", - size, resp.ContentLength) - } - - // For 4GB, let's just verify we can read the correct number of bytes. - // Computing the full hash would take too long. - bytesRead := int64(0) - buf := make([]byte, 64*1024) // 64KB buffer. - for { - n, err := resp.Body.Read(buf) - bytesRead += int64(n) - if err == io.EOF { - break - } - if err != nil { - t.Fatalf("Failed to read response body: %v", err) - } - } - - if bytesRead != size { - t.Errorf("Expected to read %d bytes, actually read %d bytes", - size, bytesRead) - } - - t.Logf("Successfully read %d bytes (4GB) from parallel download", - bytesRead) -} diff --git a/pkg/distribution/transport/parallel/race_off.go b/pkg/distribution/transport/parallel/race_off.go deleted file mode 100644 index cafa39d74..000000000 --- a/pkg/distribution/transport/parallel/race_off.go +++ /dev/null @@ -1,8 +0,0 @@ -//go:build !race -// +build !race - -package parallel - -// raceEnabled is a compile-time constant indicating whether the race -// detector is enabled. -const raceEnabled = false diff --git a/pkg/distribution/transport/parallel/race_on.go b/pkg/distribution/transport/parallel/race_on.go deleted file mode 100644 index aa1bf7743..000000000 --- a/pkg/distribution/transport/parallel/race_on.go +++ /dev/null @@ -1,8 +0,0 @@ -//go:build race -// +build race - -package parallel - -// raceEnabled is a compile-time constant indicating whether the race -// detector is enabled. -const raceEnabled = true diff --git a/pkg/distribution/transport/parallel/transport.go b/pkg/distribution/transport/parallel/transport.go deleted file mode 100644 index 400d4bfb0..000000000 --- a/pkg/distribution/transport/parallel/transport.go +++ /dev/null @@ -1,700 +0,0 @@ -// Package parallel provides an http.RoundTripper that transparently -// parallelizes GET requests using concurrent byte-range requests for better -// throughput. -// -// ───────────────────────────── How it works ───────────────────────────── -// - For non-GET requests, the transport passes them through unmodified to -// the underlying transport. -// - For GET requests, it first performs a HEAD request to check if the -// server supports byte ranges and to determine the total response size. -// - If the HEAD request indicates range support and known size, the -// transport generates multiple concurrent GET requests with specific -// byte-range headers. -// - Subranges are written to temporary files and stitched together in a -// custom Response.Body that's transparent to the caller. -// - Per-host and per-request concurrency limits are enforced using -// semaphores. -// -// ───────────────────────────── Notes & caveats ─────────────────────────── -// - Only works with servers that support "Accept-Ranges: bytes" and -// provide Content-Length or Content-Range headers with total size -// information. -// - Content-Encoding (compression) is not compatible with byte ranges, -// so compressed responses fall back to single-threaded behavior. -// - Temporary files are created for each subrange and cleaned up -// automatically. -// - The transport respects per-host concurrency limits to avoid -// overwhelming servers. -package parallel - -import ( - "context" - "errors" - "fmt" - "io" - "net" - "net/http" - "os" - "strconv" - "strings" - "sync" - - "github.com/docker/model-runner/pkg/distribution/transport/internal/bufferfile" - "github.com/docker/model-runner/pkg/distribution/transport/internal/common" -) - -// Option configures a ParallelTransport. -type Option func(*ParallelTransport) - -// WithMaxConcurrentPerHost sets the maximum concurrent requests per -// hostname. Default concurrency limits are applied if not specified. -func WithMaxConcurrentPerHost(limits map[string]uint) Option { - return func(pt *ParallelTransport) { - pt.maxConcurrentPerHost = make(map[string]uint, len(limits)) - for host, limit := range limits { - pt.maxConcurrentPerHost[host] = limit - } - } -} - -// WithMaxConcurrentPerRequest sets the maximum concurrent subrange -// requests for a single request. Default: 4. -func WithMaxConcurrentPerRequest(n uint) Option { - return func(pt *ParallelTransport) { pt.maxConcurrentPerRequest = n } -} - -// WithMinChunkSize sets the minimum size in bytes for each subrange chunk. -// Requests smaller than this will not be parallelized. Default: 1MB. -func WithMinChunkSize(size int64) Option { - return func(pt *ParallelTransport) { pt.minChunkSize = size } -} - -// WithTempDir sets the directory for temporary files. If empty, -// os.TempDir() is used. -func WithTempDir(dir string) Option { - return func(pt *ParallelTransport) { pt.tempDir = dir } -} - -// ParallelTransport wraps another http.RoundTripper and parallelizes GET -// requests using concurrent byte-range requests when possible. -type ParallelTransport struct { - // base is the underlying RoundTripper actually used to send requests. - base http.RoundTripper - // maxConcurrentPerHost maps canonicalized hostname to maximum - // concurrent requests. A value of 0 means unlimited. The "" entry is - // the default for unspecified hosts. - maxConcurrentPerHost map[string]uint - // maxConcurrentPerRequest is the maximum number of concurrent - // subrange requests for a single request. - maxConcurrentPerRequest uint - // minChunkSize is the minimum size in bytes for parallelization to be - // worthwhile. - minChunkSize int64 - // tempDir is the directory for temporary files. - tempDir string - // semaphores tracks per-host concurrency limits. - semaphores map[string]*semaphore - // semMu protects the semaphores map. - semMu sync.RWMutex -} - -// New returns a ParallelTransport wrapping base. If base is nil, -// http.DefaultTransport is used. Options configure parallelization behavior. -func New(base http.RoundTripper, opts ...Option) *ParallelTransport { - if base == nil { - base = http.DefaultTransport - } - pt := &ParallelTransport{ - base: base, - maxConcurrentPerHost: map[string]uint{"": 4}, // default 4 per host. - maxConcurrentPerRequest: 4, - minChunkSize: 1024 * 1024, // 1MB. - tempDir: os.TempDir(), - semaphores: make(map[string]*semaphore), - } - for _, o := range opts { - o(pt) - } - return pt -} - -// RoundTrip implements http.RoundTripper. It parallelizes GET requests -// when possible, otherwise passes requests through to the underlying -// transport. -func (pt *ParallelTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // Non-GET requests pass through unmodified. - if req.Method != http.MethodGet { - return pt.base.RoundTrip(req) - } - // Respect caller-provided Range requests. We do not parallelize when the - // request already specifies a byte range, to preserve exact semantics. - if strings.TrimSpace(req.Header.Get("Range")) != "" { - return pt.base.RoundTrip(req) - } - - // Check if parallelization is possible and worthwhile. - canParallelize, pInfo, err := pt.checkParallelizable(req) - if err != nil { - return nil, err - } - if !canParallelize || - pInfo.totalSize < pt.minChunkSize*int64(pt.maxConcurrentPerRequest) { - // Fall back to single request. - return pt.base.RoundTrip(req) - } - - // Perform parallel download. - return pt.parallelDownload(req, pInfo) -} - -// parallelInfo holds information needed for parallel downloads. -type parallelInfo struct { - // totalSize is the total size of the resource in bytes. - totalSize int64 - // etag is the strong ETag validator from the HEAD response, used for - // If-Range. - etag string - // lastModified is the Last-Modified header value, used as fallback - // validator for If-Range. - lastModified string - // header is a clone of the server headers (from HEAD) used to seed the - // final response headers without an extra GET probe. - header http.Header - // proto/protoMajor/protoMinor reflect the server protocol from the HEAD - // response for constructing the final response. - proto string - protoMajor int - protoMinor int -} - -// checkParallelizable performs a HEAD request to determine if the resource -// supports byte ranges and returns the parallel info if available. -func (pt *ParallelTransport) checkParallelizable(req *http.Request) (bool, *parallelInfo, error) { - // Create HEAD request. - headReq := req.Clone(req.Context()) - headReq.Method = http.MethodHead - headReq.Body = nil - headReq.ContentLength = 0 - // Clone and sanitize headers to avoid conditional responses and implicit - // compression that could skew metadata. - headReq.Header = req.Header.Clone() - common.ScrubConditionalHeaders(headReq.Header) - headReq.Header.Set("Accept-Encoding", "identity") - - // Perform HEAD request. - headResp, err := pt.base.RoundTrip(headReq) - if err != nil { - return false, nil, err - } - defer headResp.Body.Close() - - // Only proceed on 200 OK or 206 Partial Content. Anything else (e.g., - // 304 Not Modified due to missed scrub, redirects, etc.) is treated as - // non-parallelizable for safety. - if headResp.StatusCode != http.StatusOK && - headResp.StatusCode != http.StatusPartialContent { - return false, nil, nil - } - - // Check if range requests are supported. - if !common.SupportsRange(headResp.Header) { - return false, nil, nil - } - - // Check for compression which would interfere with byte ranges. - if headResp.Header.Get("Content-Encoding") != "" { - return false, nil, nil - } - - // Get total content length. - totalSize := headResp.ContentLength - if totalSize <= 0 { - // Try to parse from Content-Range if present (206 response). - if headResp.StatusCode == http.StatusPartialContent { - if _, _, total, ok := common.ParseContentRange( - headResp.Header.Get("Content-Range")); ok && total > 0 { - totalSize = total - } else { - return false, nil, nil - } - } else { - return false, nil, nil - } - } - - if totalSize <= 0 { - return false, nil, nil - } - - // Capture validators for If-Range to ensure consistency across parallel - // requests. - info := ¶llelInfo{ - totalSize: totalSize, - header: headResp.Header.Clone(), - proto: headResp.Proto, - protoMajor: headResp.ProtoMajor, - protoMinor: headResp.ProtoMinor, - } - - if et := headResp.Header.Get("ETag"); et != "" && !common.IsWeakETag(et) { - info.etag = et - } else if lm := headResp.Header.Get("Last-Modified"); lm != "" { - info.lastModified = lm - } - - return true, info, nil -} - -// parallelDownload performs a parallel download by splitting the request -// into multiple concurrent byte-range requests. -func (pt *ParallelTransport) parallelDownload(req *http.Request, pInfo *parallelInfo) (*http.Response, error) { - totalSize := pInfo.totalSize - - // Calculate chunk size and number of chunks. - numChunks := int(pt.maxConcurrentPerRequest) - if totalSize < int64(numChunks)*pt.minChunkSize { - numChunks = int(totalSize / pt.minChunkSize) - if numChunks < 1 { - numChunks = 1 - } - } - - chunkSize := totalSize / int64(numChunks) - remainder := totalSize % int64(numChunks) - - // Get or create semaphore for this host. - sem := pt.getSemaphore(req.URL.Host) - - // Create chunks and temporary files. - chunks := make([]*chunk, numChunks) - var start int64 - for i := 0; i < numChunks; i++ { - size := chunkSize - if i == numChunks-1 { - size += remainder // Last chunk gets the remainder. - } - end := start + size - 1 - - fifo, err := bufferfile.NewFIFOInDir(pt.tempDir) - if err != nil { - // Clean up any created FIFOs. - for j := 0; j < i; j++ { - chunks[j].cleanup() - } - return nil, fmt.Errorf("parallel: failed to create FIFO: %w", err) - } - - chunk := &chunk{ - start: start, - end: end, - fifo: fifo, - state: chunkNotStarted, - } - chunks[i] = chunk - start = end + 1 - } - - // Start downloading chunks concurrently (don't wait for completion). - for i, ch := range chunks { - go func(i int, ch *chunk) { - ch.setSimpleState(chunkDownloading, nil) - if err := pt.downloadChunk(req, ch, sem, pInfo); err != nil { - ch.setSimpleState(chunkFailed, fmt.Errorf("chunk %d: %w", i, err)) - ch.fifo.Close() // Close FIFO on error to interrupt readers. - } else { - ch.setSimpleState(chunkCompleted, nil) - // Close write side to signal no more writes (EOF when all data - // read). - ch.fifo.CloseWrite() - } - }(i, ch) - } - - // Create stitched response. - body := &stitchedBody{ - chunks: chunks, - totalSize: totalSize, - ctx: req.Context(), - } - - // Create response using the header response as template. - resp := &http.Response{ - Status: "200 OK", - StatusCode: http.StatusOK, - Proto: pInfo.proto, - ProtoMajor: pInfo.protoMajor, - ProtoMinor: pInfo.protoMinor, - Header: pInfo.header.Clone(), - Body: body, - ContentLength: totalSize, - Request: req, - } - - // Override headers that we control. - resp.Header.Set("Content-Length", strconv.FormatInt(totalSize, 10)) - resp.Header.Del("Content-Range") // Remove any partial content headers. - - return resp, nil -} - -// downloadChunk downloads a single chunk using a byte-range request. -func (pt *ParallelTransport) downloadChunk(origReq *http.Request, chunk *chunk, sem *semaphore, pInfo *parallelInfo) error { - // Acquire semaphore. - if err := sem.acquire(origReq.Context()); err != nil { - return err - } - defer sem.release() - - // Create range request. - rangeReq := origReq.Clone(origReq.Context()) - rangeReq.Header = origReq.Header.Clone() - rangeReq.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", chunk.start, chunk.end)) - - // Prevent compression which would interfere with byte ranges. - rangeReq.Header.Set("Accept-Encoding", "identity") - - // Add If-Range header for consistency validation. - if pInfo.etag != "" { - rangeReq.Header.Set("If-Range", pInfo.etag) - } else if pInfo.lastModified != "" { - rangeReq.Header.Set("If-Range", pInfo.lastModified) - } - - // Remove conditional headers that could conflict with If-Range. - common.ScrubConditionalHeaders(rangeReq.Header) - - // Perform request. - resp, err := pt.base.RoundTrip(rangeReq) - if err != nil { - return err - } - defer resp.Body.Close() - - // Check for If-Range validation failure (server returns 200 instead of 206). - if resp.StatusCode == http.StatusOK { - return fmt.Errorf( - "server returned 200 to range request, resource may have changed (If-Range validation failed)") - } - - // Verify we got a partial content response. - if resp.StatusCode != http.StatusPartialContent { - return fmt.Errorf( - "expected 206 Partial Content, got %d", resp.StatusCode) - } - - // Verify the range matches what we requested. - if start, end, _, ok := common.ParseContentRange(resp.Header.Get("Content-Range")); ok { - if start != chunk.start || end != chunk.end { - return fmt.Errorf( - "server returned range %d-%d, requested %d-%d", - start, end, chunk.start, chunk.end) - } - } - - // Copy response body to FIFO and verify full chunk length is received. - buf := make([]byte, 32*1024) // 32KB buffer. - var copied int64 - for { - n, err := resp.Body.Read(buf) - if n > 0 { - // Write to FIFO - if _, writeErr := chunk.fifo.Write(buf[:n]); writeErr != nil { - return fmt.Errorf( - "failed to write chunk data: %w", writeErr) - } - copied += int64(n) - } - - if err == io.EOF { - // Validate that we received the complete range we requested. - expected := (chunk.end - chunk.start + 1) - if copied != expected { - return fmt.Errorf( - "short read for chunk: got %d, want %d", copied, expected) - } - break - } - if err != nil { - return fmt.Errorf( - "failed to read chunk data: %w", err) - } - } - - return nil -} - -// getSemaphore returns the semaphore for the given host, creating it if needed. -func (pt *ParallelTransport) getSemaphore(host string) *semaphore { - canonicalHost := canonicalizeHost(host) - - pt.semMu.RLock() - if sem, exists := pt.semaphores[canonicalHost]; exists { - pt.semMu.RUnlock() - return sem - } - pt.semMu.RUnlock() - - pt.semMu.Lock() - defer pt.semMu.Unlock() - - // Double-check after acquiring write lock. - if sem, exists := pt.semaphores[canonicalHost]; exists { - return sem - } - - // Determine limit for this host. - limit := pt.maxConcurrentPerHost[canonicalHost] - if limit == 0 { - // Check default. - if defaultLimit, exists := pt.maxConcurrentPerHost[""]; exists { - limit = defaultLimit - } - } - - sem := newSemaphore(int(limit)) - pt.semaphores[canonicalHost] = sem - return sem -} - -// canonicalizeHost returns a canonical form of the hostname for semaphore lookup. -func canonicalizeHost(host string) string { - // Remove port if present. - if h, _, err := net.SplitHostPort(host); err == nil { - host = h - } - return strings.ToLower(host) -} - -// chunkState represents the current state of a chunk download. -type chunkState int - -const ( - chunkNotStarted chunkState = iota - chunkDownloading - chunkCompleted - chunkFailed -) - -// chunk represents a byte range chunk being downloaded to a temporary file. -type chunk struct { - // start is the inclusive starting byte offset for this chunk. - start int64 - // end is the inclusive ending byte offset for this chunk. - end int64 - // fifo is the FIFO buffer where this chunk's data is stored. - fifo *bufferfile.FIFO - // state tracks the current download state of this chunk. - state chunkState - // err holds any error that occurred during download. - err error - // mu protects state and err fields. - mu sync.Mutex -} - -// close closes the FIFO handle. -func (c *chunk) close() error { - if c.fifo == nil { - return nil - } - return c.fifo.Close() -} - -// cleanup closes and removes the FIFO. -func (c *chunk) cleanup() { - if c.fifo != nil { - // Only close the FIFO. Do not nil the pointer to avoid races with - // in-flight writer goroutines checking or using this handle. - c.fifo.Close() - } -} - -// setSimpleState updates the chunk state. No condition signaling needed since FIFO handles coordination. -func (c *chunk) setSimpleState(state chunkState, err error) { - c.mu.Lock() - defer c.mu.Unlock() - c.state = state - c.err = err -} - -// readAvailable reads up to len(p) bytes from the chunk, blocking until data is available. -// Returns the number of bytes read and any error. Returns io.EOF when chunk is complete -// and all data has been read. -func (c *chunk) readAvailable(p []byte, ctx context.Context) (int, error) { - // Check for context cancellation - select { - case <-ctx.Done(): - return 0, ctx.Err() - default: - } - - // Check if chunk failed first - c.mu.Lock() - if c.state == chunkFailed && c.err != nil { - err := c.err - c.mu.Unlock() - return 0, err - } - c.mu.Unlock() - - // Try to read from FIFO - n, err := c.fifo.Read(p) - - // If we got data, return it - if n > 0 { - return n, nil - } - - // If FIFO is closed or returned EOF, check chunk state - if err == io.EOF { - // If chunk is completed and FIFO EOF, we're truly done - c.mu.Lock() - if c.state == chunkCompleted { - c.mu.Unlock() - return 0, io.EOF - } - c.mu.Unlock() - // If chunk not completed but FIFO EOF, there might be an error - // Fall through to return the EOF - } - - return n, err -} - -// stitchedBody implements io.ReadCloser by reading from multiple chunk files in sequence. -type stitchedBody struct { - // chunks is the ordered list of chunk files to read from. - chunks []*chunk - // totalSize is the expected total number of bytes across all chunks. - totalSize int64 - // currentIdx is the index of the chunk currently being read from. - currentIdx int - // bytesRead is the total number of bytes delivered to callers so far. - bytesRead int64 - // closed indicates whether Close() has been called. - closed bool - // ctx is the request context for cancellation. - ctx context.Context - // mu protects all fields from concurrent access. - mu sync.Mutex -} - -// Read reads data by stitching together chunks in order. -func (sb *stitchedBody) Read(p []byte) (int, error) { - sb.mu.Lock() - defer sb.mu.Unlock() - - if sb.closed { - return 0, errors.New("stitchedBody: read from closed body") - } - - if sb.currentIdx >= len(sb.chunks) { - return 0, io.EOF - } - - totalRead := 0 - for len(p) > 0 && sb.currentIdx < len(sb.chunks) { - ch := sb.chunks[sb.currentIdx] - - // Unlock while reading from chunk (chunk handles its own locking) - sb.mu.Unlock() - - // Read available data from current chunk - n, err := ch.readAvailable(p, sb.ctx) - - // Re-lock to update state - sb.mu.Lock() - - if sb.closed { - return totalRead, errors.New("stitchedBody: read from closed body") - } - - if n > 0 { - totalRead += n - sb.bytesRead += int64(n) - p = p[n:] - } - - if err == io.EOF { - // Current chunk is complete, move to next - sb.currentIdx++ - } else if err != nil { - return totalRead, fmt.Errorf("stitchedBody: chunk %d error: %w", sb.currentIdx, err) - } else if n == 0 { - // No error but no data read - this shouldn't happen with readAvailable - // but handle it to avoid infinite loops - return totalRead, fmt.Errorf("stitchedBody: chunk %d read 0 bytes without error or EOF", sb.currentIdx) - } - } - - if totalRead == 0 && sb.currentIdx >= len(sb.chunks) { - return 0, io.EOF - } - - return totalRead, nil -} - -// Close closes all chunk files and cleans up temporary files. -func (sb *stitchedBody) Close() error { - sb.mu.Lock() - defer sb.mu.Unlock() - - if sb.closed { - return nil - } - sb.closed = true - - var errs []error - for _, ch := range sb.chunks { - if err := ch.close(); err != nil { - errs = append(errs, err) - } - ch.cleanup() - } - - if len(errs) > 0 { - return fmt.Errorf("stitchedBody: close errors: %v", errs) - } - return nil -} - -// semaphore implements a counting semaphore for limiting concurrency. -type semaphore struct { - // ch is the buffered channel used to limit concurrent operations. - // If nil, no limits are enforced (unlimited concurrency). - ch chan struct{} -} - -// newSemaphore creates a new semaphore with the given capacity. -// If capacity is 0 or negative, the semaphore allows unlimited concurrency. -func newSemaphore(capacity int) *semaphore { - if capacity <= 0 { - // Unlimited semaphore - nil channel means no limits. - return &semaphore{} - } - return &semaphore{ - ch: make(chan struct{}, capacity), - } -} - -// acquire acquires a semaphore slot, blocking until one is available or context is canceled. -func (s *semaphore) acquire(ctx context.Context) error { - if s.ch == nil { - // Unlimited semaphore - no need to acquire. - return nil - } - select { - case s.ch <- struct{}{}: - return nil - case <-ctx.Done(): - return ctx.Err() - } -} - -// release releases a semaphore slot. -func (s *semaphore) release() { - if s.ch == nil { - // Unlimited semaphore - no need to release. - return - } - <-s.ch -} diff --git a/pkg/distribution/transport/parallel/transport_test.go b/pkg/distribution/transport/parallel/transport_test.go deleted file mode 100644 index 70692a3b9..000000000 --- a/pkg/distribution/transport/parallel/transport_test.go +++ /dev/null @@ -1,847 +0,0 @@ -package parallel - -import ( - "bytes" - "io" - "net/http" - "sync" - "testing" - "time" - - testutil "github.com/docker/model-runner/pkg/distribution/transport/internal/testing" -) - -// TestParallelDownload_Success verifies parallel downloads using -// testutil.FakeTransport. -func TestParallelDownload_Success(t *testing.T) { - url := "https://example.com/large-file" - payload := testutil.GenerateTestData(100000) // 100KB. - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: `"test-etag"`, - }) - - client := &http.Client{ - Transport: New(ft, WithMaxConcurrentPerRequest(4), WithMinChunkSize(1024)), - } - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read: %v", err) - } - - testutil.AssertDataEquals(t, got, payload) - - // Verify parallel requests were made. - reqs := ft.GetRequests() - var headCount, rangeCount, getCount int - for _, req := range reqs { - if req.Method == http.MethodHead { - headCount++ - } else if req.Method == http.MethodGet { - getCount++ - if req.Header.Get("Range") != "" { - rangeCount++ - } - } - t.Logf("Request: %s %s, Range: %s", - req.Method, req.URL, req.Header.Get("Range")) - } - - if headCount != 1 { - t.Errorf("expected 1 HEAD request, got %d", headCount) - } - if rangeCount < 2 { - t.Errorf("expected at least 2 range requests, got %d (total GET: %d)", - rangeCount, getCount) - } -} - -// TestSmallFile_FallsBackToSingle verifies small files aren't parallelized. -func TestSmallFile_FallsBackToSingle(t *testing.T) { - url := "https://example.com/small-file" - payload := []byte("small content") - - ft := testutil.NewFakeTransport() - ft.AddSimple(url, bytes.NewReader(payload), int64(len(payload)), true) - - client := &http.Client{ - Transport: New(ft, WithMaxConcurrentPerRequest(4), WithMinChunkSize(1024)), - } - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read: %v", err) - } - - testutil.AssertDataEquals(t, got, payload) - - // Should only have HEAD and single GET. - reqs := ft.GetRequests() - var headCount, rangeCount, fullGetCount int - for _, req := range reqs { - if req.Method == http.MethodHead { - headCount++ - } else if req.Method == http.MethodGet { - if req.Header.Get("Range") != "" { - rangeCount++ - } else { - fullGetCount++ - } - } - } - - if headCount != 1 { - t.Errorf("expected 1 HEAD request, got %d", headCount) - } - if rangeCount != 0 { - t.Errorf("expected 0 range requests, got %d", rangeCount) - } - if fullGetCount != 1 { - t.Errorf("expected 1 full GET request, got %d", fullGetCount) - } -} - -// TestNoRangeSupport_FallsBack tests fallback when server doesn't support -// ranges. -func TestNoRangeSupport_FallsBack(t *testing.T) { - url := "https://example.com/no-range" - payload := testutil.GenerateTestData(100000) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: false, // No range support. - }) - - client := &http.Client{Transport: New(ft)} - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read: %v", err) - } - - testutil.AssertDataEquals(t, got, payload) - - // Should fall back to single request. - reqs := ft.GetRequests() - var rangeCount int - for _, req := range reqs { - if req.Header.Get("Range") != "" { - rangeCount++ - } - } - - if rangeCount != 0 { - t.Errorf("expected no range requests, got %d", rangeCount) - } -} - -// TestContentEncoding_FallsBack tests fallback with Content-Encoding. -func TestContentEncoding_FallsBack(t *testing.T) { - url := "https://example.com/gzip" - payload := testutil.GenerateTestData(100000) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - Headers: http.Header{ - "Content-Encoding": []string{"gzip"}, - }, - }) - - client := &http.Client{Transport: New(ft)} - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read: %v", err) - } - - testutil.AssertDataEquals(t, got, payload) - - // Should fall back due to Content-Encoding. - reqs := ft.GetRequests() - var rangeCount int - for _, req := range reqs { - if req.Header.Get("Range") != "" { - rangeCount++ - } - } - - if rangeCount != 0 { - t.Errorf("expected no range requests due to Content-Encoding, got %d", - rangeCount) - } -} - -// TestETagValidation verifies ETag is used for If-Range validation. -func TestETagValidation(t *testing.T) { - url := "https://example.com/etag-test" - payload := testutil.GenerateTestData(100000) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: `"strong-etag"`, - }) - - client := &http.Client{Transport: New(ft)} - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read: %v", err) - } - - testutil.AssertDataEquals(t, got, payload) - - // Check If-Range headers. - headers := ft.GetRequestHeaders(url) - for _, h := range headers { - if h.Get("Range") != "" { - if ifRange := h.Get("If-Range"); ifRange != `"strong-etag"` { - t.Errorf("expected If-Range with ETag, got %q", ifRange) - } - } - } -} - -// TestWeakETag_UsesLastModified tests weak ETags trigger Last-Modified usage. -func TestWeakETag_UsesLastModified(t *testing.T) { - url := "https://example.com/weak-etag" - payload := testutil.GenerateTestData(100000) - lastModified := time.Unix(1700000000, 0).UTC().Format(http.TimeFormat) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: `W/"weak-etag"`, - LastModified: lastModified, - }) - - client := &http.Client{Transport: New(ft)} - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read: %v", err) - } - - testutil.AssertDataEquals(t, got, payload) - - // Check If-Range uses Last-Modified instead of weak ETag. - headers := ft.GetRequestHeaders(url) - for _, h := range headers { - if h.Get("Range") != "" { - ifRange := h.Get("If-Range") - if ifRange != lastModified { - t.Errorf("expected If-Range with Last-Modified, got %q", - ifRange) - } - } - } -} - -// TestConcurrencyLimits verifies per-host concurrency limits. -func TestConcurrencyLimits(t *testing.T) { - url := "https://example.com/large" - payload := testutil.GenerateTestData(500000) // 500KB to ensure parallelization. - - ft := testutil.NewFakeTransport() - ft.AddSimple(url, bytes.NewReader(payload), int64(len(payload)), true) - - // Track concurrent requests. maxConcurrent records the peak concurrent range - // downloads observed while currentConcurrent holds the in-flight count at any - // moment. mu ensures those counters are updated atomically. rangeRequests - // counts how many range downloads we observed. wg waits until every tracked - // range request finishes. rangeStartedCh buffers notifications when a new - // tracked range request begins. releaseCh blocks the request until the test - // releases it. releaseOnce ensures releaseCh is only closed once, even on - // early exits. - var maxConcurrent, currentConcurrent int - var mu sync.Mutex - rangeRequests := 0 - var wg sync.WaitGroup - rangeStartedCh := make(chan struct{}, 8) - releaseCh := make(chan struct{}) - var releaseOnce sync.Once - defer releaseOnce.Do(func() { close(releaseCh) }) - - ft.RequestHook = func(req *http.Request) { - rangeHeader := req.Header.Get("Range") - if rangeHeader != "" && rangeHeader != "bytes=0-0" { - wg.Add(1) - - mu.Lock() - currentConcurrent++ - rangeRequests++ - if currentConcurrent > maxConcurrent { - maxConcurrent = currentConcurrent - } - mu.Unlock() - - // Capture the start of the range request without blocking. - select { - case rangeStartedCh <- struct{}{}: - default: - } - - <-releaseCh - - mu.Lock() - currentConcurrent-- - mu.Unlock() - - wg.Done() - } - t.Logf("Request: %s %s, Range: %s", req.Method, req.URL, rangeHeader) - } - - client := &http.Client{ - Transport: New(ft, - WithMaxConcurrentPerHost(map[string]uint{"example.com": 2}), - WithMaxConcurrentPerRequest(4), - WithMinChunkSize(10000)), // Lower min chunk size to ensure parallelization. - } - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - // Drive the download in a goroutine so range requests can start while the - // test observes concurrency. - readDone := make(chan error, 1) - go func() { - _, err := io.ReadAll(resp.Body) - readDone <- err - }() - - for i := 0; i < 2; i++ { - select { - case <-rangeStartedCh: - case <-time.After(time.Second): - releaseOnce.Do(func() { close(releaseCh) }) - t.Fatalf("timed out waiting for parallel range requests to start") - } - } - - releaseOnce.Do(func() { close(releaseCh) }) - - if err := <-readDone; err != nil { - t.Fatalf("read: %v", err) - } - - wg.Wait() - - mu.Lock() - maxSeen := maxConcurrent - madeRanges := rangeRequests - mu.Unlock() - - if maxSeen > 2 { - t.Errorf("expected max 2 concurrent requests, got %d", maxSeen) - } - - if madeRanges == 0 { - t.Error("no range requests were made") - } -} - -// TestIfRangeValidation tests If-Range validation behavior. -func TestIfRangeValidation(t *testing.T) { - url := "https://example.com/if-range-test" - payload := testutil.GenerateTestData(100000) - etag := `"original-etag"` - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: etag, - }) - - // Change ETag on range requests to simulate resource change. - ft.ResponseHook = func(resp *http.Response) { - if resp.Request.Header.Get("Range") != "" { - // Check If-Range validation. - ifRange := resp.Request.Header.Get("If-Range") - if ifRange != etag { - // Resource changed, return full content. - resp.StatusCode = http.StatusOK - resp.Status = "200 OK" - resp.Header.Del("Content-Range") - resp.Body = io.NopCloser(bytes.NewReader(payload)) - } - } - } - - client := &http.Client{Transport: New(ft)} - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read: %v", err) - } - - testutil.AssertDataEquals(t, got, payload) -} - -// TestNoContentLength_FallsBack tests fallback when Content-Length is -// missing. -func TestNoContentLength_FallsBack(t *testing.T) { - url := "https://example.com/no-length" - payload := testutil.GenerateTestData(100000) - - ft := testutil.NewFakeTransport() - ft.AddSimple(url, bytes.NewReader(payload), int64(len(payload)), true) - - // Remove Content-Length from HEAD response. - ft.ResponseHook = func(resp *http.Response) { - if resp.Request.Method == http.MethodHead { - resp.ContentLength = -1 - resp.Header.Del("Content-Length") - } - } - - client := &http.Client{Transport: New(ft)} - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read: %v", err) - } - - testutil.AssertDataEquals(t, got, payload) - - // Should fall back to single request. - reqs := ft.GetRequests() - var rangeCount int - for _, req := range reqs { - if req.Header.Get("Range") != "" { - rangeCount++ - } - } - - if rangeCount != 0 { - t.Errorf("expected no range requests without Content-Length, got %d", - rangeCount) - } -} - -// TestNonGetRequest_PassesThrough verifies non-GET requests are passed -// through unmodified. -func TestNonGetRequest_PassesThrough(t *testing.T) { - url := "https://example.com/resource" - postData := []byte("post data") - responseData := []byte("response") - - ft := testutil.NewFakeTransport() - ft.AddSimple(url, bytes.NewReader(responseData), int64(len(responseData)), false) - - client := &http.Client{Transport: New(ft)} - - // Test POST request. - resp, err := client.Post(url, "application/json", - bytes.NewReader(postData)) - if err != nil { - t.Fatalf("POST failed: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read failed: %v", err) - } - - testutil.AssertDataEquals(t, got, responseData) - - // Should not have any HEAD requests. - reqs := ft.GetRequests() - for _, req := range reqs { - if req.Method == http.MethodHead { - t.Error("unexpected HEAD request for non-GET method") - } - if req.Header.Get("Range") != "" { - t.Error("unexpected Range header for non-GET method") - } - } -} - -// TestWrongRangeResponse_HandlesError tests handling of incorrect range -// responses. -func TestWrongRangeResponse_HandlesError(t *testing.T) { - url := "https://example.com/wrong-range" - payload := testutil.GenerateTestData(100000) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - }) - - // Return wrong range in response. - ft.ResponseHook = func(resp *http.Response) { - if resp.Request.Header.Get("Range") == "bytes=1000-1999" { - // Return different range than requested. - resp.Header.Set("Content-Range", "bytes 2000-2999/100000") - } - } - - client := &http.Client{Transport: New(ft)} - - // Make a specific range request. - req, err := http.NewRequest("GET", url, nil) - if err != nil { - t.Fatalf("create request: %v", err) - } - req.Header.Set("Range", "bytes=1000-1999") - - resp, err := client.Do(req) - if err != nil { - t.Fatalf("GET failed: %v", err) - } - defer resp.Body.Close() - - // Should still work (parallel transport doesn't validate Content-Range - // for user requests). - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read failed: %v", err) - } - - // Should get the correct range data. - want := payload[1000:2000] - testutil.AssertDataEquals(t, got, want) -} - -// TestChunkBoundaries verifies correct chunk boundary calculation. -func TestChunkBoundaries(t *testing.T) { - url := "https://example.com/boundaries" - // Use specific size to test boundary conditions. - payload := testutil.GenerateTestData(10000) // Exactly 10KB. - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - }) - - client := &http.Client{ - Transport: New(ft, - WithMaxConcurrentPerRequest(4), - WithMinChunkSize(2500)), // Should result in 4 chunks of 2500 bytes. - } - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read: %v", err) - } - - testutil.AssertDataEquals(t, got, payload) - - // Check the range requests. - reqs := ft.GetRequests() - - var actualRanges []string - for _, req := range reqs { - if r := req.Header.Get("Range"); r != "" && r != "bytes=0-0" { - actualRanges = append(actualRanges, r) - } - } - - // We might not get exactly these ranges due to scheduling, but verify we - // got multiple. - if len(actualRanges) < 2 { - t.Errorf("expected multiple range requests, got %d", len(actualRanges)) - } - - t.Logf("Actual ranges: %v", actualRanges) -} - -// TestETagChanged_FallsBackToSingle tests handling when ETag changes -// mid-download. -func TestETagChanged_FallsBackToSingle(t *testing.T) { - url := "https://example.com/changing" - payload := testutil.GenerateTestData(100000) - originalETag := `"original"` - changedETag := `"changed"` - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: originalETag, - }) - - requestCount := 0 - var mu sync.Mutex - ft.ResponseHook = func(resp *http.Response) { - mu.Lock() - requestCount++ - rc := requestCount - mu.Unlock() - // Change ETag after first request. - if rc > 1 && resp.Request.Header.Get("Range") != "" { - // Simulate resource change - return full content with new ETag. - resp.StatusCode = http.StatusOK - resp.Status = "200 OK" - resp.Header.Set("ETag", changedETag) - resp.Header.Del("Content-Range") - resp.Body = io.NopCloser(bytes.NewReader(payload)) - } - } - - client := &http.Client{Transport: New(ft)} - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read: %v", err) - } - - // Should still get the full payload. - testutil.AssertDataEquals(t, got, payload) -} - -// TestNoValidator_StillWorks tests parallel download without ETag or -// Last-Modified. -func TestNoValidator_StillWorks(t *testing.T) { - url := "https://example.com/no-validator" - payload := testutil.GenerateTestData(100000) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - // No ETag or LastModified. - }) - - client := &http.Client{Transport: New(ft)} - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read: %v", err) - } - - testutil.AssertDataEquals(t, got, payload) - - // Check that no If-Range headers were sent. - headers := ft.GetRequestHeaders(url) - for _, h := range headers { - if ifRange := h.Get("If-Range"); ifRange != "" { - t.Errorf("unexpected If-Range header: %q", ifRange) - } - } -} - -// TestConditionalHeadersScrubbed verifies conditional headers are removed. -func TestConditionalHeadersScrubbed(t *testing.T) { - url := "https://example.com/conditional" - payload := testutil.GenerateTestData(100000) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: `"test"`, - }) - - // Track headers and validate scrubbing for both HEAD and range GETs. - ft.RequestHook = func(req *http.Request) { - // For range requests made by parallel transport, - // conditional headers should be removed. - if req.Header.Get("Range") != "" { - if req.Header.Get("If-Match") != "" { - t.Errorf("%s request: If-Match header should be removed", - req.Method) - } - if req.Header.Get("If-None-Match") != "" { - t.Errorf("%s request: If-None-Match header should be removed", - req.Method) - } - if req.Header.Get("If-Modified-Since") != "" { - t.Errorf("%s request: If-Modified-Since header should be removed", - req.Method) - } - if req.Header.Get("If-Unmodified-Since") != "" { - t.Errorf("%s request: If-Unmodified-Since header should be removed", - req.Method) - } - } - // HEAD made by parallel transport should scrub conditional headers and - // force identity encoding. - if req.Method == http.MethodHead { - if req.Header.Get("If-Match") != "" || - req.Header.Get("If-None-Match") != "" || - req.Header.Get("If-Modified-Since") != "" || - req.Header.Get("If-Unmodified-Since") != "" { - t.Error("HEAD request should have conditional headers scrubbed") - } - if ae := req.Header.Get("Accept-Encoding"); ae != "identity" { - t.Errorf("HEAD should set Accept-Encoding=identity, got %q", ae) - } - } - // If-Range should only be present on range requests with proper value. - if ifRange := req.Header.Get("If-Range"); ifRange != "" { - if req.Header.Get("Range") == "" { - t.Error("If-Range without Range header") - } - } - } - - client := &http.Client{Transport: New(ft)} - - // Create request with conditional headers. - req, err := http.NewRequest("GET", url, nil) - if err != nil { - t.Fatalf("create request: %v", err) - } - req.Header.Set("If-Match", `"wrong"`) - req.Header.Set("If-None-Match", `"also-wrong"`) - req.Header.Set("If-Modified-Since", "Wed, 21 Oct 2015 07:28:00 GMT") - req.Header.Set("If-Unmodified-Since", "Wed, 21 Oct 2015 07:28:00 GMT") - - resp, err := client.Do(req) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read: %v", err) - } - - testutil.AssertDataEquals(t, got, payload) -} - -// TestRangeHeader_PassesThrough verifies that requests with an explicit -// Range header are passed through without parallelization, and no HEAD -// request is issued by the transport. -func TestRangeHeader_PassesThrough(t *testing.T) { - url := "https://example.com/ranged" - payload := testutil.GenerateTestData(8192) - - ft := testutil.NewFakeTransport() - ft.AddSimple(url, bytes.NewReader(payload), int64(len(payload)), true) - - client := &http.Client{Transport: New(ft, WithMaxConcurrentPerRequest(4))} - - req, err := http.NewRequest("GET", url, nil) - if err != nil { - t.Fatalf("create request: %v", err) - } - req.Header.Set("Range", "bytes=1000-1999") - - resp, err := client.Do(req) - if err != nil { - t.Fatalf("GET failed: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read failed: %v", err) - } - want := payload[1000:2000] - testutil.AssertDataEquals(t, got, want) - - // Ensure no HEAD was made and that only the user’s single GET with Range - // was sent (no extra parallel range requests). - reqs := ft.GetRequests() - var headCount, rangeGets int - for _, r := range reqs { - if r.Method == http.MethodHead { - headCount++ - } - if r.Method == http.MethodGet && r.Header.Get("Range") != "" { - rangeGets++ - } - } - if headCount != 0 { - t.Errorf("expected 0 HEAD requests, got %d", headCount) - } - if rangeGets != 1 { - t.Errorf("expected exactly 1 ranged GET, got %d", rangeGets) - } -} diff --git a/pkg/distribution/transport/resumable/transport.go b/pkg/distribution/transport/resumable/transport.go deleted file mode 100644 index 528d7fd93..000000000 --- a/pkg/distribution/transport/resumable/transport.go +++ /dev/null @@ -1,527 +0,0 @@ -// Package resumable provides an http.RoundTripper that transparently resumes -// interrupted GET responses from servers that support byte ranges. -// -// ───────────────────────────── How it works ───────────────────────────── -// - For GET responses with status 200 or 206 and "Accept-Ranges: bytes", -// the transport replaces resp.Body with a resumable reader. -// - If a mid-stream read fails (e.g., connection cut), it issues a follow-up -// request with a "Range" header to continue from the last delivered byte. -// It uses ETag (strong only) or Last-Modified via If-Range for safety. -// - If the server doesn’t support ranges (or for non-GET), it passes -// through the response unmodified. -// -// ───────────────────────────── Notes & caveats ─────────────────────────── -// - Only single byte ranges are supported when the original request already -// includes Range (multi-range requests are passed through without resuming). -// - Auto-decompression must not be active, or offsets won’t line up. If the -// initial response was transparently decompressed (resp.Uncompressed == true) -// or Content-Encoding was set, resumption is disabled for that response. -// - Cookies added by an http.Client Jar after the initial response aren’t -// automatically applied to follow-up range requests (since they bypass -// http.Client). Existing request headers (incl. Cookie, Authorization, etc.) -// are preserved, but Set-Cookie from the initial response won't be consulted. -// - Some servers don’t advertise Accept-Ranges but still support Range. -// This implementation requires explicit "Accept-Ranges: bytes" for safety. -package resumable - -import ( - "context" - "errors" - "fmt" - "io" - "math" - "math/rand" - "net/http" - "strings" - "sync" - "time" - - "github.com/docker/model-runner/pkg/distribution/transport/internal/common" -) - -// Option configures a ResumableTransport. -type Option func(*ResumableTransport) - -// WithMaxRetries sets the maximum number of resume attempts after an error. -// Default: 3. -func WithMaxRetries(n int) Option { - return func(rt *ResumableTransport) { rt.maxRetries = n } -} - -// BackoffFunc computes the sleep duration for a given retry attempt (0-based). -type BackoffFunc func(attempt int) time.Duration - -// WithBackoff sets the backoff strategy for resume attempts. -// Default: jittered exponential starting at 200ms, capped at 5s. -func WithBackoff(f BackoffFunc) Option { - return func(rt *ResumableTransport) { rt.backoff = f } -} - -// ResumableTransport wraps another http.RoundTripper and transparently retries -// mid-stream failures for GET requests against servers that support range requests. -type ResumableTransport struct { - // base is the underlying RoundTripper actually used to send requests. - base http.RoundTripper - // maxRetries is the maximum number of resume attempts that will be made - // after a read error before giving up. - maxRetries int - // backoff computes how long to wait before each retry attempt. - // Called with the total number of attempts made so far (0-based). - backoff BackoffFunc -} - -// New returns a ResumableTransport wrapping base. If base is nil, -// http.DefaultTransport is used. Options configure retries/backoff. -func New(base http.RoundTripper, opts ...Option) *ResumableTransport { - if base == nil { - base = http.DefaultTransport - } - rt := &ResumableTransport{ - base: base, - maxRetries: 3, - backoff: func(i int) time.Duration { - // 200ms * 2^i with ±20% jitter, capped at 5s - base := 200 * time.Millisecond - d := time.Duration(float64(base) * math.Pow(2, float64(i))) - if d > 5*time.Second { - d = 5 * time.Second - } - j := 0.2 + rand.Float64()*0.4 // [0.2,0.6) - return time.Duration(float64(d) * j) - }, - } - for _, o := range opts { - o(rt) - } - return rt -} - -// RoundTrip implements http.RoundTripper. It wraps GET requests that return -// 200/206 responses with "Accept-Ranges: bytes" support in a resumable body. -func (rt *ResumableTransport) RoundTrip(req *http.Request) (*http.Response, error) { - // Always use the base transport to perform the initial request. - resp, err := rt.base.RoundTrip(req) - if resp == nil || err != nil { - return resp, err - } - - // If the request doesn't meet our resumability criteria, then return it - // directly. - if !isResumable(req, resp) { - return resp, nil - } - - // Create a resumable body to perform retries if needed. - rb := newResumableBody(req, resp, rt) - resp.Body = rb - if n, ok := rb.plannedLength(); ok { - resp.ContentLength = n - } else { - resp.ContentLength = -1 - } - return resp, nil -} - -// isResumable checks if the pair (request, response) is eligible for resume. -func isResumable(req *http.Request, resp *http.Response) bool { - if req.Method != http.MethodGet { - return false - } - if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusPartialContent { - return false - } - if !common.SupportsRange(resp.Header) { - return false - } - // Disallow when the response was auto-decompressed or has a Content-Encoding. - if resp.Uncompressed || resp.Header.Get("Content-Encoding") != "" { - return false - } - // If the original request specified a Range, only support single-range. - if r := req.Header.Get("Range"); strings.TrimSpace(r) != "" { - if _, _, ok := common.ParseSingleRange(r); !ok { - return false - } - } - return true -} - -// resumableBody wraps a Response.Body to add transparent resume support. -// It keeps track of how many bytes have been delivered and re-issues -// Range requests starting from that offset when a read fails. -type resumableBody struct { - // mu guards access to all fields below. - mu sync.Mutex - // ctx is the request context; canceled if caller cancels. - ctx context.Context - // tr is the owning ResumableTransport (for retry/backoff settings). - tr *ResumableTransport - // base is the underlying RoundTripper to use for follow-up Range requests. - base http.RoundTripper - // origReq is the original *http.Request; used as a template for retries. - origReq *http.Request - // current is the most recent *http.Response from which we are reading. - current *http.Response - // rc is the active body we are currently reading from. - rc io.ReadCloser - // bytesRead is how many bytes we have successfully delivered to the caller - // (relative to initialStart). - bytesRead int64 - // initialStart is the starting offset of the stream on the wire, usually 0. - initialStart int64 - // initialEnd is the inclusive end offset (if known from Range header). - initialEnd *int64 - // totalSize is the total known length of the resource, if available. - totalSize *int64 - // etag is the validator used for If-Range in resumed requests (preferred). - etag string - // lastModified is a fallback validator for If-Range if no ETag is present. - lastModified string - // retriesUsed counts how many resume attempts we have made so far. - retriesUsed int - // originalRangeSpec is the Range header sent on the initial request. - originalRangeSpec string - // done marks that we’ve finished delivering all bytes (EOF). - done bool -} - -// newResumableBody constructs a resumableBody from the initial response. -func newResumableBody(req *http.Request, resp *http.Response, tr *ResumableTransport) *resumableBody { - rb := &resumableBody{ - ctx: req.Context(), - tr: tr, - base: tr.base, - origReq: req, - current: resp, - rc: resp.Body, - originalRangeSpec: req.Header.Get("Range"), - } - - // Extract starting offsets from request Range if present (single-range only). - if start, end, ok := common.ParseSingleRange(rb.originalRangeSpec); ok { - rb.initialStart = start - if end >= 0 { - rb.initialEnd = &end - } - } - - // Refine offsets from Content-Range header if response was 206. - if resp.StatusCode == http.StatusPartialContent { - if s, e, total, ok := common.ParseContentRange(resp.Header.Get("Content-Range")); ok { - rb.initialStart = s - if e >= 0 { - rb.initialEnd = &e - } - if total >= 0 { - rb.totalSize = &total - } - } - } else if resp.StatusCode == http.StatusOK { - // For 200 OK, the server is sending a full stream starting at 0 - // regardless of any Range header on the request. - rb.initialStart = 0 - rb.initialEnd = nil - if resp.ContentLength >= 0 { - total := int64(resp.ContentLength) - rb.totalSize = &total - } - } - - // Capture validators for If-Range to ensure consistency across resumes. - if et := resp.Header.Get("ETag"); et != "" && !common.IsWeakETag(et) { - rb.etag = et - } else if lm := resp.Header.Get("Last-Modified"); lm != "" { - rb.lastModified = lm - } - return rb -} - -// Read delivers bytes to the caller. If an error occurs mid-stream, it will -// transparently try to resume by issuing a new Range request. When the total -// length is unknown (e.g., 200 OK without Content-Length), completeness cannot -// be verified precisely; in such cases EOF is treated as the natural end. -func (rb *resumableBody) Read(p []byte) (int, error) { - for { - // Snapshot state without holding the lock across I/O. - rb.mu.Lock() - if rb.done { - rb.mu.Unlock() - return 0, io.EOF - } - rc := rb.rc - planned, plannedOK := rb.plannedLength() - already := rb.bytesRead - rb.mu.Unlock() - - if rc == nil { - if err := rb.resume(already); err != nil { - return 0, err - } - continue - } - - n, err := rc.Read(p) - - rb.mu.Lock() - rb.bytesRead += int64(n) - - switch { - case err == nil: - rb.mu.Unlock() - return n, nil - case errors.Is(err, io.EOF): - // If planned length is known and we are short, resume. - if plannedOK && already+int64(n) < planned { - _ = rb.rc.Close() - rb.rc = nil - if rb.retriesUsed >= rb.tr.maxRetries { - rb.mu.Unlock() - return n, io.ErrUnexpectedEOF - } - // Return bytes now; resume on next call. - if n > 0 { - rb.mu.Unlock() - return n, nil - } - // Resume outside lock. - nextOffset := rb.bytesRead - rb.mu.Unlock() - if rerr := rb.resume(nextOffset); rerr != nil { - return 0, rerr - } - continue - } - // Completed. - rb.done = true - rb.mu.Unlock() - return n, io.EOF - default: - // Underlying read failed mid-stream. Try to resume. - _ = rb.rc.Close() - rb.rc = nil - - if n > 0 { - rb.mu.Unlock() - // Surface bytes already read; the caller will call Read again. - return n, nil - } - if rb.retriesUsed >= rb.tr.maxRetries { - rb.mu.Unlock() - return 0, err - } - off := rb.bytesRead - rb.mu.Unlock() - if rerr := rb.resume(off); rerr != nil { - return 0, rerr - } - continue - } - } -} - -// Close closes the current response body if present. -func (rb *resumableBody) Close() error { - rb.mu.Lock() - rc := rb.rc - rb.rc = nil - rb.done = true - rb.mu.Unlock() - if rc != nil { - return rc.Close() - } - return nil -} - -// plannedLength returns the exact number of bytes this resumableBody intends to produce, if knowable. -func (rb *resumableBody) plannedLength() (int64, bool) { - if rb.initialEnd != nil { - return *rb.initialEnd - rb.initialStart + 1, true - } - if rb.current.StatusCode == http.StatusOK && rb.totalSize != nil { - // 200 OK with known total size at start-of-stream. - return *rb.totalSize, true - } - return 0, false -} - -// resume attempts to resume the response stream at the given absolute offset -// (relative to the very first byte on the wire). The method will make up to the -// remaining retry budget attempts. On success it swaps rb.rc with a fresh body. -func (rb *resumableBody) resume(absoluteOffset int64) error { - remaining := rb.tr.maxRetries - rb.retriesUsed - for attempt := 0; attempt < remaining; attempt++ { - if err := rb.ctx.Err(); err != nil { - return err - } - - // For safety, do not attempt an unvalidated resume when neither a - // strong ETag nor Last-Modified validator is available. - if rb.etag == "" && rb.lastModified == "" { - return fmt.Errorf("resumable: cannot resume without validator") - } - - start := rb.initialStart + absoluteOffset - rangeVal := buildRangeHeader(start, rb.initialEnd) - req := rb.cloneBaseRequest(rangeVal) - - // Backoff for subsequent attempts. - if attempt > 0 || rb.retriesUsed > 0 { - if err := waitBackoff(rb.ctx, rb.tr.backoff, rb.retriesUsed+attempt); err != nil { - return err - } - } - - resp, err := rb.base.RoundTrip(req) - if err != nil { - continue // try again within budget - } - - switch resp.StatusCode { - case http.StatusPartialContent: - // Validate server honored our starting offset precisely. - s, e, _, ok := common.ParseContentRange(resp.Header.Get("Content-Range")) - if !ok || s != start { - _ = resp.Body.Close() - continue // try again; mismatched range - } - // If we requested a closed range and the end does not match, do - // not accept this response. - if rb.initialEnd != nil && e >= 0 && e != *rb.initialEnd { - _ = resp.Body.Close() - continue - } - // Install the new response under lock. - rb.mu.Lock() - rb.installResponseLocked(resp) - rb.retriesUsed++ - rb.mu.Unlock() - return nil - - case http.StatusOK: - // If we requested a range but got a full response, it likely means the - // validator failed (resource changed) or the server ignored Range. - _ = resp.Body.Close() - return fmt.Errorf("resumable: server returned 200 to a range request; resource may have changed") - - case http.StatusMultipleChoices, http.StatusMovedPermanently, http.StatusFound, - http.StatusSeeOther, http.StatusNotModified, http.StatusUseProxy, - http.StatusTemporaryRedirect, http.StatusPermanentRedirect: - _ = resp.Body.Close() - return fmt.Errorf("resumable: resume received redirect status %d", resp.StatusCode) - - case http.StatusRequestedRangeNotSatisfiable: - // If we've already read to/ past the expected end, we are actually done. - if rb.rangeIsComplete(absoluteOffset) { - rb.done = true - _ = resp.Body.Close() - return io.EOF - } - _ = resp.Body.Close() - - default: - _ = resp.Body.Close() - } - } - return fmt.Errorf("resumable: exceeded retry budget after %d attempts", rb.tr.maxRetries) -} - -// installResponseLocked installs resp as the current response and updates -// validators and size info. Caller must hold rb.mu. -func (rb *resumableBody) installResponseLocked(resp *http.Response) { - if rb.rc != nil && rb.rc != resp.Body { - _ = rb.rc.Close() - } - rb.current = resp - rb.rc = resp.Body - - // Persist validators from the server if they are strong. - if et := resp.Header.Get("ETag"); et != "" && !common.IsWeakETag(et) { - rb.etag = et - } - if lm := resp.Header.Get("Last-Modified"); lm != "" { - rb.lastModified = lm - } - - // Merge any updated size info from the Content-Range. - if s, e, total, ok := common.ParseContentRange(resp.Header.Get("Content-Range")); ok { - _ = s // start validated by caller - if e >= 0 { - rb.initialEnd = &e - } - if total >= 0 { - rb.totalSize = &total - } - } -} - -// cloneBaseRequest builds a new GET request with the same headers as the original, -// except with a different Range and If-Range validator. To avoid mismatched -// encodings, we also force identity encoding. -func (rb *resumableBody) cloneBaseRequest(rangeVal string) *http.Request { - req := rb.origReq.Clone(rb.ctx) - req.Body = nil - req.ContentLength = 0 - req.Header = rb.origReq.Header.Clone() - - // Ensure we control the Range validator set. - req.Header.Set("Range", rangeVal) - // Remove conditional headers that could conflict with If-Range semantics. - common.ScrubConditionalHeaders(req.Header) - - if rb.etag != "" { - req.Header.Set("If-Range", rb.etag) - } else if rb.lastModified != "" { - req.Header.Set("If-Range", rb.lastModified) - } - - // Prevent transparent decompression on resumed requests. - req.Header.Set("Accept-Encoding", "identity") - return req -} - -// buildRangeHeader constructs a "Range" header value for a given start and -// optional inclusive end. -func buildRangeHeader(start int64, end *int64) string { - if end == nil { - return fmt.Sprintf("bytes=%d-", start) - } - return fmt.Sprintf("bytes=%d-%d", start, *end) -} - -// waitBackoff sleeps using the provided backoff function, unless the context -// is canceled. -func waitBackoff(ctx context.Context, bf BackoffFunc, attempt int) error { - d := time.Duration(0) - if bf != nil { - d = bf(attempt) - } - if d <= 0 { - return nil - } - t := time.NewTimer(d) - defer t.Stop() - select { - case <-t.C: - return nil - case <-ctx.Done(): - return ctx.Err() - } -} - -// rangeIsComplete returns true if the bytes we have delivered already meet the -// expected end of the range / resource, so a 416 implies we are done. -func (rb *resumableBody) rangeIsComplete(absoluteOffset int64) bool { - if rb.totalSize != nil { - // If we know total size, we are complete if start+offset >= total. - if rb.initialStart+absoluteOffset >= *rb.totalSize { - return true - } - } - if rb.initialEnd != nil { - // initialEnd is inclusive. - if rb.initialStart+absoluteOffset >= *rb.initialEnd+1 { - return true - } - } - return false -} diff --git a/pkg/distribution/transport/resumable/transport_test.go b/pkg/distribution/transport/resumable/transport_test.go deleted file mode 100644 index 0c86c5855..000000000 --- a/pkg/distribution/transport/resumable/transport_test.go +++ /dev/null @@ -1,1386 +0,0 @@ -package resumable - -import ( - "bytes" - "fmt" - "io" - "net/http" - "strings" - "sync" - "testing" - "time" - - testutil "github.com/docker/model-runner/pkg/distribution/transport/internal/testing" -) - -// blockingBody simulates a response body that blocks on Read until closed. -type blockingBody struct { - ch chan struct{} -} - -func newBlockingBody() *blockingBody { return &blockingBody{ch: make(chan struct{})} } -func (b *blockingBody) Read(p []byte) (int, error) { <-b.ch; return 0, io.EOF } -func (b *blockingBody) Close() error { close(b.ch); return nil } - -// TestResumeSingleFailure_Succeeds tests resuming after a single failure. -func TestResumeSingleFailure_Succeeds(t *testing.T) { - url := "https://example.com/test-file" - payload := testutil.GenerateTestData(5000) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: `"test-etag"`, - }) - - // Simulate failure after 2500 bytes on first request. - ft.SetFailAfter(url, 2500) - - client := &http.Client{ - Transport: New(ft, WithMaxRetries(3)), - } - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET failed: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read failed: %v", err) - } - - testutil.AssertDataEquals(t, got, payload) - - // Verify resume happened. - reqs := ft.GetRequests() - var rangeRequests int - for _, req := range reqs { - if req.Header.Get("Range") != "" { - rangeRequests++ - t.Logf("Range request: %s", req.Header.Get("Range")) - } - } - - if rangeRequests < 1 { - t.Error("expected at least one range request for resume") - } -} - -// TestResumeMultipleFailuresWithinBudget_Succeeds tests multiple resume -// attempts. -func TestResumeMultipleFailuresWithinBudget_Succeeds(t *testing.T) { - url := "https://example.com/multi-fail" - payload := testutil.GenerateTestData(10000) - - ft := testutil.NewFakeTransport() - - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: `"multi-fail-etag"`, - }) - - // Hook to inject failures - use SetFailAfter multiple times. - failurePoints := []int{2000, 5000, 7500} - failureIndex := 0 - requestCount := 0 - var mu sync.Mutex - ft.ResponseHook = func(resp *http.Response) { - if resp.Request.Method == http.MethodGet && - failureIndex < len(failurePoints) { - // For non-range requests, inject failure. - if resp.Request.Header.Get("Range") == "" { - mu.Lock() - idx := failureIndex - failureIndex++ - mu.Unlock() - resp.Body = testutil.NewFlakyReader( - bytes.NewReader(payload), - int64(len(payload)), - failurePoints[idx], - ) - } else { - // For range requests, check which failure point we're at. - mu.Lock() - requestCount++ - rc := requestCount - fi := failureIndex - mu.Unlock() - if rc <= len(failurePoints) && - fi < len(failurePoints) { - // Parse range to determine data slice. - rangeHeader := resp.Request.Header.Get("Range") - if rangeHeader != "" { - // Simple parsing for bytes=N- format. - var start int - fmt.Sscanf(rangeHeader, "bytes=%d-", &start) - rangeData := payload[start:] - - // Apply next failure point relative to this - // range. - nextFailure := failurePoints[fi] - start - if nextFailure > 0 && - nextFailure < len(rangeData) { - resp.Body = testutil.NewFlakyReader( - bytes.NewReader(rangeData), - int64(len(rangeData)), - nextFailure, - ) - mu.Lock() - failureIndex++ - mu.Unlock() - } - } - } - } - } - } - - client := &http.Client{ - Transport: New(ft, WithMaxRetries(5)), - } - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET failed: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read failed: %v", err) - } - - testutil.AssertDataEquals(t, got, payload) - - // Check that multiple resumes happened. - reqs := ft.GetRequests() - var rangeCount int - for _, req := range reqs { - if req.Header.Get("Range") != "" { - rangeCount++ - } - } - - if rangeCount < 2 { - t.Errorf("expected at least 2 range requests, got %d", rangeCount) - } -} - -// TestExceedRetryBudget_Fails tests failure when retry budget is exceeded. -func TestExceedRetryBudget_Fails(t *testing.T) { - url := "https://example.com/too-many-failures" - payload := testutil.GenerateTestData(4096) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: `"fail-test"`, - }) - - // Always fail after 100 bytes. - ft.ResponseHook = func(resp *http.Response) { - if resp.Request.Method == http.MethodGet { - resp.Body = testutil.NewFlakyReader(bytes.NewReader(payload), int64(len(payload)), 100) - } - } - - client := &http.Client{ - Transport: New(ft, WithMaxRetries(2)), // Low retry limit. - } - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET failed: %v", err) - } - defer resp.Body.Close() - - _, err = io.ReadAll(resp.Body) - if err == nil { - t.Error("expected error after exceeding retry budget") - } - - // Check that retries were attempted. - reqs := ft.GetRequests() - var attempts int - for _, req := range reqs { - if req.Method == http.MethodGet { - attempts++ - } - } - - // Initial + 2 retries = 3 total. - if attempts < 2 { - t.Errorf("expected at least 2 GET attempts, got %d", attempts) - } -} - -// TestReadCloseInterleaving ensures Close does not deadlock with a blocked Read -// and unblocks promptly. -func TestReadCloseInterleaving(t *testing.T) { - url := "https://example.com/blocking" - payload := testutil.GenerateTestData(1024) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: `"etag"`, - }) - // Replace body with a blocking body for the initial GET. - bb := newBlockingBody() - ft.ResponseHook = func(resp *http.Response) { - if resp.Request.Method == http.MethodGet && resp.Request.Header.Get("Range") == "" { - resp.Body = bb - } - } - - client := &http.Client{Transport: New(ft)} - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET: %v", err) - } - - done := make(chan struct{}) - go func() { - defer close(done) - _, _ = io.ReadAll(resp.Body) - }() - - // Close should unblock the read goroutine promptly. - if err := resp.Body.Close(); err != nil { - t.Fatalf("close: %v", err) - } - select { - case <-done: - case <-time.After(1 * time.Second): - t.Fatal("read did not unblock after Close") - } -} - -// TestMultiRange_PassThrough ensures multi-range requests are not wrapped. -func TestMultiRange_PassThrough(t *testing.T) { - url := "https://example.com/multirange" - payload := testutil.GenerateTestData(4096) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - }) - - client := &http.Client{Transport: New(ft)} - req, err := http.NewRequest("GET", url, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Range", "bytes=0-10,20-30") - - resp, err := client.Do(req) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - // FakeTransport does not implement multi-range; it returns 400. - if resp.StatusCode != http.StatusBadRequest { - t.Fatalf("expected 400 from fake transport for multi-range, got %d", resp.StatusCode) - } - - // Ensure no If-Range was injected on request headers. - hdrs := ft.GetRequestHeaders(url) - for _, h := range hdrs { - if h.Get("If-Range") != "" { - t.Error("unexpected If-Range header on multi-range request") - } - } -} - -// TestInitialRange_200OK_Ignored ensures if server responds 200 to a ranged -// request, the stream is treated as starting at 0 and reads succeed. -func TestInitialRange_200OK_Ignored(t *testing.T) { - url := "https://example.com/range-ignored" - payload := testutil.GenerateTestData(2048) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: `"e"`, - }) - - // Force 200 full response even when Range is present. - ft.ResponseHook = func(resp *http.Response) { - if resp.Request.Header.Get("Range") != "" && resp.StatusCode == http.StatusPartialContent { - resp.StatusCode = http.StatusOK - resp.Status = "200 OK" - resp.Header.Del("Content-Range") - resp.Body = io.NopCloser(bytes.NewReader(payload)) - } - } - - client := &http.Client{Transport: New(ft)} - req, _ := http.NewRequest("GET", url, nil) - req.Header.Set("Range", "bytes=100-199") - resp, err := client.Do(req) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - data, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read: %v", err) - } - testutil.AssertDataEquals(t, data, payload) -} - -// TestRedirectOnResume returns 3xx for resume request and expects a clear error. -func TestRedirectOnResume(t *testing.T) { - url := "https://example.com/redirect-on-resume" - payload := testutil.GenerateTestData(5000) - etag := `"strong"` - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: etag, - }) - ft.SetFailAfter(url, 2500) - - ft.ResponseHook = func(resp *http.Response) { - if resp.Request.Header.Get("Range") != "" { - resp.StatusCode = http.StatusFound - resp.Status = "302 Found" - resp.Header.Del("Content-Range") - resp.Body = io.NopCloser(bytes.NewReader(nil)) - } - } - - client := &http.Client{Transport: New(ft, WithMaxRetries(2))} - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - _, err = io.ReadAll(resp.Body) - if err == nil || !strings.Contains(err.Error(), "redirect status") { - t.Fatalf("expected redirect error, got %v", err) - } -} - -// TestWrongStartOnResume_IsRejected tests handling of unexpected range -// responses. -func TestWrongStartOnResume_IsRejected(t *testing.T) { - url := "https://example.com/wrong-start" - payload := testutil.GenerateTestData(5000) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: `"test"`, - }) - - // Return wrong range on resume. - resumeAttempted := false - var muResume sync.Mutex - ft.ResponseHook = func(resp *http.Response) { - if resp.Request.Header.Get("Range") == "bytes=2500-" { - muResume.Lock() - resumeAttempted = true - muResume.Unlock() - // Return wrong start position. - resp.Header.Set("Content-Range", "bytes 3000-4999/5000") - resp.Body = io.NopCloser(testutil.NewFlakyReader( - bytes.NewReader(payload[3000:]), - int64(len(payload[3000:])), - 0, - )) - } - } - - // First fail after 2500 bytes. - ft.SetFailAfter(url, 2500) - - client := &http.Client{ - Transport: New(ft, WithMaxRetries(3)), - } - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - _, err = io.ReadAll(resp.Body) - if err == nil { - t.Error("expected error due to wrong range start") - } - - muResume.Lock() - attempted := resumeAttempted - muResume.Unlock() - if !attempted { - t.Error("resume was not attempted") - } -} - -// TestNon206OnResume_IsRejected tests handling when server returns 200 -// instead of 206. -func TestNon206OnResume_IsRejected(t *testing.T) { - url := "https://example.com/non-206" - payload := testutil.GenerateTestData(5000) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: `"test"`, - }) - - // Return 200 on range request (simulating resource change). - ft.ResponseHook = func(resp *http.Response) { - if resp.Request.Header.Get("Range") == "bytes=2500-" { - resp.StatusCode = http.StatusOK - resp.Status = "200 OK" - resp.Header.Del("Content-Range") - resp.Body = io.NopCloser(testutil.NewFlakyReader( - bytes.NewReader(payload), - int64(len(payload)), - 0, - )) - } - } - - ft.SetFailAfter(url, 2500) - - client := &http.Client{ - Transport: New(ft, WithMaxRetries(3)), - } - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - _, err = io.ReadAll(resp.Body) - if err == nil || - err.Error() != "resumable: server returned 200 to a range request; resource may have changed" { - t.Errorf("expected specific error, got: %v", err) - } -} - -// TestNoRangeSupport_PassesThrough_NoResume tests fallback when server -// doesn't support ranges. -func TestNoRangeSupport_PassesThrough_NoResume(t *testing.T) { - url := "https://example.com/no-range" - payload := testutil.GenerateTestData(5000) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: false, // No range support. - }) - - // Simulate failure - should not be able to resume. - ft.SetFailAfter(url, 2500) - - client := &http.Client{ - Transport: New(ft, WithMaxRetries(3)), - } - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET failed: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err == nil { - t.Error("expected read error due to no range support and failure") - } - - // Should only get partial data. - if len(got) >= len(payload) { - t.Errorf("got %d bytes, expected less than %d", len(got), len(payload)) - } -} - -// TestIfRange_ETag_Matches_AllowsResume tests If-Range with ETag validation. -func TestIfRange_ETag_Matches_AllowsResume(t *testing.T) { - url := "https://example.com/if-range-etag" - payload := testutil.GenerateTestData(7500) - etag := `"strong-etag"` - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: etag, - }) - - // Simulate failure to trigger resume. - failCount := 0 - var muFail sync.Mutex - ft.ResponseHook = func(resp *http.Response) { - muFail.Lock() - fc := failCount - if resp.Request.Method == http.MethodGet && fc == 0 { - failCount = fc + 1 - muFail.Unlock() - // First request fails after 3000 bytes. - resp.Body = testutil.NewFlakyReader( - bytes.NewReader(payload), - int64(len(payload)), - 3000, - ) - return - } - muFail.Unlock() - } - - client := &http.Client{ - Transport: New(ft, WithMaxRetries(3)), - } - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET failed: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read failed: %v", err) - } - - testutil.AssertDataEquals(t, got, payload) - - // Check If-Range header on resume request. - headers := ft.GetRequestHeaders(url) - foundIfRange := false - for _, h := range headers { - if h.Get("Range") != "" { - if ifRange := h.Get("If-Range"); ifRange == etag { - foundIfRange = true - break - } - } - } - - if !foundIfRange { - t.Error("expected If-Range header with ETag on resume") - } -} - -// TestIfRange_ETag_ChangedOnResume_RejectsResume tests ETag change detection. -func TestIfRange_ETag_ChangedOnResume_RejectsResume(t *testing.T) { - url := "https://example.com/etag-changed" - payload := testutil.GenerateTestData(5000) - originalETag := `"original"` - changedETag := `"changed"` - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: originalETag, - }) - - // Change ETag on resume attempt. - ft.ResponseHook = func(resp *http.Response) { - if resp.Request.Header.Get("Range") != "" { - // Simulate resource change. - resp.StatusCode = http.StatusOK - resp.Status = "200 OK" - resp.Header.Set("ETag", changedETag) - resp.Header.Del("Content-Range") - resp.Body = io.NopCloser(testutil.NewFlakyReader( - bytes.NewReader(payload), - int64(len(payload)), - 0, - )) - } - } - - ft.SetFailAfter(url, 2500) - - client := &http.Client{ - Transport: New(ft, WithMaxRetries(3)), - } - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - _, err = io.ReadAll(resp.Body) - if err == nil || - err.Error() != "resumable: server returned 200 to a range request; resource may have changed" { - t.Errorf("expected resource change error, got: %v", err) - } -} - -// TestIfRange_LastModified_Matches_AllowsResume tests If-Range with Last-Modified -func TestIfRange_LastModified_Matches_AllowsResume(t *testing.T) { - url := "https://example.com/if-range-lm" - payload := testutil.GenerateTestData(6000) - lastModified := "Wed, 21 Oct 2015 07:28:00 GMT" - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - LastModified: lastModified, - // No ETag, so should use Last-Modified - }) - - // Simulate failure - ft.SetFailAfter(url, 3000) - - client := &http.Client{ - Transport: New(ft, WithMaxRetries(3)), - } - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET failed: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read failed: %v", err) - } - - testutil.AssertDataEquals(t, got, payload) - - // Check If-Range uses Last-Modified - headers := ft.GetRequestHeaders(url) - foundIfRange := false - for _, h := range headers { - if h.Get("Range") != "" { - if ifRange := h.Get("If-Range"); ifRange == lastModified { - foundIfRange = true - break - } - } - } - - if !foundIfRange { - t.Error("expected If-Range header with Last-Modified on resume") - } -} - -// TestIfRange_LastModified_ChangedOnResume_RejectsResume tests Last-Modified change detection -func TestIfRange_LastModified_ChangedOnResume_RejectsResume(t *testing.T) { - url := "https://example.com/lm-changed" - payload := testutil.GenerateTestData(5000) - originalLM := "Wed, 21 Oct 2015 07:28:00 GMT" - changedLM := "Thu, 22 Oct 2015 08:30:00 GMT" - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - LastModified: originalLM, - }) - - // Change Last-Modified on resume - ft.ResponseHook = func(resp *http.Response) { - if resp.Request.Header.Get("Range") != "" { - // Simulate resource change - resp.StatusCode = http.StatusOK - resp.Status = "200 OK" - resp.Header.Set("Last-Modified", changedLM) - resp.Header.Del("Content-Range") - resp.Body = io.NopCloser(testutil.NewFlakyReader( - bytes.NewReader(payload), - int64(len(payload)), - 0, - )) - } - } - - ft.SetFailAfter(url, 2500) - - client := &http.Client{ - Transport: New(ft, WithMaxRetries(3)), - } - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - _, err = io.ReadAll(resp.Body) - if err == nil || - err.Error() != "resumable: server returned 200 to a range request; resource may have changed" { - t.Errorf("expected resource change error, got: %v", err) - } -} - -// TestIfRange_RequiredButUnavailable_MissingRejected tests when no validator is available -func TestIfRange_RequiredButUnavailable_MissingRejected(t *testing.T) { - url := "https://example.com/no-validator" - payload := testutil.GenerateTestData(5000) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - // No ETag or LastModified - }) - - ft.SetFailAfter(url, 2500) - - client := &http.Client{ - Transport: New(ft, WithMaxRetries(3)), - } - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - _, err = io.ReadAll(resp.Body) - // Safer behavior: do not attempt resume without a validator. Expect an - // error to be surfaced when the initial stream fails and cannot resume. - if err == nil { - t.Error("expected error due to missing resume validator") - } -} - -// TestIfRange_WeakETag_Present_UsesLastModified_AllowsResume tests weak ETags fall back to Last-Modified -func TestIfRange_WeakETag_Present_UsesLastModified_AllowsResume(t *testing.T) { - url := "https://example.com/weak-etag" - payload := testutil.GenerateTestData(10000) - lastModified := "Mon, 02 Jan 2006 15:04:05 MST" - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: `W/"weak-etag"`, // Weak ETag - LastModified: lastModified, - }) - - // Simulate failure - ft.SetFailAfter(url, 5000) - - client := &http.Client{ - Transport: New(ft, WithMaxRetries(3)), - } - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET failed: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read failed: %v", err) - } - - testutil.AssertDataEquals(t, got, payload) - - // Should use Last-Modified for If-Range, not weak ETag - headers := ft.GetRequestHeaders(url) - for _, h := range headers { - if h.Get("Range") != "" { - ifRange := h.Get("If-Range") - if ifRange == `W/"weak-etag"` { - t.Error("should not use weak ETag for If-Range") - } - if ifRange != lastModified { - t.Errorf("expected If-Range with Last-Modified, got %q", ifRange) - } - } - } -} - -// TestGzipContentEncoding_DisablesResume tests that Content-Encoding disables resume -func TestGzipContentEncoding_DisablesResume(t *testing.T) { - url := "https://example.com/gzip" - payload := testutil.GenerateTestData(12000) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - Headers: http.Header{ - "Content-Encoding": []string{"gzip"}, - }, - }) - - // Simulate failure - ft.SetFailAfter(url, 6000) - - client := &http.Client{ - Transport: New(ft, WithMaxRetries(3)), - } - - resp, err := client.Get(url) - if err != nil { - t.Fatalf("GET failed: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - // Should fail because Content-Encoding prevents resume - if err == nil { - t.Error("expected error due to Content-Encoding preventing resume") - } - - // Should only have partial data - if len(got) >= len(payload) { - t.Errorf("got %d bytes, expected less due to failure", len(got)) - } -} - -// TestResumeHeaders_ScrubbedAndIdentityEncoding tests header handling on resume -func TestResumeHeaders_ScrubbedAndIdentityEncoding(t *testing.T) { - url := "https://example.com/headers" - payload := testutil.GenerateTestData(5000) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: `"test"`, - }) - - // Check headers on resume - ft.RequestHook = func(req *http.Request) { - if req.Header.Get("Range") != "" { - // Check that Accept-Encoding is set to identity - if ae := req.Header.Get("Accept-Encoding"); ae != "identity" { - t.Errorf("expected Accept-Encoding: identity, got: %q", ae) - } - // Check that conditional headers are removed - if req.Header.Get("If-Modified-Since") != "" { - t.Error("If-Modified-Since should be removed on resume") - } - if req.Header.Get("If-None-Match") != "" { - t.Error("If-None-Match should be removed on resume") - } - } - } - - ft.SetFailAfter(url, 2500) - - client := &http.Client{ - Transport: New(ft, WithMaxRetries(3)), - } - - // Create request with various headers - req, err := http.NewRequest("GET", url, nil) - if err != nil { - t.Fatalf("create request: %v", err) - } - req.Header.Set("Accept-Encoding", "gzip, deflate") - req.Header.Set("If-Modified-Since", "Wed, 21 Oct 2015 07:28:00 GMT") - req.Header.Set("If-None-Match", `"other"`) - - resp, err := client.Do(req) - if err != nil { - t.Fatalf("GET: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read: %v", err) - } - - testutil.AssertDataEquals(t, got, payload) -} - -// TestRangeRequest_Initial tests resume with initial Range request -func TestRangeRequest_Initial(t *testing.T) { - url := "https://example.com/range-initial" - payload := testutil.GenerateTestData(10240) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: `"range-test"`, - }) - - // Simulate failure on range request - failCount := 0 - var muRange sync.Mutex - ft.ResponseHook = func(resp *http.Response) { - muRange.Lock() - fc := failCount - if resp.Request.Header.Get("Range") == "bytes=1024-5119" && fc == 0 { - failCount = fc + 1 - muRange.Unlock() - // Fail after 2000 bytes of the range - rangeData := payload[1024:5120] - resp.Body = testutil.NewFlakyReader( - bytes.NewReader(rangeData), - int64(len(rangeData)), - 2000, - ) - return - } - muRange.Unlock() - } - - // Create request with initial Range header - req, err := http.NewRequest("GET", url, nil) - if err != nil { - t.Fatalf("create request: %v", err) - } - req.Header.Set("Range", "bytes=1024-5119") - - client := &http.Client{ - Transport: New(ft, WithMaxRetries(3)), - } - - resp, err := client.Do(req) - if err != nil { - t.Fatalf("GET failed: %v", err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatalf("read failed: %v", err) - } - - want := payload[1024:5120] - testutil.AssertDataEquals(t, got, want) - - // Check resume happened with adjusted range - headers := ft.GetRequestHeaders(url) - foundResume := false - for _, h := range headers { - rangeHeader := h.Get("Range") - if rangeHeader != "" && rangeHeader != "bytes=1024-5119" { - foundResume = true - t.Logf("Resume range: %s", rangeHeader) - } - } - - if !foundResume { - t.Error("expected resume with adjusted range") - } -} - -// Additional range request tests for comprehensive coverage -func TestRangeInitial_ZeroToN_NoCuts_Succeeds(t *testing.T) { - url := "https://example.com/range-0-n" - payload := testutil.GenerateTestData(5000) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - }) - - req, err := http.NewRequest("GET", url, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Range", "bytes=0-2499") - - client := &http.Client{Transport: New(ft)} - - resp, err := client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - want := payload[0:2500] - testutil.AssertDataEquals(t, got, want) -} - -func TestRangeInitial_MidSpan_NoCuts_Succeeds(t *testing.T) { - url := "https://example.com/range-mid" - payload := testutil.GenerateTestData(5000) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - }) - - req, err := http.NewRequest("GET", url, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Range", "bytes=1000-1999") - - client := &http.Client{Transport: New(ft)} - - resp, err := client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - want := payload[1000:2000] - testutil.AssertDataEquals(t, got, want) -} - -func TestRangeInitial_FromNToEnd_NoCuts_Succeeds(t *testing.T) { - url := "https://example.com/range-to-end" - payload := testutil.GenerateTestData(5000) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - }) - - req, err := http.NewRequest("GET", url, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Range", "bytes=3000-") - - client := &http.Client{Transport: New(ft)} - - resp, err := client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - want := payload[3000:] - testutil.AssertDataEquals(t, got, want) -} - -func TestRangeInitial_ZeroToN_WithCut_Resumes(t *testing.T) { - url := "https://example.com/range-0-n-cut" - payload := testutil.GenerateTestData(5000) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: `"test"`, - }) - - // Fail the range request partway through - failCount := 0 - ft.ResponseHook = func(resp *http.Response) { - if resp.Request.Header.Get("Range") == "bytes=0-2499" && failCount == 0 { - failCount++ - resp.Body = testutil.NewFlakyReader( - bytes.NewReader(payload[0:2500]), - int64(len(payload[0:2500])), - 1000, - ) - } - } - - req, err := http.NewRequest("GET", url, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Range", "bytes=0-2499") - - client := &http.Client{Transport: New(ft, WithMaxRetries(3))} - - resp, err := client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - want := payload[0:2500] - testutil.AssertDataEquals(t, got, want) - - // Verify resume happened - headers := ft.GetRequestHeaders(url) - foundResume := false - for _, h := range headers { - rangeHeader := h.Get("Range") - if rangeHeader != "" && rangeHeader != "bytes=0-2499" { - foundResume = true - if rangeHeader != "bytes=1000-2499" { - t.Errorf("expected resume at bytes=1000-2499, got: %s", rangeHeader) - } - } - } - - if !foundResume { - t.Error("expected resume") - } -} - -func TestRangeInitial_MidSpan_WithMultipleCuts_Resumes(t *testing.T) { - url := "https://example.com/range-mid-cuts" - payload := testutil.GenerateTestData(10000) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: `"test"`, - }) - - // Multiple failures on the range request - failCount := 0 - var muCut sync.Mutex - ft.ResponseHook = func(resp *http.Response) { - rangeHeader := resp.Request.Header.Get("Range") - muCut.Lock() - fc := failCount - if rangeHeader == "bytes=2000-5999" && fc == 0 { - failCount = fc + 1 - muCut.Unlock() - resp.Body = testutil.NewFlakyReader( - bytes.NewReader(payload[2000:6000]), - int64(len(payload[2000:6000])), - 1000, - ) - return - } else if rangeHeader == "bytes=3000-5999" && fc == 1 { - failCount = fc + 1 - muCut.Unlock() - resp.Body = testutil.NewFlakyReader( - bytes.NewReader(payload[3000:6000]), - int64(len(payload[3000:6000])), - 1500, - ) - return - } - muCut.Unlock() - } - - req, err := http.NewRequest("GET", url, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Range", "bytes=2000-5999") - - client := &http.Client{Transport: New(ft, WithMaxRetries(5))} - - resp, err := client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - want := payload[2000:6000] - testutil.AssertDataEquals(t, got, want) - - // Check that multiple resumes happened. - reqs := ft.GetRequests() - var rangeCount int - for _, r := range reqs { - if r.Header.Get("Range") != "" { - rangeCount++ - } - } - - if rangeCount < 3 { - t.Errorf("expected at least 3 range requests, got %d", rangeCount) - } -} - -func TestRangeInitial_FromNToEnd_WithCut_Resumes(t *testing.T) { - url := "https://example.com/range-to-end-cut" - payload := testutil.GenerateTestData(10000) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: `"test"`, - }) - - // Fail the open-ended range request - failCount := 0 - var muOpen sync.Mutex - ft.ResponseHook = func(resp *http.Response) { - muOpen.Lock() - fc := failCount - if resp.Request.Header.Get("Range") == "bytes=7000-" && fc == 0 { - failCount = fc + 1 - muOpen.Unlock() - resp.Body = testutil.NewFlakyReader( - bytes.NewReader(payload[7000:]), - int64(len(payload[7000:])), - 1500, - ) - return - } - muOpen.Unlock() - } - - req, err := http.NewRequest("GET", url, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Range", "bytes=7000-") - - client := &http.Client{Transport: New(ft, WithMaxRetries(3))} - - resp, err := client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - want := payload[7000:] - testutil.AssertDataEquals(t, got, want) - - // Verify resume happened - headers := ft.GetRequestHeaders(url) - foundResume := false - for _, h := range headers { - rangeHeader := h.Get("Range") - if rangeHeader != "" && rangeHeader != "bytes=7000-" { - foundResume = true - // Accept either open-ended or closed range - if rangeHeader != "bytes=8500-" && rangeHeader != "bytes=8500-9999" { - t.Errorf("expected resume at bytes=8500- or bytes=8500-9999, got: %s", rangeHeader) - } - } - } - - if !foundResume { - t.Error("expected resume") - } -} - -func TestRangeInitial_ResumeHeaderStart_Correct(t *testing.T) { - url := "https://example.com/range-header-check" - payload := testutil.GenerateTestData(5000) - - ft := testutil.NewFakeTransport() - ft.Add(url, &testutil.FakeResource{ - Data: bytes.NewReader(payload), - Length: int64(len(payload)), - SupportsRange: true, - ETag: `"test"`, - }) - - // Fail at exactly 1234 bytes - failCount := 0 - var muHdr sync.Mutex - ft.ResponseHook = func(resp *http.Response) { - muHdr.Lock() - fc := failCount - if resp.Request.Header.Get("Range") == "bytes=1000-2999" && fc == 0 { - failCount = fc + 1 - muHdr.Unlock() - rangeData := payload[1000:3000] - resp.Body = testutil.NewFlakyReader( - bytes.NewReader(rangeData), - int64(len(rangeData)), - 234, - ) // Will have read 1234 total - return - } - muHdr.Unlock() - } - - req, err := http.NewRequest("GET", url, nil) - if err != nil { - t.Fatal(err) - } - req.Header.Set("Range", "bytes=1000-2999") - - client := &http.Client{Transport: New(ft, WithMaxRetries(3))} - - resp, err := client.Do(req) - if err != nil { - t.Fatal(err) - } - defer resp.Body.Close() - - got, err := io.ReadAll(resp.Body) - if err != nil { - t.Fatal(err) - } - - want := payload[1000:3000] - testutil.AssertDataEquals(t, got, want) - - // Check the resume request has correct start position - headers := ft.GetRequestHeaders(url) - for _, h := range headers { - rangeHeader := h.Get("Range") - if rangeHeader != "" && rangeHeader != "bytes=1000-2999" { - if rangeHeader != "bytes=1234-2999" { - t.Errorf("expected resume at bytes=1234-2999, got: %s", rangeHeader) - } - break - } - } -} diff --git a/tools/benchmarks/parallelget/README.md b/tools/benchmarks/parallelget/README.md deleted file mode 100644 index 2db5a75f1..000000000 --- a/tools/benchmarks/parallelget/README.md +++ /dev/null @@ -1,110 +0,0 @@ -# ParallelGet Benchmark Tool - -A command-line benchmarking tool that compares the performance of standard HTTP GET requests against parallelized requests using the `transport/parallel` package. - -## Features - -- **Performance Comparison**: Downloads the same URL twice (standard vs parallel) and compares timing -- **Response Validation**: Ensures both downloads produce identical results (byte-for-byte comparison) -- **Configurable Parameters**: Adjustable chunk size and concurrency settings -- **Detailed Metrics**: Reports download speeds, timing differences, and performance improvements -- **Dynamic Progress Display**: Shows real-time progress bars during downloads with percentage and byte counts -- **Clean Output**: User-friendly performance summary with emojis and clear formatting - -## Usage - -```bash -go run ./tools/benchmarks/parallelget [flags] -``` - -or - -```bash -go build ./tools/benchmarks/parallelget -./parallelget [flags] -``` - -### Arguments - -- ``: The HTTP URL to benchmark (required) - -### Flags - -- `--chunk-size int`: Minimum chunk size in bytes for parallelization (default: 1MB) -- `--max-concurrent uint`: Maximum concurrent requests for parallel transport (default: 4) -- `-h, --help`: Show help information - -### Examples - -```bash -# Basic usage -./parallelget https://example.com/large-file.zip - -# Custom chunk size (512KB) and higher concurrency -./parallelget https://example.com/large-file.zip --chunk-size 524288 --max-concurrent 8 - -# Small chunk size for testing with smaller files -./parallelget https://httpbin.org/bytes/10485760 --chunk-size 262144 --max-concurrent 6 -``` - -## Output - -The tool provides detailed output including: - -1. **Configuration**: Shows the chunk size and concurrency settings -2. **Progress**: Real-time updates for each benchmark phase -3. **Individual Results**: Download speed and timing for each approach -4. **Validation**: Confirms that both downloads produced identical content -5. **Performance Summary**: - - Speedup factor (e.g., "3.2x faster") - - Time saved/penalty - - Detailed timing breakdown - -### Sample Output - -``` -Benchmarking HTTP GET performance for: https://example.com/large-file.zip -Configuration: chunk-size=1048576 bytes, max-concurrent=4 - -Running non-parallel benchmark... - Progress: [██████████████████████████████] 100.0% (10485760/10485760 bytes) -✓ Non-parallel: 10485760 bytes in 2.1s (4.76 MB/s) -Running parallel benchmark... - Progress: [██████████████████████████████] 100.0% (10485760/10485760 bytes) -✓ Parallel: 10485760 bytes in 650ms (15.38 MB/s) -Validating response consistency... -✓ Responses match perfectly - -============================================================ -PERFORMANCE COMPARISON -============================================================ -🚀 Parallel was 3.23x faster than non-parallel -⏱️ Time saved: 1.45s (69.0%) - -Detailed timing: - Non-parallel: 2.1s - Parallel: 650ms - Difference: -1.45s -``` - -## How It Works - -1. **Non-Parallel Benchmark**: Uses `net/http.DefaultClient` with `net/http.DefaultTransport` -2. **Parallel Benchmark**: Uses `net/http.DefaultClient` with `transport/parallel.ParallelTransport` wrapping `net/http.DefaultTransport` -3. **Response Storage**: Both responses are written to temporary files for validation -4. **Validation**: Performs byte-by-byte comparison to ensure identical content -5. **Cleanup**: Automatically removes temporary files after completion - -## Notes - -- The tool requires the server to support HTTP range requests (`Accept-Ranges: bytes`) for parallel downloads to work -- If the server doesn't support range requests or the file is too small, the parallel transport will automatically fall back to a single request -- Temporary files are automatically cleaned up, even if the tool exits unexpectedly -- The tool validates that both downloads produce identical results before reporting performance metrics - -## Use Cases - -- **Performance Testing**: Evaluate the effectiveness of parallel downloads for different URLs -- **Configuration Tuning**: Find optimal chunk size and concurrency settings for specific servers or file types -- **Server Compatibility**: Test whether servers properly support range requests -- **Network Optimization**: Understand the impact of parallel downloads on different network conditions diff --git a/tools/benchmarks/parallelget/main.go b/tools/benchmarks/parallelget/main.go deleted file mode 100644 index 8e3468612..000000000 --- a/tools/benchmarks/parallelget/main.go +++ /dev/null @@ -1,348 +0,0 @@ -package main - -import ( - "bytes" - "crypto/sha256" - "fmt" - "io" - "net/http" - "os" - "strings" - "sync" - "time" - - "github.com/spf13/cobra" - - "github.com/docker/model-runner/pkg/distribution/transport/parallel" -) - -var ( - minChunkSize int64 - maxConcurrent uint -) - -var rootCmd = &cobra.Command{ - Use: "parallelget ", - Short: "Benchmark parallel vs non-parallel HTTP GET requests", - Long: `parallelget is a benchmarking tool that compares the performance of standard -HTTP GET requests against parallelized requests using the transport/parallel package. - -It downloads the same URL twice - once using the standard HTTP client and once -using a parallel transport - then compares the results and reports performance metrics.`, - Args: cobra.ExactArgs(1), - RunE: runBenchmark, - SilenceUsage: true, -} - -func init() { - rootCmd.Flags().Int64Var(&minChunkSize, "chunk-size", 1024*1024, "Minimum chunk size in bytes for parallelization (default 1MB)") - rootCmd.Flags().UintVar(&maxConcurrent, "max-concurrent", 4, "Maximum concurrent requests for parallel transport (default 4)") -} - -func main() { - if err := rootCmd.Execute(); err != nil { - fmt.Fprintf(os.Stderr, "Error: %v\n", err) - os.Exit(1) - } -} - -func runBenchmark(cmd *cobra.Command, args []string) error { - url := args[0] - - fmt.Printf("Benchmarking HTTP GET performance for: %s\n", url) - fmt.Printf("Configuration: chunk-size=%d bytes, max-concurrent=%d\n\n", minChunkSize, maxConcurrent) - - // Create temporary files for storing responses. - nonParallelFile, err := os.CreateTemp("", "benchmark-non-parallel-*.tmp") - if err != nil { - return fmt.Errorf("failed to create temp file for non-parallel response: %w", err) - } - defer func() { - nonParallelFile.Close() - os.Remove(nonParallelFile.Name()) - }() - - parallelFile, err := os.CreateTemp("", "benchmark-parallel-*.tmp") - if err != nil { - return fmt.Errorf("failed to create temp file for parallel response: %w", err) - } - defer func() { - parallelFile.Close() - os.Remove(parallelFile.Name()) - }() - - // Run non-parallel benchmark. - fmt.Println("Running non-parallel benchmark...") - nonParallelDuration, nonParallelSize, err := benchmarkNonParallel(url, nonParallelFile) - if err != nil { - return fmt.Errorf("non-parallel benchmark failed: %w", err) - } - fmt.Printf("✓ Non-parallel: %d bytes in %v (%.2f MB/s)\n", nonParallelSize, nonParallelDuration, - float64(nonParallelSize)/nonParallelDuration.Seconds()/(1024*1024)) - - // Run parallel benchmark. - fmt.Println("Running parallel benchmark...") - parallelDuration, parallelSize, err := benchmarkParallel(url, parallelFile) - if err != nil { - return fmt.Errorf("parallel benchmark failed: %w", err) - } - fmt.Printf("✓ Parallel: %d bytes in %v (%.2f MB/s)\n", parallelSize, parallelDuration, - float64(parallelSize)/parallelDuration.Seconds()/(1024*1024)) - - // Validate responses match. - fmt.Println("Validating response consistency...") - if err := validateResponses(nonParallelFile, parallelFile); err != nil { - return fmt.Errorf("response validation failed: %w", err) - } - fmt.Println("✓ Responses match perfectly") - - // Print performance comparison. - fmt.Println("\n" + strings.Repeat("=", 60)) - fmt.Println("PERFORMANCE COMPARISON") - fmt.Println(strings.Repeat("=", 60)) - - speedup := float64(nonParallelDuration) / float64(parallelDuration) - if speedup > 1.0 { - fmt.Printf("🚀 Parallel was %.2fx faster than non-parallel\n", speedup) - timeSaved := nonParallelDuration - parallelDuration - fmt.Printf("⏱️ Time saved: %v (%.1f%%)\n", timeSaved, (1.0-1.0/speedup)*100) - } else if speedup < 1.0 { - slowdown := 1.0 / speedup - fmt.Printf("⚠️ Parallel was %.2fx slower than non-parallel\n", slowdown) - fmt.Printf("⏱️ Time penalty: %v (%.1f%%)\n", parallelDuration-nonParallelDuration, (slowdown-1.0)*100) - } else { - fmt.Println("📊 Both approaches performed equally") - } - - fmt.Printf("\nDetailed timing:\n") - fmt.Printf(" Non-parallel: %v\n", nonParallelDuration) - fmt.Printf(" Parallel: %v\n", parallelDuration) - fmt.Printf(" Difference: %v\n", parallelDuration-nonParallelDuration) - - return nil -} - -// performHTTPGet executes an HTTP GET request using the specified transport -// and measures the time taken to download the entire response body. -// The response is written to outputFile and progress is displayed during the download. -func performHTTPGet(url string, transport http.RoundTripper, outputFile *os.File) (time.Duration, int64, error) { - client := &http.Client{ - Transport: transport, - } - - start := time.Now() - - resp, err := client.Get(url) - if err != nil { - return 0, 0, err - } - defer resp.Body.Close() - - if resp.StatusCode != http.StatusOK { - return 0, 0, fmt.Errorf("HTTP %d: %s", resp.StatusCode, resp.Status) - } - - // Create progress writer with content length if available. - contentLength := resp.ContentLength - if contentLength <= 0 { - contentLength = -1 // Unknown size. - } - progressWriter := newProgressWriter(outputFile, contentLength, " Progress") - - written, err := io.Copy(progressWriter, resp.Body) - progressWriter.finish() // Ensure final progress is shown. - - if err != nil { - return 0, 0, err - } - - duration := time.Since(start) - return duration, written, nil -} - -// benchmarkNonParallel performs a standard HTTP GET request using the default transport -// and measures the time taken to download the entire response body. -// The response is written to outputFile and progress is displayed during the download. -func benchmarkNonParallel(url string, outputFile *os.File) (time.Duration, int64, error) { - return performHTTPGet(url, http.DefaultTransport, outputFile) -} - -// benchmarkParallel performs an HTTP GET request using the parallel transport -// and measures the time taken to download the entire response body. -// The parallel transport uses byte-range requests to download chunks concurrently. -// The response is written to outputFile and progress is displayed during the download. -func benchmarkParallel(url string, outputFile *os.File) (time.Duration, int64, error) { - // Create parallel transport with configuration. - parallelTransport := parallel.New( - http.DefaultTransport, - parallel.WithMaxConcurrentPerHost(map[string]uint{"": 0}), - parallel.WithMinChunkSize(minChunkSize), - parallel.WithMaxConcurrentPerRequest(maxConcurrent), - ) - - return performHTTPGet(url, parallelTransport, outputFile) -} - -func validateResponses(file1, file2 *os.File) error { - // Get file sizes first for quick comparison. - stat1, err := file1.Stat() - if err != nil { - return fmt.Errorf("failed to stat non-parallel file: %w", err) - } - - stat2, err := file2.Stat() - if err != nil { - return fmt.Errorf("failed to stat parallel file: %w", err) - } - - // Compare file sizes - if they differ, no need to compute hashes. - if stat1.Size() != stat2.Size() { - return fmt.Errorf("file sizes differ: non-parallel=%d bytes, parallel=%d bytes", - stat1.Size(), stat2.Size()) - } - - // Compute SHA-256 hash for first file. - hash1, err := computeFileHash(file1) - if err != nil { - return fmt.Errorf("failed to compute hash for non-parallel file: %w", err) - } - - // Compute SHA-256 hash for second file. - hash2, err := computeFileHash(file2) - if err != nil { - return fmt.Errorf("failed to compute hash for parallel file: %w", err) - } - - // Compare the hashes. - if !bytes.Equal(hash1, hash2) { - return fmt.Errorf("file contents differ: SHA-256 hashes do not match") - } - - return nil -} - -// computeFileHash computes the SHA-256 hash of a file's contents. -// The file is read from the beginning using a single io.Copy operation for efficiency. -func computeFileHash(file *os.File) ([]byte, error) { - // Seek to beginning of file. - if _, err := file.Seek(0, io.SeekStart); err != nil { - return nil, fmt.Errorf("failed to seek to beginning: %w", err) - } - - // Create SHA-256 hasher. - hasher := sha256.New() - - // Copy entire file content to hasher in a single operation. - _, err := io.Copy(hasher, file) - if err != nil { - return nil, fmt.Errorf("failed to read file for hashing: %w", err) - } - - // Return the computed hash. - return hasher.Sum(nil), nil -} - -// progressWriter wraps an io.Writer and provides progress updates during writes. -// It displays a progress bar with percentage completion and transfer rates, -// updating the display at regular intervals to avoid excessive output. -type progressWriter struct { - // writer is the underlying writer to write data to. - writer io.Writer - // total is the total expected bytes (-1 if unknown). - total int64 - // written is the number of bytes written so far. - written int64 - // lastUpdate is the last time the progress display was updated. - lastUpdate time.Time - // label is the label to display with the progress bar. - label string - // finished indicates whether the progress display has been finalized. - finished bool - // mu protects concurrent access to progress state. - mu sync.Mutex -} - -// newProgressWriter creates a new progress writer that wraps the given writer. -// The total parameter specifies the expected number of bytes (use -1 if unknown). -// The label parameter is displayed alongside the progress bar. -func newProgressWriter(writer io.Writer, total int64, label string) *progressWriter { - return &progressWriter{ - writer: writer, - total: total, - label: label, - lastUpdate: time.Now(), - } -} - -// Write implements io.Writer, writing data to the underlying writer and updating progress. -// Progress is displayed at most every 100ms to avoid overwhelming the terminal with updates. -// The final progress update is handled by the finish() method to ensure clean display. -func (pw *progressWriter) Write(data []byte) (int, error) { - // Write data to the underlying writer first. - n, err := pw.writer.Write(data) - if n > 0 { - pw.mu.Lock() - pw.written += int64(n) - now := time.Now() - - // Update progress every 100ms to balance responsiveness and performance. - // Don't update on completion - let finish() handle the final display. - if now.Sub(pw.lastUpdate) >= 100*time.Millisecond && (pw.total < 0 || pw.written < pw.total) { - pw.printProgress() - pw.lastUpdate = now - } - pw.mu.Unlock() - } - return n, err -} - -// printProgress displays the current progress to the terminal. -// For files with known size, shows a progress bar with percentage and bytes. -// For files with unknown size, shows only the bytes transferred. -// Uses carriage return (\r) to overwrite the previous progress line. -func (pw *progressWriter) printProgress() { - if pw.finished { - return - } - - if pw.total < 0 { - // Unknown total size - just show bytes transferred. - fmt.Printf("\r%s: %d bytes", pw.label, pw.written) - return - } - - // Calculate percentage, capping at 100% to handle edge cases. - percent := float64(pw.written) / float64(pw.total) * 100 - if percent > 100 { - percent = 100 - } - - // Create a visual progress bar using filled and empty characters. - barWidth := 30 - filled := int(percent / 100 * float64(barWidth)) - if filled > barWidth { - filled = barWidth - } - - bar := strings.Repeat("█", filled) + strings.Repeat("░", barWidth-filled) - - // Display progress bar with percentage and byte counts. - fmt.Printf("\r%s: [%s] %.1f%% (%d/%d bytes)", - pw.label, bar, percent, pw.written, pw.total) -} - -// finish completes the progress display by showing the final progress state -// and adding a newline to move the cursor to the next line. -// This ensures the progress bar doesn't interfere with subsequent output. -// It's safe to call multiple times - subsequent calls are ignored. -func (pw *progressWriter) finish() { - pw.mu.Lock() - defer pw.mu.Unlock() - if !pw.finished { - // Display final progress state. - pw.printProgress() - // Move to next line to prevent interference with subsequent output. - fmt.Println() - pw.finished = true - } -}