Skip to content

Commit 37ee5dc

Browse files
Merge pull request stanfordnlp#773 from stanfordnlp/curieo-org-initial-groq-support
Curieo org initial groq support
2 parents a12b362 + 969220b commit 37ee5dc

File tree

8 files changed

+252
-6
lines changed

8 files changed

+252
-6
lines changed
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
---
2+
sidebar_position: 9
3+
---
4+
5+
# dspy.GROQ
6+
7+
### Usage
8+
9+
```python
10+
lm = dspy.GROQ(model='mixtral-8x7b-32768', api_key ="gsk_***" )
11+
```
12+
13+
### Constructor
14+
15+
The constructor initializes the base class `LM` and verifies the provided arguments like the `api_key` for GROQ api retriver. The `kwargs` attribute is initialized with default values for relevant text generation parameters needed for communicating with the GPT API, such as `temperature`, `max_tokens`, `top_p`, `frequency_penalty`, `presence_penalty`, and `n`.
16+
17+
```python
18+
class GroqLM(LM):
19+
def __init__(
20+
self,
21+
api_key: str,
22+
model: str = "mixtral-8x7b-32768",
23+
**kwargs,
24+
):
25+
```
26+
27+
28+
29+
**Parameters:**
30+
- `api_key` str: API provider authentication token. Defaults to None.
31+
- `model` str: model name. Defaults to "mixtral-8x7b-32768' options: ['llama2-70b-4096', 'gemma-7b-it']
32+
- `**kwargs`: Additional language model arguments to pass to the API provider.
33+
34+
### Methods
35+
36+
#### `def __call__(self, prompt: str, only_completed: bool = True, return_sorted: bool = False, **kwargs, ) -> list[dict[str, Any]]:`
37+
38+
Retrieves completions from GROQ by calling `request`.
39+
40+
Internally, the method handles the specifics of preparing the request prompt and corresponding payload to obtain the response.
41+
42+
After generation, the generated content look like `choice["message"]["content"]`.
43+
44+
**Parameters:**
45+
- `prompt` (_str_): Prompt to send to OpenAI.
46+
- `only_completed` (_bool_, _optional_): Flag to return only completed responses and ignore completion due to length. Defaults to True.
47+
- `return_sorted` (_bool_, _optional_): Flag to sort the completion choices using the returned averaged log-probabilities. Defaults to False.
48+
- `**kwargs`: Additional keyword arguments for completion request.
49+
50+
**Returns:**
51+
- `List[Dict[str, Any]]`: List of completion choices.

dsp/modules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,12 @@
88
from .databricks import *
99
from .google import *
1010
from .gpt3 import *
11+
from .groq_client import *
1112
from .hf import HFModel
1213
from .hf_client import Anyscale, HFClientTGI, Together
1314
from .mistral import *
1415
from .ollama import *
1516
from .pyserini import *
1617
from .sbert import *
1718
from .sentence_vectorizer import *
19+

