Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ pub use crate::requests::TokenizeOptions;
use chrono::Local;
use crossterm::ExecutableCommand;
use log::{debug, error, info, warn, Level, LevelFilter};
use reqwest::Url;
use tokenizers::{FromPretrainedParameters, Tokenizer};
use tokio::sync::broadcast::Sender;
use tokio::sync::Mutex;
Expand All @@ -32,7 +33,8 @@ mod table;
mod writers;

pub struct RunConfiguration {
pub url: String,
pub url: Url,
pub api_key: String,
pub tokenizer_name: String,
pub profile: Option<String>,
pub max_vus: u64,
Expand Down Expand Up @@ -84,8 +86,8 @@ pub async fn run(mut run_config: RunConfiguration, stop_sender: Sender<()>) -> a
};
let tokenizer = Arc::new(tokenizer);
let backend = OpenAITextGenerationBackend::try_new(
"".to_string(),
run_config.url.clone(),
run_config.api_key,
run_config.url,
run_config.model_name.clone(),
tokenizer,
run_config.duration,
Expand Down
20 changes: 9 additions & 11 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,12 @@ struct Args {
warmup: Duration,
/// The URL of the backend to benchmark. Must be compatible with OpenAI Message API
#[clap(default_value = "http://localhost:8000", short, long, env)]
#[arg(value_parser = parse_url)]
url: String,
url: Url,

/// The api key send to the [`url`] as Header "Authorization: Bearer {API_KEY}".
#[clap(default_value = "", short, long, env)]
api_key: String,

/// Disable console UI
#[clap(short, long, env)]
no_console: bool,
Expand Down Expand Up @@ -115,13 +119,6 @@ fn parse_duration(s: &str) -> Result<Duration, Error> {
humantime::parse_duration(s).map_err(|_| Error::new(InvalidValue))
}

fn parse_url(s: &str) -> Result<String, Error> {
match Url::parse(s) {
Ok(_) => Ok(s.to_string()),
Err(_) => Err(Error::new(InvalidValue)),
}
}

fn parse_key_val(s: &str) -> Result<HashMap<String, String>, Error> {
let mut key_val_map = HashMap::new();
let items = s.split(",").collect::<Vec<&str>>();
Expand Down Expand Up @@ -197,7 +194,7 @@ async fn main() {
let stop_sender_clone = stop_sender.clone();
// get HF token
let token_env_key = "HF_TOKEN".to_string();
let cache = hf_hub::Cache::default();
let cache = hf_hub::Cache::from_env();
let hf_token = match std::env::var(token_env_key).ok() {
Some(token) => Some(token),
None => cache.token(),
Expand All @@ -210,7 +207,8 @@ async fn main() {
.run_id
.unwrap_or(uuid::Uuid::new_v4().to_string()[..7].to_string());
let run_config = RunConfiguration {
url: args.url.clone(),
url: args.url,
api_key: args.api_key,
profile: args.profile.clone(),
tokenizer_name: args.tokenizer_name.clone(),
max_vus: args.max_vus,
Expand Down
23 changes: 13 additions & 10 deletions src/requests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use log::{debug, error, info, trace, warn};
use rand_distr::Distribution;
use rayon::iter::split;
use rayon::prelude::*;
use reqwest::Url;
use reqwest_eventsource::{Error, Event, EventSource};
use serde::{Deserialize, Serialize};
use std::cmp::Ordering;
Expand Down Expand Up @@ -58,7 +59,7 @@ impl Clone for Box<dyn TextGenerationBackend + Send + Sync> {
#[derive(Debug, Clone)]
pub struct OpenAITextGenerationBackend {
pub api_key: String,
pub base_url: String,
pub base_url: Url,
pub model_name: String,
pub client: reqwest::Client,
pub tokenizer: Arc<Tokenizer>,
Expand Down Expand Up @@ -101,7 +102,7 @@ pub struct OpenAITextGenerationRequest {
impl OpenAITextGenerationBackend {
pub fn try_new(
api_key: String,
base_url: String,
base_url: Url,
model_name: String,
tokenizer: Arc<Tokenizer>,
timeout: time::Duration,
Expand All @@ -128,7 +129,9 @@ impl TextGenerationBackend for OpenAITextGenerationBackend {
request: Arc<TextGenerationRequest>,
sender: Sender<TextGenerationAggregatedResponse>,
) {
let url = format!("{base_url}/v1/chat/completions", base_url = self.base_url);
let mut url = self.base_url.clone();
url.set_path("/v1/chat/completions");
// let url = format!("{base_url}", base_url = self.base_url);
let mut aggregated_response = TextGenerationAggregatedResponse::new(request.clone());
let messages = vec![OpenAITextGenerationMessage {
role: "user".to_string(),
Expand Down Expand Up @@ -547,7 +550,7 @@ impl ConversationTextRequestGenerator {
filename: String,
hf_token: Option<String>,
) -> anyhow::Result<PathBuf> {
let api = ApiBuilder::new().with_token(hf_token).build()?;
let api = ApiBuilder::from_env().with_token(hf_token).build()?;
let repo = api.dataset(repo_name);
let dataset = repo.get(&filename)?;
Ok(dataset)
Expand Down Expand Up @@ -829,7 +832,7 @@ mod tests {
w.write_all(b"data: [DONE]\n\n")
})
.create_async().await;
let url = s.url();
let url = s.url().parse().unwrap();
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
let backend = OpenAITextGenerationBackend::try_new(
"".to_string(),
Expand Down Expand Up @@ -890,7 +893,7 @@ mod tests {
w.write_all(b"data: [DONE]\n\n")
})
.create_async().await;
let url = s.url();
let url = s.url().parse().unwrap();
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
let backend = OpenAITextGenerationBackend::try_new(
"".to_string(),
Expand Down Expand Up @@ -975,7 +978,7 @@ mod tests {
.with_chunked_body(|w| w.write_all(b"data: {\"error\": \"Internal server error\"}\n\n"))
.create_async()
.await;
let url = s.url();
let url = s.url().parse().unwrap();
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
let backend = OpenAITextGenerationBackend::try_new(
"".to_string(),
Expand Down Expand Up @@ -1021,7 +1024,7 @@ mod tests {
.with_chunked_body(|w| w.write_all(b"this is wrong\n\n"))
.create_async()
.await;
let url = s.url();
let url = s.url().parse().unwrap();
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
let backend = OpenAITextGenerationBackend::try_new(
"".to_string(),
Expand Down Expand Up @@ -1067,7 +1070,7 @@ mod tests {
.with_chunked_body(|w| w.write_all(b"data: {\"foo\": \"bar\"}\n\n"))
.create_async()
.await;
let url = s.url();
let url = s.url().parse().unwrap();
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
let backend = OpenAITextGenerationBackend::try_new(
"".to_string(),
Expand Down Expand Up @@ -1117,7 +1120,7 @@ mod tests {
w.write_all(b"data: [DONE]\n\n")
})
.create_async().await;
let url = s.url();
let url = s.url().parse().unwrap();
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
let backend = OpenAITextGenerationBackend::try_new(
"".to_string(),
Expand Down
2 changes: 1 addition & 1 deletion src/scheduler.rs
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ mod tests {
w.write_all(b"data: [DONE]\n\n")
})
.create_async().await;
let url = s.url();
let url = s.url().parse().unwrap();
let tokenizer = Arc::new(Tokenizer::from_pretrained("gpt2", None).unwrap());
let backend = OpenAITextGenerationBackend::try_new(
"".to_string(),
Expand Down