Skip to content

Commit 41312fb

Browse files
fix(router): fix panics on partial_cmp and empty req.texts (#138)
1 parent d05c949 commit 41312fb

File tree

3 files changed

+81
-30
lines changed

3 files changed

+81
-30
lines changed

backends/candle/build.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use anyhow::{Context, Result, bail};
1+
use anyhow::{bail, Context, Result};
22

33
fn main() {
44
println!("cargo:rerun-if-env-changed=CUDA_COMPUTE_CAP");

router/src/grpc/server.rs

Lines changed: 45 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -143,18 +143,21 @@ impl TextEmbeddingsService {
143143
response.inference,
144144
);
145145

146-
let mut predictions: Vec<Prediction> = {
146+
let mut predictions = Vec::with_capacity(response.results.len());
147+
for (i, s) in response.results.into_iter().enumerate() {
148+
// Check that s is not NaN or the partial_cmp below will panic
149+
if s.is_nan() {
150+
Err(ErrorResponse {
151+
error: "score is NaN".to_string(),
152+
error_type: ErrorType::Backend,
153+
})?;
154+
}
147155
// Map score to label
148-
response
149-
.results
150-
.into_iter()
151-
.enumerate()
152-
.map(|(i, s)| Prediction {
153-
score: s,
154-
label: id2label.get(&i.to_string()).unwrap().clone(),
155-
})
156-
.collect()
157-
};
156+
predictions.push(Prediction {
157+
score: s,
158+
label: id2label.get(&i.to_string()).unwrap().clone(),
159+
});
160+
}
158161
// Reverse sort
159162
predictions.sort_by(|x, y| x.score.partial_cmp(&y.score).unwrap());
160163
predictions.reverse();
@@ -455,6 +458,17 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService {
455458

456459
let request = request.into_inner();
457460

461+
if request.texts.is_empty() {
462+
let message = "`texts` cannot be empty".to_string();
463+
tracing::error!("{message}");
464+
let err = ErrorResponse {
465+
error: message,
466+
error_type: ErrorType::Validation,
467+
};
468+
metrics::increment_counter!("te_request_failure", "err" => "validation");
469+
Err(err)?;
470+
}
471+
458472
match &self.info.model_type {
459473
ModelType::Classifier(_) => {
460474
metrics::increment_counter!("te_request_failure", "err" => "model_type");
@@ -549,10 +563,19 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService {
549563
None
550564
};
551565

566+
let score = r.4;
567+
// Check that s is not NaN or the partial_cmp below will panic
568+
if score.is_nan() {
569+
Err(ErrorResponse {
570+
error: "score is NaN".to_string(),
571+
error_type: ErrorType::Backend,
572+
})?;
573+
}
574+
552575
ranks.push(Rank {
553576
index: index as u32,
554577
text,
555-
score: r.4,
578+
score,
556579
})
557580
}
558581

@@ -766,10 +789,19 @@ impl grpc::rerank_server::Rerank for TextEmbeddingsService {
766789
None
767790
};
768791

792+
let score = r.5;
793+
// Check that s is not NaN or the partial_cmp below will panic
794+
if score.is_nan() {
795+
Err(ErrorResponse {
796+
error: "score is NaN".to_string(),
797+
error_type: ErrorType::Backend,
798+
})?;
799+
}
800+
769801
ranks.push(Rank {
770802
index: r.0 as u32,
771803
text,
772-
score: r.5,
804+
score,
773805
})
774806
}
775807

router/src/http/server.rs

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -116,18 +116,21 @@ async fn predict(
116116
_ => panic!(),
117117
};
118118

119-
let mut predictions: Vec<Prediction> = {
119+
let mut predictions = Vec::with_capacity(response.results.len());
120+
for (i, s) in response.results.into_iter().enumerate() {
121+
// Check that s is not NaN or the partial_cmp below will panic
122+
if s.is_nan() {
123+
return Err(ErrorResponse {
124+
error: "score is NaN".to_string(),
125+
error_type: ErrorType::Backend,
126+
});
127+
}
120128
// Map score to label
121-
response
122-
.results
123-
.into_iter()
124-
.enumerate()
125-
.map(|(i, s)| Prediction {
126-
score: s,
127-
label: id2label.get(&i.to_string()).unwrap().clone(),
128-
})
129-
.collect()
130-
};
129+
predictions.push(Prediction {
130+
score: s,
131+
label: id2label.get(&i.to_string()).unwrap().clone(),
132+
});
133+
}
131134
// Reverse sort
132135
predictions.sort_by(|x, y| x.score.partial_cmp(&y.score).unwrap());
133136
predictions.reverse();
@@ -282,6 +285,17 @@ async fn rerank(
282285
let span = tracing::Span::current();
283286
let start_time = Instant::now();
284287

288+
if req.texts.is_empty() {
289+
let message = "`texts` cannot be empty".to_string();
290+
tracing::error!("{message}");
291+
let err = ErrorResponse {
292+
error: message,
293+
error_type: ErrorType::Validation,
294+
};
295+
metrics::increment_counter!("te_request_failure", "err" => "validation");
296+
Err(err)?;
297+
}
298+
285299
match &info.model_type {
286300
ModelType::Classifier(_) => {
287301
metrics::increment_counter!("te_request_failure", "err" => "model_type");
@@ -383,11 +397,16 @@ async fn rerank(
383397
None
384398
};
385399

386-
ranks.push(Rank {
387-
index,
388-
text,
389-
score: r.4,
390-
})
400+
let score = r.4;
401+
// Check that s is not NaN or the partial_cmp below will panic
402+
if score.is_nan() {
403+
Err(ErrorResponse {
404+
error: "score is NaN".to_string(),
405+
error_type: ErrorType::Backend,
406+
})?;
407+
}
408+
409+
ranks.push(Rank { index, text, score })
391410
}
392411

393412
// Reverse sort

0 commit comments

Comments
 (0)