Skip to content

Commit a8d9c26

Browse files
committed
llamarunner: Record the time for all batches during prompt processing
Currently, we only record the time for the last batch when processing the prompt. This results in unrealistically high numbers for the old llama runner. Before: total duration: 31.273112939s load duration: 4.97054657s prompt eval count: 32768 token(s) prompt eval duration: 235.137439ms prompt eval rate: 139356.80 tokens/s eval count: 1873 token(s) eval duration: 18.173182374s eval rate: 103.06 tokens/s After: total duration: 30.024798033s load duration: 4.758588663s prompt eval count: 32768 token(s) prompt eval duration: 7.779621548s prompt eval rate: 4212.03 tokens/s eval count: 1769 token(s) eval duration: 17.148014223s eval rate: 103.16 tokens/s
1 parent 0334e67 commit a8d9c26

File tree

1 file changed

+12
-2
lines changed

1 file changed

+12
-2
lines changed

runner/llamarunner/runner.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,7 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
384384
defer s.mu.Unlock()
385385

386386
var batch *llama.Batch
387+
var numOutputs int
387388

388389
seqIdx := s.nextSeq - 1
389390
for range s.seqs {
@@ -446,7 +447,12 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
446447
break
447448
}
448449

449-
batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), i+1 == len(seq.inputs), seq.cache.Id)
450+
output := i+1 == len(seq.inputs)
451+
batch.Add(input.token, input.embed, len(seq.cache.Inputs)+len(seq.pendingInputs), output, seq.cache.Id)
452+
if output {
453+
numOutputs++
454+
}
455+
450456
seq.pendingInputs = append(seq.pendingInputs, input)
451457
seq.iBatch = batch.NumTokens() - 1
452458
}
@@ -463,6 +469,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
463469
return fmt.Errorf("failed to decode batch: %w", err)
464470
}
465471

472+
if numOutputs > 0 {
473+
s.lc.Synchronize()
474+
}
475+
466476
for i, seq := range s.seqs {
467477
if seq == nil {
468478
continue
@@ -476,10 +486,10 @@ func (s *Server) processBatch(tokenBatch *llama.Batch, embedBatch *llama.Batch)
476486

477487
// don't sample prompt processing
478488
if len(seq.inputs) != 0 {
489+
seq.processingDuration += time.Since(t)
479490
continue
480491
}
481492

482-
s.lc.Synchronize()
483493
seq.numDecoded++
484494
if seq.numDecoded > 1 {
485495
seq.generationDuration += time.Since(t)

0 commit comments

Comments
 (0)