@@ -2,10 +2,10 @@ use crate::{
22 common:: { TokenClassification , TokenClassificationResult , TokenInfo } ,
33 config:: ModelConfig ,
44 error:: ApiError ,
5- inference:: utils:: softmax,
65} ;
76use ndarray:: { Axis , Ix3 } ;
87use ndarray_stats:: QuantileExt ;
8+ use ort:: tensor:: ArrayExtensions ;
99use tokenizers:: Encoding ;
1010
1111pub fn token_classification < ' a > (
@@ -29,35 +29,16 @@ pub fn token_classification<'a>(
2929
3030 for ( encoding, logits) in encodings. iter ( ) . zip ( outputs. axis_iter ( Axis ( 0 ) ) ) {
3131 let logits = logits. to_owned ( ) ;
32-
33- let scores = softmax ( & logits, Axis ( 1 ) ) ;
34-
35- let mut token_ids = encoding. get_ids ( ) . iter ( ) ;
36- let mut tokens = encoding. get_tokens ( ) . iter ( ) ;
37- let mut special_tokens_mask = encoding. get_special_tokens_mask ( ) . iter ( ) ;
38- let mut offsets = encoding. get_offsets ( ) . iter ( ) ;
39- let mut logs_iter = logits. axis_iter ( Axis ( 0 ) ) ;
40- let mut scores_iter = scores. axis_iter ( Axis ( 0 ) ) ;
32+ let scores = logits. softmax ( Axis ( 1 ) ) ;
4133
4234 let mut results = Vec :: new ( ) ;
4335
44- while let (
45- Some ( token_id) ,
46- Some ( token) ,
47- Some ( special_tokens_mask) ,
48- Some ( offset) ,
49- Some ( logs) ,
50- Some ( scores) ,
51- ) = (
52- token_ids. next ( ) ,
53- tokens. next ( ) ,
54- special_tokens_mask. next ( ) ,
55- offsets. next ( ) ,
56- logs_iter. next ( ) ,
57- scores_iter. next ( ) ,
58- ) {
59- let argmax = scores. argmax ( ) . expect ( "Model has 0 labels" ) ;
60- let score = scores[ argmax] ;
36+ for i in 0 ..encoding. len ( ) {
37+ let argmax = scores
38+ . index_axis ( Axis ( 0 ) , i)
39+ . argmax ( )
40+ . expect ( "Model has 0 labels" ) ;
41+ let score = scores. index_axis ( Axis ( 0 ) , i) [ argmax] ;
6142 let label = match config. id2label ( argmax as u32 ) {
6243 Some ( l) => l. to_string ( ) ,
6344 None => {
@@ -67,24 +48,31 @@ pub fn token_classification<'a>(
6748 )
6849 }
6950 } ;
51+ let ( start, end) = encoding. get_offsets ( ) [ i] ;
7052
71- let ( start, end) = * offset;
72-
73- if * special_tokens_mask == 1 {
53+ if encoding. get_special_tokens_mask ( ) [ i] == 1 {
7454 continue ;
7555 }
7656
7757 results. push ( TokenClassification {
7858 token_info : TokenInfo {
79- token_id : * token_id ,
80- token : token . clone ( ) ,
59+ token_id : encoding . get_ids ( ) [ i ] ,
60+ token : encoding . get_tokens ( ) [ i ] . clone ( ) ,
8161 start,
8262 end,
8363 } ,
8464 score : score,
8565 label,
86- logits : logs. iter ( ) . map ( |i| * i) . collect ( ) ,
87- scores : scores. iter ( ) . map ( |i| * i) . collect ( ) ,
66+ logits : logits
67+ . index_axis ( Axis ( 0 ) , i)
68+ . to_owned ( )
69+ . into_raw_vec_and_offset ( )
70+ . 0 ,
71+ scores : scores
72+ . index_axis ( Axis ( 0 ) , i)
73+ . to_owned ( )
74+ . into_raw_vec_and_offset ( )
75+ . 0 ,
8876 } )
8977 }
9078
0 commit comments