Skip to content

Commit b78f41b

Browse files
Added adapter_instance_id and run_id (#40)
* Added adapter_instance_id and run_id * Updated field model_type -> model_name * Review comment fixes * Version bump --------- Co-authored-by: Rahul Johny <[email protected]>
1 parent 6aa69f6 commit b78f41b

File tree

5 files changed

+51
-57
lines changed

5 files changed

+51
-57
lines changed

src/unstract/sdk/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "0.22.1"
1+
__version__ = "0.23.0"
22

33

44
def get_sdk_version():

src/unstract/sdk/audit.py

Lines changed: 27 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,6 @@ class Audit(StreamMixin):
1515
1616
Attributes:
1717
None
18-
19-
Example usage:
20-
audit = Audit()
21-
audit.push_usage_data(
22-
token_counter,
23-
workflow_id,
24-
execution_id,
25-
external_service,
26-
event_type)
2718
"""
2819

2920
def __init__(self, log_level: LogLevel = LogLevel.INFO) -> None:
@@ -33,23 +24,28 @@ def push_usage_data(
3324
self,
3425
platform_api_key: str,
3526
token_counter: TokenCountingHandler = None,
36-
workflow_id: str = "",
37-
execution_id: str = "",
38-
external_service: str = "",
27+
model_name: str = "",
3928
event_type: CBEventType = None,
29+
**kwargs,
4030
) -> None:
4131
"""Pushes the usage data to the platform service.
4232
4333
Args:
34+
platform_api_key (str): The platform API key.
4435
token_counter (TokenCountingHandler, optional): The token counter
45-
object. Defaults to None.
46-
workflow_id (str, optional): The ID of the workflow. Defaults to "".
47-
execution_id (str, optional): The ID of the execution. Defaults
48-
to "".
49-
external_service (str, optional): The name of the external service.
50-
Defaults to "".
36+
object. Defaults to None.
37+
model_name (str, optional): The name of the model.
38+
Defaults to "".
5139
event_type (CBEventType, optional): The type of the event. Defaults
52-
to None.
40+
to None.
41+
**kwargs: Optional keyword arguments.
42+
workflow_id (str, optional): The ID of the workflow.
43+
Defaults to "".
44+
execution_id (str, optional): The ID of the execution. Defaults
45+
to "".
46+
adapter_instance_id (str, optional): The adapter instance ID.
47+
Defaults to "".
48+
run_id (str, optional): The run ID. Defaults to "".
5349
5450
Returns:
5551
None
@@ -66,11 +62,18 @@ def push_usage_data(
6662
)
6763
bearer_token = platform_api_key
6864

65+
workflow_id = kwargs.get("workflow_id", "")
66+
execution_id = kwargs.get("execution_id", "")
67+
adapter_instance_id = kwargs.get("adapter_instance_id", "")
68+
run_id = kwargs.get("run_id", "")
69+
6970
data = {
70-
"usage_type": event_type,
71-
"external_service": external_service,
7271
"workflow_id": workflow_id,
7372
"execution_id": execution_id,
73+
"adapter_instance_id": adapter_instance_id,
74+
"run_id": run_id,
75+
"usage_type": event_type,
76+
"model_name": model_name,
7477
"embedding_tokens": token_counter.total_embedding_token_count,
7578
"prompt_tokens": token_counter.prompt_llm_token_count,
7679
"completion_tokens": token_counter.completion_llm_token_count,
@@ -100,3 +103,6 @@ def push_usage_data(
100103
log=f"Error while pushing usage details: {e}",
101104
level=LogLevel.ERROR,
102105
)
106+
107+
finally:
108+
token_counter.reset_counts()

src/unstract/sdk/llm.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -57,32 +57,26 @@ def run_completion(
5757
) -> Optional[dict[str, Any]]:
5858
# Setup callback manager to collect Usage stats
5959
UNCallbackManager.set_callback_manager(
60-
platform_api_key=platform_api_key, llm=llm
60+
platform_api_key=platform_api_key, llm=llm, **kwargs
6161
)
62+
# Removing specific keys from kwargs
63+
new_kwargs = kwargs.copy()
64+
for key in [
65+
"workflow_id",
66+
"execution_id",
67+
"adapter_instance_id",
68+
"run_id",
69+
]:
70+
new_kwargs.pop(key, None)
6271
for i in range(retries):
6372
try:
64-
response: CompletionResponse = llm.complete(prompt, **kwargs)
73+
response: CompletionResponse = llm.complete(
74+
prompt, **new_kwargs
75+
)
6576
match = cls.json_regex.search(response.text)
6677
if match:
6778
response.text = match.group(0)
68-
69-
usage = {}
70-
llm_token_counts = llm.callback_manager.handlers[
71-
0
72-
].llm_token_counts
73-
if llm_token_counts:
74-
llm_token_count = llm_token_counts[0]
75-
usage[
76-
"prompt_token_count"
77-
] = llm_token_count.prompt_token_count
78-
usage[
79-
"completion_token_count"
80-
] = llm_token_count.completion_token_count
81-
usage[
82-
"total_token_count"
83-
] = llm_token_count.total_token_count
84-
85-
return {"response": response, "usage": usage}
79+
return {"response": response}
8680

8781
except Exception as e:
8882
if i == retries - 1:

src/unstract/sdk/utils/callback_manager.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ def set_callback_manager(
4141
platform_api_key: str,
4242
llm: Optional[LLM] = None,
4343
embedding: Optional[BaseEmbedding] = None,
44-
workflow_id: str = "",
45-
execution_id: str = "",
44+
**kwargs,
4645
) -> LlamaIndexCallbackManager:
4746
"""Sets the standard callback manager for the llm. This is to be called
4847
explicitly whenever there is a need for the callback handling defined
@@ -52,7 +51,7 @@ def set_callback_manager(
5251
llm (LLM): The LLM type
5352
5453
Returns:
55-
CallbackManager tyoe of llama index
54+
CallbackManager type of llama index
5655
5756
Example:
5857
UNCallbackManager.set_callback_manager(
@@ -73,8 +72,7 @@ def set_callback_manager(
7372
platform_api_key=platform_api_key,
7473
llm_model=llm,
7574
embed_model=embedding,
76-
workflow_id=workflow_id,
77-
execution_id=execution_id,
75+
**kwargs,
7876
)
7977

8078
callback_manager: LlamaIndexCallbackManager = LlamaIndexCallbackManager(

src/unstract/sdk/utils/usage_handler.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,16 +36,14 @@ def __init__(
3636
platform_api_key: str,
3737
llm_model: LLM = None,
3838
embed_model: BaseEmbedding = None,
39-
workflow_id: str = "",
40-
execution_id: str = "",
4139
event_starts_to_ignore: Optional[list[CBEventType]] = None,
4240
event_ends_to_ignore: Optional[list[CBEventType]] = None,
4341
verbose: bool = False,
4442
log_level: LogLevel = LogLevel.INFO,
43+
**kwargs,
4544
) -> None:
45+
self.kwargs = kwargs
4646
self._verbose = verbose
47-
self.workflow_id = workflow_id
48-
self.execution_id = execution_id
4947
self.token_counter = token_counter
5048
self.llm_model = llm_model
5149
self.embed_model = embed_model
@@ -96,9 +94,8 @@ def on_event_end(
9694
platform_api_key=self.platform_api_key,
9795
token_counter=self.token_counter,
9896
event_type=event_type,
99-
external_service=self.llm_model.metadata.model_name,
100-
workflow_id=self.workflow_id,
101-
execution_id=self.execution_id,
97+
model_name=self.llm_model.metadata.model_name,
98+
**self.kwargs,
10299
)
103100

104101
elif (
@@ -113,7 +110,6 @@ def on_event_end(
113110
platform_api_key=self.platform_api_key,
114111
token_counter=self.token_counter,
115112
event_type=event_type,
116-
external_service=self.embed_model.model_name,
117-
workflow_id=self.workflow_id,
118-
execution_id=self.execution_id,
113+
model_name=self.embed_model.model_name,
114+
**self.kwargs,
119115
)

0 commit comments

Comments
 (0)