|
8 | 8 | from tuneapi.apis.model_mistral import Mistral |
9 | 9 | from tuneapi.apis.model_gemini import Gemini |
10 | 10 |
|
11 | | -# projectX APIs |
12 | | -from tuneapi.apis.threads import ThreadsAPI |
13 | | -from tuneapi.apis.assistants import AssistantsAPI |
14 | 11 |
|
15 | 12 | # other imports |
16 | 13 | import os |
17 | 14 | import random |
18 | 15 | from time import time |
19 | 16 | from typing import List, Optional |
20 | | - |
21 | | -# other tuneapi modules |
22 | | -import tuneapi.types as tt |
23 | | -import tuneapi.utils as tu |
24 | | - |
25 | | - |
26 | | -def test_models(thread: str | tt.Thread, models: Optional[List[str]] = None): |
27 | | - """ |
28 | | - Runs thread on all the models and prints the time taken and response. |
29 | | - """ |
30 | | - if os.path.exists(thread): |
31 | | - thread = tt.Thread.from_dict(tu.from_json(thread)) |
32 | | - |
33 | | - # get all the models |
34 | | - models_to_test = [TuneModel, Openai, Anthropic, Groq, Mistral, Gemini] |
35 | | - if models and models != "all": |
36 | | - models_to_test = [] |
37 | | - for m in models: |
38 | | - models_to_test.append(globals()[m]) |
39 | | - |
40 | | - # run all in a loop |
41 | | - for model in models_to_test: |
42 | | - print(tu.color.bold(f"[{model.__name__}]"), end=" ", flush=True) |
43 | | - try: |
44 | | - st = time() |
45 | | - m = model() |
46 | | - out = m.chat(thread) |
47 | | - et = time() |
48 | | - print( |
49 | | - tu.color.blue(f"[{et-st:0.2f}s]"), |
50 | | - tu.color.green(f"[SUCCESS]", True), |
51 | | - out, |
52 | | - ) |
53 | | - except Exception as e: |
54 | | - et = time() |
55 | | - print( |
56 | | - tu.color.blue(f"[{et-st:0.2f}s]"), |
57 | | - tu.color.red(f"[ERROR]", True), |
58 | | - str(e), |
59 | | - ) |
60 | | - continue |
61 | | - |
62 | | - |
63 | | -def benchmark_models( |
64 | | - thread: str | tt.Thread, |
65 | | - models: Optional[List[str]] = "all", |
66 | | - n: int = 20, |
67 | | - max_threads: int = 5, |
68 | | - o: str = "benchmark.csv", |
69 | | -): |
70 | | - """ |
71 | | - Benchmarks a thread on all the models and saves the time taken and response in a CSV file and creates matplotlib |
72 | | - histogram chart with latency and char count distribution. Runs `n` iterations for each model. |
73 | | -
|
74 | | - It requires `matplotlib>=3.8.2` and `pandas>=2.2.0` to be installed. |
75 | | - """ |
76 | | - |
77 | | - try: |
78 | | - import matplotlib.pyplot as plt |
79 | | - import pandas as pd |
80 | | - except ImportError: |
81 | | - tu.logger.error( |
82 | | - "This is a special CLI helper function. If you want to use this then run: pip install matplotlib>=3.8.2 pandas>=2.2.0" |
83 | | - ) |
84 | | - raise ImportError("Please install the required packages") |
85 | | - |
86 | | - # if this is a JSON then load the thread |
87 | | - if os.path.exists(thread): |
88 | | - thread = tt.Thread.from_dict(tu.from_json(thread)) |
89 | | - |
90 | | - # get all the models |
91 | | - models_to_test = [TuneModel, Openai, Anthropic, Groq, Mistral, Gemini] |
92 | | - if models and models != "all": |
93 | | - models_to_test = [] |
94 | | - for m in models: |
95 | | - models_to_test.append(globals()[m]) |
96 | | - |
97 | | - # function to perform benchmarking |
98 | | - def _bench(thread, model): |
99 | | - try: |
100 | | - st = time() |
101 | | - m = model() |
102 | | - out = m.chat(thread) |
103 | | - return model.__name__, time() - st, out, False |
104 | | - except Exception as e: |
105 | | - return model.__name__, time() - st, str(e), True |
106 | | - |
107 | | - # threaded map and get the results |
108 | | - inputs = [] |
109 | | - for m in models_to_test: |
110 | | - for _ in range(n): |
111 | | - inputs.append((thread, m)) |
112 | | - random.shuffle(inputs) |
113 | | - print(f"Total combinations: {len(inputs)}") |
114 | | - results = tu.threaded_map( |
115 | | - fn=_bench, |
116 | | - inputs=inputs, |
117 | | - pbar=True, |
118 | | - safe=False, |
119 | | - max_threads=max_threads, |
120 | | - ) |
121 | | - model_wise_errors = {} |
122 | | - all_results = [] |
123 | | - for r in results: |
124 | | - name, time_taken, out, error = r |
125 | | - if error: |
126 | | - model_wise_errors.setdefault(name, 0) |
127 | | - model_wise_errors[name] += 1 |
128 | | - else: |
129 | | - all_results.append( |
130 | | - { |
131 | | - "model": name, |
132 | | - "time": time_taken, |
133 | | - "response": out, |
134 | | - } |
135 | | - ) |
136 | | - n_errors = sum(model_wise_errors.values()) |
137 | | - if n_errors: |
138 | | - print( |
139 | | - tu.color.red(f"{n_errors} FAILED", True) |
140 | | - + f" ie. {n_errors/len(inputs)*100:.2f}% failure rate" |
141 | | - ) |
142 | | - n_success = len(inputs) - n_errors |
143 | | - print( |
144 | | - tu.color.green(f"{n_success} SUCCESS", True) |
145 | | - + f" ie. {n_success/len(inputs)*100:.2f}% success rate" |
146 | | - ) |
147 | | - |
148 | | - # create the report and save it |
149 | | - df = pd.DataFrame(all_results) |
150 | | - print("Created the benchmark report at:", tu.color.bold(o)) |
151 | | - df.to_csv(o, index=False) |
152 | | - |
153 | | - # create the histogram |
154 | | - fig, axs = plt.subplots(3, 1, figsize=(15, 10)) |
155 | | - latency_by_models = {} |
156 | | - char_count_by_models = {} |
157 | | - for res in all_results: |
158 | | - latency_by_models.setdefault(res["model"], []).append(res["time"]) |
159 | | - char_count_by_models.setdefault(res["model"], []).append(len(res["response"])) |
160 | | - |
161 | | - # histogram for latency |
162 | | - axs[0].hist( |
163 | | - latency_by_models.values(), |
164 | | - bins=20, |
165 | | - alpha=0.7, |
166 | | - label=list(latency_by_models.keys()), |
167 | | - ) |
168 | | - axs[0].set_title("Latency Distribution (lower is better)") |
169 | | - axs[0].set_xlabel("Time (s)") |
170 | | - axs[0].set_ylabel("Frequency") |
171 | | - axs[0].legend() |
172 | | - |
173 | | - # histogram for character count |
174 | | - axs[1].hist( |
175 | | - char_count_by_models.values(), |
176 | | - bins=20, |
177 | | - alpha=0.7, |
178 | | - label=list(char_count_by_models.keys()), |
179 | | - ) |
180 | | - axs[1].set_title("Character Count Distribution") |
181 | | - axs[1].set_xlabel("Count") |
182 | | - axs[1].set_ylabel("Frequency") |
183 | | - axs[1].legend() |
184 | | - plt.tight_layout() |
185 | | - |
186 | | - # bar graph for success and failure rate |
187 | | - axs[2].bar( |
188 | | - model_wise_errors.keys(), |
189 | | - model_wise_errors.values(), |
190 | | - color="red", |
191 | | - label="Failed", |
192 | | - ) |
193 | | - axs[2].set_title("Failure Rate (lower is better)") |
194 | | - axs[2].set_xlabel("Model") |
195 | | - axs[2].set_ylabel("Count") |
196 | | - axs[2].legend() |
197 | | - |
198 | | - # save the plot |
199 | | - print("Created the benchmark plot at:", tu.color.bold("benchmark.png")) |
200 | | - plt.savefig("benchmark.png") |
0 commit comments