|
7 | 7 | import gradio as gr |
8 | 8 |
|
9 | 9 | from gradio_i18n import Translate, gettext as _ |
| 10 | +from numba.cuda import shared |
| 11 | + |
10 | 12 | from test_api import test_api_connection |
11 | 13 | from cache_utils import setup_workspace, cleanup_workspace |
12 | 14 | from count_tokens import count_tokens |
@@ -180,9 +182,6 @@ def sum_tokens(client): |
180 | 182 | json.dump(output_data, tmpfile, ensure_ascii=False) |
181 | 183 | output_file = tmpfile.name |
182 | 184 |
|
183 | | - # Clean up workspace |
184 | | - cleanup_workspace(graph_gen.working_dir) |
185 | | - |
186 | 185 | synthesizer_tokens = sum_tokens(graph_gen.synthesizer_llm_client) |
187 | 186 | trainee_tokens = sum_tokens(graph_gen.trainee_llm_client) if config['if_trainee_model'] else 0 |
188 | 187 | total_tokens = synthesizer_tokens + trainee_tokens |
@@ -217,6 +216,10 @@ def sum_tokens(client): |
217 | 216 | except Exception as e: # pylint: disable=broad-except |
218 | 217 | raise gr.Error(f"Error occurred: {str(e)}") |
219 | 218 |
|
| 219 | + finally: |
| 220 | + # Clean up workspace |
| 221 | + cleanup_workspace(graph_gen.working_dir) |
| 222 | + |
220 | 223 |
|
221 | 224 | with (gr.Blocks(title="GraphGen Demo", theme=gr.themes.Glass(), |
222 | 225 | css=css) as demo): |
@@ -476,4 +479,5 @@ def sum_tokens(client): |
476 | 479 | ) |
477 | 480 |
|
478 | 481 | if __name__ == "__main__": |
| 482 | + demo.queue(api_open=False, default_concurrency_limit=10) |
479 | 483 | demo.launch(server_name='0.0.0.0') |
0 commit comments