Skip to content

Commit 56ed7e8

Browse files
committed
Refactor testProviderMetrics to run iterations concurrently; enhance logging for each run and improve error handling for failed runs
1 parent 00e9249 commit 56ed7e8

File tree

1 file changed

+68
-36
lines changed

1 file changed

+68
-36
lines changed

main.go

Lines changed: 68 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -167,58 +167,90 @@ func testProviderMetrics(config ProviderConfig, tke *tiktoken.Tiktoken, wg *sync
167167
// Create a logger for this provider that writes to both stdout and file
168168
providerLogger := log.New(io.MultiWriter(os.Stdout, logFile), "", log.LstdFlags)
169169

170-
providerLogger.Printf("--- Testing: %s (%s) - Running 3 iterations ---", config.Name, config.Model)
170+
providerLogger.Printf("--- Testing: %s (%s) - Running 3 concurrent iterations ---", config.Name, config.Model)
171171

172172
// Create 2-minute timeout context for all runs
173173
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
174174
defer cancel()
175175

176-
// Run up to 3 iterations and collect metrics
176+
// Run 3 iterations concurrently
177177
const maxIterations = 3
178-
var e2eSum, ttftSum time.Duration
179-
var throughputSum float64
180-
var tokensSum int
181-
successfulRuns := 0
178+
type runResult struct {
179+
e2e time.Duration
180+
ttft time.Duration
181+
throughput float64
182+
tokens int
183+
err error
184+
runNum int
185+
}
186+
187+
resultsChan := make(chan runResult, maxIterations)
188+
var runWg sync.WaitGroup
182189

190+
// Launch concurrent workers
183191
for i := 1; i <= maxIterations; i++ {
184-
// Check if timeout exceeded before starting next run
185-
if ctx.Err() != nil {
186-
providerLogger.Printf("[%s] Timeout reached after %d run(s)", config.Name, successfulRuns)
187-
break
188-
}
192+
runWg.Add(1)
193+
go func(runNum int) {
194+
defer runWg.Done()
195+
providerLogger.Printf("[%s] Run %d/%d starting", config.Name, runNum, maxIterations)
189196

190-
providerLogger.Printf("[%s] Run %d/%d", config.Name, i, maxIterations)
191-
192-
e2e, ttft, throughput, tokens, runErr := singleTestRun(config, tke, providerLogger, ctx)
193-
if runErr != nil {
194-
providerLogger.Printf("[%s] Run %d failed: %v", config.Name, i, runErr)
195-
// If no successful runs yet, save error result
196-
if successfulRuns == 0 && i == maxIterations {
197-
result := TestResult{
198-
Provider: config.Name,
199-
Model: config.Model,
200-
Timestamp: time.Now(),
201-
Success: false,
202-
Error: runErr.Error(),
203-
}
204-
saveResult(resultsDir, result)
205-
appendResult(results, resultsMutex, result)
197+
e2e, ttft, throughput, tokens, runErr := singleTestRun(config, tke, providerLogger, ctx)
198+
199+
if runErr != nil {
200+
providerLogger.Printf("[%s] Run %d failed: %v", config.Name, runNum, runErr)
201+
} else {
202+
providerLogger.Printf("[%s] Run %d complete: E2E=%s TTFT=%s Throughput=%.2f tok/s",
203+
config.Name, runNum, formatDuration(e2e), formatDuration(ttft), throughput)
206204
}
207-
break
208-
}
209205

210-
e2eSum += e2e
211-
ttftSum += ttft
212-
throughputSum += throughput
213-
tokensSum += tokens
214-
successfulRuns++
206+
resultsChan <- runResult{
207+
e2e: e2e,
208+
ttft: ttft,
209+
throughput: throughput,
210+
tokens: tokens,
211+
err: runErr,
212+
runNum: runNum,
213+
}
214+
}(i)
215+
}
216+
217+
// Close channel after all workers complete
218+
go func() {
219+
runWg.Wait()
220+
close(resultsChan)
221+
}()
215222

216-
providerLogger.Printf("[%s] Run %d complete: E2E=%s TTFT=%s Throughput=%.2f tok/s",
217-
config.Name, i, formatDuration(e2e), formatDuration(ttft), throughput)
223+
// Collect results from all workers
224+
var e2eSum, ttftSum time.Duration
225+
var throughputSum float64
226+
var tokensSum int
227+
successfulRuns := 0
228+
var firstError error
229+
230+
for result := range resultsChan {
231+
if result.err == nil {
232+
e2eSum += result.e2e
233+
ttftSum += result.ttft
234+
throughputSum += result.throughput
235+
tokensSum += result.tokens
236+
successfulRuns++
237+
} else if firstError == nil {
238+
firstError = result.err
239+
}
218240
}
219241

220242
if successfulRuns == 0 {
221243
providerLogger.Printf("[%s] All runs failed", config.Name)
244+
// Save error result
245+
result := TestResult{
246+
Provider: config.Name,
247+
Model: config.Model,
248+
Timestamp: time.Now(),
249+
Success: false,
250+
Error: firstError.Error(),
251+
}
252+
saveResult(resultsDir, result)
253+
appendResult(results, resultsMutex, result)
222254
return
223255
}
224256

0 commit comments

Comments
 (0)