diff --git a/Cargo.toml b/Cargo.toml
index 92d5498..3e9f2a7 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -36,3 +36,4 @@ tokio = { git = "https://github.com/second-state/wasi_tokio.git", branch = "v1.3
[features]
default = []
+search = []
diff --git a/README.md b/README.md
index fb8b00e..10c2617 100644
--- a/README.md
+++ b/README.md
@@ -444,16 +444,16 @@ git clone https://github.com/LlamaEdge/rag-api-server.git
cd rag-api-server
# (Optional) Add the `wasm32-wasi` target to the Rust toolchain
-rustup target add wasm32-wasi
+rustup target add wasm32-wasip1
-# Build `rag-api-server.wasm` with the `http` support only, or
-cargo build --target wasm32-wasi --release
+# Build `rag-api-server.wasm` without internet search
+cargo build --target wasm32-wasip1 --release
-# Build `rag-api-server.wasm` with both `http` and `https` support
-cargo build --target wasm32-wasi --release --features full
+# Build `rag-api-server.wasm` with internet search capability
+cargo build --target wasm32-wasip1 --release --features search
# Copy the `rag-api-server.wasm` to the root directory
-cp target/wasm32-wasi/release/rag-api-server.wasm .
+cp target/wasm32-wasip1/release/rag-api-server.wasm .
```
To check the CLI options,
@@ -524,6 +524,19 @@ To check the CLI options of the `rag-api-server` wasm app, you can run the follo
Print version
```
+Compiling the server with the `search` feature enabled (using either the `--features search` flag when building or editing `Cargo.toml`), the following extra CLI arguments will be made available:
+
+ ```bash
+ --api-key
+ API key to be supplied to the endpoint, if supported
+ [default: ]
+ --query-server-url
+ The URL for the LlamaEdge query server. Supplying this implies usage
+ --search-backend
+ The search API backend to use for internet search
+ [default: tavily]
+ ```
+
## Execute
@@ -547,19 +560,40 @@ For the purpose of demonstration, we use the [Llama-2-7b-chat-hf-Q5_K_M.gguf](ht
docker run -p 6333:6333 -p 6334:6334 -v $(pwd)/qdrant_storage:/qdrant/storage:z qdrant/qdrant
```
+### Start without Internet Serach
+
- Start an instance of LlamaEdge-RAG API server
- ```bash
- wasmedge --dir .:. --nn-preload default:GGML:AUTO:Llama-2-7b-chat-hf-Q5_K_M.gguf \
- --nn-preload embedding:GGML:AUTO:all-MiniLM-L6-v2-ggml-model-f16.gguf \
- rag-api-server.wasm \
- --model-name Llama-2-7b-chat-hf-Q5_K_M,all-MiniLM-L6-v2-ggml-model-f16 \
- --ctx-size 4096,384 \
- --prompt-template llama-2-chat,embedding \
- --rag-prompt "Use the following pieces of context to answer the user's question.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n----------------\n" \
- --log-prompts \
- --log-stat
- ```
+ ```bash
+ wasmedge --dir .:. --nn-preload default:GGML:AUTO:Llama-2-7b-chat-hf-Q5_K_M.gguf \
+ --nn-preload embedding:GGML:AUTO:all-MiniLM-L6-v2-ggml-model-f16.gguf \
+ rag-api-server.wasm \
+ --model-name Llama-2-7b-chat-hf-Q5_K_M,all-MiniLM-L6-v2-ggml-model-f16 \
+ --ctx-size 4096,384 \
+ --prompt-template llama-2-chat,embedding \
+ --rag-prompt "Use the following pieces of context to answer the user's question.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n----------------\n" \
+ --log-prompts \
+ --log-stat
+ ```
+
+### Start with Internet Search
+
+- Start an instance of LlamaEdge-RAG API server with URL of your chosen [LlamaEdge Query Server](https://github.com/LlamaEdge/llamaedge-query-server/) instance. The query server can be ran locally.
+
+ ```bash
+ wasmedge --dir .:. --nn-preload default:GGML:AUTO:Llama-2-7b-chat-hf-Q5_K_M.gguf \
+ --nn-preload embedding:GGML:AUTO:all-MiniLM-L6-v2-ggml-model-f16.gguf \
+ rag-api-server.wasm \
+ --model-name Llama-2-7b-chat-hf-Q5_K_M,all-MiniLM-L6-v2-ggml-model-f16 \
+ --ctx-size 4096,384 \
+ --prompt-template llama-2-chat,embedding \
+ --rag-prompt "Use the following pieces of context to answer the user's question.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n----------------\n" \
+ --api-key "xxx" \ # Use if your chosen LlamaEdge query server endpoint requires one.
+ --query-server-url "http://0.0.0.0:8081/" \ # URL of the LlamaEdge query server of your choosing. This is the default local endpoint.
+ --log-prompts \
+ --log-stat
+ ```
+
## Usage Example
@@ -580,6 +614,8 @@ For the purpose of demonstration, we use the [Llama-2-7b-chat-hf-Q5_K_M.gguf](ht
-d '{"messages":[{"role":"system", "content": "You are a helpful assistant."}, {"role":"user", "content": "What is the location of Paris, France along the Seine River?"}], "model":"Llama-2-7b-chat-hf-Q5_K_M"}'
```
+Internet search will only be used if the question cannot be answered using RAG. If it is needed, the user message will be queried to the `/query/summarize` endpoint on the [LlamaEdge Query Server](https://github.com/LlamaEdge/llamaedge-query-server/) instance, where the server will respond with the summary of the internet search results if it decides it is necessary.
+
## Set Log Level
You can set the log level of the API server by setting the `LLAMA_LOG` environment variable. For example, to set the log level to `debug`, you can run the following command:
diff --git a/src/backend/ggml.rs b/src/backend/ggml.rs
index 4f2285d..d2a69e2 100644
--- a/src/backend/ggml.rs
+++ b/src/backend/ggml.rs
@@ -261,6 +261,9 @@ pub(crate) async fn rag_query_handler(mut req: Request
) -> Response
info!(target: "stdout", "Compute embeddings for user query.");
+ #[cfg(feature = "search")]
+ let query: String;
+
// * compute embeddings for user query
let embedding_response = match chat_request.messages.is_empty() {
true => {
@@ -287,6 +290,10 @@ pub(crate) async fn rag_query_handler(mut req: Request) -> Response
}
};
+ #[cfg(feature = "search")]
+ {
+ query = query_text.clone();
+ }
// log
info!(target: "stdout", "query text: {}", query_text);
@@ -372,6 +379,9 @@ pub(crate) async fn rag_query_handler(mut req: Request) -> Response
}
};
+ #[cfg(feature = "search")]
+ let mut web_search_allowed: bool = false;
+
if let Some(ro) = res {
match ro.points {
Some(scored_points) => {
@@ -379,6 +389,12 @@ pub(crate) async fn rag_query_handler(mut req: Request) -> Response
true => {
// log
warn!(target: "stdout", "{}", format!("No point retrieved (score < threshold {})", server_info.qdrant_config.score_threshold));
+
+ #[cfg(feature = "search")]
+ {
+ info!(target: "stdout", "No points retrieved, enabling web search.");
+ web_search_allowed = true;
+ }
}
false => {
// update messages with retrieved context
@@ -435,10 +451,132 @@ pub(crate) async fn rag_query_handler(mut req: Request) -> Response
// log
warn!(target: "stdout", "{}", format!("No point retrieved (score < threshold {})", server_info.qdrant_config.score_threshold
));
+
+ #[cfg(feature = "search")]
+ {
+ info!(target: "stdout", "No points retrieved, enabling web search.");
+ web_search_allowed = true;
+ }
}
}
}
+ #[cfg(feature = "search")]
+ if web_search_allowed {
+ let search_arguments = match crate::SEARCH_ARGUMENTS.get() {
+ Some(sc) => sc,
+ None => {
+ let err_msg = "Failed to obtain SEARCH_ARGUMENTS. Was it set?".to_string();
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ };
+
+ let endpoint: hyper::Uri = match search_arguments.query_server_url.parse() {
+ Ok(uri) => uri,
+ Err(e) => {
+ let err_msg = format!("LlamaEdge Query server URL could not be parsed: {}", e);
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ };
+
+ let summary_endpoint = match hyper::Uri::builder()
+ .scheme(endpoint.scheme().unwrap().to_string().as_str())
+ .authority(endpoint.authority().unwrap().to_string().as_str())
+ .path_and_query("/query/summarize")
+ .build()
+ {
+ Ok(se) => se,
+ Err(_) => {
+ let err_msg = "Couldn't build summary_endpoint from query_server_url".to_string();
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ };
+
+ //perform query, extract summary, add to
+ let req = match Request::builder()
+ .method(Method::POST)
+ .uri(summary_endpoint)
+ .header("content-type", "application/json")
+ .body(Body::from(
+ serde_json::json!({
+ "search_config" : {
+ "api_key": search_arguments.api_key,
+ },
+ "backend": search_arguments.search_backend,
+ "query": query,
+ })
+ .to_string(),
+ )) {
+ Ok(request) => request,
+ Err(_) => {
+ let err_msg = "Failed to build request to LLamaEdge query server.".to_string();
+ error!(target: "stdout", "{}", &err_msg);
+ return error::internal_server_error(err_msg);
+ }
+ };
+
+ info!(target: "stdout", "Querying the LlamaEdge query server.");
+
+ let client = hyper::client::Client::new();
+ match client.request(req).await {
+ Ok(res) => {
+ let is_success = res.status().is_success();
+
+ let body_bytes = match hyper::body::to_bytes(res.into_body()).await {
+ Ok(bytes) => bytes,
+ Err(e) => {
+ let err_msg = format!("Couldn't convert body into bytes: {}", e);
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ };
+
+ let body_json: serde_json::Value = match serde_json::from_slice(&body_bytes) {
+ Ok(json) => json,
+ Err(e) => {
+ let err_msg = format!("Couldn't convert body into json: {}", e);
+ error!(target: "stdout", "{}", &err_msg);
+
+ return error::internal_server_error(err_msg);
+ }
+ };
+
+ info!(target: "stdout", "processed query server response json body: \n{}", body_json);
+
+ // if the request is a success, check decision and inject results accordingly.
+ if is_success && body_json["decision"].as_bool().unwrap_or(true) {
+ // the logic to ensure "results" is a serde_json::Value::String is present on the
+ // llamaedge-query-server.
+ let results = body_json["results"].as_str().unwrap_or("");
+
+ info!(target: "stdout", "injecting search summary into conversation context.");
+ //inject search results
+ let system_search_result_message: ChatCompletionRequestMessage =
+ ChatCompletionRequestMessage::new_system_message(results, None);
+
+ chat_request.messages.insert(
+ chat_request.messages.len() - 1,
+ system_search_result_message,
+ )
+ }
+ }
+ Err(e) => {
+ let err_msg = format!(
+ "Couldn't make request to LlamaEdge query server, switching to regular RAG: {}",
+ e
+ );
+ warn!(target: "stdout", "{}", &err_msg);
+ }
+ };
+ }
+
// chat completion
let res = match llama_core::chat::chat(&mut chat_request).await {
Ok(result) => match result {
diff --git a/src/main.rs b/src/main.rs
index 4f8ab6e..606bd5d 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -3,6 +3,7 @@ extern crate log;
mod backend;
mod error;
+
mod utils;
use anyhow::Result;
@@ -21,6 +22,8 @@ use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, net::SocketAddr, path::PathBuf};
use tokio::net::TcpListener;
+#[cfg(feature = "search")]
+use utils::SearchArguments;
use utils::{is_valid_url, LogLevel};
type Error = Box;
@@ -29,6 +32,9 @@ type Error = Box;
pub(crate) static GLOBAL_RAG_PROMPT: OnceCell = OnceCell::new();
// server info
pub(crate) static SERVER_INFO: OnceCell = OnceCell::new();
+// search cli arguments
+#[cfg(feature = "search")]
+pub(crate) static SEARCH_ARGUMENTS: OnceCell = OnceCell::new();
// default socket address
const DEFAULT_SOCKET_ADDRESS: &str = "0.0.0.0:8080";
@@ -127,6 +133,18 @@ struct Cli {
/// Deprecated. Print all log information to stdout
#[arg(long)]
log_all: bool,
+ /// API key to be supplied to the endpoint, if supported.
+ #[cfg(feature = "search")]
+ #[arg(long, default_value = "")]
+ api_key: String,
+ /// The URL for the LlamaEdge query server. Supplying this implies usage.
+ #[cfg(feature = "search")]
+ #[arg(long, required = true)]
+ query_server_url: String,
+ /// The search API backend to use for internet search.
+ #[cfg(feature = "search")]
+ #[arg(long, default_value = "tavily", requires = "query-server-url")]
+ search_backend: String,
}
#[tokio::main(flavor = "current_thread")]
@@ -452,6 +470,19 @@ async fn main() -> Result<(), ServerError> {
}
});
+ #[cfg(feature = "search")]
+ {
+ let search_arguments = SearchArguments {
+ api_key: cli.api_key,
+ query_server_url: cli.query_server_url,
+ search_backend: cli.search_backend,
+ };
+
+ SEARCH_ARGUMENTS
+ .set(search_arguments)
+ .map_err(|_| ServerError::Operation("Failed to set `SEARCH_ARGUMENTS`.".to_string()))?;
+ }
+
// let server = Server::bind(&addr).serve(new_service);
let tcp_listener = TcpListener::bind(addr).await.unwrap();
diff --git a/src/utils.rs b/src/utils.rs
index 837da3f..d973680 100644
--- a/src/utils.rs
+++ b/src/utils.rs
@@ -9,6 +9,18 @@ pub(crate) fn gen_chat_id() -> String {
format!("chatcmpl-{}", uuid::Uuid::new_v4())
}
+/// Search related items that aren't directly supported by SearchConfig
+#[cfg(feature = "search")]
+#[derive(Debug)]
+pub(crate) struct SearchArguments {
+ /// API key to be supplied to the endpoint, if supported. Not used by Bing.
+ pub(crate) api_key: String,
+ /// The URL for the LlamaEdge query server. Supplying this implies usage.
+ pub(crate) query_server_url: String,
+ /// The search API backend to use for internet search.
+ pub(crate) search_backend: String,
+}
+
#[derive(
Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum, Serialize, Deserialize,
)]