Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Zen instruments the following AI SDKs to track which models are used and how man
* ✅ [`anthropic`](https://pypi.org/project/anthropic/)
* ✅ [`mistralai`](https://pypi.org/project/mistralai) ^1.0.0
* ✅ [`boto3`](https://pypi.org/project/boto3) (AWS Bedrock)
* ✅ [`groq`](https://pypi.org/project/groq)

Zen is compatible with Python 3.8-3.13 and can run on Windows, Linux, and Mac OS X.

Expand Down
44 changes: 44 additions & 0 deletions aikido_zen/sinks/groq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from aikido_zen.helpers.on_ai_call import on_ai_call
from aikido_zen.helpers.register_call import register_call
from aikido_zen.sinks import on_import, patch_function, after, after_async


def get_provider_and_model_from_groq_model(groq_model: str):
# e.g. return_value.model = 'openai/gpt-oss-20b'
provider = groq_model.split("/")[0]
model = "/".join(groq_model.split("/")[1:])
return provider, model


@after
def _completions_create(func, instance, args, kwargs, return_value):
op = f"groq.resources.chat.completions.Completions.create"
register_call(op, "ai_op")

provider, model = get_provider_and_model_from_groq_model(return_value.model)
on_ai_call(
provider=provider,
model=model,
input_tokens=return_value.usage.prompt_tokens,
output_tokens=return_value.usage.completion_tokens,
)


@after_async
async def _completions_create_async(func, instance, args, kwargs, return_value):
op = f"groq.resources.chat.completions.AsyncCompletions.create"
register_call(op, "ai_op")

provider, model = get_provider_and_model_from_groq_model(return_value.model)
on_ai_call(
provider=provider,
model=model,
input_tokens=return_value.usage.prompt_tokens,
output_tokens=return_value.usage.completion_tokens,
)


@on_import("groq.resources.chat.completions")
def patch(m):
patch_function(m, "Completions.create", _completions_create)
patch_function(m, "AsyncCompletions.create", _completions_create_async)
70 changes: 70 additions & 0 deletions aikido_zen/sinks/tests/groq_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os
import pytest
import aikido_zen.sinks.groq
from groq import Groq, AsyncGroq
import asyncio

from aikido_zen.thread.thread_cache import get_cache

skip_no_api_key = pytest.mark.skipif(
"GROQ_API_KEY" not in os.environ,
reason="GROQ_API_KEY environment variable not set",
)


@pytest.fixture(autouse=True)
def setup():
get_cache().reset()
yield
get_cache().reset()


def get_ai_stats():
return get_cache().ai_stats.get_stats()


@skip_no_api_key
def test_groq_messages_create():
client = Groq()
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": "Explain the importance of low latency LLMs in 10-15 words.",
}
],
model="openai/gpt-oss-20b",
max_completion_tokens=20,
)
print(chat_completion.choices[0].message.content)

assert get_ai_stats()[0]["model"] == "gpt-oss-20b"
assert get_ai_stats()[0]["calls"] == 1
assert get_ai_stats()[0]["provider"] == "openai"
assert get_ai_stats()[0]["tokens"]["input"] == 87
assert get_ai_stats()[0]["tokens"]["output"] == 20
assert get_ai_stats()[0]["tokens"]["total"] == 107


@skip_no_api_key
@pytest.mark.asyncio
async def test_groq_messages_create_async():
client = AsyncGroq()
chat_completion = await client.chat.completions.create(
messages=[
{
"role": "user",
"content": "Explain the importance of low latency LLMs in great length.",
}
],
model="meta-llama/llama-4-scout-17b-16e-instruct",
max_completion_tokens=20,
)
print(chat_completion.choices[0].message.content)

assert get_ai_stats()[0]["model"] == "llama-4-scout-17b-16e-instruct"
assert get_ai_stats()[0]["calls"] == 1
assert get_ai_stats()[0]["provider"] == "meta-llama"
assert get_ai_stats()[0]["tokens"]["input"] == 22
assert get_ai_stats()[0]["tokens"]["output"] == 20
assert get_ai_stats()[0]["tokens"]["total"] == 42
25 changes: 24 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ mistralai = { version = "^1.8.2", python = ">=3.9,<4.0" }
boto3 = { version = "^1.40.17", python = ">=3.9,<4.0" }
django = "4"
requests = "^2.32.3"
groq = "^0.31.1"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
Loading