Skip to content

Commit 6dce5ba

Browse files
Copilotericcurtin
andcommitted
Add docker model bench command for benchmarking model inference performance
Co-authored-by: ericcurtin <[email protected]>
1 parent 206fe5c commit 6dce5ba

File tree

3 files changed

+597
-0
lines changed

3 files changed

+597
-0
lines changed

cmd/cli/commands/bench.go

Lines changed: 390 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,390 @@
1+
package commands
2+
3+
import (
4+
"bytes"
5+
"context"
6+
"encoding/json"
7+
"errors"
8+
"fmt"
9+
"io"
10+
"math"
11+
"net/http"
12+
"sort"
13+
"strings"
14+
"sync"
15+
"time"
16+
17+
"github.com/docker/model-runner/cmd/cli/commands/completion"
18+
"github.com/docker/model-runner/cmd/cli/desktop"
19+
"github.com/docker/model-runner/pkg/inference"
20+
"github.com/fatih/color"
21+
"github.com/spf13/cobra"
22+
)
23+
24+
const defaultBenchPrompt = "Write a short story about a robot learning to paint."
25+
26+
// benchResult holds the result of a single benchmark request
27+
type benchResult struct {
28+
Duration time.Duration
29+
PromptTokens int
30+
CompletionTokens int
31+
TotalTokens int
32+
Error error
33+
}
34+
35+
// benchStats holds aggregated statistics for benchmark results
36+
type benchStats struct {
37+
Concurrency int
38+
TotalRequests int
39+
SuccessfulReqs int
40+
FailedReqs int
41+
TotalDuration time.Duration
42+
MeanDuration time.Duration
43+
MinDuration time.Duration
44+
MaxDuration time.Duration
45+
StdDevDuration time.Duration
46+
TotalTokens int
47+
CompletionTokens int
48+
TokensPerSecond float64
49+
}
50+
51+
func newBenchCmd() *cobra.Command {
52+
var prompt string
53+
var concurrencies []int
54+
var numRequests int
55+
56+
const cmdArgs = "MODEL"
57+
c := &cobra.Command{
58+
Use: "bench " + cmdArgs,
59+
Short: "Benchmark a model's performance with concurrent requests",
60+
Long: `Benchmark a model's performance by measuring tokens per second with varying levels of concurrency.
61+
62+
This command provides a hyperfine-like experience for benchmarking LLM inference performance.
63+
It runs the specified model with 1, 2, 4, and 8 concurrent requests by default and reports
64+
timing statistics including tokens per second.`,
65+
Example: ` # Benchmark with default prompt and concurrency levels
66+
docker model bench llama3.2
67+
68+
# Benchmark with custom prompt
69+
docker model bench llama3.2 --prompt "Explain quantum computing"
70+
71+
# Benchmark with specific concurrency levels
72+
docker model bench llama3.2 --concurrency 1,4,8
73+
74+
# Run more requests per concurrency level
75+
docker model bench llama3.2 --requests 5`,
76+
RunE: func(cmd *cobra.Command, args []string) error {
77+
model := args[0]
78+
79+
if _, err := ensureStandaloneRunnerAvailable(cmd.Context(), asPrinter(cmd), false); err != nil {
80+
return fmt.Errorf("unable to initialize standalone model runner: %w", err)
81+
}
82+
83+
// Check if model exists locally
84+
_, err := desktopClient.Inspect(model, false)
85+
if err != nil {
86+
if !errors.Is(err, desktop.ErrNotFound) {
87+
return handleClientError(err, "Failed to inspect model")
88+
}
89+
cmd.Println("Unable to find model '" + model + "' locally. Pulling from the server.")
90+
if err := pullModel(cmd, desktopClient, model, false); err != nil {
91+
return err
92+
}
93+
}
94+
95+
return runBenchmark(cmd, model, prompt, concurrencies, numRequests)
96+
},
97+
ValidArgsFunction: completion.ModelNames(getDesktopClient, 1),
98+
}
99+
c.Args = requireExactArgs(1, "bench", cmdArgs)
100+
101+
c.Flags().StringVar(&prompt, "prompt", defaultBenchPrompt, "Prompt to use for benchmarking")
102+
c.Flags().IntSliceVarP(&concurrencies, "concurrency", "c", []int{1, 2, 4, 8}, "Concurrency levels to test")
103+
c.Flags().IntVarP(&numRequests, "requests", "n", 3, "Number of requests per concurrency level")
104+
105+
return c
106+
}
107+
108+
func runBenchmark(cmd *cobra.Command, model, prompt string, concurrencies []int, numRequests int) error {
109+
boldCyan := color.New(color.FgCyan, color.Bold)
110+
bold := color.New(color.Bold)
111+
green := color.New(color.FgGreen)
112+
yellow := color.New(color.FgYellow)
113+
114+
boldCyan.Fprintf(cmd.OutOrStdout(), "Benchmark: %s\n", model)
115+
cmd.Printf(" Prompt: %s\n", truncateString(prompt, 50))
116+
cmd.Printf(" Requests per concurrency level: %d\n\n", numRequests)
117+
118+
// Warm-up run
119+
cmd.Print("Warming up...")
120+
_, err := runSingleBenchmark(cmd.Context(), model, prompt)
121+
if err != nil {
122+
cmd.Println(" failed!")
123+
return fmt.Errorf("warm-up failed: %w", err)
124+
}
125+
cmd.Println(" done")
126+
127+
var allStats []benchStats
128+
129+
for _, concurrency := range concurrencies {
130+
bold.Fprintf(cmd.OutOrStdout(), "Running with %d concurrent request(s)...\n", concurrency)
131+
132+
stats, err := runConcurrentBenchmarks(cmd.Context(), model, prompt, concurrency, numRequests)
133+
if err != nil {
134+
return fmt.Errorf("benchmark failed at concurrency %d: %w", concurrency, err)
135+
}
136+
137+
allStats = append(allStats, stats)
138+
139+
// Print progress
140+
if stats.FailedReqs > 0 {
141+
yellow.Fprintf(cmd.OutOrStdout(), " Completed: %d/%d requests (%.1f%% success rate)\n",
142+
stats.SuccessfulReqs, stats.TotalRequests,
143+
float64(stats.SuccessfulReqs)/float64(stats.TotalRequests)*100)
144+
} else {
145+
green.Fprintf(cmd.OutOrStdout(), " Completed: %d/%d requests\n",
146+
stats.SuccessfulReqs, stats.TotalRequests)
147+
}
148+
cmd.Printf(" Mean: %s ± %s\n", formatDuration(stats.MeanDuration), formatDuration(stats.StdDevDuration))
149+
cmd.Printf(" Range: [%s ... %s]\n", formatDuration(stats.MinDuration), formatDuration(stats.MaxDuration))
150+
cmd.Printf(" Tokens/sec: %.2f\n\n", stats.TokensPerSecond)
151+
}
152+
153+
// Print summary table
154+
printBenchmarkSummary(cmd, allStats)
155+
156+
return nil
157+
}
158+
159+
func runConcurrentBenchmarks(ctx context.Context, model, prompt string, concurrency, numRequests int) (benchStats, error) {
160+
results := make([]benchResult, numRequests)
161+
var wg sync.WaitGroup
162+
sem := make(chan struct{}, concurrency)
163+
164+
startTime := time.Now()
165+
166+
for i := 0; i < numRequests; i++ {
167+
wg.Add(1)
168+
go func(idx int) {
169+
defer wg.Done()
170+
sem <- struct{}{}
171+
defer func() { <-sem }()
172+
173+
result, err := runSingleBenchmark(ctx, model, prompt)
174+
if err != nil {
175+
results[idx] = benchResult{Error: err}
176+
return
177+
}
178+
results[idx] = result
179+
}(i)
180+
}
181+
182+
wg.Wait()
183+
totalDuration := time.Since(startTime)
184+
185+
return calculateStats(results, concurrency, totalDuration), nil
186+
}
187+
188+
func runSingleBenchmark(ctx context.Context, model, prompt string) (benchResult, error) {
189+
start := time.Now()
190+
191+
reqBody := desktop.OpenAIChatRequest{
192+
Model: model,
193+
Messages: []desktop.OpenAIChatMessage{
194+
{
195+
Role: "user",
196+
Content: prompt,
197+
},
198+
},
199+
Stream: true,
200+
}
201+
202+
jsonData, err := json.Marshal(reqBody)
203+
if err != nil {
204+
return benchResult{}, fmt.Errorf("error marshaling request: %w", err)
205+
}
206+
207+
completionsPath := inference.InferencePrefix + "/v1/chat/completions"
208+
req, err := http.NewRequestWithContext(ctx, http.MethodPost, modelRunner.URL(completionsPath), bytes.NewReader(jsonData))
209+
if err != nil {
210+
return benchResult{}, fmt.Errorf("error creating request: %w", err)
211+
}
212+
req.Header.Set("Content-Type", "application/json")
213+
req.Header.Set("User-Agent", "docker-model-cli/"+desktop.Version)
214+
215+
resp, err := modelRunner.Client().Do(req)
216+
if err != nil {
217+
return benchResult{}, fmt.Errorf("error making request: %w", err)
218+
}
219+
defer resp.Body.Close()
220+
221+
if resp.StatusCode != http.StatusOK {
222+
body, _ := io.ReadAll(resp.Body)
223+
return benchResult{}, fmt.Errorf("error response: status=%d body=%s", resp.StatusCode, body)
224+
}
225+
226+
// Read and parse the streaming response
227+
var finalUsage struct {
228+
CompletionTokens int `json:"completion_tokens"`
229+
PromptTokens int `json:"prompt_tokens"`
230+
TotalTokens int `json:"total_tokens"`
231+
}
232+
233+
body, err := io.ReadAll(resp.Body)
234+
if err != nil {
235+
return benchResult{}, fmt.Errorf("error reading response: %w", err)
236+
}
237+
238+
// Parse SSE events to get the usage
239+
lines := strings.Split(string(body), "\n")
240+
for _, line := range lines {
241+
if !strings.HasPrefix(line, "data: ") {
242+
continue
243+
}
244+
data := strings.TrimPrefix(line, "data: ")
245+
if data == "[DONE]" {
246+
break
247+
}
248+
249+
var streamResp desktop.OpenAIChatResponse
250+
if err := json.Unmarshal([]byte(data), &streamResp); err != nil {
251+
continue
252+
}
253+
254+
if streamResp.Usage != nil {
255+
finalUsage.CompletionTokens = streamResp.Usage.CompletionTokens
256+
finalUsage.PromptTokens = streamResp.Usage.PromptTokens
257+
finalUsage.TotalTokens = streamResp.Usage.TotalTokens
258+
}
259+
}
260+
261+
duration := time.Since(start)
262+
263+
return benchResult{
264+
Duration: duration,
265+
PromptTokens: finalUsage.PromptTokens,
266+
CompletionTokens: finalUsage.CompletionTokens,
267+
TotalTokens: finalUsage.TotalTokens,
268+
}, nil
269+
}
270+
271+
func calculateStats(results []benchResult, concurrency int, totalDuration time.Duration) benchStats {
272+
stats := benchStats{
273+
Concurrency: concurrency,
274+
TotalRequests: len(results),
275+
MinDuration: time.Duration(math.MaxInt64),
276+
TotalDuration: totalDuration,
277+
}
278+
279+
var durations []time.Duration
280+
281+
for _, r := range results {
282+
if r.Error != nil {
283+
stats.FailedReqs++
284+
continue
285+
}
286+
287+
stats.SuccessfulReqs++
288+
stats.TotalTokens += r.TotalTokens
289+
stats.CompletionTokens += r.CompletionTokens
290+
durations = append(durations, r.Duration)
291+
292+
if r.Duration < stats.MinDuration {
293+
stats.MinDuration = r.Duration
294+
}
295+
if r.Duration > stats.MaxDuration {
296+
stats.MaxDuration = r.Duration
297+
}
298+
}
299+
300+
if len(durations) == 0 {
301+
stats.MinDuration = 0
302+
return stats
303+
}
304+
305+
// Calculate mean
306+
var totalDur time.Duration
307+
for _, d := range durations {
308+
totalDur += d
309+
}
310+
stats.MeanDuration = totalDur / time.Duration(len(durations))
311+
312+
// Calculate standard deviation
313+
var sumSquares float64
314+
meanFloat := float64(stats.MeanDuration)
315+
for _, d := range durations {
316+
diff := float64(d) - meanFloat
317+
sumSquares += diff * diff
318+
}
319+
variance := sumSquares / float64(len(durations))
320+
stats.StdDevDuration = time.Duration(math.Sqrt(variance))
321+
322+
// Calculate tokens per second (based on completion tokens generated during the total wall-clock time)
323+
if stats.TotalDuration > 0 {
324+
stats.TokensPerSecond = float64(stats.CompletionTokens) / stats.TotalDuration.Seconds()
325+
}
326+
327+
return stats
328+
}
329+
330+
func printBenchmarkSummary(cmd *cobra.Command, allStats []benchStats) {
331+
bold := color.New(color.Bold)
332+
green := color.New(color.FgGreen)
333+
334+
bold.Fprintln(cmd.OutOrStdout(), "Summary")
335+
cmd.Println(strings.Repeat("─", 70))
336+
cmd.Printf("%-12s %-15s %-15s %-15s\n", "Concurrency", "Mean Time", "Tokens/sec", "Success Rate")
337+
cmd.Println(strings.Repeat("─", 70))
338+
339+
// Find the best tokens per second
340+
var bestTPS float64
341+
for _, s := range allStats {
342+
if s.TokensPerSecond > bestTPS {
343+
bestTPS = s.TokensPerSecond
344+
}
345+
}
346+
347+
for _, s := range allStats {
348+
successRate := float64(s.SuccessfulReqs) / float64(s.TotalRequests) * 100
349+
meanStr := fmt.Sprintf("%s ± %s", formatDuration(s.MeanDuration), formatDuration(s.StdDevDuration))
350+
tpsStr := fmt.Sprintf("%.2f", s.TokensPerSecond)
351+
successStr := fmt.Sprintf("%.0f%%", successRate)
352+
353+
if s.TokensPerSecond == bestTPS {
354+
cmd.Printf("%-12d %-15s ", s.Concurrency, meanStr)
355+
green.Fprintf(cmd.OutOrStdout(), "%-15s", tpsStr)
356+
cmd.Printf(" %-15s\n", successStr)
357+
} else {
358+
cmd.Printf("%-12d %-15s %-15s %-15s\n", s.Concurrency, meanStr, tpsStr, successStr)
359+
}
360+
}
361+
362+
cmd.Println(strings.Repeat("─", 70))
363+
364+
// Find optimal concurrency
365+
sort.Slice(allStats, func(i, j int) bool {
366+
return allStats[i].TokensPerSecond > allStats[j].TokensPerSecond
367+
})
368+
369+
if len(allStats) > 0 {
370+
best := allStats[0]
371+
green.Fprintf(cmd.OutOrStdout(), "\nOptimal concurrency: %d (%.2f tokens/sec)\n", best.Concurrency, best.TokensPerSecond)
372+
}
373+
}
374+
375+
func formatDuration(d time.Duration) string {
376+
if d < time.Millisecond {
377+
return fmt.Sprintf("%.2fµs", float64(d.Microseconds()))
378+
}
379+
if d < time.Second {
380+
return fmt.Sprintf("%.2fms", float64(d.Nanoseconds())/1e6)
381+
}
382+
return fmt.Sprintf("%.2fs", d.Seconds())
383+
}
384+
385+
func truncateString(s string, maxLen int) string {
386+
if len(s) <= maxLen {
387+
return s
388+
}
389+
return s[:maxLen-3] + "..."
390+
}

0 commit comments

Comments
 (0)