Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -206,3 +206,5 @@ marimo/_static/
marimo/_lsp/
__marimo__/
notes.txt

testing_base
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,38 @@ By combining these features, Data Tools helps you move from a collection of sepa
- **Extensible Pipeline Architecture**: Easily add custom analysis steps to the pipeline.
- **DataFrame Agnostic**: Uses a factory pattern to seamlessly handle different dataframe types (e.g., pandas).

## Installation and Setup

### Installation

To install the library and its dependencies, run the following command:

```bash
pip install data_tools
```

### LLM Configuration

This library uses LLMs for features like Business Glossary Generation. It supports any LLM provider compatible with LangChain's `init_chat_model` function. To configure your LLM provider, you need to set environment variables.

The `LLM_CONFIG` environment variable should be set to your desired model, optionally including the provider, in the format `provider:model_name`. If the provider is omitted, it will try to infer the provider.

**For OpenAI:**

```bash
# Provider is optional
export LLM_CONFIG="gpt-4"
export OPENAI_API_KEY="your-super-secret-key"
```

**For Google GenAI:**

```bash
export LLM_CONFIG="google_genai:gemini-pro"
export GOOGLE_API_KEY="your-google-api-key"
```


## Usage Examples

### Example 1: Automated Link Prediction (Primary Use Case)
Expand Down
2 changes: 1 addition & 1 deletion notebooks/upstream.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1644,7 +1644,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.3"
"version": "3.10.17"
}
},
"nbformat": 4,
Expand Down
12 changes: 11 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ classifiers = [
"Operating System :: OS Independent",
]
dependencies = [
"langchain>=0.3.27",
"langchain[anthropic,aws,cohere,google-genai,google-vertexai,groq]>=0.3.27",
"langchain-community>=0.3.21",
"langchain-openai>=0.3.28",
"langgraph>=0.6.4",
Expand All @@ -31,11 +31,16 @@ dependencies = [
"pydantic>=2.11.7",
"pydantic-settings>=2.10.1",
"pyfunctional>=1.5.0",
"python-dotenv>=1.1.1",
"scikit-learn==1.2.2",
"symspellpy>=6.9.0",
"trieregex>=1.0.0",
"xgboost==1.7.5",
"pyyaml>=6.0.2",
"langchain-deepseek>=0.1.4",
"langchain-nvidia-ai-endpoints>=0.3.16",
"langchain-xai>=0.2.5",
"langchain-perplexity>=0.1.2",
]

[dependency-groups]
Expand All @@ -46,6 +51,11 @@ test = [
"pytest-asyncio>=1.1.0",
]
lint = ["ruff"]
dev = [
"pytest>=8.4.1",
"pytest-asyncio>=1.1.0",
"pytest-cov>=6.2.1",
]

[tool.ruff]
src = ["src"]
Expand Down
1 change: 0 additions & 1 deletion src/data_tools/analysis/steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,6 @@ class BusinessGlossaryGenerator(AnalysisStep):
def __init__(self, domain: str):
"""
Initializes the BusinessGlossaryGenerator with optional additional context.

:param domain: The industry domain to which the dataset belongs.
"""
self.domain = domain
Expand Down
94 changes: 33 additions & 61 deletions src/data_tools/core/llms/chat.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
import logging
import time

from typing import TYPE_CHECKING

import openai

from langchain.chat_models import init_chat_model
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
from langchain.prompts import BaseChatPromptTemplate, ChatPromptTemplate
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain_core.rate_limiters import InMemoryRateLimiter

from .config import get_llm_config
from data_tools.core import settings

log = logging.getLogger(__name__)

Expand All @@ -24,33 +21,21 @@ class ChatModelLLM:
A Wrapper around Chat LLM to invoke on any of the pipeline that uses llm
'''

call_backs = [OpenAICallbackHandler()]

# type of chat models to support
CHAT_MODELS = {
"azure": AzureChatOpenAI,
"openai": ChatOpenAI
}

model = None
# number of retries to the LLM.
MAX_RETRIES = 5

# time to sleep if we hit llm RateLimitError
SLEEP_TIME = 25
MAX_RETRIES = settings.MAX_RETRIES

def __init__(self, model_name: str, response_schemas: list[ResponseSchema] = None,
output_parser=StructuredOutputParser, prompt_template=ChatPromptTemplate, template_string: str = None, config: dict = {},
*args, **kwargs
):

self.model: BaseChatModel = self.CHAT_MODELS[model_name](**config) # the llm model
self.model: BaseChatModel = init_chat_model(model_name, max_retries=self.MAX_RETRIES, rate_limiter=self._get_rate_limiter(), **config) # llm model

self.parser: StructuredOutputParser = output_parser # the output parser

self.prompt_template: BaseChatPromptTemplate = prompt_template # prompt template

self.output_parser = self.__output_parser_builder__(response_schemas=response_schemas) if response_schemas is not None else None # the builded output parser
self.output_parser = self.__output_parser_builder__(response_schemas=response_schemas) if response_schemas is not None else None # the built output parser

self.format_instructions = self.output_parser.get_format_instructions() if self.output_parser is not None else None # the format instructions

Expand All @@ -67,57 +52,47 @@ def __output_parser_builder__(self, response_schemas: list[ResponseSchema] = Non
output_parser = self.parser.from_response_schemas(response_schemas=response_schemas)
return output_parser

@classmethod
def _get_rate_limiter(cls):
rate_limiter = None
if settings.ENABLE_RATE_LIMITER:
rate_limiter = InMemoryRateLimiter(
requests_per_second=0.5, # <-- We can only make a request once every 2 seconds!
check_every_n_seconds=0.1, # Wake up every 100 ms to check whether allowed to make a request,
max_bucket_size=5, # Controls the maximum burst size.
)
return rate_limiter

def message_builder():
...

def invoke(self, *args, **kwargs):
"""
The final invoke method that takes any arguments that is to be finally added in the prompt message and invokes the llm call.
"""

# format_instructions = ""
# output_parser = None
# if response_schemas:
# output_parser = self.parser.from_response_schemas(response_schemas=response_schemas)
# format_instructions = output_parser.get_format_instructions()

# prompt = self.prompt_template.from_template(template = template_string)

sucessfull_parsing = False

messages = self.llm_prompt.format(
format_instructions=self.format_instructions,
**kwargs
)
# ()
retries = 0
_message = messages
response = ""

while True:
try:
response = self.model.invoke(_message,
config={"metadata": kwargs.get("metadata", {}),
}).content

_message = messages
except Exception as ex:
# ()
log.warning(f"[!] Error while llm invoke: {ex}")
try:
if retries > self.MAX_RETRIES:
break

response = self.model.invoke(_message,
config={'callbacks': self.call_backs,
"metadata": kwargs.get("metadata", {}),
}).content

_message = messages
break
except openai.RateLimitError:
log.warning(f"[!] LLM API rate limit hit ... sleeping for {self.SLEEP_TIME} seconds")
time.sleep(self.SLEEP_TIME)
except Exception as ex:
# ()
log.warning(f"[!] Error while llm invoke: {ex}")
try:
_message = messages[0].content
except Exception:
return "", sucessfull_parsing, messages
# response = self.model.invoke(messages[0].content).content
retries += 1
_message = messages[0].content
except Exception:
return "", sucessfull_parsing, messages

messages = messages[0].content if isinstance(messages, list) else messages

Expand All @@ -135,17 +110,14 @@ def invoke(self, *args, **kwargs):
return response, sucessfull_parsing, messages

@classmethod
def get_llm(cls, model_name: str = "azure", api_config: dict = {}, other_config: dict = {}):

config = {**get_llm_config(api_config, type=model_name), **other_config}
def get_llm(cls, model_name: str, llm_config: dict = {}):

return cls.CHAT_MODELS[model_name](**config)
return init_chat_model(model_name, max_retries=cls.MAX_RETRIES, rate_limiter=cls._get_rate_limiter(), **llm_config)

@classmethod
def build(cls,
model_name: str = "azure",
api_config: dict = {},
other_config: dict = {},
llm_config: dict = {},
prompt_template=ChatPromptTemplate,
output_parser=StructuredOutputParser,
response_schemas: list[ResponseSchema] = None,
Expand All @@ -166,7 +138,7 @@ def build(cls,

return cls(
model_name=model_name,
config={**get_llm_config(api_config, type=model_name), **other_config},
config={**llm_config},
prompt_template=prompt_template,
output_parser=output_parser,
template_string=template_string,
Expand Down
15 changes: 6 additions & 9 deletions src/data_tools/core/pipeline/business_glossary/bg.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,8 @@ def __init__(self, profiling_data: pd.DataFrame, *args, **kwargs):
)

self.__table_glossary_llm = ChatModelLLM.build(
model_name=settings.LLM_TYPE,
api_config=settings.BG_CONFIG["NORMAL_INFERENCE"],
other_config=self.LLM_CONFIG_1,
model_name=settings.LLM_PROVIDER,
llm_config=self.LLM_CONFIG_1,
response_schemas=table_glossary,
template_string=BUSINESS_GLOSSARY_PROMPTS[self.TEMPLATE_NAME][
"TABLE_GLOSSARY_TEMPLATE"
Expand All @@ -100,19 +99,17 @@ def __init__(self, profiling_data: pd.DataFrame, *args, **kwargs):
)

self.__business_glossary_llm = ChatModelLLM.build(
model_name=settings.LLM_TYPE,
api_config=settings.BG_CONFIG["NORMAL_INFERENCE"],
other_config=self.LLM_CONFIG_2,
model_name=settings.LLM_PROVIDER,
llm_config=self.LLM_CONFIG_2,
response_schemas=column_glossary,
template_string=BUSINESS_GLOSSARY_PROMPTS[self.TEMPLATE_NAME][
"BUSINESS_GLOSSARY_TEMPLATE"
],
prompt_template=PromptTemplate,
)
self.__business_tags_llm = ChatModelLLM.build(
model_name=settings.LLM_TYPE,
api_config=settings.BG_CONFIG["NORMAL_INFERENCE"],
other_config=self.LLM_CONFIG_2,
model_name=settings.LLM_PROVIDER,
llm_config=self.LLM_CONFIG_2,
response_schemas=column_tag_glossary,
template_string=BUSINESS_GLOSSARY_PROMPTS[self.TEMPLATE_NAME][
"BUSINESS_TAGS_TEMPLATE"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@ def __init__(self, *args, **kwargs):
# langfuse for monitoring

self.chat_llm = ChatModelLLM.build(
model_name=settings.LLM_TYPE,
api_config=settings.DI_CONFIG,
other_config=self.LLM_CONFIG,
model_name=settings.LLM_PROVIDER,
llm_config=self.LLM_CONFIG,
template_string=self.DIM_MEASURE_PROMPT,
response_schemas=self.dm_class_schema,
)
Expand Down
5 changes: 2 additions & 3 deletions src/data_tools/core/pipeline/key_identification/ki.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,8 @@ def __init__(self, profiling_data: pd.DataFrame,
*args, **kwargs):

self.__chat_llm = ChatModelLLM.build(
model_name=settings.LLM_TYPE,
api_config=settings.KI_CONFIG,
other_config=self.LLM_CONFIG,
model_name=settings.LLM_PROVIDER,
llm_config=self.LLM_CONFIG,
prompt_template=PromptTemplate,
template_string=self.KI_PROMPT_TEMPLATE,
response_schemas=self.primary_key,
Expand Down
5 changes: 2 additions & 3 deletions src/data_tools/core/pipeline/link_prediction/lp.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,8 @@ def __init__(

self.llm = (
ChatModelLLM.get_llm(
model_name=settings.LLM_TYPE,
api_config=settings.LP_CONFIG,
other_config=self.LLM_CONFIG,
model_name=settings.LLM_PROVIDER,
llm_config=self.LLM_CONFIG,
)
if llm is None
else llm
Expand Down
27 changes: 25 additions & 2 deletions src/data_tools/core/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,16 @@
from functools import lru_cache
from pathlib import Path

from dotenv import load_dotenv
from pydantic_settings import BaseSettings, SettingsConfigDict

from data_tools.core.utilities.configs import load_model_configuration

load_dotenv(dotenv_path=Path(__file__).resolve().parent.parent / ".env")


BASE_PATH = Path(__file__).resolve().parent.parent


class Settings(BaseSettings):
"""Global Configuration"""
Expand Down Expand Up @@ -39,11 +45,28 @@ class Settings(BaseSettings):
SQL_DIALECT: str = "postgresql"
DOMAIN: str = "ecommerce"
UNIVERSAL_INSTRUCTIONS: str = ""
L2_SAMPLE_LIMIT: int = 10

# LLM CONFIGS
LLM_PROVIDER: str
LLM_SAMPLE_LIMIT: int = 15
STRATA_SAMPLE_LIMIT: int = 4
MAX_RETRIES: int = 5
SLEEP_TIME: int = 25
ENABLE_RATE_LIMITER: bool = False

# LP
HALLUCINATIONS_MAX_RETRY: int = 2
UNIQUENESS_THRESHOLD: float = 0.9

# DATETIME
DATE_TIME_FORMAT_LIMIT: int = 25
REMOVE_DATETIME_LP: bool = True

model_config = SettingsConfigDict(
env_file=".env",
env_file=f"{BASE_PATH}/.env",
env_file_encoding="utf-8",
extra="ignore",
extra="allow",
case_sensitive=True,
)
L2_SAMPLE_LIMIT: int = 10
Expand Down
2 changes: 1 addition & 1 deletion tests/analysis/test_key_identification.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,4 @@ def test_key_identification_end_to_end():
# which contains the identified key information.
# We expect one identified key.

assert identified_key.column_name == "order_id"
assert identified_key == "order_id"
Loading