Skip to content

Commit 8156743

Browse files
[FEATURE] Capture metrics (#136)
* Added metrics_mixin Signed-off-by: Deepak <[email protected]> * Added capture_metrics decorator func Signed-off-by: Deepak <[email protected]> * Used capture_metrics decorator wherever needed Signed-off-by: Deepak <[email protected]> * Added clear_metrics function Signed-off-by: Deepak <[email protected]> * Added redis dependency Signed-off-by: Deepak <[email protected]> * Addressed review comments Signed-off-by: Deepak <[email protected]> * Addressed review comments Signed-off-by: Deepak <[email protected]> * Version bump Signed-off-by: Deepak <[email protected]> * Adderessed review comments Signed-off-by: Deepak <[email protected]> --------- Signed-off-by: Deepak <[email protected]>
1 parent 5da46eb commit 8156743

File tree

7 files changed

+816
-626
lines changed

7 files changed

+816
-626
lines changed

pdm.lock

Lines changed: 649 additions & 612 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ dependencies = [
4141
#Unable to import llm adapters : No module named 'mistralai.models.chat_completion'
4242
#Looks like mistralai>0.4.2 is not backward compatible
4343
"mistralai==0.4.2",
44-
4544
"llama-index-llms-anyscale==0.1.4",
4645
"llama-index-llms-anthropic==0.1.16",
4746
"llama-index-llms-azure-openai==0.1.10",
@@ -58,6 +57,7 @@ dependencies = [
5857
"singleton-decorator~=1.0.0",
5958
"httpx>=0.25.2",
6059
"pdfplumber>=0.11.2",
60+
"redis>=5.2.1",
6161
]
6262
readme = "README.md"
6363
urls = { Homepage = "https://unstract.com", "Release notes" = "https://github.com/Zipstack/unstract-sdk/releases", Source = "https://github.com/Zipstack/unstract-sdk" }
@@ -120,4 +120,6 @@ path = "src/unstract/sdk/__init__.py"
120120
# Adding the following override to resolve dependency version
121121
# for environs. Otherwise, it stays stuck while resolving pins
122122
[tool.pdm.resolution.overrides]
123-
grpcio = ">=1.62.1"
123+
grpcio = "1.62.3"
124+
grpcio-tools = "1.62.3"
125+
grpcio-health-checking = "1.62.3"

src/unstract/sdk/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.54.0rc7"
1+
__version__ = "0.54.0rc8"
22

33

44
def get_sdk_version():

src/unstract/sdk/index.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from unstract.sdk.file_storage import FileStorage, FileStorageProvider
2828
from unstract.sdk.tool.base import BaseTool
2929
from unstract.sdk.utils import ToolUtils
30-
from unstract.sdk.utils.common_utils import log_elapsed
30+
from unstract.sdk.utils.common_utils import capture_metrics, log_elapsed
3131
from unstract.sdk.vector_db import VectorDB
3232
from unstract.sdk.x2txt import X2Text
3333

@@ -39,10 +39,19 @@ class Constants:
3939

4040

4141
class Index:
42-
def __init__(self, tool: BaseTool):
42+
def __init__(
43+
self,
44+
tool: BaseTool,
45+
run_id: Optional[str] = None,
46+
capture_metrics: bool = False,
47+
):
4348
# TODO: Inherit from StreamMixin and avoid using BaseTool
4449
self.tool = tool
50+
self._run_id = run_id
51+
self._capture_metrics = capture_metrics
52+
self._metrics = {}
4553

54+
@capture_metrics
4655
def query_index(
4756
self,
4857
embedding_instance_id: str,
@@ -180,6 +189,7 @@ def extract_text(
180189
return extracted_text
181190

182191
@log_elapsed(operation="CHECK_AND_INDEX(overall)")
192+
@capture_metrics
183193
def index(
184194
self,
185195
tool_id: str,
@@ -449,6 +459,12 @@ def generate_index_key(
449459
hashed_index_key = ToolUtils.hash_str(json.dumps(index_key, sort_keys=True))
450460
return hashed_index_key
451461

462+
def get_metrics(self):
463+
return self._metrics
464+
465+
def clear_metrics(self):
466+
self._metrics = {}
467+
452468
@deprecated(version="0.45.0", reason="Use generate_index_key() instead")
453469
def generate_file_id(
454470
self,

src/unstract/sdk/llm.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from unstract.sdk.helper import SdkHelper
2020
from unstract.sdk.tool.base import BaseTool
2121
from unstract.sdk.utils.callback_manager import CallbackManager
22+
from unstract.sdk.utils.common_utils import capture_metrics
2223

2324
logger = logging.getLogger(__name__)
2425

@@ -36,6 +37,7 @@ def __init__(
3637
tool: BaseTool,
3738
adapter_instance_id: Optional[str] = None,
3839
usage_kwargs: dict[Any, Any] = {},
40+
capture_metrics: bool = False,
3941
):
4042
"""Creates an instance of this LLM class.
4143
@@ -50,6 +52,10 @@ def __init__(
5052
self._adapter_instance_id = adapter_instance_id
5153
self._llm_instance: LlamaIndexLLM = None
5254
self._usage_kwargs = usage_kwargs
55+
self._capture_metrics = capture_metrics
56+
self._run_id = usage_kwargs.get("run_id")
57+
self._usage_reason = usage_kwargs.get("llm_usage_reason")
58+
self._metrics = {}
5359
self._initialise()
5460

5561
def _initialise(self):
@@ -65,14 +71,16 @@ def _initialise(self):
6571
kwargs=self._usage_kwargs,
6672
)
6773

74+
@capture_metrics
6875
def complete(
6976
self,
7077
prompt: str,
7178
extract_json: bool = True,
7279
process_text: Optional[Callable[[str], str]] = None,
7380
**kwargs: Any,
7481
) -> dict[str, Any]:
75-
"""Generates a completion response for the given prompt.
82+
"""Generates a completion response for the given prompt and captures
83+
metrics if run_id is provided.
7684
7785
Args:
7886
prompt (str): The input text prompt for generating the completion.
@@ -85,12 +93,8 @@ def complete(
8593
**kwargs (Any): Additional arguments passed to the completion function.
8694
8795
Returns:
88-
dict[str, Any]: A dictionary containing the result of the completion
89-
and any processed output.
90-
91-
Raises:
92-
LLMError: If an error occurs during the completion process, it will be
93-
raised after being processed by `parse_llm_err`.
96+
dict[str, Any]: A dictionary containing the result of the completion,
97+
any processed output, and the captured metrics (if applicable).
9498
"""
9599
try:
96100
response: CompletionResponse = self._llm_instance.complete(prompt, **kwargs)
@@ -105,12 +109,19 @@ def complete(
105109
if not isinstance(process_text_output, dict):
106110
process_text_output = {}
107111
except Exception as e:
108-
logger.error(f"Error occured inside function 'process_text': {e}")
112+
logger.error(f"Error occurred inside function 'process_text': {e}")
109113
process_text_output = {}
110-
return {LLM.RESPONSE: response, **process_text_output}
114+
response_data = {LLM.RESPONSE: response, **process_text_output}
115+
return response_data
111116
except Exception as e:
112117
raise parse_llm_err(e, self._llm_adapter_class) from e
113118

119+
def get_metrics(self):
120+
return self._metrics
121+
122+
def get_usage_reason(self):
123+
return self._usage_reason
124+
114125
def stream_complete(
115126
self,
116127
prompt: str,

src/unstract/sdk/metrics_mixin.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
import logging
2+
import os
3+
import time
4+
import uuid
5+
from typing import Any
6+
7+
from redis import StrictRedis
8+
9+
logger = logging.getLogger(__name__)
10+
11+
12+
class MetricsMixin:
13+
TIME_TAKEN_KEY = "time_taken(s)"
14+
15+
def __init__(self, run_id):
16+
"""Initialize the MetricsMixin class.
17+
18+
Args:
19+
run_id (str): Unique identifier for the run.
20+
"""
21+
self.run_id = run_id
22+
self.op_id = str(uuid.uuid4()) # Unique identifier for this instance
23+
self.redis_client = None
24+
try:
25+
# Initialize Redis client
26+
self.redis_client = StrictRedis(
27+
host=os.getenv("REDIS_HOST", "unstract-redis"),
28+
port=int(os.getenv("REDIS_PORT", 6379)),
29+
username=os.getenv("REDIS_USER", "default"),
30+
password=os.getenv("REDIS_PASSWORD", ""),
31+
db=1,
32+
decode_responses=True,
33+
)
34+
except Exception as e:
35+
logger.error(
36+
"Failed to initialize Redis client" f" for run_id={run_id}: {e}"
37+
)
38+
39+
self.redis_key = f"metrics:{self.run_id}:{self.op_id}"
40+
41+
# Set the start time immediately upon initialization
42+
self.set_start_time()
43+
44+
def set_start_time(self, ttl=86400):
45+
"""Store the current timestamp in Redis when the instance is
46+
created."""
47+
if self.redis_client is None:
48+
logger.error("Redis client is not initialized. Cannot set start time.")
49+
return
50+
self.redis_client.set(self.redis_key, time.time(), ex=ttl)
51+
52+
def collect_metrics(self) -> dict[str, Any]:
53+
"""Calculate the time taken since the timestamp was set and delete the
54+
Redis key.
55+
56+
Returns:
57+
dict: The calculated time taken and the associated run_id and op_id.
58+
"""
59+
60+
if self.redis_client is None:
61+
logger.error("Redis client is not initialized. Cannot collect metrics.")
62+
return {self.TIME_TAKEN_KEY: None}
63+
64+
if not self.redis_client.exists(self.redis_key):
65+
return {self.TIME_TAKEN_KEY: None}
66+
67+
start_time = float(self.redis_client.get(self.redis_key))
68+
time_taken = round(time.time() - start_time, 3)
69+
70+
# Delete the Redis key after use
71+
self.redis_client.delete(self.redis_key)
72+
73+
return {self.TIME_TAKEN_KEY: time_taken}

src/unstract/sdk/utils/common_utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import uuid
55

66
from unstract.sdk.constants import LogLevel
7+
from unstract.sdk.metrics_mixin import MetricsMixin
78

89
logger = logging.getLogger(__name__)
910

@@ -54,3 +55,53 @@ def wrapper(*args, **kwargs):
5455
return wrapper
5556

5657
return decorator
58+
59+
60+
def capture_metrics(func):
61+
"""Decorator to capture metrics at the start and end of a function."""
62+
63+
@functools.wraps(func)
64+
def wrapper(self, *args, **kwargs):
65+
# Ensure the required attributes exist; if not,
66+
# execute the function and return its result
67+
if not all(
68+
hasattr(self, attr) for attr in ["_run_id", "_capture_metrics", "_metrics"]
69+
):
70+
return func(self, *args, **kwargs)
71+
72+
# Check if run_id exists and if metrics should be captured
73+
metrics_mixin = None
74+
time_taken_key = MetricsMixin.TIME_TAKEN_KEY
75+
if self._run_id and self._capture_metrics:
76+
metrics_mixin = MetricsMixin(run_id=self._run_id)
77+
78+
try:
79+
result = func(self, *args, **kwargs)
80+
finally:
81+
# If metrics are being captured, collect and assign them at the end
82+
if metrics_mixin:
83+
new_metrics = metrics_mixin.collect_metrics()
84+
85+
# If time_taken(s) exists in both self._metrics and new_metrics, sum it
86+
if (
87+
self._metrics
88+
and time_taken_key in self._metrics
89+
and time_taken_key in new_metrics
90+
):
91+
previously_measured_time = self._metrics.get(time_taken_key)
92+
newly_measured_time = new_metrics.get(time_taken_key)
93+
94+
# Only sum if both are valid
95+
if previously_measured_time and newly_measured_time:
96+
self._metrics[time_taken_key] = (
97+
previously_measured_time + newly_measured_time
98+
)
99+
else:
100+
self._metrics[time_taken_key] = None
101+
else:
102+
# If the key isn't in self._metrics, set it to new_metrics
103+
self._metrics = new_metrics
104+
105+
return result
106+
107+
return wrapper

0 commit comments

Comments
 (0)