Skip to content

Commit f2f59a9

Browse files
committed
[Bugfix] Modify TPS metric calculation. Add default cpu threads for hybrid CPU system.
1 parent d8c26ee commit f2f59a9

File tree

2 files changed

+10
-1
lines changed

2 files changed

+10
-1
lines changed

examples/models/llama/runner/generation.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,8 @@ def generate( # noqa: C901
146146

147147
generate_time = time.time() - generate_start
148148
print(f"Prefill time: {prefill_time}")
149-
print(f"Generation tok/s: {len(tokens) / generate_time}")
149+
num_generated_tokens = len(tokens) - len(prompt_tokens) - 1
150+
print(f"Generation tok/s: {num_generated_tokens / generate_time}")
150151

151152
return tokens if echo else tokens[len(prompt_tokens) :]
152153

examples/models/llama/runner/native.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,21 @@ def build_args_parser() -> argparse.ArgumentParser:
126126
help="Maximum length of the generated response sequence.",
127127
)
128128

129+
parser.add_argument(
130+
"--cpu_threads",
131+
type=int,
132+
default=4,
133+
help="Number of CPU threads to use for inference.",
134+
)
135+
129136
return parser
130137

131138

132139
def main() -> None:
133140
parser = build_args_parser()
134141
args = parser.parse_args()
135142
validate_args(args)
143+
portable_lib._unsafe_reset_threadpool(args.cpu_threads)
136144
runner = NativeLlamaRunner(args)
137145
generated_tokens = runner.text_completion(
138146
prompt=args.prompt,

0 commit comments

Comments
 (0)