Skip to content

Commit 00e9249

Browse files
committed
Added averaging of three runs, Refactor testProviderMetrics for improved error handling and logging; enhance metric calculations and reporting for better clarity and accuracy
1 parent 2d9823c commit 00e9249

File tree

1 file changed

+136
-127
lines changed

1 file changed

+136
-127
lines changed

main.go

Lines changed: 136 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -40,34 +40,19 @@ type TestResult struct {
4040
Error string `json:"error,omitempty"`
4141
}
4242

43-
// testProviderMetrics runs a full benchmark test against a single provider.
44-
// It is designed to be run as a goroutine.
45-
func testProviderMetrics(config ProviderConfig, tke *tiktoken.Tiktoken, wg *sync.WaitGroup, logDir, resultsDir string, results *[]TestResult, resultsMutex *sync.Mutex) {
46-
// Defer wg.Done() if this is part of a concurrent group
47-
if wg != nil {
48-
defer wg.Done()
49-
}
50-
51-
// Create log file for this provider
52-
timestamp := time.Now().Format("20060102-150405")
53-
logFile, err := os.Create(filepath.Join(logDir, fmt.Sprintf("%s-%s.log", config.Name, timestamp)))
54-
if err != nil {
55-
log.Printf("Error creating log file for %s: %v", config.Name, err)
56-
return
57-
}
58-
defer logFile.Close()
59-
60-
// Create a logger for this provider that writes to both stdout and file
61-
providerLogger := log.New(io.MultiWriter(os.Stdout, logFile), "", log.LstdFlags)
62-
63-
providerLogger.Printf("--- Testing: %s (%s) ---", config.Name, config.Model)
43+
// formatDuration formats a duration as decimal seconds
44+
func formatDuration(d time.Duration) string {
45+
return fmt.Sprintf("%.3fs", d.Seconds())
46+
}
6447

65-
// 5. Configure the OpenAI Client
48+
// singleTestRun performs one test run and returns metrics or error
49+
func singleTestRun(config ProviderConfig, tke *tiktoken.Tiktoken, providerLogger *log.Logger, ctx context.Context) (e2e, ttft time.Duration, throughput float64, tokens int, err error) {
50+
// Configure the OpenAI Client
6651
clientConfig := openai.DefaultConfig(config.APIKey)
6752
clientConfig.BaseURL = config.BaseURL
6853
client := openai.NewClientWithConfig(clientConfig)
6954

70-
// 6. Define the request
55+
// Define the request
7156
prompt := "You are a helpful assistant. Please write a short, 150-word story about a curious robot exploring an ancient, overgrown library on a forgotten planet."
7257
messages := []openai.ChatCompletionMessage{
7358
{
@@ -83,65 +68,37 @@ func testProviderMetrics(config ProviderConfig, tke *tiktoken.Tiktoken, wg *sync
8368
Stream: true,
8469
}
8570

86-
// 7. Execute the stream and measure metrics
87-
startTime := time.Now() // ---- START TIMER
71+
// Execute the stream and measure metrics
72+
startTime := time.Now()
8873
var firstTokenTime time.Time
8974
var fullResponseContent strings.Builder
9075

91-
// Add timeout context to prevent indefinite hangs (2 minutes)
92-
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
93-
defer cancel()
94-
95-
stream, err := client.CreateChatCompletionStream(ctx, req)
96-
if err != nil {
97-
providerLogger.Printf("Error creating stream for %s: %v", config.Name, err)
98-
// Save error result
99-
result := TestResult{
100-
Provider: config.Name,
101-
Model: config.Model,
102-
Timestamp: time.Now(),
103-
Success: false,
104-
Error: err.Error(),
105-
}
106-
saveResult(resultsDir, result)
107-
appendResult(results, resultsMutex, result)
108-
return
76+
stream, streamErr := client.CreateChatCompletionStream(ctx, req)
77+
if streamErr != nil {
78+
return 0, 0, 0, 0, fmt.Errorf("error creating stream: %w", streamErr)
10979
}
110-
defer stream.Close() // IMPORTANT: Always close the stream
80+
defer stream.Close()
11181

11282
providerLogger.Printf("[%s] ... Request sent. Waiting for stream ...", config.Name)
11383

11484
for {
115-
response, err := stream.Recv()
85+
response, recvErr := stream.Recv()
11686

11787
// Check for end of stream
118-
if errors.Is(err, io.EOF) {
88+
if errors.Is(recvErr, io.EOF) {
11989
providerLogger.Printf("[%s] ... Stream complete.", config.Name)
12090
break
12191
}
12292

123-
if err != nil {
124-
errMsg := err.Error()
93+
if recvErr != nil {
12594
if ctx.Err() == context.DeadlineExceeded {
126-
errMsg = "Timeout: stream took longer than 5 minutes"
95+
return 0, 0, 0, 0, fmt.Errorf("timeout exceeded")
12796
}
128-
providerLogger.Printf("Stream error for %s: %v", config.Name, errMsg)
129-
// Save error result
130-
result := TestResult{
131-
Provider: config.Name,
132-
Model: config.Model,
133-
Timestamp: time.Now(),
134-
Success: false,
135-
Error: errMsg,
136-
}
137-
saveResult(resultsDir, result)
138-
appendResult(results, resultsMutex, result)
139-
return
97+
return 0, 0, 0, 0, fmt.Errorf("stream error: %w", recvErr)
14098
}
14199

142-
// Check if Choices array is empty (some APIs send empty chunks)
100+
// Check if Choices array is empty
143101
if len(response.Choices) == 0 {
144-
providerLogger.Printf("[%s] ... Received empty chunk (no Choices)", config.Name)
145102
continue
146103
}
147104

@@ -150,7 +107,7 @@ func testProviderMetrics(config ProviderConfig, tke *tiktoken.Tiktoken, wg *sync
150107

151108
// Check if this is the first chunk with actual text
152109
if content != "" && firstTokenTime.IsZero() {
153-
firstTokenTime = time.Now() // ---- TTFT METRIC
110+
firstTokenTime = time.Now()
154111
providerLogger.Printf("[%s] ... First token received!", config.Name)
155112
}
156113

@@ -160,23 +117,10 @@ func testProviderMetrics(config ProviderConfig, tke *tiktoken.Tiktoken, wg *sync
160117
}
161118
}
162119

163-
endTime := time.Now() // ---- E2E METRIC
164-
165-
// --- 8. Calculate and Print Results ---
120+
endTime := time.Now()
166121

167122
if firstTokenTime.IsZero() {
168-
providerLogger.Printf("Error for %s: Did not receive any content from the API.", config.Name)
169-
// Save error result
170-
result := TestResult{
171-
Provider: config.Name,
172-
Model: config.Model,
173-
Timestamp: time.Now(),
174-
Success: false,
175-
Error: "No content received from API",
176-
}
177-
saveResult(resultsDir, result)
178-
appendResult(results, resultsMutex, result)
179-
return
123+
return 0, 0, 0, 0, fmt.Errorf("no content received from API")
180124
}
181125

182126
// Get accurate token count
@@ -185,60 +129,125 @@ func testProviderMetrics(config ProviderConfig, tke *tiktoken.Tiktoken, wg *sync
185129
completionTokens := len(tokenList)
186130

187131
if completionTokens == 0 {
188-
providerLogger.Printf("Error for %s: Received response with 0 tokens.", config.Name)
189-
// Save error result
190-
result := TestResult{
191-
Provider: config.Name,
192-
Model: config.Model,
193-
Timestamp: time.Now(),
194-
Success: false,
195-
Error: "Received 0 tokens",
196-
}
197-
saveResult(resultsDir, result)
198-
appendResult(results, resultsMutex, result)
199-
return
132+
return 0, 0, 0, 0, fmt.Errorf("received 0 tokens")
200133
}
201134

202-
// 1. End-to-End Latency
135+
// Calculate metrics
203136
e2eLatency := endTime.Sub(startTime)
137+
ttftLatency := firstTokenTime.Sub(startTime)
138+
generationTime := e2eLatency - ttftLatency
204139

205-
// 2. Time to First Token (TTFT)
206-
ttft := firstTokenTime.Sub(startTime)
207-
208-
// 3. Throughput (Tokens per Second)
209-
// This is (Total Tokens - 1) / (Time from first token to last token)
210-
generationTime := e2eLatency - ttft
211-
var throughput float64
212-
140+
var throughputVal float64
213141
if generationTime.Seconds() <= 0 {
214-
// Handle edge case where generation is too fast or only 1 token
215-
throughput = 0.0
142+
throughputVal = 0.0
216143
} else {
217-
throughput = (float64(completionTokens) - 1.0) / generationTime.Seconds()
144+
throughputVal = (float64(completionTokens) - 1.0) / generationTime.Seconds()
145+
}
146+
147+
return e2eLatency, ttftLatency, throughputVal, completionTokens, nil
148+
}
149+
150+
// testProviderMetrics runs a full benchmark test against a single provider.
151+
// It runs 3 iterations and reports averaged results, with a 2-minute total timeout.
152+
func testProviderMetrics(config ProviderConfig, tke *tiktoken.Tiktoken, wg *sync.WaitGroup, logDir, resultsDir string, results *[]TestResult, resultsMutex *sync.Mutex) {
153+
// Defer wg.Done() if this is part of a concurrent group
154+
if wg != nil {
155+
defer wg.Done()
218156
}
219157

220-
// --- Print Results (use providerLogger for thread-safety) ---
158+
// Create log file for this provider
159+
timestamp := time.Now().Format("20060102-150405")
160+
logFile, err := os.Create(filepath.Join(logDir, fmt.Sprintf("%s-%s.log", config.Name, timestamp)))
161+
if err != nil {
162+
log.Printf("Error creating log file for %s: %v", config.Name, err)
163+
return
164+
}
165+
defer logFile.Close()
166+
167+
// Create a logger for this provider that writes to both stdout and file
168+
providerLogger := log.New(io.MultiWriter(os.Stdout, logFile), "", log.LstdFlags)
169+
170+
providerLogger.Printf("--- Testing: %s (%s) - Running 3 iterations ---", config.Name, config.Model)
171+
172+
// Create 2-minute timeout context for all runs
173+
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
174+
defer cancel()
175+
176+
// Run up to 3 iterations and collect metrics
177+
const maxIterations = 3
178+
var e2eSum, ttftSum time.Duration
179+
var throughputSum float64
180+
var tokensSum int
181+
successfulRuns := 0
182+
183+
for i := 1; i <= maxIterations; i++ {
184+
// Check if timeout exceeded before starting next run
185+
if ctx.Err() != nil {
186+
providerLogger.Printf("[%s] Timeout reached after %d run(s)", config.Name, successfulRuns)
187+
break
188+
}
189+
190+
providerLogger.Printf("[%s] Run %d/%d", config.Name, i, maxIterations)
191+
192+
e2e, ttft, throughput, tokens, runErr := singleTestRun(config, tke, providerLogger, ctx)
193+
if runErr != nil {
194+
providerLogger.Printf("[%s] Run %d failed: %v", config.Name, i, runErr)
195+
// If no successful runs yet, save error result
196+
if successfulRuns == 0 && i == maxIterations {
197+
result := TestResult{
198+
Provider: config.Name,
199+
Model: config.Model,
200+
Timestamp: time.Now(),
201+
Success: false,
202+
Error: runErr.Error(),
203+
}
204+
saveResult(resultsDir, result)
205+
appendResult(results, resultsMutex, result)
206+
}
207+
break
208+
}
209+
210+
e2eSum += e2e
211+
ttftSum += ttft
212+
throughputSum += throughput
213+
tokensSum += tokens
214+
successfulRuns++
215+
216+
providerLogger.Printf("[%s] Run %d complete: E2E=%s TTFT=%s Throughput=%.2f tok/s",
217+
config.Name, i, formatDuration(e2e), formatDuration(ttft), throughput)
218+
}
219+
220+
if successfulRuns == 0 {
221+
providerLogger.Printf("[%s] All runs failed", config.Name)
222+
return
223+
}
224+
225+
// Calculate averages
226+
avgE2E := e2eSum / time.Duration(successfulRuns)
227+
avgTTFT := ttftSum / time.Duration(successfulRuns)
228+
avgThroughput := throughputSum / float64(successfulRuns)
229+
avgTokens := tokensSum / successfulRuns
230+
231+
// Print averaged results
221232
providerLogger.Println("==============================================")
222-
providerLogger.Printf(" LLM Metrics for: %s", config.Name)
233+
providerLogger.Printf(" LLM Metrics for: %s (averaged over %d run(s))", config.Name, successfulRuns)
223234
providerLogger.Printf(" Model: %s", config.Model)
224-
providerLogger.Printf(" Total Output Tokens: %d", completionTokens)
235+
providerLogger.Printf(" Avg Output Tokens: %d", avgTokens)
225236
providerLogger.Println("----------------------------------------------")
226-
providerLogger.Printf(" End-to-End Latency: %v", e2eLatency)
227-
providerLogger.Printf(" Latency (TTFT): %v", ttft)
228-
providerLogger.Printf(" Throughput (Tokens/sec): %.2f tokens/s", throughput)
237+
providerLogger.Printf(" End-to-End Latency: %s", formatDuration(avgE2E))
238+
providerLogger.Printf(" Latency (TTFT): %s", formatDuration(avgTTFT))
239+
providerLogger.Printf(" Throughput (Tokens/sec): %.2f tokens/s", avgThroughput)
229240
providerLogger.Println("==============================================")
230-
// Uncomment to see the full response
231-
// providerLogger.Printf("[%s] Full Response:\n%s\n", config.Name, fullResponse)
232241

233242
// Save successful result
234243
result := TestResult{
235244
Provider: config.Name,
236245
Model: config.Model,
237246
Timestamp: time.Now(),
238-
E2ELatency: e2eLatency,
239-
TTFT: ttft,
240-
Throughput: throughput,
241-
CompletionTokens: completionTokens,
247+
E2ELatency: avgE2E,
248+
TTFT: avgTTFT,
249+
Throughput: avgThroughput,
250+
CompletionTokens: avgTokens,
242251
Success: true,
243252
}
244253
saveResult(resultsDir, result)
@@ -306,11 +315,11 @@ func generateMarkdownReport(resultsDir string, results []TestResult, sessionTime
306315

307316
for _, r := range results {
308317
if r.Success {
309-
report.WriteString(fmt.Sprintf("| %s | %s | %v | %v | %.2f tok/s | %d |\n",
318+
report.WriteString(fmt.Sprintf("| %s | %s | %s | %s | %.2f tok/s | %d |\n",
310319
r.Provider,
311320
r.Model,
312-
r.E2ELatency,
313-
r.TTFT,
321+
formatDuration(r.E2ELatency),
322+
formatDuration(r.TTFT),
314323
r.Throughput,
315324
r.CompletionTokens))
316325
}
@@ -361,12 +370,12 @@ func generateMarkdownReport(resultsDir string, results []TestResult, sessionTime
361370
report.WriteString("|------|----------|------------|------|-------------|\n")
362371

363372
for i, r := range successfulResults {
364-
report.WriteString(fmt.Sprintf("| %d | %s | %.2f tok/s | %v | %v |\n",
373+
report.WriteString(fmt.Sprintf("| %d | %s | %.2f tok/s | %s | %s |\n",
365374
i+1,
366375
r.Provider,
367376
r.Throughput,
368-
r.TTFT,
369-
r.E2ELatency))
377+
formatDuration(r.TTFT),
378+
formatDuration(r.E2ELatency)))
370379
}
371380
report.WriteString("\n")
372381

@@ -385,12 +394,12 @@ func generateMarkdownReport(resultsDir string, results []TestResult, sessionTime
385394
report.WriteString("|------|----------|------|------------|-------------|\n")
386395

387396
for i, r := range successfulResults {
388-
report.WriteString(fmt.Sprintf("| %d | %s | %v | %.2f tok/s | %v |\n",
397+
report.WriteString(fmt.Sprintf("| %d | %s | %s | %.2f tok/s | %s |\n",
389398
i+1,
390399
r.Provider,
391-
r.TTFT,
400+
formatDuration(r.TTFT),
392401
r.Throughput,
393-
r.E2ELatency))
402+
formatDuration(r.E2ELatency)))
394403
}
395404
report.WriteString("\n")
396405

@@ -409,11 +418,11 @@ func generateMarkdownReport(resultsDir string, results []TestResult, sessionTime
409418
report.WriteString("|------|----------|-------------|------|------------|\n")
410419

411420
for i, r := range successfulResults {
412-
report.WriteString(fmt.Sprintf("| %d | %s | %v | %v | %.2f tok/s |\n",
421+
report.WriteString(fmt.Sprintf("| %d | %s | %s | %s | %.2f tok/s |\n",
413422
i+1,
414423
r.Provider,
415-
r.E2ELatency,
416-
r.TTFT,
424+
formatDuration(r.E2ELatency),
425+
formatDuration(r.TTFT),
417426
r.Throughput))
418427
}
419428
report.WriteString("\n")

0 commit comments

Comments
 (0)