-
Notifications
You must be signed in to change notification settings - Fork 32
Creating augmented suggester #54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 16 commits
Commits
Show all changes
30 commits
Select commit
Hold shift + click to select a range
a82e872
Created new augmented_model_suggester and corresponding utils.
grace-sng7 63e6d82
Updated dependencies according to augmented_model_suggester and corre…
grace-sng7 b81e676
Updated LLM query prompt.
grace-sng7 2af6350
Minor fixes after testing AugmentedModelSuggester.
grace-sng7 e349bf5
Edited CauseNet search function.
grace-sng7 633d76d
Updated README.md to include augmented_model_suggester
grace-sng7 a474e5d
Update README.md
grace-sng7 d31fc79
Merge pull request #53 from grace-sng7/creating_augmented_suggester
grace-sng7 6b71deb
Update README.md
grace-sng7 d763517
Added augmented model suggester examples notebook
grace-sng7 e2f57e4
Merge pull request #55 from grace-sng7/creating_augmented_suggester
grace-sng7 072adc4
Uploaded augmented model suggester examples notebook again.
grace-sng7 8d82bd9
Merge branch 'py-why:creating_augmented_suggester' into creating_augm…
grace-sng7 fd17322
Merge pull request #56 from grace-sng7/creating_augmented_suggester
grace-sng7 05d9aa9
Set to ignore notebook testing for augmented model suggester examples
grace-sng7 0d1c2b5
Merge pull request #57 from grace-sng7/augmented_suggester
grace-sng7 00b61a2
Updated augmented_model_suggester_examples notebooks, docstrings, and…
grace-sng7 83a968f
Merge pull request #58 from grace-sng7/augmented_suggester
grace-sng7 2c6c7c2
Updated citations
grace-sng7 bfde305
Merge pull request #59 from grace-sng7/augmented_suggester
grace-sng7 9148c2a
Edited augmented model suggester llm_query method
grace-sng7 e868bb2
Merge pull request #60 from grace-sng7/augmented_suggester
grace-sng7 1e19398
Updated ignore_notebooks in tests
grace-sng7 ad04e9b
Merge pull request #61 from grace-sng7/augmented_suggester
grace-sng7 a90e0e4
Added onxruntime dependency
grace-sng7 59c6012
Merge pull request #62 from grace-sng7/augmented_suggester
grace-sng7 a68930b
onnxruntime dependency
grace-sng7 0f769a0
Merge pull request #63 from grace-sng7/augmented_suggester
grace-sng7 2e1ca67
Removed onnxruntime-silicon
grace-sng7 de4c85f
Merge pull request #64 from grace-sng7/augmented_suggester
grace-sng7 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
136 changes: 136 additions & 0 deletions
136
docs/notebooks/augmented_model_suggester_examples.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,136 @@ | ||
| { | ||
| "nbformat": 4, | ||
| "nbformat_minor": 0, | ||
| "metadata": { | ||
| "colab": { | ||
| "provenance": [] | ||
| }, | ||
| "kernelspec": { | ||
| "name": "python3", | ||
| "display_name": "Python 3" | ||
| }, | ||
| "language_info": { | ||
| "name": "python" | ||
| } | ||
| }, | ||
| "cells": [ | ||
| { | ||
| "cell_type": "code", | ||
| "source": [ | ||
| "pip install dotenv" | ||
| ], | ||
| "metadata": { | ||
| "id": "cmZerbMu6Uk4" | ||
| }, | ||
| "execution_count": null, | ||
| "outputs": [] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": { | ||
| "id": "EulKv3Km4nMa" | ||
| }, | ||
| "outputs": [], | ||
| "source": [ | ||
| "from dotenv import load_dotenv\n", | ||
| "import os\n", | ||
| "\n", | ||
| "load_dotenv()\n", | ||
| "\n", | ||
| "os.environ[\"OPENAI_API_KEY\"] = '' # specify your key here" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "source": [ | ||
| "pip install pywhyllm" | ||
| ], | ||
| "metadata": { | ||
| "collapsed": true, | ||
| "id": "83sxVcP97xlH" | ||
| }, | ||
| "execution_count": null, | ||
| "outputs": [] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "source": [ | ||
| "from pywhyllm.suggesters.augmented_model_suggester import AugmentedModelSuggester\n", | ||
| "\n", | ||
| "model = AugmentedModelSuggester('gpt-4')" | ||
| ], | ||
| "metadata": { | ||
| "id": "VdfEKuDLEYcU" | ||
| }, | ||
| "execution_count": null, | ||
| "outputs": [] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "source": [ | ||
| "result = model.suggest_pairwise_relationship(\"smoking\", \"lung cancer\")" | ||
| ], | ||
| "metadata": { | ||
| "id": "D85ec6Pk5JzA" | ||
| }, | ||
| "execution_count": null, | ||
| "outputs": [] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "source": [ | ||
| "result" | ||
| ], | ||
| "metadata": { | ||
| "id": "W3bFehXh5SQl" | ||
| }, | ||
| "execution_count": null, | ||
| "outputs": [] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "source": [ | ||
| "result = model.suggest_pairwise_relationship(\"income\", \"exercise level\")" | ||
| ], | ||
| "metadata": { | ||
| "id": "odFkp921hQsX" | ||
| }, | ||
| "execution_count": null, | ||
| "outputs": [] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "source": [ | ||
| "result" | ||
| ], | ||
| "metadata": { | ||
| "id": "ZIeStj9OwIPe" | ||
| }, | ||
| "execution_count": null, | ||
| "outputs": [] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "source": [ | ||
| "result = model.suggest_pairwise_relationship(\"flooding\", \"rain\")" | ||
| ], | ||
| "metadata": { | ||
| "id": "Fm5XCFrRwKsV" | ||
| }, | ||
| "execution_count": null, | ||
| "outputs": [] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "source": [ | ||
| "result" | ||
| ], | ||
| "metadata": { | ||
| "id": "HDo098ICwzi7" | ||
| }, | ||
| "execution_count": null, | ||
| "outputs": [] | ||
| } | ||
| ] | ||
| } | ||
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,45 @@ | ||
| import logging | ||
| import re | ||
|
|
||
| from .simple_model_suggester import SimpleModelSuggester | ||
| from pywhyllm.utils.data_loader import * | ||
amit-sharma marked this conversation as resolved.
Outdated
Show resolved
Hide resolved
|
||
| from pywhyllm.utils.augmented_model_suggester_utils import * | ||
|
|
||
|
|
||
| class AugmentedModelSuggester(SimpleModelSuggester): | ||
| def __init__(self, llm, file_path: str = 'data/causenet-precision.jsonl.bz2'): | ||
grace-sng7 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| super().__init__(llm) | ||
| self.file_path = file_path | ||
|
|
||
| logging.basicConfig(level=logging.INFO) | ||
| url = "https://groups.uni-paderborn.de/wdqa/causenet/causality-graphs/causenet-precision.jsonl.bz2" | ||
| success = download_causenet(url, file_path) | ||
|
|
||
| if success: | ||
| print(f"File downloaded to {file_path}") | ||
| json_data = load_causenet_json(file_path) | ||
| self.causenet_dict = create_causenet_dict(json_data) | ||
| else: | ||
| print("Download failed") | ||
|
|
||
| def suggest_pairwise_relationship(self, variable1: str, variable2: str): | ||
| result = find_top_match_in_causenet(self.causenet_dict, variable1, variable2) | ||
| if result: | ||
| source_text = get_source_text(result) | ||
| retriever = split_data_and_create_vectorstore_retriever(source_text) | ||
| response = query_llm(variable1, variable2, source_text, retriever) | ||
| else: | ||
| response = query_llm(variable1, variable2) | ||
|
|
||
| answer = re.findall(r'<answer>(.*?)</answer>', response) | ||
| answer = [ans.strip() for ans in answer] | ||
| answer_str = "".join(answer) | ||
|
|
||
| if answer_str == "A": | ||
| return [variable1, variable2, response] | ||
| elif answer_str == "B": | ||
| return [variable2, variable1, response] | ||
| elif answer_str == "C": | ||
| return [None, None, response] | ||
| else: | ||
| assert False, "Invalid answer from LLM: " + answer_str | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,143 @@ | ||
| import os | ||
| from langchain_text_splitters import RecursiveCharacterTextSplitter | ||
| from langchain_core.documents import Document | ||
| from langchain_chroma import Chroma | ||
| from langchain_huggingface import HuggingFaceEmbeddings | ||
| from langchain_openai import ChatOpenAI | ||
| from langchain_core.prompts import ChatPromptTemplate | ||
| from langchain.chains import create_retrieval_chain | ||
| from langchain.chains.combine_documents import create_stuff_documents_chain | ||
| import numpy as np | ||
| from rank_bm25 import BM25Okapi | ||
| from sentence_transformers import SentenceTransformer, util | ||
|
|
||
|
|
||
| def find_top_match_in_causenet(causenet_dict, variable1, variable2, threshold=0.7): | ||
| # Sample dictionary | ||
| pair_strings = [ | ||
| f"{causenet_dict[key]['causal_relation']['cause']}-{causenet_dict[key]['causal_relation']['effect']}" | ||
| for key in causenet_dict] | ||
|
|
||
| # Tokenize for BM25 | ||
| tokenized_pairs = [text.split() for text in pair_strings] | ||
| bm25 = BM25Okapi(tokenized_pairs) | ||
|
|
||
| # Original and reverse queries | ||
| query = variable1 + "-" + variable2 | ||
| reverse_query = variable2 + "-" + variable1 | ||
| tokenized_query = query.split() | ||
| tokenized_reverse_query = reverse_query.split() | ||
|
|
||
| # Combine tokens from both queries (remove duplicates) | ||
| combined_query = list(set(tokenized_query + tokenized_reverse_query)) | ||
|
|
||
| # Get top-k candidates using BM25 with combined query | ||
| k = 5 | ||
| scores = bm25.get_scores(combined_query) | ||
| top_k_indices = np.argsort(scores)[::-1][:k] | ||
| candidate_pairs = [pair_strings[i] for i in top_k_indices] | ||
|
|
||
| # Apply SBERT to candidates | ||
| model = SentenceTransformer('all-MiniLM-L6-v2') | ||
| query_embedding = model.encode(query, convert_to_tensor=True) | ||
| reverse_query_embedding = model.encode(reverse_query, convert_to_tensor=True) | ||
| candidate_embeddings = model.encode(candidate_pairs, convert_to_tensor=True) | ||
|
|
||
| # Compute similarities for both original and reverse queries | ||
| similarities = util.cos_sim(query_embedding, candidate_embeddings).flatten() | ||
| reverse_similarities = util.cos_sim(reverse_query_embedding, candidate_embeddings).flatten() | ||
|
|
||
| # Take the maximum similarity for each candidate (original or reverse) | ||
| max_similarities = np.maximum(similarities, reverse_similarities) | ||
|
|
||
| # Get the top match and its similarity score | ||
| top_idx = np.argmax(max_similarities) | ||
| top_similarity = max_similarities[top_idx] | ||
| top_pair = candidate_pairs[top_idx] | ||
|
|
||
| # Check if the top similarity meets the threshold | ||
| if top_similarity >= threshold: | ||
| print(f"Best match: {top_pair} (Similarity: {top_similarity:.4f})") | ||
| return causenet_dict[top_pair] | ||
| else: | ||
| print(f"No match found with similarity above {threshold} (Best similarity: {top_similarity:.4f})") | ||
| return None | ||
|
|
||
|
|
||
| def get_source_text(causenet_query_result): | ||
| source_text = "" | ||
| if causenet_query_result: | ||
| for item in causenet_query_result["sources"]: | ||
| if item["type"] == 'wikipedia_sentence' or item["type"] == 'clueweb12_sentence': | ||
| source_text += item["payload"]["sentence"] + " " | ||
|
|
||
| return source_text | ||
|
|
||
|
|
||
| def split_data_and_create_vectorstore_retriever(source_text): | ||
| document = Document(page_content=source_text) | ||
|
|
||
| # Initialize the text splitter | ||
| text_splitter = RecursiveCharacterTextSplitter( | ||
| chunk_size=100, # Adjust chunk size as needed | ||
| chunk_overlap=20 # Overlap for context | ||
| ) | ||
| # Split the documents | ||
| splits = text_splitter.split_documents([document]) | ||
|
|
||
| embeddings = HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2") | ||
|
|
||
| # Create a vector store from the document splits | ||
| vectorstore = Chroma.from_documents( | ||
| documents=splits, | ||
| embedding=embeddings, | ||
| persist_directory="./chroma_db" # Optional: Save to disk for reuse | ||
| ) | ||
|
|
||
| # Create a retriever from the vector store | ||
| retriever = vectorstore.as_retriever( | ||
| search_type="similarity", | ||
| search_kwargs={"k": 5} # Retrieve top 5 relevant chunks | ||
| ) | ||
|
|
||
| return retriever | ||
|
|
||
|
|
||
| def query_llm(variable1, variable2, source_text=None, retriever=None): | ||
| # Initialize the language model | ||
| llm = ChatOpenAI(model="gpt-4") | ||
|
|
||
| if source_text: | ||
| system_prompt = """You are a helpful assistant for causal reasoning. | ||
|
|
||
| Context: {context} | ||
| """ | ||
| else: | ||
| system_prompt = """You are a helpful assistant for causal reasoning. | ||
| """ | ||
|
|
||
| # prompt template | ||
| prompt = ChatPromptTemplate.from_messages([ | ||
| ("system", system_prompt), | ||
| ("human", "{input}") | ||
| ]) | ||
|
|
||
| query = f"""Which cause-and-effect-relationship is more likely? Provide reasoning and you must give your final answer (A, B, or C) in <answer> </answer> tags with the letter only. | ||
| A. {variable1} causes {variable2} B. {variable2} causes {variable1} C. neither {variable1} nor {variable2} cause each other.""" | ||
|
|
||
| # Define the system prompt | ||
| if source_text: | ||
| # Create a document chain to combine retrieved documents | ||
| question_answer_chain = create_stuff_documents_chain(llm, prompt) | ||
|
|
||
| # Create the RAG chain | ||
| rag_chain = create_retrieval_chain(retriever, question_answer_chain) | ||
|
|
||
| response = rag_chain.invoke({"input": query}) | ||
| return response['answer'] | ||
|
|
||
|
|
||
| else: | ||
| default_chain = prompt | llm | ||
| response = default_chain.invoke({"input": query}) | ||
| return response.content |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.