Skip to content

Commit 9ce5ea3

Browse files
Features/merged llm main (#5)
* added di l1. needs testing * di l1 stable. sample values changed to 10,000 * di l2 stable * added ki. refactored column profiling * lp boilerplate * added test conditions for lp architecture * lp testing * lp stable * added ki lp configs * updated test cases, added dates * bg testing * updated readme * Moved common code to dataframe * centralized llm configs for all steps, added libraries for different llms --------- Co-authored-by: JaskaranIntugle <[email protected]>
1 parent cda53fe commit 9ce5ea3

File tree

13 files changed

+3760
-2605
lines changed

13 files changed

+3760
-2605
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -206,3 +206,5 @@ marimo/_static/
206206
marimo/_lsp/
207207
__marimo__/
208208
notes.txt
209+
210+
testing_base

README.md

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,38 @@ By combining these features, Data Tools helps you move from a collection of sepa
1515
- **Extensible Pipeline Architecture**: Easily add custom analysis steps to the pipeline.
1616
- **DataFrame Agnostic**: Uses a factory pattern to seamlessly handle different dataframe types (e.g., pandas).
1717

18+
## Installation and Setup
19+
20+
### Installation
21+
22+
To install the library and its dependencies, run the following command:
23+
24+
```bash
25+
pip install data_tools
26+
```
27+
28+
### LLM Configuration
29+
30+
This library uses LLMs for features like Business Glossary Generation. It supports any LLM provider compatible with LangChain's `init_chat_model` function. To configure your LLM provider, you need to set environment variables.
31+
32+
The `LLM_CONFIG` environment variable should be set to your desired model, optionally including the provider, in the format `provider:model_name`. If the provider is omitted, it will try to infer the provider.
33+
34+
**For OpenAI:**
35+
36+
```bash
37+
# Provider is optional
38+
export LLM_CONFIG="gpt-4"
39+
export OPENAI_API_KEY="your-super-secret-key"
40+
```
41+
42+
**For Google GenAI:**
43+
44+
```bash
45+
export LLM_CONFIG="google_genai:gemini-pro"
46+
export GOOGLE_API_KEY="your-google-api-key"
47+
```
48+
49+
1850
## Usage Examples
1951

2052
### Example 1: Automated Link Prediction (Primary Use Case)

notebooks/upstream.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1832,7 +1832,7 @@
18321832
"name": "python",
18331833
"nbconvert_exporter": "python",
18341834
"pygments_lexer": "ipython3",
1835-
"version": "3.12.3"
1835+
"version": "3.10.17"
18361836
}
18371837
},
18381838
"nbformat": 4,

