Skip to content

Commit 726eec2

Browse files
Decouple LastRunMetrics class from Llama and reuse it for Mistral
1 parent 115f25a commit 726eec2

File tree

3 files changed

+40
-35
lines changed

3 files changed

+40
-35
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package com.example.auxiliary;
2+
3+
/**
4+
* Record to store metrics from the last model run.
5+
* @param totalTokens The total number of tokens processed
6+
* @param totalSeconds The total time in seconds
7+
*/
8+
public record LastRunMetrics(int totalTokens, double totalSeconds) {
9+
/**
10+
* Singleton instance to store the latest metrics
11+
*/
12+
private static LastRunMetrics latestMetrics;
13+
14+
/**
15+
* Sets the metrics for the latest run
16+
*
17+
* @param tokens The total number of tokens processed
18+
* @param seconds The total time in seconds
19+
*/
20+
public static void setMetrics(int tokens, double seconds) {
21+
latestMetrics = new LastRunMetrics(tokens, seconds);
22+
}
23+
24+
/**
25+
* Prints the metrics from the latest run to stderr
26+
*/
27+
public static void printMetrics() {
28+
if (latestMetrics != null) {
29+
double tokensPerSecond = latestMetrics.totalTokens() / latestMetrics.totalSeconds();
30+
System.err.printf("\n\nachieved tok/s: %.2f. Tokens: %d, seconds: %.2f\n", tokensPerSecond, latestMetrics.totalTokens(), latestMetrics.totalSeconds());
31+
}
32+
}
33+
}

src/main/java/com/example/model/llama/Llama.java

Lines changed: 4 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
package com.example.model.llama;
22

3-
import com.example.auxiliary.Parallel;
3+
import com.example.auxiliary.LastRunMetrics;
44
import com.example.auxiliary.format.LlamaChatFormat;
5+
import com.example.auxiliary.Parallel;
56
import com.example.core.model.tensor.FloatTensor;
67
import com.example.model.Configuration;
78
import com.example.inference.sampler.Sampler;
@@ -480,7 +481,7 @@ public void runInteractive(Sampler sampler, Options options) {
480481

481482
// Optionally print performance metrics after each response
482483
if (SHOW_PERF_INTERACTIVE) {
483-
Llama.LastRunMetrics.printMetrics();
484+
LastRunMetrics.printMetrics();
484485
}
485486
}
486487
} finally {
@@ -538,44 +539,12 @@ public void runInstructOnce(Sampler sampler, Options options) {
538539
System.out.println(responseText);
539540
}
540541

541-
Llama.LastRunMetrics.printMetrics();
542+
LastRunMetrics.printMetrics();
542543

543544
if (tornadoVMPlan != null) {
544545
tornadoVMPlan.freeTornadoExecutionPlan();
545546
}
546547
}
547548

548-
/**
549-
* Record to store metrics from the last model run.
550-
* @param totalTokens The total number of tokens processed
551-
* @param totalSeconds The total time in seconds
552-
*/
553-
public record LastRunMetrics(int totalTokens, double totalSeconds) {
554-
/**
555-
* Singleton instance to store the latest metrics
556-
*/
557-
private static LastRunMetrics latestMetrics;
558-
559-
/**
560-
* Sets the metrics for the latest run
561-
*
562-
* @param tokens The total number of tokens processed
563-
* @param seconds The total time in seconds
564-
*/
565-
public static void setMetrics(int tokens, double seconds) {
566-
latestMetrics = new LastRunMetrics(tokens, seconds);
567-
}
568-
569-
/**
570-
* Prints the metrics from the latest run to stderr
571-
*/
572-
public static void printMetrics() {
573-
if (latestMetrics != null) {
574-
double tokensPerSecond = latestMetrics.totalTokens() / latestMetrics.totalSeconds();
575-
System.err.printf("\n\nachieved tok/s: %.2f. Tokens: %d, seconds: %.2f\n", tokensPerSecond, latestMetrics.totalTokens(), latestMetrics.totalSeconds());
576-
}
577-
}
578-
}
579-
580549
}
581550

src/main/java/com/example/model/mistral/Mistral.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package com.example.model.mistral;
22

33
import com.example.auxiliary.Parallel;
4+
import com.example.auxiliary.LastRunMetrics;
45
import com.example.auxiliary.format.MistralChatFormat;
56
import com.example.core.model.tensor.FloatTensor;
67
import com.example.model.Configuration;
@@ -323,5 +324,7 @@ public void runInstructOnce(Sampler sampler, Options options) {
323324
String responseText = tokenizer.decode(responseTokens);
324325
System.out.println(responseText);
325326
}
327+
328+
LastRunMetrics.printMetrics();
326329
}
327330
}

0 commit comments

Comments
 (0)