Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Binary file added recipes/rag-example/OJ_L_202401689_EN_TXT.pdf
Binary file not shown.
52 changes: 52 additions & 0 deletions recipes/rag-example/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Minimal RAG example

## Context

This recipe is intended as a minimal example for Retrieval Augmented Generation using Eole models.
It relies on additional tools, such as Langchain's [loaders](https://python.langchain.com/docs/integrations/document_loaders/pymupdf/) and [splitters](https://python.langchain.com/v0.1/docs/modules/data_connection/document_transformers/recursive_text_splitter/), as well as [ChromaDB](https://docs.trychroma.com/getting-started) for vector search.

The example is using the rather hard to digest ["EU AI Act" full text](https://digital-strategy.ec.europa.eu/en/policies/regulatory-framework-ai#:~:text=The%20AI%20Act%20(Regulation%20(EU,regarding%20specific%20uses%20of%20AI.) for the sake of the exercise.

This is just a very quickly put together proof of concept, and is not expected to give perfect answers.

## Usage

### 0. Install requirements

```bash
pip install langchain pymupdf chromadb
```

### 1. Convert the model you want to use (Llama-3.1-8B by default)

**Set environment variables**

```bash
export EOLE_MODEL_DIR=<where_to_store_models>
export HF_TOKEN=<your_hf_token>
```

**Download and convert model**

```bash
eole convert HF --model_dir meta-llama/Meta-Llama-3.1-8B --output $EOLE_MODEL_DIR/llama3.1-8b --token $HF_TOKEN
```

### 2. Adapt and run the script

Modify, if needed, the model_path in `PredictConfig`:
```python
...
config = PredictConfig(
model_path=os.path.expandvars("${EOLE_MODEL_DIR}/llama3.1-8b"), # <------ change if needed
src="dummy",
max_length=500,
...
```

**Run the script**
```bash
python3 test_rag.py
```

Note: You can test various queries by changing the `QUERY` variable.
124 changes: 124 additions & 0 deletions recipes/rag-example/test_rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# flake8: noqa

import os
from rich import print
from tqdm import tqdm
from langchain_community.document_loaders import PyMuPDFLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
import chromadb

from eole.utils.logging import init_logger
from eole.config.run import PredictConfig
from eole.inference_engine import InferenceEnginePY

# Set up logging
logger = init_logger()

# 1. Load and Split the Document
logger.info("Loading and splitting the document...")
loader = PyMuPDFLoader("./OJ_L_202401689_EN_TXT.pdf")
docs = loader.load()

text_splitter = RecursiveCharacterTextSplitter(chunk_size=1500, chunk_overlap=150)
documents = text_splitter.split_documents(documents=docs)
print(f"[INFO] Total chunks: {len(documents)}")

# 2. Set Up ChromaDB Client and Collection
logger.info("Setting up ChromaDB client...")
chroma_client = chromadb.PersistentClient(path="chromadb_data")
collection = chroma_client.get_or_create_collection(name="test-eu")

# 3. Insert Documents into Collection
logger.info("Checking for existing data in the collection...")

# Retrieve all existing IDs from the collection
existing_ids = set(collection.get(ids=None)["ids"]) # Fetches all IDs in the collection
logger.info(f"Found {len(existing_ids)} existing documents in the collection.")
batch_size = 100
for i in tqdm(
range(0, len(documents), batch_size),
desc="Upserting batches in the vector database",
):
batch_ids = [str(k) for k in range(i, min(len(documents), i + batch_size))]
new_ids = [
id_ for id_ in batch_ids if id_ not in existing_ids
] # Filter out existing IDs

if new_ids: # Only upsert if there are new IDs
new_docs = [documents[int(id_)] for id_ in new_ids]
collection.upsert(
ids=new_ids,
documents=[doc.page_content for doc in new_docs],
metadatas=[doc.metadata for doc in new_docs],
)
logger.info(f"Upserted {len(new_ids)} new documents.")
else:
logger.info(f"Skipping batch {i // batch_size + 1}, all IDs already exist.")

# print(collection.peek(10))

# 4. Query the Collection
QUERY = "What is the general position around using biometrics and facial recognition in public places?"
# QUERY = "Are there any derogations for specific actors?"
# QUERY = "What are the main obligations of importers?"
# QUERY = "What are the main risks and penalties incurred?"
# QUERY = "What is the maximum fine for potential offenders?"
# QUERY = "What are the main prohibited practices coverd by the act?"
# QUERY = "What are the main accepted practices covered by the act?"

print(f"[INFO] Querying collection with: {QUERY}")
results = collection.query(query_texts=[QUERY], n_results=5)

best_id = int(results["ids"][0][0])
print(f"[INFO] Best result ID: {best_id}")
context_docs = collection.get(ids=[str(best_id - 1), str(best_id), str(best_id + 1)])

# 5. Prepare the Prompt for Inference
PROMPT = """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise. The answer should be understandable outside of its context.
The context comes from this document: "Regulation (EU) 2024/1689 of the European Parliament and of the Council of 13 June 2024 laying down harmonised rules on artificial intelligence and amending Regulations (EC) No 300/2008, (EU) No 167/2013, (EU) No 168/2013, (EU) 2018/858, (EU) 2018/1139 and (EU) 2019/2144 and Directives 2014/90/EU, (EU) 2016/797 and (EU) 2020/1828 (Artificial Intelligence Act)Text with EEA relevance."
Question: {question}
Context:  {context}
Answer:"""

context = "\n".join(context_docs["documents"])
prompt = PROMPT.format(question=QUERY, context=context)
logger.info("Generated Prompt:")
print(prompt)


# 6. Perform Inference
logger.info("Running inference...")
config = PredictConfig(
model_path=os.path.expandvars("${EOLE_MODEL_DIR}/llama3.1-8b"),
src="dummy",
max_length=500,
gpu_ranks=[0],
# Uncomment to activate bnb quantization
# quant_type="bnb_NF4",
# quant_layers=[
# "gate_up_proj",
# "down_proj",
# "up_proj",
# "linear_values",
# "linear_query",
# "linear_keys",
# "final_linear",
# "w_in",
# "w_out",
# ],
top_p=0.3,
temperature=0.35,
beam_size=5,
seed=42,
batch_size=1,
batch_type="sents",
)

engine = InferenceEnginePY(config)

_, _, predictions = engine.infer_list([prompt])

# 7. Display the Prediction
answer = predictions[0][0].replace("⦅newline⦆", "\n")
logger.info("Final Answer:")
print(answer)
Loading