@@ -6,6 +6,7 @@ use log::{debug, error, info, trace, warn};
66use rand_distr:: Distribution ;
77use rayon:: iter:: split;
88use rayon:: prelude:: * ;
9+ use reqwest:: Url ;
910use reqwest_eventsource:: { Error , Event , EventSource } ;
1011use serde:: { Deserialize , Serialize } ;
1112use std:: cmp:: Ordering ;
@@ -58,7 +59,7 @@ impl Clone for Box<dyn TextGenerationBackend + Send + Sync> {
5859#[ derive( Debug , Clone ) ]
5960pub struct OpenAITextGenerationBackend {
6061 pub api_key : String ,
61- pub base_url : String ,
62+ pub base_url : Url ,
6263 pub model_name : String ,
6364 pub client : reqwest:: Client ,
6465 pub tokenizer : Arc < Tokenizer > ,
@@ -101,7 +102,7 @@ pub struct OpenAITextGenerationRequest {
101102impl OpenAITextGenerationBackend {
102103 pub fn try_new (
103104 api_key : String ,
104- base_url : String ,
105+ base_url : Url ,
105106 model_name : String ,
106107 tokenizer : Arc < Tokenizer > ,
107108 timeout : time:: Duration ,
@@ -829,7 +830,7 @@ mod tests {
829830 w. write_all ( b"data: [DONE]\n \n " )
830831 } )
831832 . create_async ( ) . await ;
832- let url = s. url ( ) ;
833+ let url = s. url ( ) . parse ( ) . unwrap ( ) ;
833834 let tokenizer = Arc :: new ( Tokenizer :: from_pretrained ( "gpt2" , None ) . unwrap ( ) ) ;
834835 let backend = OpenAITextGenerationBackend :: try_new (
835836 "" . to_string ( ) ,
@@ -890,7 +891,7 @@ mod tests {
890891 w. write_all ( b"data: [DONE]\n \n " )
891892 } )
892893 . create_async ( ) . await ;
893- let url = s. url ( ) ;
894+ let url = s. url ( ) . parse ( ) . unwrap ( ) ;
894895 let tokenizer = Arc :: new ( Tokenizer :: from_pretrained ( "gpt2" , None ) . unwrap ( ) ) ;
895896 let backend = OpenAITextGenerationBackend :: try_new (
896897 "" . to_string ( ) ,
@@ -975,7 +976,7 @@ mod tests {
975976 . with_chunked_body ( |w| w. write_all ( b"data: {\" error\" : \" Internal server error\" }\n \n " ) )
976977 . create_async ( )
977978 . await ;
978- let url = s. url ( ) ;
979+ let url = s. url ( ) . parse ( ) . unwrap ( ) ;
979980 let tokenizer = Arc :: new ( Tokenizer :: from_pretrained ( "gpt2" , None ) . unwrap ( ) ) ;
980981 let backend = OpenAITextGenerationBackend :: try_new (
981982 "" . to_string ( ) ,
@@ -1021,7 +1022,7 @@ mod tests {
10211022 . with_chunked_body ( |w| w. write_all ( b"this is wrong\n \n " ) )
10221023 . create_async ( )
10231024 . await ;
1024- let url = s. url ( ) ;
1025+ let url = s. url ( ) . parse ( ) . unwrap ( ) ;
10251026 let tokenizer = Arc :: new ( Tokenizer :: from_pretrained ( "gpt2" , None ) . unwrap ( ) ) ;
10261027 let backend = OpenAITextGenerationBackend :: try_new (
10271028 "" . to_string ( ) ,
@@ -1067,7 +1068,7 @@ mod tests {
10671068 . with_chunked_body ( |w| w. write_all ( b"data: {\" foo\" : \" bar\" }\n \n " ) )
10681069 . create_async ( )
10691070 . await ;
1070- let url = s. url ( ) ;
1071+ let url = s. url ( ) . parse ( ) . unwrap ( ) ;
10711072 let tokenizer = Arc :: new ( Tokenizer :: from_pretrained ( "gpt2" , None ) . unwrap ( ) ) ;
10721073 let backend = OpenAITextGenerationBackend :: try_new (
10731074 "" . to_string ( ) ,
@@ -1117,7 +1118,7 @@ mod tests {
11171118 w. write_all ( b"data: [DONE]\n \n " )
11181119 } )
11191120 . create_async ( ) . await ;
1120- let url = s. url ( ) ;
1121+ let url = s. url ( ) . parse ( ) . unwrap ( ) ;
11211122 let tokenizer = Arc :: new ( Tokenizer :: from_pretrained ( "gpt2" , None ) . unwrap ( ) ) ;
11221123 let backend = OpenAITextGenerationBackend :: try_new (
11231124 "" . to_string ( ) ,
0 commit comments