Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ Zen instruments the following AI SDKs to track which models are used and how man
* ✅ [`anthropic`](https://pypi.org/project/anthropic/)
* ✅ [`mistralai`](https://pypi.org/project/mistralai) ^1.0.0
* ✅ [`boto3`](https://pypi.org/project/boto3) (AWS Bedrock)
* ✅ [`groq`](https://pypi.org/project/groq)

Zen is compatible with Python 3.8-3.13 and can run on Windows, Linux, and Mac OS X.

Expand Down
44 changes: 44 additions & 0 deletions aikido_zen/sinks/groq.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from aikido_zen.helpers.on_ai_call import on_ai_call
from aikido_zen.helpers.register_call import register_call
from aikido_zen.sinks import on_import, patch_function, after, after_async


def get_provider_and_model_from_groq_model(groq_model: str):
# e.g. return_value.model = 'openai/gpt-oss-20b'
provider = groq_model.split("/")[0]
model = "/".join(groq_model.split("/")[1:])
return provider, model


@after
def _completions_create(func, instance, args, kwargs, return_value):
op = f"groq.resources.chat.completions.Completions.create"
register_call(op, "ai_op")

provider, model = get_provider_and_model_from_groq_model(return_value.model)
on_ai_call(
provider=provider,
model=model,
input_tokens=return_value.usage.prompt_tokens,
output_tokens=return_value.usage.completion_tokens,
)


@after_async
async def _completions_create_async(func, instance, args, kwargs, return_value):
op = f"groq.resources.chat.completions.AsyncCompletions.create"
register_call(op, "ai_op")

provider, model = get_provider_and_model_from_groq_model(return_value.model)
on_ai_call(
provider=provider,
model=model,
input_tokens=return_value.usage.prompt_tokens,
output_tokens=return_value.usage.completion_tokens,
)


@on_import("groq.resources.chat.completions")
def patch(m):
patch_function(m, "Completions.create", _completions_create)
patch_function(m, "AsyncCompletions.create", _completions_create_async)
70 changes: 70 additions & 0 deletions aikido_zen/sinks/tests/groq_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import os
import pytest
import aikido_zen.sinks.groq
from groq import Groq, AsyncGroq
import asyncio

from aikido_zen.thread.thread_cache import get_cache

skip_no_api_key = pytest.mark.skipif(
"GROQ_API_KEY" not in os.environ,
reason="GROQ_API_KEY environment variable not set",
)


@pytest.fixture(autouse=True)
def setup():
get_cache().reset()
yield
get_cache().reset()


def get_ai_stats():
return get_cache().ai_stats.get_stats()


@skip_no_api_key
def test_groq_messages_create():
client = Groq()
chat_completion = client.chat.completions.create(
messages=[
{
"role": "user",
"content": "Explain the importance of low latency LLMs in 10-15 words.",
}
],
model="openai/gpt-oss-20b",
max_completion_tokens=20,
)
print(chat_completion.choices[0].message.content)

assert get_ai_stats()[0]["model"] == "gpt-oss-20b"
assert get_ai_stats()[0]["calls"] == 1
assert get_ai_stats()[0]["provider"] == "openai"
assert get_ai_stats()[0]["tokens"]["input"] == 87
assert get_ai_stats()[0]["tokens"]["output"] == 20
assert get_ai_stats()[0]["tokens"]["total"] == 107


@skip_no_api_key
@pytest.mark.asyncio
async def test_anthropic_messages_create_async():
client = AsyncGroq()
chat_completion = await client.chat.completions.create(
messages=[
{
"role": "user",
"content": "Explain the importance of low latency LLMs in great length.",
}
],
model="meta-llama/llama-4-scout-17b-16e-instruct",
max_completion_tokens=20,
)
print(chat_completion.choices[0].message.content)

assert get_ai_stats()[0]["model"] == "llama-4-scout-17b-16e-instruct"
assert get_ai_stats()[0]["calls"] == 1
assert get_ai_stats()[0]["provider"] == "meta-llama"
assert get_ai_stats()[0]["tokens"]["input"] == 22
assert get_ai_stats()[0]["tokens"]["output"] == 20
assert get_ai_stats()[0]["tokens"]["total"] == 42
25 changes: 24 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,7 @@ mistralai = { version = "^1.8.2", python = ">=3.9,<4.0" }
boto3 = { version = "^1.40.17", python = ">=3.9,<4.0" }
django = "4"
requests = "^2.32.3"
groq = "^0.31.1"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
Loading