Skip to content

Commit 971c1c1

Browse files
committed
feat: 判题
Signed-off-by: YdrMaster <[email protected]>
1 parent 99d583d commit 971c1c1

File tree

1 file changed

+23
-15
lines changed

1 file changed

+23
-15
lines changed

test-utils/src/lib.rs

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -145,27 +145,35 @@ pub fn test_infer(
145145
let maybe_file = Path::new(prompt);
146146
if maybe_file.is_file() {
147147
let file = std::fs::read_to_string(maybe_file).unwrap();
148+
let mut correct = 0;
149+
let mut total = 0;
148150
for line in file.lines() {
149-
let line = serde_json::from_str::<String>(line).unwrap();
150-
let prompt = format!("<s>{line}\n");
151+
let line =
152+
serde_json::from_str::<serde_json::Map<String, serde_json::Value>>(line).unwrap();
153+
let serde_json::Value::String(prompt) = &line["origin_prompt"] else {
154+
unreachable!()
155+
};
156+
let serde_json::Value::String(gold) = &line["gold"] else {
157+
unreachable!()
158+
};
159+
let prompt = format!("<s>{prompt}\n");
151160

152161
// print_now!("{prompt}");
153162

154-
let mut tokens = tokenizer.encode(&prompt);
155-
let mut pos = 0;
156-
for _ in 0..max_steps {
157-
let next = lm(&tokens, pos);
163+
let tokens = tokenizer.encode(&prompt);
164+
let ans = tokenizer.decode(lm(&tokens, 0));
158165

159-
pos += tokens.len();
160-
if next == eos {
161-
break;
162-
}
163-
164-
let piece = tokenizer.decode(next);
165-
print_now!("{piece}");
166-
tokens = vec![next];
166+
let result = gold.as_str() == ans;
167+
if result {
168+
correct += 1
167169
}
168-
println!()
170+
total += 1;
171+
172+
println!(
173+
"({correct:>4}/{total:<4} {:6.2}%) {} {gold} {ans}",
174+
100. * (correct as f64 / total as f64),
175+
if result { "✔" } else { "✘" },
176+
)
169177
}
170178
return;
171179
}

0 commit comments

Comments
 (0)