Skip to content

Commit 6aaab2e

Browse files
committed
feat: optimize inference postprocessing
1 parent 11d412a commit 6aaab2e

File tree

3 files changed

+34
-60
lines changed

3 files changed

+34
-60
lines changed

encoderfile/src/inference/embedding.rs

Lines changed: 8 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -33,39 +33,26 @@ pub fn embedding<'a>(
3333
false => embs.into_owned(),
3434
};
3535

36-
let mut token_ids = encoding.get_ids().iter();
37-
let mut tokens = encoding.get_tokens().iter();
38-
let mut special_tokens_mask = encoding.get_special_tokens_mask().iter();
39-
let mut offsets = encoding.get_offsets().iter();
40-
let mut embeddings_iter = transformed.axis_iter(Axis(0));
41-
4236
let mut results = Vec::new();
4337

44-
while let (Some(token_id), Some(token), Some(special_tokens_mask), Some(offset), Some(e)) = (
45-
token_ids.next(),
46-
tokens.next(),
47-
special_tokens_mask.next(),
48-
offsets.next(),
49-
embeddings_iter.next(),
50-
) {
51-
if *special_tokens_mask == 1 {
38+
for i in 0..encoding.len() {
39+
if encoding.get_special_tokens_mask()[i] == 1 {
5240
continue;
5341
}
5442

55-
let (start, end) = *offset;
56-
let embedding: Vec<f32> = e.iter().map(|i| *i).collect();
57-
43+
let (start, end) = encoding.get_offsets()[i];
5844
let token_info = TokenInfo {
59-
token: token.clone(),
60-
token_id: *token_id,
45+
token: encoding.get_tokens()[i].clone(),
46+
token_id: encoding.get_ids()[i],
6147
start,
6248
end,
6349
};
6450

51+
let e = transformed.index_axis(Axis(0), i);
6552
results.push(TokenEmbedding {
66-
embedding,
53+
embedding: e.to_owned().into_raw_vec_and_offset().0,
6754
token_info: Some(token_info),
68-
})
55+
});
6956
}
7057

7158
embeddings.push(TokenEmbeddingSequence {
@@ -74,6 +61,4 @@ pub fn embedding<'a>(
7461
}
7562

7663
Ok(embeddings)
77-
78-
// Err(ApiError::InternalError("Not Implemented"))
7964
}

encoderfile/src/inference/sequence_classification.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::{common::SequenceClassificationResult, config::ModelConfig, error::ApiError};
22
use ndarray::{Axis, Ix2};
33
use ndarray_stats::QuantileExt;
4+
use ort::tensor::ArrayExtensions;
45
use tokenizers::Encoding;
56

67
pub fn sequence_classification<'a>(
@@ -23,16 +24,16 @@ pub fn sequence_classification<'a>(
2324
.expect("Model does not return tensor of shape [n_batch, n_labels]")
2425
.into_owned();
2526

26-
let probabilities = super::utils::softmax(&outputs, Axis(1));
27+
let probabilities = outputs.softmax(Axis(1));
2728

2829
let results = outputs
2930
.axis_iter(Axis(0))
3031
.zip(probabilities.axis_iter(Axis(0)))
3132
.map(|(logs, probs)| {
3233
let predicted_index = probs.argmax().expect("Model has 0 labels");
3334
SequenceClassificationResult {
34-
logits: logs.iter().map(|i| *i).collect(),
35-
scores: probs.iter().map(|i| *i).collect(),
35+
logits: logs.to_owned().into_raw_vec_and_offset().0,
36+
scores: probs.to_owned().into_raw_vec_and_offset().0,
3637
predicted_index: (predicted_index as u32),
3738
predicted_label: config
3839
.id2label(predicted_index as u32)

encoderfile/src/inference/token_classification.rs

Lines changed: 22 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,10 @@ use crate::{
22
common::{TokenClassification, TokenClassificationResult, TokenInfo},
33
config::ModelConfig,
44
error::ApiError,
5-
inference::utils::softmax,
65
};
76
use ndarray::{Axis, Ix3};
87
use ndarray_stats::QuantileExt;
8+
use ort::tensor::ArrayExtensions;
99
use tokenizers::Encoding;
1010

1111
pub fn token_classification<'a>(
@@ -29,35 +29,16 @@ pub fn token_classification<'a>(
2929

3030
for (encoding, logits) in encodings.iter().zip(outputs.axis_iter(Axis(0))) {
3131
let logits = logits.to_owned();
32-
33-
let scores = softmax(&logits, Axis(1));
34-
35-
let mut token_ids = encoding.get_ids().iter();
36-
let mut tokens = encoding.get_tokens().iter();
37-
let mut special_tokens_mask = encoding.get_special_tokens_mask().iter();
38-
let mut offsets = encoding.get_offsets().iter();
39-
let mut logs_iter = logits.axis_iter(Axis(0));
40-
let mut scores_iter = scores.axis_iter(Axis(0));
32+
let scores = logits.softmax(Axis(1));
4133

4234
let mut results = Vec::new();
4335

44-
while let (
45-
Some(token_id),
46-
Some(token),
47-
Some(special_tokens_mask),
48-
Some(offset),
49-
Some(logs),
50-
Some(scores),
51-
) = (
52-
token_ids.next(),
53-
tokens.next(),
54-
special_tokens_mask.next(),
55-
offsets.next(),
56-
logs_iter.next(),
57-
scores_iter.next(),
58-
) {
59-
let argmax = scores.argmax().expect("Model has 0 labels");
60-
let score = scores[argmax];
36+
for i in 0..encoding.len() {
37+
let argmax = scores
38+
.index_axis(Axis(0), i)
39+
.argmax()
40+
.expect("Model has 0 labels");
41+
let score = scores.index_axis(Axis(0), i)[argmax];
6142
let label = match config.id2label(argmax as u32) {
6243
Some(l) => l.to_string(),
6344
None => {
@@ -67,24 +48,31 @@ pub fn token_classification<'a>(
6748
)
6849
}
6950
};
51+
let (start, end) = encoding.get_offsets()[i];
7052

71-
let (start, end) = *offset;
72-
73-
if *special_tokens_mask == 1 {
53+
if encoding.get_special_tokens_mask()[i] == 1 {
7454
continue;
7555
}
7656

7757
results.push(TokenClassification {
7858
token_info: TokenInfo {
79-
token_id: *token_id,
80-
token: token.clone(),
59+
token_id: encoding.get_ids()[i],
60+
token: encoding.get_tokens()[i].clone(),
8161
start,
8262
end,
8363
},
8464
score: score,
8565
label,
86-
logits: logs.iter().map(|i| *i).collect(),
87-
scores: scores.iter().map(|i| *i).collect(),
66+
logits: logits
67+
.index_axis(Axis(0), i)
68+
.to_owned()
69+
.into_raw_vec_and_offset()
70+
.0,
71+
scores: scores
72+
.index_axis(Axis(0), i)
73+
.to_owned()
74+
.into_raw_vec_and_offset()
75+
.0,
8876
})
8977
}
9078

0 commit comments

Comments
 (0)