|
| 1 | +import g4f |
| 2 | +from .main import GPT4FREE |
| 3 | +from pathlib import Path |
| 4 | +from pytgpt.utils import default_path |
| 5 | +from json import dump, load |
| 6 | +from time import time |
| 7 | +from threading import Thread as thr |
| 8 | +from functools import wraps |
| 9 | +import datetime |
| 10 | +from rich.progress import Progress |
| 11 | +import click |
| 12 | +import logging |
| 13 | + |
| 14 | +results_path = Path(default_path) / "provider_test.json" |
| 15 | + |
| 16 | + |
| 17 | +def exception_handler(func): |
| 18 | + |
| 19 | + @wraps(func) |
| 20 | + def decorator(*args, **kwargs): |
| 21 | + try: |
| 22 | + return func(*args, **kwargs) |
| 23 | + except Exception as e: |
| 24 | + pass |
| 25 | + |
| 26 | + return decorator |
| 27 | + |
| 28 | + |
| 29 | +@exception_handler |
| 30 | +def is_working(provider: str) -> bool: |
| 31 | + """Test working status of a provider |
| 32 | +
|
| 33 | + Args: |
| 34 | + provider (str): Provider name |
| 35 | +
|
| 36 | + Returns: |
| 37 | + bool: is_working status |
| 38 | + """ |
| 39 | + bot = GPT4FREE(provider=provider, is_conversation=False) |
| 40 | + text = bot.chat("hello") |
| 41 | + assert isinstance(text, str) |
| 42 | + assert bool(text.strip()) |
| 43 | + assert "<" not in text |
| 44 | + assert len(text) > 2 |
| 45 | + return True |
| 46 | + |
| 47 | + |
| 48 | +class TestProviders: |
| 49 | + |
| 50 | + def __init__(self, test_at_once: int = 5, quiet: bool = False, timeout: int = 20): |
| 51 | + """Constructor |
| 52 | +
|
| 53 | + Args: |
| 54 | + test_at_once (int, optional): Test n providers at once. Defaults to 5. |
| 55 | + quiet (bool, optinal): Disable stdout. Defaults to False. |
| 56 | + timout (int, optional): Thread timeout for each provider. Defaults to 20. |
| 57 | + """ |
| 58 | + self.test_at_once: int = test_at_once |
| 59 | + self.quiet = quiet |
| 60 | + self.timeout = timeout |
| 61 | + self.working_providers: list = [ |
| 62 | + provider.__name__ |
| 63 | + for provider in g4f.Provider.__providers__ |
| 64 | + if provider.working |
| 65 | + ] |
| 66 | + self.results_path: Path = results_path |
| 67 | + self.__create_empty_file(ignore_if_found=True) |
| 68 | + |
| 69 | + def __create_empty_file(self, ignore_if_found: bool = False): |
| 70 | + if ignore_if_found and self.results_path.is_file(): |
| 71 | + return |
| 72 | + with self.results_path.open("w") as fh: |
| 73 | + dump({"results": []}, fh) |
| 74 | + |
| 75 | + def test_provider(self, name: str): |
| 76 | + """Test each provider and save successful ones |
| 77 | +
|
| 78 | + Args: |
| 79 | + name (str): Provider name |
| 80 | + """ |
| 81 | + |
| 82 | + try: |
| 83 | + bot = GPT4FREE(provider=name, is_conversation=False) |
| 84 | + start_time = time() |
| 85 | + text = bot.chat("hello there") |
| 86 | + assert isinstance(text, str), "Non-string response returned" |
| 87 | + assert bool(text.strip()), "Empty string" |
| 88 | + assert "<" not in text, "Html code returned." |
| 89 | + assert len(text) > 2 |
| 90 | + except Exception as e: |
| 91 | + pass |
| 92 | + else: |
| 93 | + with self.results_path.open() as fh: |
| 94 | + current_results = load(fh) |
| 95 | + new_result = dict(time=time() - start_time, name=name) |
| 96 | + current_results["results"].append(new_result) |
| 97 | + logging.info(f"Test result - {new_result['name']} - {new_result['time']}") |
| 98 | + |
| 99 | + with self.results_path.open("w") as fh: |
| 100 | + dump(current_results, fh) |
| 101 | + |
| 102 | + @exception_handler |
| 103 | + def main( |
| 104 | + self, |
| 105 | + ): |
| 106 | + self.__create_empty_file() |
| 107 | + threads = [] |
| 108 | + # Create a progress bar |
| 109 | + total = len(self.working_providers) |
| 110 | + with Progress() as progress: |
| 111 | + logging.info(f"Testing {total} providers : {self.working_providers}") |
| 112 | + task = progress.add_task( |
| 113 | + f"[cyan]Testing...[{self.test_at_once}]", |
| 114 | + total=total, |
| 115 | + visible=self.quiet == False, |
| 116 | + ) |
| 117 | + while not progress.finished: |
| 118 | + for count, provider in enumerate(self.working_providers, start=1): |
| 119 | + t1 = thr( |
| 120 | + target=self.test_provider, |
| 121 | + args=(provider,), |
| 122 | + ) |
| 123 | + t1.start() |
| 124 | + if count % self.test_at_once == 0 or count == len(provider): |
| 125 | + for t in threads: |
| 126 | + try: |
| 127 | + t.join(self.timeout) |
| 128 | + except Exception as e: |
| 129 | + pass |
| 130 | + threads.clear() |
| 131 | + else: |
| 132 | + threads.append(t1) |
| 133 | + progress.update(task, advance=1) |
| 134 | + |
| 135 | + def get_results(self, run: bool = False, best: bool = False) -> list[dict]: |
| 136 | + """Get test results |
| 137 | +
|
| 138 | + Args: |
| 139 | + run (bool, optional): Run the test first. Defaults to False. |
| 140 | + best (bool, optional): Return name of the best provider. Defaults to False. |
| 141 | +
|
| 142 | + Returns: |
| 143 | + list[dict]|str: Test results. |
| 144 | + """ |
| 145 | + if run: |
| 146 | + self.main() |
| 147 | + |
| 148 | + with self.results_path.open() as fh: |
| 149 | + results: dict = load(fh) |
| 150 | + |
| 151 | + results = results["results"] |
| 152 | + time_list = [] |
| 153 | + |
| 154 | + sorted_list = [] |
| 155 | + for entry in results: |
| 156 | + time_list.append(entry["time"]) |
| 157 | + |
| 158 | + time_list.sort() |
| 159 | + |
| 160 | + for time_value in time_list: |
| 161 | + for entry in results: |
| 162 | + if entry["time"] == time_value: |
| 163 | + sorted_list.append(entry) |
| 164 | + return sorted_list[0]["name"] if best else sorted_list |
| 165 | + |
| 166 | + @property |
| 167 | + def best(self): |
| 168 | + """Fastest provider overally""" |
| 169 | + return self.get_results(run=False, best=True) |
| 170 | + |
| 171 | + @property |
| 172 | + def auto(self): |
| 173 | + """Best working provider""" |
| 174 | + for result in self.get_results(run=False, best=False): |
| 175 | + logging.info("Confirming working status of provider : " + result["name"]) |
| 176 | + if is_working(result["name"]): |
| 177 | + return result["name"] |
0 commit comments