Skip to content

Commit f2c99a8

Browse files
committed
Since we want a URL from the start we can actually use a URL all the
way. Fixing.
1 parent 7739717 commit f2c99a8

File tree

4 files changed

+15
-20
lines changed

4 files changed

+15
-20
lines changed

src/lib.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ pub use crate::requests::TokenizeOptions;
1414
use chrono::Local;
1515
use crossterm::ExecutableCommand;
1616
use log::{debug, error, info, warn, Level, LevelFilter};
17+
use reqwest::Url;
1718
use tokenizers::{FromPretrainedParameters, Tokenizer};
1819
use tokio::sync::broadcast::Sender;
1920
use tokio::sync::Mutex;
@@ -32,7 +33,7 @@ mod table;
3233
mod writers;
3334

3435
pub struct RunConfiguration {
35-
pub url: String,
36+
pub url: Url,
3637
pub tokenizer_name: String,
3738
pub profile: Option<String>,
3839
pub max_vus: u64,
@@ -85,7 +86,7 @@ pub async fn run(mut run_config: RunConfiguration, stop_sender: Sender<()>) -> a
8586
let tokenizer = Arc::new(tokenizer);
8687
let backend = OpenAITextGenerationBackend::try_new(
8788
"".to_string(),
88-
run_config.url.clone(),
89+
run_config.url,
8990
run_config.model_name.clone(),
9091
tokenizer,
9192
run_config.duration,

src/main.rs

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@ struct Args {
4545
warmup: Duration,
4646
/// The URL of the backend to benchmark. Must be compatible with OpenAI Message API
4747
#[clap(default_value = "http://localhost:8000", short, long, env)]
48-
#[arg(value_parser = parse_url)]
49-
url: String,
48+
url: Url,
49+
5050
/// Disable console UI
5151
#[clap(short, long, env)]
5252
no_console: bool,
@@ -115,13 +115,6 @@ fn parse_duration(s: &str) -> Result<Duration, Error> {
115115
humantime::parse_duration(s).map_err(|_| Error::new(InvalidValue))
116116
}
117117

118-
fn parse_url(s: &str) -> Result<String, Error> {
119-
match Url::parse(s) {
120-
Ok(_) => Ok(s.to_string()),
121-
Err(_) => Err(Error::new(InvalidValue)),
122-
}
123-
}
124-
125118
fn parse_key_val(s: &str) -> Result<HashMap<String, String>, Error> {
126119
let mut key_val_map = HashMap::new();
127120
let items = s.split(",").collect::<Vec<&str>>();

src/requests.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use log::{debug, error, info, trace, warn};
66
use rand_distr::Distribution;
77
use rayon::iter::split;
88
use rayon::prelude::*;
9+
use reqwest::Url;
910
use reqwest_eventsource::{Error, Event, EventSource};
1011
use serde::{Deserialize, Serialize};
1112
use std::cmp::Ordering;
@@ -58,7 +59,7 @@ impl Clone for Box<dyn TextGenerationBackend + Send + Sync> {
5859
#[derive(Debug, Clone)]
5960
pub struct OpenAITextGenerationBackend {
6061
pub api_key: String,
61-
pub base_url: String,
62+
pub base_url: Url,
6263
pub model_name: String,
6364
pub client: reqwest::Client,
6465
pub tokenizer: Arc<Tokenizer>,
@@ -101,7 +102,7 @@ pub struct OpenAITextGenerationRequest {
101102
impl OpenAITextGenerationBackend {
102103
pub fn try_new(
103104
api_key: String,
104-
base_url: String,
105+
base_url: Url,
105106
model_name: String,
106107
tokenizer: Arc<Tokenizer>,
107108
timeout: time::Duration,
@@ -829,7 +830,7 @@ mod tests {
829830
w.write_all(b"data: [DONE]\n\n")
830831
})
831832
.create_async().await;
832-
let url = s.url();
833+
let url = s.url().parse().unwrap();
833834
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
834835
let backend = OpenAITextGenerationBackend::try_new(
835836
"".to_string(),
@@ -890,7 +891,7 @@ mod tests {
890891
w.write_all(b"data: [DONE]\n\n")
891892
})
892893
.create_async().await;
893-
let url = s.url();
894+
let url = s.url().parse().unwrap();
894895
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
895896
let backend = OpenAITextGenerationBackend::try_new(
896897
"".to_string(),
@@ -975,7 +976,7 @@ mod tests {
975976
.with_chunked_body(|w| w.write_all(b"data: {\"error\": \"Internal server error\"}\n\n"))
976977
.create_async()
977978
.await;
978-
let url = s.url();
979+
let url = s.url().parse().unwrap();
979980
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
980981
let backend = OpenAITextGenerationBackend::try_new(
981982
"".to_string(),
@@ -1021,7 +1022,7 @@ mod tests {
10211022
.with_chunked_body(|w| w.write_all(b"this is wrong\n\n"))
10221023
.create_async()
10231024
.await;
1024-
let url = s.url();
1025+
let url = s.url().parse().unwrap();
10251026
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
10261027
let backend = OpenAITextGenerationBackend::try_new(
10271028
"".to_string(),
@@ -1067,7 +1068,7 @@ mod tests {
10671068
.with_chunked_body(|w| w.write_all(b"data: {\"foo\": \"bar\"}\n\n"))
10681069
.create_async()
10691070
.await;
1070-
let url = s.url();
1071+
let url = s.url().parse().unwrap();
10711072
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
10721073
let backend = OpenAITextGenerationBackend::try_new(
10731074
"".to_string(),
@@ -1117,7 +1118,7 @@ mod tests {
11171118
w.write_all(b"data: [DONE]\n\n")
11181119
})
11191120
.create_async().await;
1120-
let url = s.url();
1121+
let url = s.url().parse().unwrap();
11211122
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
11221123
let backend = OpenAITextGenerationBackend::try_new(
11231124
"".to_string(),

src/scheduler.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ mod tests {
232232
w.write_all(b"data: [DONE]\n\n")
233233
})
234234
.create_async().await;
235-
let url = s.url();
235+
let url = s.url().parse().unwrap();
236236
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
237237
let backend = OpenAITextGenerationBackend::try_new(
238238
"".to_string(),

0 commit comments

Comments
 (0)