diff --git a/Cargo.toml b/Cargo.toml index 92d5498..3e9f2a7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,3 +36,4 @@ tokio = { git = "https://github.com/second-state/wasi_tokio.git", branch = "v1.3 [features] default = [] +search = [] diff --git a/README.md b/README.md index fb8b00e..10c2617 100644 --- a/README.md +++ b/README.md @@ -444,16 +444,16 @@ git clone https://github.com/LlamaEdge/rag-api-server.git cd rag-api-server # (Optional) Add the `wasm32-wasi` target to the Rust toolchain -rustup target add wasm32-wasi +rustup target add wasm32-wasip1 -# Build `rag-api-server.wasm` with the `http` support only, or -cargo build --target wasm32-wasi --release +# Build `rag-api-server.wasm` without internet search +cargo build --target wasm32-wasip1 --release -# Build `rag-api-server.wasm` with both `http` and `https` support -cargo build --target wasm32-wasi --release --features full +# Build `rag-api-server.wasm` with internet search capability +cargo build --target wasm32-wasip1 --release --features search # Copy the `rag-api-server.wasm` to the root directory -cp target/wasm32-wasi/release/rag-api-server.wasm . +cp target/wasm32-wasip1/release/rag-api-server.wasm . ```
To check the CLI options, @@ -524,6 +524,19 @@ To check the CLI options of the `rag-api-server` wasm app, you can run the follo Print version ``` +Compiling the server with the `search` feature enabled (using either the `--features search` flag when building or editing `Cargo.toml`), the following extra CLI arguments will be made available: + + ```bash + --api-key + API key to be supplied to the endpoint, if supported + [default: ] + --query-server-url + The URL for the LlamaEdge query server. Supplying this implies usage + --search-backend + The search API backend to use for internet search + [default: tavily] + ``` +
## Execute @@ -547,19 +560,40 @@ For the purpose of demonstration, we use the [Llama-2-7b-chat-hf-Q5_K_M.gguf](ht docker run -p 6333:6333 -p 6334:6334 -v $(pwd)/qdrant_storage:/qdrant/storage:z qdrant/qdrant ``` +### Start without Internet Serach + - Start an instance of LlamaEdge-RAG API server - ```bash - wasmedge --dir .:. --nn-preload default:GGML:AUTO:Llama-2-7b-chat-hf-Q5_K_M.gguf \ - --nn-preload embedding:GGML:AUTO:all-MiniLM-L6-v2-ggml-model-f16.gguf \ - rag-api-server.wasm \ - --model-name Llama-2-7b-chat-hf-Q5_K_M,all-MiniLM-L6-v2-ggml-model-f16 \ - --ctx-size 4096,384 \ - --prompt-template llama-2-chat,embedding \ - --rag-prompt "Use the following pieces of context to answer the user's question.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n----------------\n" \ - --log-prompts \ - --log-stat - ``` + ```bash + wasmedge --dir .:. --nn-preload default:GGML:AUTO:Llama-2-7b-chat-hf-Q5_K_M.gguf \ + --nn-preload embedding:GGML:AUTO:all-MiniLM-L6-v2-ggml-model-f16.gguf \ + rag-api-server.wasm \ + --model-name Llama-2-7b-chat-hf-Q5_K_M,all-MiniLM-L6-v2-ggml-model-f16 \ + --ctx-size 4096,384 \ + --prompt-template llama-2-chat,embedding \ + --rag-prompt "Use the following pieces of context to answer the user's question.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n----------------\n" \ + --log-prompts \ + --log-stat + ``` + +### Start with Internet Search + +- Start an instance of LlamaEdge-RAG API server with URL of your chosen [LlamaEdge Query Server](https://github.com/LlamaEdge/llamaedge-query-server/) instance. The query server can be ran locally. + + ```bash + wasmedge --dir .:. --nn-preload default:GGML:AUTO:Llama-2-7b-chat-hf-Q5_K_M.gguf \ + --nn-preload embedding:GGML:AUTO:all-MiniLM-L6-v2-ggml-model-f16.gguf \ + rag-api-server.wasm \ + --model-name Llama-2-7b-chat-hf-Q5_K_M,all-MiniLM-L6-v2-ggml-model-f16 \ + --ctx-size 4096,384 \ + --prompt-template llama-2-chat,embedding \ + --rag-prompt "Use the following pieces of context to answer the user's question.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n----------------\n" \ + --api-key "xxx" \ # Use if your chosen LlamaEdge query server endpoint requires one. + --query-server-url "http://0.0.0.0:8081/" \ # URL of the LlamaEdge query server of your choosing. This is the default local endpoint. + --log-prompts \ + --log-stat + ``` + ## Usage Example @@ -580,6 +614,8 @@ For the purpose of demonstration, we use the [Llama-2-7b-chat-hf-Q5_K_M.gguf](ht -d '{"messages":[{"role":"system", "content": "You are a helpful assistant."}, {"role":"user", "content": "What is the location of Paris, France along the Seine River?"}], "model":"Llama-2-7b-chat-hf-Q5_K_M"}' ``` +Internet search will only be used if the question cannot be answered using RAG. If it is needed, the user message will be queried to the `/query/summarize` endpoint on the [LlamaEdge Query Server](https://github.com/LlamaEdge/llamaedge-query-server/) instance, where the server will respond with the summary of the internet search results if it decides it is necessary. + ## Set Log Level You can set the log level of the API server by setting the `LLAMA_LOG` environment variable. For example, to set the log level to `debug`, you can run the following command: diff --git a/src/backend/ggml.rs b/src/backend/ggml.rs index 4f2285d..d2a69e2 100644 --- a/src/backend/ggml.rs +++ b/src/backend/ggml.rs @@ -261,6 +261,9 @@ pub(crate) async fn rag_query_handler(mut req: Request) -> Response info!(target: "stdout", "Compute embeddings for user query."); + #[cfg(feature = "search")] + let query: String; + // * compute embeddings for user query let embedding_response = match chat_request.messages.is_empty() { true => { @@ -287,6 +290,10 @@ pub(crate) async fn rag_query_handler(mut req: Request) -> Response } }; + #[cfg(feature = "search")] + { + query = query_text.clone(); + } // log info!(target: "stdout", "query text: {}", query_text); @@ -372,6 +379,9 @@ pub(crate) async fn rag_query_handler(mut req: Request) -> Response } }; + #[cfg(feature = "search")] + let mut web_search_allowed: bool = false; + if let Some(ro) = res { match ro.points { Some(scored_points) => { @@ -379,6 +389,12 @@ pub(crate) async fn rag_query_handler(mut req: Request) -> Response true => { // log warn!(target: "stdout", "{}", format!("No point retrieved (score < threshold {})", server_info.qdrant_config.score_threshold)); + + #[cfg(feature = "search")] + { + info!(target: "stdout", "No points retrieved, enabling web search."); + web_search_allowed = true; + } } false => { // update messages with retrieved context @@ -435,10 +451,132 @@ pub(crate) async fn rag_query_handler(mut req: Request) -> Response // log warn!(target: "stdout", "{}", format!("No point retrieved (score < threshold {})", server_info.qdrant_config.score_threshold )); + + #[cfg(feature = "search")] + { + info!(target: "stdout", "No points retrieved, enabling web search."); + web_search_allowed = true; + } } } } + #[cfg(feature = "search")] + if web_search_allowed { + let search_arguments = match crate::SEARCH_ARGUMENTS.get() { + Some(sc) => sc, + None => { + let err_msg = "Failed to obtain SEARCH_ARGUMENTS. Was it set?".to_string(); + error!(target: "stdout", "{}", &err_msg); + + return error::internal_server_error(err_msg); + } + }; + + let endpoint: hyper::Uri = match search_arguments.query_server_url.parse() { + Ok(uri) => uri, + Err(e) => { + let err_msg = format!("LlamaEdge Query server URL could not be parsed: {}", e); + error!(target: "stdout", "{}", &err_msg); + + return error::internal_server_error(err_msg); + } + }; + + let summary_endpoint = match hyper::Uri::builder() + .scheme(endpoint.scheme().unwrap().to_string().as_str()) + .authority(endpoint.authority().unwrap().to_string().as_str()) + .path_and_query("/query/summarize") + .build() + { + Ok(se) => se, + Err(_) => { + let err_msg = "Couldn't build summary_endpoint from query_server_url".to_string(); + error!(target: "stdout", "{}", &err_msg); + + return error::internal_server_error(err_msg); + } + }; + + //perform query, extract summary, add to + let req = match Request::builder() + .method(Method::POST) + .uri(summary_endpoint) + .header("content-type", "application/json") + .body(Body::from( + serde_json::json!({ + "search_config" : { + "api_key": search_arguments.api_key, + }, + "backend": search_arguments.search_backend, + "query": query, + }) + .to_string(), + )) { + Ok(request) => request, + Err(_) => { + let err_msg = "Failed to build request to LLamaEdge query server.".to_string(); + error!(target: "stdout", "{}", &err_msg); + return error::internal_server_error(err_msg); + } + }; + + info!(target: "stdout", "Querying the LlamaEdge query server."); + + let client = hyper::client::Client::new(); + match client.request(req).await { + Ok(res) => { + let is_success = res.status().is_success(); + + let body_bytes = match hyper::body::to_bytes(res.into_body()).await { + Ok(bytes) => bytes, + Err(e) => { + let err_msg = format!("Couldn't convert body into bytes: {}", e); + error!(target: "stdout", "{}", &err_msg); + + return error::internal_server_error(err_msg); + } + }; + + let body_json: serde_json::Value = match serde_json::from_slice(&body_bytes) { + Ok(json) => json, + Err(e) => { + let err_msg = format!("Couldn't convert body into json: {}", e); + error!(target: "stdout", "{}", &err_msg); + + return error::internal_server_error(err_msg); + } + }; + + info!(target: "stdout", "processed query server response json body: \n{}", body_json); + + // if the request is a success, check decision and inject results accordingly. + if is_success && body_json["decision"].as_bool().unwrap_or(true) { + // the logic to ensure "results" is a serde_json::Value::String is present on the + // llamaedge-query-server. + let results = body_json["results"].as_str().unwrap_or(""); + + info!(target: "stdout", "injecting search summary into conversation context."); + //inject search results + let system_search_result_message: ChatCompletionRequestMessage = + ChatCompletionRequestMessage::new_system_message(results, None); + + chat_request.messages.insert( + chat_request.messages.len() - 1, + system_search_result_message, + ) + } + } + Err(e) => { + let err_msg = format!( + "Couldn't make request to LlamaEdge query server, switching to regular RAG: {}", + e + ); + warn!(target: "stdout", "{}", &err_msg); + } + }; + } + // chat completion let res = match llama_core::chat::chat(&mut chat_request).await { Ok(result) => match result { diff --git a/src/main.rs b/src/main.rs index 4f8ab6e..606bd5d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,6 +3,7 @@ extern crate log; mod backend; mod error; + mod utils; use anyhow::Result; @@ -21,6 +22,8 @@ use once_cell::sync::OnceCell; use serde::{Deserialize, Serialize}; use std::{collections::HashMap, net::SocketAddr, path::PathBuf}; use tokio::net::TcpListener; +#[cfg(feature = "search")] +use utils::SearchArguments; use utils::{is_valid_url, LogLevel}; type Error = Box; @@ -29,6 +32,9 @@ type Error = Box; pub(crate) static GLOBAL_RAG_PROMPT: OnceCell = OnceCell::new(); // server info pub(crate) static SERVER_INFO: OnceCell = OnceCell::new(); +// search cli arguments +#[cfg(feature = "search")] +pub(crate) static SEARCH_ARGUMENTS: OnceCell = OnceCell::new(); // default socket address const DEFAULT_SOCKET_ADDRESS: &str = "0.0.0.0:8080"; @@ -127,6 +133,18 @@ struct Cli { /// Deprecated. Print all log information to stdout #[arg(long)] log_all: bool, + /// API key to be supplied to the endpoint, if supported. + #[cfg(feature = "search")] + #[arg(long, default_value = "")] + api_key: String, + /// The URL for the LlamaEdge query server. Supplying this implies usage. + #[cfg(feature = "search")] + #[arg(long, required = true)] + query_server_url: String, + /// The search API backend to use for internet search. + #[cfg(feature = "search")] + #[arg(long, default_value = "tavily", requires = "query-server-url")] + search_backend: String, } #[tokio::main(flavor = "current_thread")] @@ -452,6 +470,19 @@ async fn main() -> Result<(), ServerError> { } }); + #[cfg(feature = "search")] + { + let search_arguments = SearchArguments { + api_key: cli.api_key, + query_server_url: cli.query_server_url, + search_backend: cli.search_backend, + }; + + SEARCH_ARGUMENTS + .set(search_arguments) + .map_err(|_| ServerError::Operation("Failed to set `SEARCH_ARGUMENTS`.".to_string()))?; + } + // let server = Server::bind(&addr).serve(new_service); let tcp_listener = TcpListener::bind(addr).await.unwrap(); diff --git a/src/utils.rs b/src/utils.rs index 837da3f..d973680 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -9,6 +9,18 @@ pub(crate) fn gen_chat_id() -> String { format!("chatcmpl-{}", uuid::Uuid::new_v4()) } +/// Search related items that aren't directly supported by SearchConfig +#[cfg(feature = "search")] +#[derive(Debug)] +pub(crate) struct SearchArguments { + /// API key to be supplied to the endpoint, if supported. Not used by Bing. + pub(crate) api_key: String, + /// The URL for the LlamaEdge query server. Supplying this implies usage. + pub(crate) query_server_url: String, + /// The search API backend to use for internet search. + pub(crate) search_backend: String, +} + #[derive( Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum, Serialize, Deserialize, )]