Skip to content

Commit 99517a8

Browse files
committed
[0.5.0] add finetuning API
1 parent 2fabdae commit 99517a8

File tree

14 files changed

+227
-257
lines changed

14 files changed

+227
-257
lines changed

docs/changelog.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,26 @@ minor versions.
77

88
All relevant steps to be taken will be mentioned here.
99

10+
0.5.0 **(breaking)**
11+
--------------------
12+
13+
In this release we have moved all the Tune Studio specific API out of ``tuneapi.apis`` to ``tuneapi.endpoints`` to avoid
14+
cluttering the ``apis`` namespace.
15+
16+
.. code-block:: patch
17+
18+
- from tuneapi import apis as ta
19+
+ from tuneapi import endpoints as te
20+
...
21+
- ta.ThreadsAPI(...)
22+
+ te.ThreadsAPI(...)
23+
24+
- Add support for finetuning APIs with ``tuneapi.endpoints.FinetuningAPI``
25+
- Primary environment variables have been changed from ``TUNE_API_KEY`` to ``TUNEAPI_TOKEN`` and from ``TUNE_ORG_ID``
26+
to ``TUNEORG_ID``, if you were using these please update your environment variables
27+
- Removed CLI methods ``test_models`` and ``benchmark_models``, if you want to use those, please copy the code from
28+
`this commit <https://github.com/NimbleBoxAI/tuneapi/blob/2fabdae461f4187621fe8ffda73a58a5ab7485b0/tuneapi/apis/__init__.py#L26>`_
29+
1030
0.4.18
1131
------
1232

docs/conf.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
project = "tuneapi"
1414
copyright = "2024, Frello Technologies"
1515
author = "Frello Technologies"
16-
release = "0.4.18"
16+
release = "0.5.0"
1717

1818
# -- General configuration ---------------------------------------------------
1919
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "tuneapi"
3-
version = "0.4.18"
3+
version = "0.5.0"
44
description = "Tune AI APIs."
55
authors = ["Frello Technology Private Limited <[email protected]>"]
66
license = "MIT"

tuneapi/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Copyright © 2023- Frello Technology Private Limited
22

3-
__version__ = "0.4.18"
3+
__version__ = "0.5.0"

tuneapi/__main__.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,12 @@
22

33
from fire import Fire
44

5+
from tuneapi import __version__
6+
57

68
def main():
7-
from tuneapi.apis import test_models, benchmark_models
89

9-
Fire(
10-
{
11-
"models": {
12-
"test": test_models,
13-
"benchmark": benchmark_models,
14-
},
15-
}
16-
)
10+
Fire({"version": __version__})
1711

1812

1913
if __name__ == "__main__":

tuneapi/apis/__init__.py

Lines changed: 0 additions & 184 deletions
Original file line numberDiff line numberDiff line change
@@ -8,193 +8,9 @@
88
from tuneapi.apis.model_mistral import Mistral
99
from tuneapi.apis.model_gemini import Gemini
1010

11-
# projectX APIs
12-
from tuneapi.apis.threads import ThreadsAPI
13-
from tuneapi.apis.assistants import AssistantsAPI
1411