pyproject.toml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ classifiers = [
1515
"Operating System :: OS Independent",
1616
]
1717
dependencies = [
18-
"langchain>=0.3.27",
18+
"langchain[anthropic,aws,cohere,google-genai,google-vertexai,groq]>=0.3.27",
1919
"langchain-community>=0.3.21",
2020
"langchain-openai>=0.3.28",
2121
"langgraph>=0.6.4",
@@ -31,11 +31,16 @@ dependencies = [
3131
"pydantic>=2.11.7",
3232
"pydantic-settings>=2.10.1",
3333
"pyfunctional>=1.5.0",
34+
"python-dotenv>=1.1.1",
3435
"scikit-learn==1.2.2",
3536
"symspellpy>=6.9.0",
3637
"trieregex>=1.0.0",
3738
"xgboost==1.7.5",
3839
"pyyaml>=6.0.2",
40+
"langchain-deepseek>=0.1.4",
41+
"langchain-nvidia-ai-endpoints>=0.3.16",
42+
"langchain-xai>=0.2.5",
43+
"langchain-perplexity>=0.1.2",
3944
]
4045

4146
[dependency-groups]
@@ -46,6 +51,11 @@ test = [
4651
"pytest-asyncio>=1.1.0",
4752
]
4853
lint = ["ruff"]
54+
dev = [
55+
"pytest>=8.4.1",
56+
"pytest-asyncio>=1.1.0",
57+
"pytest-cov>=6.2.1",
58+
]
4959

5060
[tool.ruff]
5161
src = ["src"]

src/data_tools/analysis/steps.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,6 @@ class BusinessGlossaryGenerator(AnalysisStep):
123123
def __init__(self, domain: str):
124124
"""
125125
Initializes the BusinessGlossaryGenerator with optional additional context.
126-
127126
:param domain: The industry domain to which the dataset belongs.
128127
"""
129128
self.domain = domain

src/data_tools/core/llms/chat.py

Lines changed: 33 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
import logging
2-
import time
32

43
from typing import TYPE_CHECKING
54

6-
import openai
7-
5+
from langchain.chat_models import init_chat_model
86
from langchain.output_parsers import ResponseSchema, StructuredOutputParser
97
from langchain.prompts import BaseChatPromptTemplate, ChatPromptTemplate
10-
from langchain_community.callbacks.openai_info import OpenAICallbackHandler
11-
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
8+
from langchain_core.rate_limiters import InMemoryRateLimiter
129

13-
from .config import get_llm_config
10+
from data_tools.core import settings
1411

1512
log = logging.getLogger(__name__)
1613

@@ -24,33 +21,21 @@ class ChatModelLLM:
2421
A Wrapper around Chat LLM to invoke on any of the pipeline that uses llm
2522
'''
2623

27-
call_backs = [OpenAICallbackHandler()]
28-
29-
# type of chat models to support
30-
CHAT_MODELS = {
31-
"azure": AzureChatOpenAI,
32-
"openai": ChatOpenAI
33-
}
34-
35-
model = None
3624
# number of retries to the LLM.
37-
MAX_RETRIES = 5
38-
39-
# time to sleep if we hit llm RateLimitError
40-
SLEEP_TIME = 25
25+
MAX_RETRIES = settings.MAX_RETRIES
4126

4227
def __init__(self, model_name: str, response_schemas: list[ResponseSchema] = None,
4328
output_parser=StructuredOutputParser, prompt_template=ChatPromptTemplate, template_string: str = None, config: dict = {},
4429
*args, **kwargs
4530
):
4631

47-
self.model: BaseChatModel = self.CHAT_MODELS[model_name](**config) # the llm model
32+
self.model: BaseChatModel = init_chat_model(model_name, max_retries=self.MAX_RETRIES, rate_limiter=self._get_rate_limiter(), **config) # llm model
4833

4934
self.parser: StructuredOutputParser = output_parser # the output parser
5035

5136
self.prompt_template: BaseChatPromptTemplate = prompt_template # prompt template
5237

53-
self.output_parser = self.__output_parser_builder__(response_schemas=response_schemas) if response_schemas is not None else None # the builded output parser
38+
self.output_parser = self.__output_parser_builder__(response_schemas=response_schemas) if response_schemas is not None else None # the built output parser
5439

5540
self.format_instructions = self.output_parser.get_format_instructions() if self.output_parser is not None else None # the format instructions
5641

@@ -67,57 +52,47 @@ def __output_parser_builder__(self, response_schemas: list[ResponseSchema] = Non
6752
output_parser = self.parser.from_response_schemas(response_schemas=response_schemas)
6853
return output_parser
6954

55+
@classmethod
56+
def _get_rate_limiter(cls):
57+
rate_limiter = None
58+
if settings.ENABLE_RATE_LIMITER:
59+
rate_limiter = InMemoryRateLimiter(
60+
requests_per_second=0.5, # <-- We can only make a request once every 2 seconds!
61+
check_every_n_seconds=0.1, # Wake up every 100 ms to check whether allowed to make a request,
62+
max_bucket_size=5, # Controls the maximum burst size.
63+
)
64+
return rate_limiter
65+
7066
def message_builder():
7167
...
7268

7369
def invoke(self, *args, **kwargs):
7470
"""
7571
The final invoke method that takes any arguments that is to be finally added in the prompt message and invokes the llm call.
7672
"""
77-
78-
# format_instructions = ""
79-
# output_parser = None
80-
# if response_schemas:
81-
# output_parser = self.parser.from_response_schemas(response_schemas=response_schemas)
82-
# format_instructions = output_parser.get_format_instructions()
83-
84-
# prompt = self.prompt_template.from_template(template = template_string)
8573

8674
sucessfull_parsing = False
8775

8876
messages = self.llm_prompt.format(
8977
format_instructions=self.format_instructions,
9078
**kwargs
9179
)
92-
# ()
93-
retries = 0
9480
_message = messages
9581
response = ""
9682

97-
while True:
83+
try:
84+
response = self.model.invoke(_message,
85+
config={"metadata": kwargs.get("metadata", {}),
86+
}).content
87+
88+
_message = messages
89+
except Exception as ex:
90+
# ()
91+
log.warning(f"[!] Error while llm invoke: {ex}")
9892
try:
99-
if retries > self.MAX_RETRIES:
100-
break
101-
102-
response = self.model.invoke(_message,
103-
config={'callbacks': self.call_backs,
104-
"metadata": kwargs.get("metadata", {}),
105-
}).content
106-
107-
_message = messages
108-
break
109-
except openai.RateLimitError:
110-
log.warning(f"[!] LLM API rate limit hit ... sleeping for {self.SLEEP_TIME} seconds")
111-
time.sleep(self.SLEEP_TIME)
112-
except Exception as ex:
113-
# ()
114-
log.warning(f"[!] Error while llm invoke: {ex}")
115-
try:
116-
_message = messages[0].content
117-
except Exception:
118-
return "", sucessfull_parsing, messages
119-
# response = self.model.invoke(messages[0].content).content
120-
retries += 1
93+
_message = messages[0].content
94+
except Exception:
95+
return "", sucessfull_parsing, messages
12196

12297
messages = messages[0].content if isinstance(messages, list) else messages
12398

@@ -135,17 +110,14 @@ def invoke(self, *args, **kwargs):
135110
return response, sucessfull_parsing, messages
136111

137112
@classmethod
138-
def get_llm(cls, model_name: str = "azure", api_config: dict = {}, other_config: dict = {}):
139-
140-
config = {**get_llm_config(api_config, type=model_name), **other_config}
113+
def get_llm(cls, model_name: str, llm_config: dict = {}):
141114

142-
return cls.CHAT_MODELS[model_name](**config)
115+
return init_chat_model(model_name, max_retries=cls.MAX_RETRIES, rate_limiter=cls._get_rate_limiter(), **llm_config)
143116

144117
@classmethod
145118
def build(cls,
146119
model_name: str = "azure",
147-
api_config: dict = {},
148-
other_config: dict = {},
120+
llm_config: dict = {},
149121
prompt_template=ChatPromptTemplate,
150122
output_parser=StructuredOutputParser,
151123
response_schemas: list[ResponseSchema] = None,
@@ -166,7 +138,7 @@ def build(cls,
166138

167139
return cls(
168140
model_name=model_name,
169-
config={**get_llm_config(api_config, type=model_name), **other_config},
141+
config={**llm_config},
170142
prompt_template=prompt_template,
171143
output_parser=output_parser,
172144
template_string=template_string,

src/data_tools/core/pipeline/business_glossary/bg.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,8 @@ def __init__(self, profiling_data: pd.DataFrame, *args, **kwargs):
8989
)
9090

9191
self.__table_glossary_llm = ChatModelLLM.build(
92-
model_name=settings.LLM_TYPE,
93-
api_config=settings.BG_CONFIG["NORMAL_INFERENCE"],
94-
other_config=self.LLM_CONFIG_1,
92+
model_name=settings.LLM_PROVIDER,
93+
llm_config=self.LLM_CONFIG_1,
9594
response_schemas=table_glossary,
9695
template_string=BUSINESS_GLOSSARY_PROMPTS[self.TEMPLATE_NAME][
9796
"TABLE_GLOSSARY_TEMPLATE"
@@ -100,19 +99,17 @@ def __init__(self, profiling_data: pd.DataFrame, *args, **kwargs):
10099
)
101100

102101
self.__business_glossary_llm = ChatModelLLM.build(
103-
model_name=settings.LLM_TYPE,
104-
api_config=settings.BG_CONFIG["NORMAL_INFERENCE"],
105-
other_config=self.LLM_CONFIG_2,
102+
model_name=settings.LLM_PROVIDER,
103+
llm_config=self.LLM_CONFIG_2,
106104
response_schemas=column_glossary,
107105
template_string=BUSINESS_GLOSSARY_PROMPTS[self.TEMPLATE_NAME][
108106
"BUSINESS_GLOSSARY_TEMPLATE"
109107
],
110108
prompt_template=PromptTemplate,
111109
)
112110
self.__business_tags_llm = ChatModelLLM.build(
113-
model_name=settings.LLM_TYPE,
114-
api_config=settings.BG_CONFIG["NORMAL_INFERENCE"],
115-
other_config=self.LLM_CONFIG_2,
111+
model_name=settings.LLM_PROVIDER,
112+
llm_config=self.LLM_CONFIG_2,
116113
response_schemas=column_tag_glossary,
117114
template_string=BUSINESS_GLOSSARY_PROMPTS[self.TEMPLATE_NAME][
118115
"BUSINESS_TAGS_TEMPLATE"

src/data_tools/core/pipeline/datatype_identification/l2_model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -54,9 +54,8 @@ def __init__(self, *args, **kwargs):
5454
# langfuse for monitoring
5555

5656
self.chat_llm = ChatModelLLM.build(
57-
model_name=settings.LLM_TYPE,
58-
api_config=settings.DI_CONFIG,
59-
other_config=self.LLM_CONFIG,
57+
model_name=settings.LLM_PROVIDER,
58+
llm_config=self.LLM_CONFIG,
6059
template_string=self.DIM_MEASURE_PROMPT,
6160
response_schemas=self.dm_class_schema,
6261
)

src/data_tools/core/pipeline/key_identification/ki.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,8 @@ def __init__(self, profiling_data: pd.DataFrame,
4747
*args, **kwargs):
4848

4949
self.__chat_llm = ChatModelLLM.build(
50-
model_name=settings.LLM_TYPE,
51-
api_config=settings.KI_CONFIG,
52-
other_config=self.LLM_CONFIG,
50+
model_name=settings.LLM_PROVIDER,
51+
llm_config=self.LLM_CONFIG,
5352
prompt_template=PromptTemplate,
5453
template_string=self.KI_PROMPT_TEMPLATE,
5554
response_schemas=self.primary_key,

src/data_tools/core/pipeline/link_prediction/lp.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,8 @@ def __init__(
9696

9797
self.llm = (
9898
ChatModelLLM.get_llm(
99-
model_name=settings.LLM_TYPE,
100-
api_config=settings.LP_CONFIG,
101-
other_config=self.LLM_CONFIG,
99+
model_name=settings.LLM_PROVIDER,
100+
llm_config=self.LLM_CONFIG,
102101
)
103102
if llm is None
104103
else llm

0 commit comments

Comments
 (0)