Skip to content

Commit 8d1e39a

Browse files
[FIX] Added headers to prompt service calls (#173)
* Added headers to prompt service calls Signed-off-by: Deepak <[email protected]> * Fixed pre-commit issues Signed-off-by: Deepak <[email protected]> --------- Signed-off-by: Deepak <[email protected]>
1 parent 4903774 commit 8d1e39a

File tree

5 files changed

+32
-16
lines changed

5 files changed

+32
-16
lines changed

src/unstract/sdk/adapters/embedding/open_ai/src/open_ai.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from llama_index.embeddings.openai import OpenAIEmbedding
77

88
from unstract.sdk.adapters.embedding.embedding_adapter import EmbeddingAdapter
9-
from unstract.sdk.adapters.embedding.helper import EmbeddingHelper
109
from unstract.sdk.adapters.exceptions import AdapterError
1110

1211

src/unstract/sdk/adapters/embedding/qdrant_fast_embed/src/qdrant_fast_embed.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from llama_index.embeddings.fastembed import FastEmbedEmbedding
66

77
from unstract.sdk.adapters.embedding.embedding_adapter import EmbeddingAdapter
8-
from unstract.sdk.adapters.embedding.helper import EmbeddingHelper
98
from unstract.sdk.adapters.exceptions import AdapterError
109

1110

src/unstract/sdk/adapters/embedding/vertex_ai/src/vertex_ai.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import json
2-
import logging
32
import os
4-
from typing import Any, Optional
3+
from typing import Any
54

65
from google.auth.transport import requests as google_requests
76
from google.oauth2.service_account import Credentials

src/unstract/sdk/adapters/x2text/llm_whisperer_v2/src/llm_whisperer_v2.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import json
21
import logging
32
import os
43
from typing import Any, Optional
@@ -11,7 +10,6 @@
1110
TextExtractionResult,
1211
)
1312
from unstract.sdk.adapters.x2text.llm_whisperer_v2.src.constants import (
14-
HTTPMethod,
1513
WhispererEndpoint,
1614
)
1715
from unstract.sdk.adapters.x2text.llm_whisperer_v2.src.dto import WhispererRequestParams

src/unstract/sdk/prompt.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,36 +36,50 @@ def __init__(
3636

3737
@log_elapsed(operation="ANSWER_PROMPTS")
3838
def answer_prompt(
39-
self, payload: dict[str, Any], params: Optional[dict[str, str]] = None
39+
self,
40+
payload: dict[str, Any],
41+
params: Optional[dict[str, str]] = None,
42+
headers: Optional[dict[str, str]] = None,
4043
) -> dict[str, Any]:
4144
url_path = "answer-prompt"
4245
if self.is_public_call:
4346
url_path = "answer-prompt-public"
4447
return self._post_call(
45-
url_path=url_path,
46-
payload=payload,
47-
params=params,
48+
url_path=url_path, payload=payload, params=params, headers=headers
4849
)
4950

5051
def single_pass_extraction(
51-
self, payload: dict[str, Any], params: Optional[dict[str, str]] = None
52+
self,
53+
payload: dict[str, Any],
54+
params: Optional[dict[str, str]] = None,
55+
headers: Optional[dict[str, str]] = None,
5256
) -> dict[str, Any]:
5357
return self._post_call(
5458
url_path="single-pass-extraction",
5559
payload=payload,
5660
params=params,
61+
headers=headers,
5762
)
5863

5964
def summarize(
60-
self, payload: dict[str, Any], params: Optional[dict[str, str]] = None
65+
self,
66+
payload: dict[str, Any],
67+
params: Optional[dict[str, str]] = None,
68+
headers: Optional[dict[str, str]] = None,
6169
) -> dict[str, Any]:
62-
return self._post_call(url_path="summarize", payload=payload, params=params)
70+
return self._post_call(
71+
url_path="summarize",
72+
payload=payload,
73+
params=params,
74+
headers=headers,
75+
)
6376

6477
def _post_call(
6578
self,
6679
url_path: str,
6780
payload: dict[str, Any],
6881
params: Optional[dict[str, str]] = None,
82+
headers: Optional[dict[str, str]] = None,
6983
) -> dict[str, Any]:
7084
"""Invokes and communicates to prompt service to fetch response for the
7185
prompt.
@@ -74,6 +88,7 @@ def _post_call(
7488
url_path (str): URL path to the service endpoint
7589
payload (dict): Payload to send in the request body
7690
params (dict, optional): Query parameters to include in the request
91+
headers (dict, optional): Headers to include in the request
7792
7893
Returns:
7994
dict: Response from the prompt service
@@ -94,13 +109,19 @@ def _post_call(
94109
"status_code": 500,
95110
}
96111
url: str = f"{self.base_url}/{url_path}"
97-
headers: dict[str, str] = {}
112+
113+
default_headers = {}
114+
98115
if not self.is_public_call:
99-
headers = {"Authorization": f"Bearer {self.bearer_token}"}
116+
default_headers = {"Authorization": f"Bearer {self.bearer_token}"}
117+
118+
if headers:
119+
default_headers.update(headers)
120+
100121
response: Response = Response()
101122
try:
102123
response = requests.post(
103-
url=url, json=payload, params=params, headers=headers
124+
url=url, json=payload, params=params, headers=default_headers
104125
)
105126
response.raise_for_status()
106127
result["status"] = "OK"

0 commit comments

Comments
 (0)