|
20 | 20 |
|
21 | 21 | import aiohttp |
22 | 22 | from aiohttp import web |
| 23 | +from aiohttp_apispec import ( |
| 24 | + docs, |
| 25 | + request_schema, |
| 26 | + response_schema, |
| 27 | + setup_aiohttp_apispec, |
| 28 | +) |
| 29 | +from marshmallow import Schema, fields |
23 | 30 | import logging |
24 | 31 |
|
25 | 32 | import mimetypes |
|
38 | 45 | from api_server.routes.internal.internal_routes import InternalRoutes |
39 | 46 | from protocol import BinaryEventTypes |
40 | 47 |
|
| 48 | +def inline_schema(name, **fields_map): |
| 49 | + """ Utility function to create an inline schema for aiohttp_apispec """ |
| 50 | + return type(name+"Schema", (Schema,), fields_map)() |
| 51 | + |
| 52 | + |
| 53 | +def inline_schema_list(name, **fields_map): |
| 54 | + """ Utility function to create an inline schema for aiohttp_apispec where the top-level is a list """ |
| 55 | + return type(name+"Schema", (Schema,), fields_map)(many=True) |
| 56 | + |
41 | 57 | async def send_socket_catch_exception(function, message): |
42 | 58 | try: |
43 | 59 | await function(message) |
@@ -258,6 +274,16 @@ async def get_root(request): |
258 | 274 | return response |
259 | 275 |
|
260 | 276 | @routes.get("/embeddings") |
| 277 | + @docs( |
| 278 | + tags=["Core"], |
| 279 | + summary="(UI) Get embeddings", |
| 280 | + description="Returns a list of the files located in the embeddings/ directory that can be used as arguments for embedding nodes. The file extension is omitted.", |
| 281 | + ) |
| 282 | + @request_schema(inline_schema("EmbeddingRequestTestParams", |
| 283 | + my_test_param1=fields.Int(required=True, description="Test description of a parameter"), |
| 284 | + my_test_param2=fields.Boolean(description="description for my_test_param2") |
| 285 | + )) |
| 286 | + @response_schema(inline_schema_list("EmbeddingNames")) |
261 | 287 | def get_embeddings(request): |
262 | 288 | embeddings = folder_paths.get_filename_list("embeddings") |
263 | 289 | return web.json_response(list(map(lambda a: os.path.splitext(a)[0], embeddings))) |
@@ -791,6 +817,18 @@ def add_routes(self): |
791 | 817 | web.static('/', self.web_root), |
792 | 818 | ]) |
793 | 819 |
|
| 820 | + def serve_api_spec(self): |
| 821 | + """ |
| 822 | + Serve the OpenAPI specification for the API. Must be called after routes are added. |
| 823 | + """ |
| 824 | + setup_aiohttp_apispec( |
| 825 | + app=self.app, |
| 826 | + title="ComfyUI API Documentation", |
| 827 | + version="v1", |
| 828 | + url="/api/docs/swagger.json", |
| 829 | + swagger_path="/api/docs", |
| 830 | + ) |
| 831 | + |
794 | 832 | def get_queue_info(self): |
795 | 833 | prompt_info = {} |
796 | 834 | exec_info = {} |
|
0 commit comments