|
| 1 | +import argparse |
| 2 | +import json |
| 3 | +import logging |
| 4 | +import os |
| 5 | +import pathlib |
| 6 | +import re |
| 7 | + |
| 8 | +from azure.identity import AzureDeveloperCliCredential, get_bearer_token_provider |
| 9 | +from azure.search.documents import SearchClient |
| 10 | +from dotenv_azd import load_azd_env |
| 11 | +from langchain_core.documents import Document as LCDocument |
| 12 | +from langchain_openai import AzureChatOpenAI, AzureOpenAIEmbeddings |
| 13 | +from ragas.embeddings import LangchainEmbeddingsWrapper |
| 14 | +from ragas.llms import LangchainLLMWrapper |
| 15 | +from ragas.testset import TestsetGenerator |
| 16 | +from ragas.testset.graph import KnowledgeGraph, Node, NodeType |
| 17 | +from ragas.testset.transforms import apply_transforms, default_transforms |
| 18 | +from rich.logging import RichHandler |
| 19 | + |
| 20 | +logger = logging.getLogger("ragapp") |
| 21 | + |
| 22 | +root_dir = pathlib.Path(__file__).parent |
| 23 | + |
| 24 | + |
| 25 | +def get_azure_credential(): |
| 26 | + AZURE_TENANT_ID = os.getenv("AZURE_TENANT_ID") |
| 27 | + if AZURE_TENANT_ID: |
| 28 | + logger.info("Setting up Azure credential using AzureDeveloperCliCredential with tenant_id %s", AZURE_TENANT_ID) |
| 29 | + azure_credential = AzureDeveloperCliCredential(tenant_id=AZURE_TENANT_ID, process_timeout=60) |
| 30 | + else: |
| 31 | + logger.info("Setting up Azure credential using AzureDeveloperCliCredential for home tenant") |
| 32 | + azure_credential = AzureDeveloperCliCredential(process_timeout=60) |
| 33 | + return azure_credential |
| 34 | + |
| 35 | + |
| 36 | +def get_search_documents(azure_credential, num_search_documents=None) -> str: |
| 37 | + search_client = SearchClient( |
| 38 | + endpoint=f"https://{os.getenv('AZURE_SEARCH_SERVICE')}.search.windows.net", |
| 39 | + index_name=os.getenv("AZURE_SEARCH_INDEX"), |
| 40 | + credential=azure_credential, |
| 41 | + ) |
| 42 | + all_documents = [] |
| 43 | + if num_search_documents is None: |
| 44 | + logger.info("Fetching all document chunks from Azure AI Search") |
| 45 | + num_search_documents = 100000 |
| 46 | + else: |
| 47 | + logger.info("Fetching %d document chunks from Azure AI Search", num_search_documents) |
| 48 | + response = search_client.search(search_text="*", top=num_search_documents).by_page() |
| 49 | + for page in response: |
| 50 | + page = list(page) |
| 51 | + all_documents.extend(page) |
| 52 | + return all_documents |
| 53 | + |
| 54 | + |
| 55 | +def generate_ground_truth_ragas(num_questions=200, num_search_documents=None, kg_file=None): |
| 56 | + azure_credential = get_azure_credential() |
| 57 | + azure_openai_api_version = os.getenv("AZURE_OPENAI_API_VERSION") or "2024-06-01" |
| 58 | + azure_endpoint = f"https://{os.getenv('AZURE_OPENAI_SERVICE')}.openai.azure.com" |
| 59 | + azure_ad_token_provider = get_bearer_token_provider( |
| 60 | + azure_credential, "https://cognitiveservices.azure.com/.default" |
| 61 | + ) |
| 62 | + generator_llm = LangchainLLMWrapper( |
| 63 | + AzureChatOpenAI( |
| 64 | + openai_api_version=azure_openai_api_version, |
| 65 | + azure_endpoint=azure_endpoint, |
| 66 | + azure_ad_token_provider=azure_ad_token_provider, |
| 67 | + azure_deployment=os.getenv("AZURE_OPENAI_EVAL_DEPLOYMENT"), |
| 68 | + model=os.environ["AZURE_OPENAI_EVAL_MODEL"], |
| 69 | + validate_base_url=False, |
| 70 | + ) |
| 71 | + ) |
| 72 | + |
| 73 | + # init the embeddings for answer_relevancy, answer_correctness and answer_similarity |
| 74 | + generator_embeddings = LangchainEmbeddingsWrapper( |
| 75 | + AzureOpenAIEmbeddings( |
| 76 | + openai_api_version=azure_openai_api_version, |
| 77 | + azure_endpoint=azure_endpoint, |
| 78 | + azure_ad_token_provider=azure_ad_token_provider, |
| 79 | + azure_deployment=os.getenv("AZURE_OPENAI_EMB_DEPLOYMENT"), |
| 80 | + model=os.environ["AZURE_OPENAI_EMB_MODEL_NAME"], |
| 81 | + ) |
| 82 | + ) |
| 83 | + |
| 84 | + # Load or create the knowledge graph |
| 85 | + if kg_file: |
| 86 | + full_path_to_kg = root_dir / kg_file |
| 87 | + if not os.path.exists(full_path_to_kg): |
| 88 | + raise FileNotFoundError(f"Knowledge graph file {full_path_to_kg} not found.") |
| 89 | + logger.info("Loading existing knowledge graph from %s", full_path_to_kg) |
| 90 | + kg = KnowledgeGraph.load(full_path_to_kg) |
| 91 | + else: |
| 92 | + # Make a knowledge_graph from Azure AI Search documents |
| 93 | + search_docs = get_search_documents(azure_credential, num_search_documents) |
| 94 | + |
| 95 | + logger.info("Creating a RAGAS knowledge graph based off of %d search documents", len(search_docs)) |
| 96 | + nodes = [] |
| 97 | + for doc in search_docs: |
| 98 | + content = doc["content"] |
| 99 | + citation = doc["sourcepage"] |
| 100 | + node = Node( |
| 101 | + type=NodeType.DOCUMENT, |
| 102 | + properties={ |
| 103 | + "page_content": f"[[{citation}]]: {content}", |
| 104 | + "document_metadata": {"citation": citation}, |
| 105 | + }, |
| 106 | + ) |
| 107 | + nodes.append(node) |
| 108 | + |
| 109 | + kg = KnowledgeGraph(nodes=nodes) |
| 110 | + |
| 111 | + logger.info("Using RAGAS to apply transforms to knowledge graph") |
| 112 | + transforms = default_transforms( |
| 113 | + documents=[LCDocument(page_content=doc["content"]) for doc in search_docs], |
| 114 | + llm=generator_llm, |
| 115 | + embedding_model=generator_embeddings, |
| 116 | + ) |
| 117 | + apply_transforms(kg, transforms) |
| 118 | + |
| 119 | + kg.save(root_dir / "ground_truth_kg.json") |
| 120 | + |
| 121 | + logger.info("Using RAGAS knowledge graph to generate %d questions", num_questions) |
| 122 | + generator = TestsetGenerator(llm=generator_llm, embedding_model=generator_embeddings, knowledge_graph=kg) |
| 123 | + dataset = generator.generate(testset_size=num_questions, with_debugging_logs=True) |
| 124 | + |
| 125 | + qa_pairs = [] |
| 126 | + for sample in dataset.samples: |
| 127 | + question = sample.eval_sample.user_input |
| 128 | + truth = sample.eval_sample.reference |
| 129 | + # Grab the citation in square brackets from the reference_contexts and add it to the truth |
| 130 | + citations = [] |
| 131 | + for context in sample.eval_sample.reference_contexts: |
| 132 | + match = re.search(r"\[\[(.*?)\]\]", context) |
| 133 | + if match: |
| 134 | + citation = match.group(1) |
| 135 | + citations.append(f"[{citation}]") |
| 136 | + truth += " " + " ".join(citations) |
| 137 | + qa_pairs.append({"question": question, "truth": truth}) |
| 138 | + |
| 139 | + with open(root_dir / "ground_truth.jsonl", "a") as f: |
| 140 | + logger.info("Writing %d QA pairs to %s", len(qa_pairs), f.name) |
| 141 | + for qa_pair in qa_pairs: |
| 142 | + f.write(json.dumps(qa_pair) + "\n") |
| 143 | + |
| 144 | + |
| 145 | +if __name__ == "__main__": |
| 146 | + logging.basicConfig( |
| 147 | + level=logging.WARNING, format="%(message)s", datefmt="[%X]", handlers=[RichHandler(rich_tracebacks=True)] |
| 148 | + ) |
| 149 | + logger.setLevel(logging.INFO) |
| 150 | + load_azd_env() |
| 151 | + |
| 152 | + parser = argparse.ArgumentParser(description="Generate ground truth data using AI Search index and RAGAS.") |
| 153 | + parser.add_argument("--numsearchdocs", type=int, help="Specify the number of search results to fetch") |
| 154 | + parser.add_argument("--numquestions", type=int, help="Specify the number of questions to generate.", default=200) |
| 155 | + parser.add_argument("--kgfile", type=str, help="Specify the path to an existing knowledge graph file") |
| 156 | + |
| 157 | + args = parser.parse_args() |
| 158 | + |
| 159 | + generate_ground_truth_ragas( |
| 160 | + num_search_documents=args.numsearchdocs, num_questions=args.numquestions, kg_file=args.kgfile |
| 161 | + ) |
0 commit comments