1512
# other imports
1613
import os
1714
import random
1815
from time import time
1916
from typing import List, Optional
20-
21-
# other tuneapi modules
22-
import tuneapi.types as tt
23-
import tuneapi.utils as tu
24-
25-
26-
def test_models(thread: str | tt.Thread, models: Optional[List[str]] = None):
27-
"""
28-
Runs thread on all the models and prints the time taken and response.
29-
"""
30-
if os.path.exists(thread):
31-
thread = tt.Thread.from_dict(tu.from_json(thread))
32-
33-
# get all the models
34-
models_to_test = [TuneModel, Openai, Anthropic, Groq, Mistral, Gemini]
35-
if models and models != "all":
36-
models_to_test = []
37-
for m in models:
38-
models_to_test.append(globals()[m])
39-
40-
# run all in a loop
41-
for model in models_to_test:
42-
print(tu.color.bold(f"[{model.__name__}]"), end=" ", flush=True)
43-
try:
44-
st = time()
45-
m = model()
46-
out = m.chat(thread)
47-
et = time()
48-
print(
49-
tu.color.blue(f"[{et-st:0.2f}s]"),
50-
tu.color.green(f"[SUCCESS]", True),
51-
out,
52-
)
53-
except Exception as e:
54-
et = time()
55-
print(
56-
tu.color.blue(f"[{et-st:0.2f}s]"),
57-
tu.color.red(f"[ERROR]", True),
58-
str(e),
59-
)
60-
continue
61-
62-
63-
def benchmark_models(
64-
thread: str | tt.Thread,
65-
models: Optional[List[str]] = "all",
66-
n: int = 20,
67-
max_threads: int = 5,
68-
o: str = "benchmark.csv",
69-
):
70-
"""
71-
Benchmarks a thread on all the models and saves the time taken and response in a CSV file and creates matplotlib
72-
histogram chart with latency and char count distribution. Runs `n` iterations for each model.
73-
74-
It requires `matplotlib>=3.8.2` and `pandas>=2.2.0` to be installed.
75-
"""
76-
77-
try:
78-
import matplotlib.pyplot as plt
79-
import pandas as pd
80-
except ImportError:
81-
tu.logger.error(
82-
"This is a special CLI helper function. If you want to use this then run: pip install matplotlib>=3.8.2 pandas>=2.2.0"
83-
)
84-
raise ImportError("Please install the required packages")
85-
86-
# if this is a JSON then load the thread
87-
if os.path.exists(thread):
88-
thread = tt.Thread.from_dict(tu.from_json(thread))
89-
90-
# get all the models
91-
models_to_test = [TuneModel, Openai, Anthropic, Groq, Mistral, Gemini]
92-
if models and models != "all":
93-
models_to_test = []
94-
for m in models:
95-
models_to_test.append(globals()[m])
96-
97-
# function to perform benchmarking
98-
def _bench(thread, model):
99-
try:
100-
st = time()
101-
m = model()
102-
out = m.chat(thread)
103-
return model.__name__, time() - st, out, False
104-
except Exception as e:
105-
return model.__name__, time() - st, str(e), True
106-
107-
# threaded map and get the results
108-
inputs = []
109-
for m in models_to_test:
110-
for _ in range(n):
111-
inputs.append((thread, m))
112-
random.shuffle(inputs)
113-
print(f"Total combinations: {len(inputs)}")
114-
results = tu.threaded_map(
115-
fn=_bench,
116-
inputs=inputs,
117-
pbar=True,
118-
safe=False,
119-
max_threads=max_threads,
120-
)
121-
model_wise_errors = {}
122-
all_results = []
123-
for r in results:
124-
name, time_taken, out, error = r
125-
if error:
126-
model_wise_errors.setdefault(name, 0)
127-
model_wise_errors[name] += 1
128-
else:
129-
all_results.append(
130-
{
131-
"model": name,
132-
"time": time_taken,
133-
"response": out,
134-
}
135-
)
136-
n_errors = sum(model_wise_errors.values())
137-
if n_errors:
138-
print(
139-
tu.color.red(f"{n_errors} FAILED", True)
140-
+ f" ie. {n_errors/len(inputs)*100:.2f}% failure rate"
141-
)
142-
n_success = len(inputs) - n_errors
143-
print(
144-
tu.color.green(f"{n_success} SUCCESS", True)
145-
+ f" ie. {n_success/len(inputs)*100:.2f}% success rate"
146-
)
147-
148-
# create the report and save it
149-
df = pd.DataFrame(all_results)
150-
print("Created the benchmark report at:", tu.color.bold(o))
151-
df.to_csv(o, index=False)
152-
153-
# create the histogram
154-
fig, axs = plt.subplots(3, 1, figsize=(15, 10))
155-
latency_by_models = {}
156-
char_count_by_models = {}
157-
for res in all_results:
158-
latency_by_models.setdefault(res["model"], []).append(res["time"])
159-
char_count_by_models.setdefault(res["model"], []).append(len(res["response"]))
160-
161-
# histogram for latency
162-
axs[0].hist(
163-
latency_by_models.values(),
164-
bins=20,
165-
alpha=0.7,
166-
label=list(latency_by_models.keys()),
167-
)
168-
axs[0].set_title("Latency Distribution (lower is better)")
169-
axs[0].set_xlabel("Time (s)")
170-
axs[0].set_ylabel("Frequency")
171-
axs[0].legend()
172-
173-
# histogram for character count
174-
axs[1].hist(
175-
char_count_by_models.values(),
176-
bins=20,
177-
alpha=0.7,
178-
label=list(char_count_by_models.keys()),
179-
)
180-
axs[1].set_title("Character Count Distribution")
181-
axs[1].set_xlabel("Count")
182-
axs[1].set_ylabel("Frequency")
183-
axs[1].legend()
184-
plt.tight_layout()
185-
186-
# bar graph for success and failure rate
187-
axs[2].bar(
188-
model_wise_errors.keys(),
189-
model_wise_errors.values(),
190-
color="red",
191-
label="Failed",
192-
)
193-
axs[2].set_title("Failure Rate (lower is better)")
194-
axs[2].set_xlabel("Model")
195-
axs[2].set_ylabel("Count")
196-
axs[2].legend()
197-
198-
# save the plot
199-
print("Created the benchmark plot at:", tu.color.bold("benchmark.png"))
200-
plt.savefig("benchmark.png")

