Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ dependencies = [
"xgboost==1.6.0",
"xpinyin==0.7.6",
"yfinance==0.2.65",
"zhipuai==2.0.1",
# following modules aren't necessary
# "nltk==3.9.1",
# "numpy>=1.26.0,<2.0.0",
Expand Down Expand Up @@ -279,4 +278,4 @@ exclude_lines = [
# HTML report configuration
directory = "htmlcov"
title = "Test Coverage Report"
# extra_css = "custom.css" # Optional custom CSS
# extra_css = "custom.css" # Optional custom CSS
58 changes: 44 additions & 14 deletions rag/llm/embedding_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import os
import threading
from abc import ABC
from typing import Any
from urllib.parse import urljoin

import dashscope
Expand All @@ -25,7 +26,6 @@
import requests
from ollama import Client
from openai import OpenAI
from zhipuai import ZhipuAI

from common.log_utils import log_exception
from common.token_utils import num_tokens_from_string, truncate, total_token_count_from_response
Expand Down Expand Up @@ -223,35 +223,65 @@ def encode_queries(self, text):
class ZhipuEmbed(Base):
_FACTORY_NAME = "ZHIPU-AI"

def __init__(self, key, model_name="embedding-2", **kwargs):
self.client = ZhipuAI(api_key=key)
def __init__(self, key, model_name="embedding-2", base_url="https://open.bigmodel.cn/api/paas/v4", **kwargs):
if not base_url:
base_url = "https://open.bigmodel.cn/api/paas/v4"
normalized_base_url = base_url.rstrip("/")
if normalized_base_url.endswith("/embeddings"):
self.base_url = normalized_base_url
elif normalized_base_url.endswith("/api"):
self.base_url = f"{normalized_base_url}/paas/v4/embeddings"
else:
self.base_url = f"{normalized_base_url}/embeddings"
self.headers = {
"authorization": f"Bearer {key}",
"content-type": "application/json",
}
self.model_name = model_name

def _request_embeddings(self, input_text: str | list[str]) -> dict[str, Any]:
payload = {
"model": self.model_name,
"input": input_text,
}
response = requests.post(self.base_url, json=payload, headers=self.headers, timeout=120)
try:
res = response.json()
except Exception as _e:
log_exception(_e, response)
raise Exception(f"Error: {response}")
if response.status_code != 200:
err = res.get("error", {})
message = err.get("message", str(res))
raise Exception(f"Error: {message}")
return res

def encode(self, texts: list):
arr = []
tks_num = 0
MAX_LEN = -1
max_len = -1
if self.model_name.lower() == "embedding-2":
MAX_LEN = 512
max_len = 512
if self.model_name.lower() == "embedding-3":
MAX_LEN = 3072
if MAX_LEN > 0:
texts = [truncate(t, MAX_LEN) for t in texts]
max_len = 3072
if max_len > 0:
texts = [truncate(t, max_len) for t in texts]

for txt in texts:
res = self.client.embeddings.create(input=txt, model=self.model_name)
res = None
try:
arr.append(res.data[0].embedding)
res = self._request_embeddings(txt)
arr.append(res["data"][0]["embedding"])
tks_num += total_token_count_from_response(res)
except Exception as _e:
log_exception(_e, res)
raise Exception(f"Error: {res}")
log_exception(_e, res if res is not None else {"model": self.model_name, "input": txt})
raise Exception(f"Error: {res}") from _e
return np.array(arr), tks_num

def encode_queries(self, text):
res = self.client.embeddings.create(input=text, model=self.model_name)
res = self._request_embeddings(text)
try:
return np.array(res.data[0].embedding), total_token_count_from_response(res)
return np.array(res["data"][0]["embedding"]), total_token_count_from_response(res)
except Exception as _e:
log_exception(_e, res)
raise Exception(f"Error: {res}")
Expand Down
17 changes: 0 additions & 17 deletions uv.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.