File tree Expand file tree Collapse file tree 1 file changed +23
-15
lines changed
Expand file tree Collapse file tree 1 file changed +23
-15
lines changed Original file line number Diff line number Diff line change @@ -145,27 +145,35 @@ pub fn test_infer(
145145 let maybe_file = Path :: new ( prompt) ;
146146 if maybe_file. is_file ( ) {
147147 let file = std:: fs:: read_to_string ( maybe_file) . unwrap ( ) ;
148+ let mut correct = 0 ;
149+ let mut total = 0 ;
148150 for line in file. lines ( ) {
149- let line = serde_json:: from_str :: < String > ( line) . unwrap ( ) ;
150- let prompt = format ! ( "<s>{line}\n " ) ;
151+ let line =
152+ serde_json:: from_str :: < serde_json:: Map < String , serde_json:: Value > > ( line) . unwrap ( ) ;
153+ let serde_json:: Value :: String ( prompt) = & line[ "origin_prompt" ] else {
154+ unreachable ! ( )
155+ } ;
156+ let serde_json:: Value :: String ( gold) = & line[ "gold" ] else {
157+ unreachable ! ( )
158+ } ;
159+ let prompt = format ! ( "<s>{prompt}\n " ) ;
151160
152161 // print_now!("{prompt}");
153162
154- let mut tokens = tokenizer. encode ( & prompt) ;
155- let mut pos = 0 ;
156- for _ in 0 ..max_steps {
157- let next = lm ( & tokens, pos) ;
163+ let tokens = tokenizer. encode ( & prompt) ;
164+ let ans = tokenizer. decode ( lm ( & tokens, 0 ) ) ;
158165
159- pos += tokens. len ( ) ;
160- if next == eos {
161- break ;
162- }
163-
164- let piece = tokenizer. decode ( next) ;
165- print_now ! ( "{piece}" ) ;
166- tokens = vec ! [ next] ;
166+ let result = gold. as_str ( ) == ans;
167+ if result {
168+ correct += 1
167169 }
168- println ! ( )
170+ total += 1 ;
171+
172+ println ! (
173+ "({correct:>4}/{total:<4} {:6.2}%) {} {gold} {ans}" ,
174+ 100. * ( correct as f64 / total as f64 ) ,
175+ if result { "✔" } else { "✘" } ,
176+ )
169177 }
170178 return ;
171179 }
You can’t perform that action at this time.
0 commit comments