tuneapi/apis/model_tune.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,12 @@ def __init__(
1919
self,
2020
id: Optional[str] = None,
2121
base_url: str = "https://proxy.tune.app/chat/completions",
22+
api_token: Optional[str] = None,
2223
org_id: Optional[str] = None,
2324
):
2425
self.tune_model_id = id or tu.ENV.TUNEAPI_MODEL("")
2526
self.base_url = base_url
26-
self.tune_api_token = tu.ENV.TUNEAPI_TOKEN("")
27+
self.tune_api_token = api_token or tu.ENV.TUNEAPI_TOKEN("")
2728
self.tune_org_id = org_id or tu.ENV.TUNEORG_ID("")
2829

2930
def __repr__(self) -> str:

tuneapi/endpoints/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
# Copyright © 2024- Frello Technology Private Limited
2+
3+
# projectX APIs
4+
from tuneapi.endpoints.threads import ThreadsAPI
5+
from tuneapi.endpoints.assistants import AssistantsAPI
6+
from tuneapi.endpoints.finetune import FinetuningAPI
Lines changed: 6 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,7 @@
99
import tuneapi.utils as tu
1010
import tuneapi.types as tt
1111

12-
13-
@cache
14-
def get_sub(
15-
url,
16-
tune_org_id,
17-
tune_api_key,
18-
) -> tu.Subway:
19-
sess = tu.Subway._get_session()
20-
sess.headers.update({"x-tune-key": tune_api_key})
21-
if tune_org_id:
22-
sess.headers.update({"X-Organization-Id": tune_org_id})
23-
return tu.Subway(url, sess)
12+
from tuneapi.endpoints.common import get_sub
2413

2514

2615
@dataclass
@@ -36,17 +25,14 @@ def __init__(
3625
tune_api_key: str = None,
3726
base_url: str = "https://studio.tune.app/v1/assistants",
3827
):
39-
self.tune_org_id = tune_org_id or tu.ENV.TUNE_ORG_ID()
40-
self.tune_api_key = tune_api_key or tu.ENV.TUNE_API_KEY()
28+
self.tune_org_id = tune_org_id or tu.ENV.TUNEORG_ID()
29+
self.tune_api_key = tune_api_key or tu.ENV.TUNEAPI_TOKEN()
30+
self.base_url = base_url
4131
if not tune_api_key:
42-
raise ValueError("Either pass tune_api_key or set Env var TUNE_API_KEY")
32+
raise ValueError("Either pass tune_api_key or set Env var TUNEAPI_TOKEN")
4333
self.sub = get_sub(base_url, self.tune_org_id, self.tune_api_key)
4434

45-
def list_assistants(
46-
self,
47-
limit: int = 10,
48-
order: str = "desc",
49-
):
35+
def list_assistants(self, limit: int = 10, order: str = "desc"):
5036
out = self.sub(params={"limit": limit, "order": order})
5137
return out["data"]
5238

tuneapi/endpoints/common.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright © 2024- Frello Technology Private Limited
2+
3+
from functools import cache
4+
5+
import tuneapi.utils as tu
6+
7+
8+
@cache
9+
def get_sub(
10+
base_url,
11+
tune_org_id: str,
12+
tune_api_key: str,
13+
) -> tu.Subway:
14+
15+
sess = tu.Subway._get_session()
16+
sess.headers.update({"x-tune-key": tune_api_key})
17+
if tune_org_id:
18+
sess.headers.update({"x-organization-id": tune_org_id})
19+
return tu.Subway(base_url, sess)

0 commit comments

Comments
 (0)