Skip to content

Commit c36d09d

Browse files
Fix time unit embedding generation (#251)
1 parent 93c2c84 commit c36d09d

File tree

4 files changed

+49
-81
lines changed

4 files changed

+49
-81
lines changed

community/fm-asr-streaming-rag/chain-server/chains.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121

2222
from accumulator import TextAccumulator
2323
from retriever import NVRetriever
24-
from common import get_logger, LLMConfig, TimeResponse, UserIntent
25-
from utils import get_llm, classify
24+
from common import get_logger, LLMConfig
25+
from utils import get_llm, classify, TimeResponse, UserIntent
2626
from prompts import RAG_PROMPT, INTENT_PROMPT, RECENCY_PROMPT, SUMMARIZATION_PROMPT
2727

2828
logger = get_logger(__name__)

community/fm-asr-streaming-rag/chain-server/common.py

Lines changed: 1 addition & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,10 @@
1414
# limitations under the License.
1515

1616
import logging
17-
import requests
18-
import json
1917
import os
20-
import numpy as np
2118

22-
from datetime import datetime, timedelta
19+
from datetime import datetime
2320
from pydantic import BaseModel, Field
24-
from typing import Literal
25-
from langchain_community.utils.math import cosine_similarity
2621

2722
NVIDIA_API_KEY = os.environ.get('NVIDIA_API_KEY', 'null')
2823
LLM_URI = os.environ.get('LLM_URI', None)
@@ -65,74 +60,3 @@ class LLMConfig(BaseModel):
6560
temperature: float = Field("Temperature of the LLM response")
6661
max_docs: int = Field("Maximum number of documents to return")
6762
num_tokens: int = Field("The maximum number of tokens in the response")
68-
69-
def nemo_embedding(text):
70-
"""
71-
Uses the NeMo Embedding MS to convert text to embeddings
72-
- ex: embeddings = nemo_embedding(['Chunk A', 'Chunk B'])
73-
"""
74-
port = os.environ.get('NEMO_EMBEDDING_PORT', 1985)
75-
url = f"http://localhost:{port}/v1/embeddings"
76-
payload = json.dumps({
77-
"input": text,
78-
"model": "NV-Embed-QA",
79-
"input_type": "query"
80-
})
81-
headers = {'Content-Type': 'application/json'}
82-
response = requests.request("POST", url, headers=headers, data=payload)
83-
embeddings = [chunk['embedding'] for chunk in response.json()['data']]
84-
return embeddings
85-
86-
def nvapi_embedding(text):
87-
session = requests.Session()
88-
url = "https://ai.api.nvidia.com/v1/retrieval/nvidia/embeddings"
89-
headers = {
90-
"Authorization": f"Bearer {NVIDIA_API_KEY}",
91-
"Accept": "application/json",
92-
}
93-
payload = {
94-
"input": text,
95-
"input_type": "passage",
96-
"model": "NV-Embed-QA"
97-
}
98-
response = session.post(url, headers=headers, json=payload)
99-
embeddings = [chunk['embedding'] for chunk in response.json()['data']]
100-
return embeddings
101-
102-
VALID_TIME_UNITS = ["seconds", "minutes", "hours", "days"]
103-
TIME_VECTORS = nvapi_embedding(VALID_TIME_UNITS)
104-
105-
def sanitize_time_unit(time_unit):
106-
"""
107-
For cases where an LLM returns a time unit that doesn't match one of the
108-
discrete options, find the closest with cosine similarity.
109-
110-
Example: 'min' -> 'minutes'
111-
"""
112-
if time_unit in VALID_TIME_UNITS:
113-
return time_unit
114-
115-
unit_embedding = nvapi_embedding([time_unit])
116-
similarity = cosine_similarity(unit_embedding, TIME_VECTORS)
117-
return VALID_TIME_UNITS[np.argmax(similarity)]
118-
119-
"""
120-
Pydantic classes that are used to detect user intent and plan accordingly
121-
"""
122-
class TimeResponse(BaseModel):
123-
timeNum: float = Field("The number of time units the user asked about")
124-
timeUnit: str = Field("The unit of time the user asked about")
125-
126-
def to_seconds(self):
127-
""" Return the total number of seconds this represents
128-
"""
129-
self.timeUnit = sanitize_time_unit(self.timeUnit)
130-
return timedelta(**{self.timeUnit: self.timeNum}).total_seconds()
131-
132-
class UserIntent(BaseModel):
133-
intentType: Literal[
134-
"SpecificTopic",
135-
"RecentSummary",
136-
"TimeWindow",
137-
"Unknown"
138-
] = Field("The intent of user's query")

community/fm-asr-streaming-rag/chain-server/prompts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# limitations under the License.
1515

1616
from pydantic import BaseModel
17-
from common import UserIntent, TimeResponse
17+
from utils import UserIntent, TimeResponse
1818

1919
def format_schema(pydantic_obj: BaseModel):
2020
return str(

community/fm-asr-streaming-rag/chain-server/utils.py

Lines changed: 45 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,14 @@
1515

1616
import json
1717
import re
18+
import json
19+
import numpy as np
1820

21+
from datetime import timedelta
1922
from langchain_nvidia_ai_endpoints import ChatNVIDIA, NVIDIAEmbeddings, NVIDIARerank
20-
from pydantic import BaseModel
23+
from pydantic import BaseModel, Field
24+
from langchain_community.utils.math import cosine_similarity
25+
from typing import Literal
2126

2227
from common import (
2328
get_logger,
@@ -97,6 +102,45 @@ def get_embedder(local: bool=True):
97102
truncate="NONE"
98103
)
99104

105+
embed_client = get_embedder()
106+
VALID_TIME_UNITS = ["seconds", "minutes", "hours", "days"]
107+
TIME_VECTORS = embed_client.embed_documents(VALID_TIME_UNITS)
108+
109+
def sanitize_time_unit(time_unit):
110+
"""
111+
For cases where an LLM returns a time unit that doesn't match one of the
112+
discrete options, find the closest with cosine similarity.
113+
114+
Example: 'min' -> 'minutes'
115+
"""
116+
if time_unit in VALID_TIME_UNITS:
117+
return time_unit
118+
119+
unit_embedding = embed_client.embed_documents([time_unit])
120+
similarity = cosine_similarity(unit_embedding, TIME_VECTORS)
121+
return VALID_TIME_UNITS[np.argmax(similarity)]
122+
123+
"""
124+
Pydantic classes that are used to detect user intent and plan accordingly
125+
"""
126+
class TimeResponse(BaseModel):
127+
timeNum: float = Field("The number of time units the user asked about")
128+
timeUnit: str = Field("The unit of time the user asked about")
129+
130+
def to_seconds(self):
131+
""" Return the total number of seconds this represents
132+
"""
133+
self.timeUnit = sanitize_time_unit(self.timeUnit)
134+
return timedelta(**{self.timeUnit: self.timeNum}).total_seconds()
135+
136+
class UserIntent(BaseModel):
137+
intentType: Literal[
138+
"SpecificTopic",
139+
"RecentSummary",
140+
"TimeWindow",
141+
"Unknown"
142+
] = Field("The intent of user's query")
143+
100144
def classify(question, chain, pydantic_obj: BaseModel):
101145
""" Parse a question into structured pydantic_obj
102146
"""

0 commit comments

Comments
 (0)