dsp/modules/groq_client.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
import logging
2+
from typing import Any
3+
4+
import backoff
5+
6+
try:
7+
import groq
8+
from groq import Groq
9+
groq_api_error = (groq.APIError, groq.RateLimitError)
10+
except ImportError:
11+
groq_api_error = (Exception)
12+
13+
14+
import dsp
15+
from dsp.modules.lm import LM
16+
17+
# Configure logging
18+
logging.basicConfig(
19+
level=logging.INFO,
20+
format="%(message)s",
21+
handlers=[logging.FileHandler("groq_usage.log")],
22+
)
23+
24+
25+
26+
def backoff_hdlr(details):
27+
"""Handler from https://pypi.org/project/backoff/"""
28+
print(
29+
"Backing off {wait:0.1f} seconds after {tries} tries "
30+
"calling function {target} with kwargs "
31+
"{kwargs}".format(**details),
32+
)
33+
34+
35+
class GroqLM(LM):
36+
"""Wrapper around groq's API.
37+
38+
Args:
39+
model (str, optional): groq supported LLM model to use. Defaults to "mixtral-8x7b-32768".
40+
api_key (Optional[str], optional): API provider Authentication token. use Defaults to None.
41+
**kwargs: Additional arguments to pass to the API provider.
42+
"""
43+
44+
def __init__(
45+
self,
46+
api_key: str,
47+
model: str = "mixtral-8x7b-32768",
48+
**kwargs,
49+
):
50+
super().__init__(model)
51+
self.provider = "groq"
52+
if api_key:
53+
self.api_key = api_key
54+
self.client = Groq(api_key = api_key)
55+
else:
56+
raise ValueError("api_key is required for groq")
57+
58+
59+
self.kwargs = {
60+
"temperature": 0.0,
61+
"max_tokens": 150,
62+
"top_p": 1,
63+
"frequency_penalty": 0,
64+
"presence_penalty": 0,
65+
"n": 1,
66+
**kwargs,
67+
}
68+
models = self.client.models.list().data
69+
if models is not None:
70+
if model in [m.id for m in models]:
71+
self.kwargs["model"] = model
72+
self.history: list[dict[str, Any]] = []
73+
74+
75+
def log_usage(self, response):
76+
"""Log the total tokens from the Groq API response."""
77+
usage_data = response.get("usage")
78+
if usage_data:
79+
total_tokens = usage_data.get("total_tokens")
80+
logging.info(f"{total_tokens}")
81+
82+
def basic_request(self, prompt: str, **kwargs):
83+
raw_kwargs = kwargs
84+
85+
kwargs = {**self.kwargs, **kwargs}
86+
87+
kwargs["messages"] = [{"role": "user", "content": prompt}]
88+
response = self.chat_request(**kwargs)
89+
90+
history = {
91+
"prompt": prompt,
92+
"response": response.choices[0].message.content,
93+
"kwargs": kwargs,
94+
"raw_kwargs": raw_kwargs,
95+
}
96+
97+
self.history.append(history)
98+
99+
return response
100+
101+
@backoff.on_exception(
102+
backoff.expo,
103+
groq_api_error,
104+
max_time=1000,
105+
on_backoff=backoff_hdlr,
106+
)
107+
def request(self, prompt: str, **kwargs):
108+
"""Handles retreival of model completions whilst handling rate limiting and caching."""
109+
if "model_type" in kwargs:
110+
del kwargs["model_type"]
111+
112+
return self.basic_request(prompt, **kwargs)
113+
114+
def _get_choice_text(self, choice) -> str:
115+
return choice.message.content
116+
117+
def chat_request(self, **kwargs):
118+
"""Handles retreival of model completions whilst handling rate limiting and caching."""
119+
response = self.client.chat.completions.create(**kwargs)
120+
return response
121+
122+
def __call__(
123+
self,
124+
prompt: str,
125+
only_completed: bool = True,
126+
return_sorted: bool = False,
127+
**kwargs,
128+
) -> list[dict[str, Any]]:
129+
"""Retrieves completions from model.
130+
131+
Args:
132+
prompt (str): prompt to send to model
133+
only_completed (bool, optional): return only completed responses and ignores completion due to length. Defaults to True.
134+
return_sorted (bool, optional): sort the completion choices using the returned probabilities. Defaults to False.
135+
136+
Returns:
137+
list[dict[str, Any]]: list of completion choices
138+
"""
139+
140+
assert only_completed, "for now"
141+
assert return_sorted is False, "for now"
142+
response = self.request(prompt, **kwargs)
143+
144+
if dsp.settings.log_openai_usage:
145+
self.log_usage(response)
146+
147+
choices = response.choices
148+
149+
completions = [self._get_choice_text(c) for c in choices]
150+
if return_sorted and kwargs.get("n", 1) > 1:
151+
scored_completions = []
152+
153+
for c in choices:
154+
tokens, logprobs = (
155+
c["logprobs"]["tokens"],
156+
c["logprobs"]["token_logprobs"],
157+
)
158+
159+
if "<|endoftext|>" in tokens:
160+
index = tokens.index("<|endoftext|>") + 1
161+
tokens, logprobs = tokens[:index], logprobs[:index]
162+
163+
avglog = sum(logprobs) / len(logprobs)
164+
scored_completions.append((avglog, self._get_choice_text(c)))
165+
166+
scored_completions = sorted(scored_completions, reverse=True)
167+
completions = [c for _, c in scored_completions]
168+
169+
return completions

dsp/modules/hf_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -435,4 +435,4 @@ def _generate(self, prompt, **kwargs):
435435

436436
@CacheMemory.cache
437437
def send_hfsglang_request_v00(arg, **kwargs):
438-
return requests.post(arg, **kwargs)
438+
return requests.post(arg, **kwargs)

dsp/modules/lm.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,11 @@ def inspect_history(self, n: int = 1, skip: int = 0):
7777
if provider == "cohere":
7878
text = choices
7979
elif provider == "openai" or provider == "ollama":
80-
text = " " + self._get_choice_text(choices[0]).strip()
81-
elif provider == "clarifai":
82-
text = choices
80+
text = ' ' + self._get_choice_text(choices[0]).strip()
81+
elif provider == "clarifai" or provider == "claude" :
82+
text=choices
83+
elif provider == "groq":
84+
text = ' ' + choices
8385
elif provider == "google":
8486
text = choices[0].parts[0].text
8587
elif provider == "mistral":

dspy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
Pyserini = dsp.PyseriniRetriever
2222
Clarifai = dsp.ClarifaiLLM
2323
Google = dsp.Google
24+
GROQ = dsp.GroqLM
2425

2526
HFClientTGI = dsp.HFClientTGI
2627
HFClientVLLM = HFClientVLLM

poetry.lock

Lines changed: 21 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ docs = [
5757
"autodoc_pydantic",
5858
"sphinx-reredirects>=0.1.2",
5959
"sphinx-automodapi==0.16.0",
60+
6061
]
6162
dev = ["pytest>=6.2.5"]
6263

@@ -108,6 +109,7 @@ sphinx_rtd_theme = { version = "*", optional = true }
108109
autodoc_pydantic = { version = "*", optional = true }
109110
sphinx-reredirects = { version = "^0.1.2", optional = true }
110111
sphinx-automodapi = { version = "0.16.0", optional = true }
112+
groq = {version = "^0.4.2", optional = true }
111113
rich = "^13.7.1"
112114
psycopg2 = {version = "^2.9.9", optional = true}
113115
pgvector = {version = "^0.2.5", optional = true}

0 commit comments

Comments
 (0)