|
27 | 27 | from fastapi_users.jwt import decode_jwt |
28 | 28 | from app.config import Settings |
29 | 29 | from app.domain import TagsGenerative |
30 | | -from app.exception import StartTrainingException, AnnotationException, ConfigurationException, ClientException |
| 30 | +from app.exception import ( |
| 31 | + StartTrainingException, |
| 32 | + AnnotationException, |
| 33 | + ConfigurationException, |
| 34 | + ClientException, |
| 35 | + ExtraDependencyRequiredException, |
| 36 | +) |
31 | 37 |
|
32 | 38 | logger = logging.getLogger("cms") |
33 | 39 |
|
@@ -118,6 +124,24 @@ async def configuration_exception_handler(_: Request, exception: ConfigurationEx |
118 | 124 | logger.exception(exception) |
119 | 125 | return JSONResponse(status_code=HTTP_500_INTERNAL_SERVER_ERROR, content={"message": str(exception)}) |
120 | 126 |
|
| 127 | + @app.exception_handler(ExtraDependencyRequiredException) |
| 128 | + async def extra_dependency_exception_handler( |
| 129 | + _: Request, |
| 130 | + exception: ExtraDependencyRequiredException |
| 131 | + ) -> JSONResponse: |
| 132 | + """ |
| 133 | + Handles extra dependency required exceptions. |
| 134 | +
|
| 135 | + Args: |
| 136 | + _ (Request): The request object. |
| 137 | + exception (ExtraDependencyRequiredException): The extra dependency required exception. |
| 138 | +
|
| 139 | + Returns: |
| 140 | + JSONResponse: A JSON response with a 500 status code and an error message. |
| 141 | + """ |
| 142 | + logger.exception(exception) |
| 143 | + return JSONResponse(status_code=HTTP_500_INTERNAL_SERVER_ERROR, content={"message": str(exception)}) |
| 144 | + |
121 | 145 | @app.exception_handler(ClientException) |
122 | 146 | async def client_exception_handler(_: Request, exception: ClientException) -> JSONResponse: |
123 | 147 | """ |
@@ -299,8 +323,8 @@ async def init_vllm_engine(app: FastAPI, |
299 | 323 | ) |
300 | 324 | from vllm import SamplingParams, TokensPrompt |
301 | 325 | except ImportError: |
302 | | - # Raise a custom exception if vLLM is not installed |
303 | | - raise ConfigurationException("Cannot import the vLLM engine. Please install it with `pip install vllm`.") |
| 326 | + logger.error("Cannot import the vLLM engine. Please install it with `pip install cms[vllm]`.") |
| 327 | + raise ExtraDependencyRequiredException("Cannot import the vLLM engine. Please install it with `pip install cms[vllm]`.") |
304 | 328 |
|
305 | 329 | parser = FlexibleArgumentParser() |
306 | 330 | parser = make_arg_parser(parser) |
|
0 commit comments