Skip to content
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,4 @@ tokio = { git = "https://github.com/second-state/wasi_tokio.git", branch = "v1.3

[features]
default = []
search = []
70 changes: 53 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -444,16 +444,16 @@ git clone https://github.com/LlamaEdge/rag-api-server.git
cd rag-api-server

# (Optional) Add the `wasm32-wasi` target to the Rust toolchain
rustup target add wasm32-wasi
rustup target add wasm32-wasip1

# Build `rag-api-server.wasm` with the `http` support only, or
cargo build --target wasm32-wasi --release
# Build `rag-api-server.wasm` without internet search
cargo build --target wasm32-wasip1 --release

# Build `rag-api-server.wasm` with both `http` and `https` support
cargo build --target wasm32-wasi --release --features full
# Build `rag-api-server.wasm` with internet search capability
cargo build --target wasm32-wasip1 --release --features search

# Copy the `rag-api-server.wasm` to the root directory
cp target/wasm32-wasi/release/rag-api-server.wasm .
cp target/wasm32-wasip1/release/rag-api-server.wasm .
```

<details> <summary> To check the CLI options, </summary>
Expand Down Expand Up @@ -524,6 +524,19 @@ To check the CLI options of the `rag-api-server` wasm app, you can run the follo
Print version
```

Compiling the server with the `search` feature enabled (using either the `--features search` flag when building or editing `Cargo.toml`), the following extra CLI arguments will be made available:

```bash
--api-key <API_KEY>
API key to be supplied to the endpoint, if supported
[default: ]
--query-server-url <QUERY_SERVER_URL>
The URL for the LlamaEdge query server. Supplying this implies usage
--search-backend <SEARCH_BACKEND>
The search API backend to use for internet search
[default: tavily]
```

</details>

## Execute
Expand All @@ -547,19 +560,40 @@ For the purpose of demonstration, we use the [Llama-2-7b-chat-hf-Q5_K_M.gguf](ht
docker run -p 6333:6333 -p 6334:6334 -v $(pwd)/qdrant_storage:/qdrant/storage:z qdrant/qdrant
```

### Start without Internet Serach

- Start an instance of LlamaEdge-RAG API server

```bash
wasmedge --dir .:. --nn-preload default:GGML:AUTO:Llama-2-7b-chat-hf-Q5_K_M.gguf \
--nn-preload embedding:GGML:AUTO:all-MiniLM-L6-v2-ggml-model-f16.gguf \
rag-api-server.wasm \
--model-name Llama-2-7b-chat-hf-Q5_K_M,all-MiniLM-L6-v2-ggml-model-f16 \
--ctx-size 4096,384 \
--prompt-template llama-2-chat,embedding \
--rag-prompt "Use the following pieces of context to answer the user's question.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n----------------\n" \
--log-prompts \
--log-stat
```
```bash
wasmedge --dir .:. --nn-preload default:GGML:AUTO:Llama-2-7b-chat-hf-Q5_K_M.gguf \
--nn-preload embedding:GGML:AUTO:all-MiniLM-L6-v2-ggml-model-f16.gguf \
rag-api-server.wasm \
--model-name Llama-2-7b-chat-hf-Q5_K_M,all-MiniLM-L6-v2-ggml-model-f16 \
--ctx-size 4096,384 \
--prompt-template llama-2-chat,embedding \
--rag-prompt "Use the following pieces of context to answer the user's question.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n----------------\n" \
--log-prompts \
--log-stat
```

### Start with Internet Search

- Start an instance of LlamaEdge-RAG API server with URL of your chosen [LlamaEdge Query Server](https://github.com/LlamaEdge/llamaedge-query-server/) instance. The query server can be ran locally.

```bash
wasmedge --dir .:. --nn-preload default:GGML:AUTO:Llama-2-7b-chat-hf-Q5_K_M.gguf \
--nn-preload embedding:GGML:AUTO:all-MiniLM-L6-v2-ggml-model-f16.gguf \
rag-api-server.wasm \
--model-name Llama-2-7b-chat-hf-Q5_K_M,all-MiniLM-L6-v2-ggml-model-f16 \
--ctx-size 4096,384 \
--prompt-template llama-2-chat,embedding \
--rag-prompt "Use the following pieces of context to answer the user's question.\nIf you don't know the answer, just say that you don't know, don't try to make up an answer.\n----------------\n" \
--api-key "xxx" \ # Use if your chosen LlamaEdge query server endpoint requires one.
--query-server-url "http://0.0.0.0:8081/" \ # URL of the LlamaEdge query server of your choosing. This is the default local endpoint.
--log-prompts \
--log-stat
```


## Usage Example

Expand All @@ -580,6 +614,8 @@ For the purpose of demonstration, we use the [Llama-2-7b-chat-hf-Q5_K_M.gguf](ht
-d '{"messages":[{"role":"system", "content": "You are a helpful assistant."}, {"role":"user", "content": "What is the location of Paris, France along the Seine River?"}], "model":"Llama-2-7b-chat-hf-Q5_K_M"}'
```

Internet search will only be used if the question cannot be answered using RAG. If it is needed, the user message will be queried to the `/query/summarize` endpoint on the [LlamaEdge Query Server](https://github.com/LlamaEdge/llamaedge-query-server/) instance, where the server will respond with the summary of the internet search results if it decides it is necessary.

## Set Log Level

You can set the log level of the API server by setting the `LLAMA_LOG` environment variable. For example, to set the log level to `debug`, you can run the following command:
Expand Down
138 changes: 138 additions & 0 deletions src/backend/ggml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ pub(crate) async fn rag_query_handler(mut req: Request<Body>) -> Response<Body>

info!(target: "stdout", "Compute embeddings for user query.");

#[cfg(feature = "search")]
let query: String;

// * compute embeddings for user query
let embedding_response = match chat_request.messages.is_empty() {
true => {
Expand All @@ -287,6 +290,10 @@ pub(crate) async fn rag_query_handler(mut req: Request<Body>) -> Response<Body>
}
};

#[cfg(feature = "search")]
{
query = query_text.clone();
}
// log
info!(target: "stdout", "query text: {}", query_text);

Expand Down Expand Up @@ -372,13 +379,22 @@ pub(crate) async fn rag_query_handler(mut req: Request<Body>) -> Response<Body>
}
};

#[cfg(feature = "search")]
let mut web_search_allowed: bool = false;

if let Some(ro) = res {
match ro.points {
Some(scored_points) => {
match scored_points.is_empty() {
true => {
// log
warn!(target: "stdout", "{}", format!("No point retrieved (score < threshold {})", server_info.qdrant_config.score_threshold));

#[cfg(feature = "search")]
{
info!(target: "stdout", "No points retrieved, enabling web search.");
web_search_allowed = true;
}
}
false => {
// update messages with retrieved context
Expand Down Expand Up @@ -435,10 +451,132 @@ pub(crate) async fn rag_query_handler(mut req: Request<Body>) -> Response<Body>
// log
warn!(target: "stdout", "{}", format!("No point retrieved (score < threshold {})", server_info.qdrant_config.score_threshold
));

#[cfg(feature = "search")]
{
info!(target: "stdout", "No points retrieved, enabling web search.");
web_search_allowed = true;
}
}
}
}

#[cfg(feature = "search")]
if web_search_allowed {
let search_arguments = match crate::SEARCH_ARGUMENTS.get() {
Some(sc) => sc,
None => {
let err_msg = "Failed to obtain SEARCH_ARGUMENTS. Was it set?".to_string();
error!(target: "stdout", "{}", &err_msg);

return error::internal_server_error(err_msg);
}
};

let endpoint: hyper::Uri = match search_arguments.query_server_url.parse() {
Ok(uri) => uri,
Err(e) => {
let err_msg = format!("LlamaEdge Query server URL could not be parsed: {}", e);
error!(target: "stdout", "{}", &err_msg);

return error::internal_server_error(err_msg);
}
};

let summary_endpoint = match hyper::Uri::builder()
.scheme(endpoint.scheme().unwrap().to_string().as_str())
.authority(endpoint.authority().unwrap().to_string().as_str())
.path_and_query("/query/summarize")
.build()
{
Ok(se) => se,
Err(_) => {
let err_msg = "Couldn't build summary_endpoint from query_server_url".to_string();
error!(target: "stdout", "{}", &err_msg);

return error::internal_server_error(err_msg);
}
};

//perform query, extract summary, add to
let req = match Request::builder()
.method(Method::POST)
.uri(summary_endpoint)
.header("content-type", "application/json")
.body(Body::from(
serde_json::json!({
"search_config" : {
"api_key": search_arguments.api_key,
},
"backend": search_arguments.search_backend,
"query": query,
})
.to_string(),
)) {
Ok(request) => request,
Err(_) => {
let err_msg = "Failed to build request to LLamaEdge query server.".to_string();
error!(target: "stdout", "{}", &err_msg);
return error::internal_server_error(err_msg);
}
};

info!(target: "stdout", "Querying the LlamaEdge query server.");

let client = hyper::client::Client::new();
match client.request(req).await {
Ok(res) => {
let is_success = res.status().is_success();

let body_bytes = match hyper::body::to_bytes(res.into_body()).await {
Ok(bytes) => bytes,
Err(e) => {
let err_msg = format!("Couldn't convert body into bytes: {}", e);
error!(target: "stdout", "{}", &err_msg);

return error::internal_server_error(err_msg);
}
};

let body_json: serde_json::Value = match serde_json::from_slice(&body_bytes) {
Ok(json) => json,
Err(e) => {
let err_msg = format!("Couldn't convert body into json: {}", e);
error!(target: "stdout", "{}", &err_msg);

return error::internal_server_error(err_msg);
}
};

info!(target: "stdout", "processed query server response json body: \n{}", body_json);

// if the request is a success, check decision and inject results accordingly.
if is_success && body_json["decision"].as_bool().unwrap_or(true) {
// the logic to ensure "results" is a serde_json::Value::String is present on the
// llamaedge-query-server.
let results = body_json["results"].as_str().unwrap_or("");

info!(target: "stdout", "injecting search summary into conversation context.");
//inject search results
let system_search_result_message: ChatCompletionRequestMessage =
ChatCompletionRequestMessage::new_system_message(results, None);

chat_request.messages.insert(
chat_request.messages.len() - 1,
system_search_result_message,
)
}
}
Err(e) => {
let err_msg = format!(
"Couldn't make request to LlamaEdge query server, switching to regular RAG: {}",
e
);
warn!(target: "stdout", "{}", &err_msg);
}
};
}

// chat completion
let res = match llama_core::chat::chat(&mut chat_request).await {
Ok(result) => match result {
Expand Down
31 changes: 31 additions & 0 deletions src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ extern crate log;

mod backend;
mod error;

mod utils;

use anyhow::Result;
Expand All @@ -21,6 +22,8 @@ use once_cell::sync::OnceCell;
use serde::{Deserialize, Serialize};
use std::{collections::HashMap, net::SocketAddr, path::PathBuf};
use tokio::net::TcpListener;
#[cfg(feature = "search")]
use utils::SearchArguments;
use utils::{is_valid_url, LogLevel};

type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
Expand All @@ -29,6 +32,9 @@ type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
pub(crate) static GLOBAL_RAG_PROMPT: OnceCell<String> = OnceCell::new();
// server info
pub(crate) static SERVER_INFO: OnceCell<ServerInfo> = OnceCell::new();
// search cli arguments
#[cfg(feature = "search")]
pub(crate) static SEARCH_ARGUMENTS: OnceCell<SearchArguments> = OnceCell::new();

// default socket address
const DEFAULT_SOCKET_ADDRESS: &str = "0.0.0.0:8080";
Expand Down Expand Up @@ -127,6 +133,18 @@ struct Cli {
/// Deprecated. Print all log information to stdout
#[arg(long)]
log_all: bool,
/// API key to be supplied to the endpoint, if supported.
#[cfg(feature = "search")]
#[arg(long, default_value = "")]
api_key: String,
/// The URL for the LlamaEdge query server. Supplying this implies usage.
#[cfg(feature = "search")]
#[arg(long, required = true)]
query_server_url: String,
/// The search API backend to use for internet search.
#[cfg(feature = "search")]
#[arg(long, default_value = "tavily", requires = "query-server-url")]
search_backend: String,
}

#[tokio::main(flavor = "current_thread")]
Expand Down Expand Up @@ -452,6 +470,19 @@ async fn main() -> Result<(), ServerError> {
}
});

#[cfg(feature = "search")]
{
let search_arguments = SearchArguments {
api_key: cli.api_key,
query_server_url: cli.query_server_url,
search_backend: cli.search_backend,
};

SEARCH_ARGUMENTS
.set(search_arguments)
.map_err(|_| ServerError::Operation("Failed to set `SEARCH_ARGUMENTS`.".to_string()))?;
}

// let server = Server::bind(&addr).serve(new_service);

let tcp_listener = TcpListener::bind(addr).await.unwrap();
Expand Down
12 changes: 12 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@ pub(crate) fn gen_chat_id() -> String {
format!("chatcmpl-{}", uuid::Uuid::new_v4())
}

/// Search related items that aren't directly supported by SearchConfig
#[cfg(feature = "search")]
#[derive(Debug)]
pub(crate) struct SearchArguments {
/// API key to be supplied to the endpoint, if supported. Not used by Bing.
pub(crate) api_key: String,
/// The URL for the LlamaEdge query server. Supplying this implies usage.
pub(crate) query_server_url: String,
/// The search API backend to use for internet search.
pub(crate) search_backend: String,
}

#[derive(
Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, clap::ValueEnum, Serialize, Deserialize,
)]
Expand Down