Skip to content

Commit 1e36b85

Browse files
committed
fix: use the GRPO trainer for evaluation
1 parent 01c4e12 commit 1e36b85

File tree

4 files changed

+182
-116
lines changed

4 files changed

+182
-116
lines changed

app/api/utils.py

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,13 @@
2727
from fastapi_users.jwt import decode_jwt
2828
from app.config import Settings
2929
from app.domain import TagsGenerative
30-
from app.exception import StartTrainingException, AnnotationException, ConfigurationException, ClientException
30+
from app.exception import (
31+
StartTrainingException,
32+
AnnotationException,
33+
ConfigurationException,
34+
ClientException,
35+
ExtraDependencyRequiredException,
36+
)
3137

3238
logger = logging.getLogger("cms")
3339

@@ -118,6 +124,24 @@ async def configuration_exception_handler(_: Request, exception: ConfigurationEx
118124
logger.exception(exception)
119125
return JSONResponse(status_code=HTTP_500_INTERNAL_SERVER_ERROR, content={"message": str(exception)})
120126

127+
@app.exception_handler(ExtraDependencyRequiredException)
128+
async def extra_dependency_exception_handler(
129+
_: Request,
130+
exception: ExtraDependencyRequiredException
131+
) -> JSONResponse:
132+
"""
133+
Handles extra dependency required exceptions.
134+
135+
Args:
136+
_ (Request): The request object.
137+
exception (ExtraDependencyRequiredException): The extra dependency required exception.
138+
139+
Returns:
140+
JSONResponse: A JSON response with a 500 status code and an error message.
141+
"""
142+
logger.exception(exception)
143+
return JSONResponse(status_code=HTTP_500_INTERNAL_SERVER_ERROR, content={"message": str(exception)})
144+
121145
@app.exception_handler(ClientException)
122146
async def client_exception_handler(_: Request, exception: ClientException) -> JSONResponse:
123147
"""
@@ -299,8 +323,8 @@ async def init_vllm_engine(app: FastAPI,
299323
)
300324
from vllm import SamplingParams, TokensPrompt
301325
except ImportError:
302-
# Raise a custom exception if vLLM is not installed
303-
raise ConfigurationException("Cannot import the vLLM engine. Please install it with `pip install vllm`.")
326+
logger.error("Cannot import the vLLM engine. Please install it with `pip install cms[vllm]`.")
327+
raise ExtraDependencyRequiredException("Cannot import the vLLM engine. Please install it with `pip install cms[vllm]`.")
304328

305329
parser = FlexibleArgumentParser()
306330
parser = make_arg_parser(parser)

app/exception.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,7 @@ class DatasetException(Exception):
3232

3333
class DeviceNotAvailableError(RuntimeError):
3434
"""An exception raised when a specificy device is required but not available."""
35+
36+
37+
class ExtraDependencyRequiredException(Exception):
38+
"""An exception raised when an extra dependency is required but not found."""

app/model_services/huggingface_llm_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,9 @@ def create_embeddings(
369369
sum_hidden_states = masked_hidden_states.sum(dim=1)
370370
num_tokens = attention_mask.sum(dim=1, keepdim=True)
371371
embeddings = sum_hidden_states / num_tokens
372+
l2_normalised = torch.nn.functional.normalize(embeddings, p=2, dim=1)
372373

373-
results = embeddings.cpu().numpy().tolist()
374+
results = l2_normalised.cpu().numpy().tolist()
374375
return results[0] if isinstance(text, str) else results
375376

376377
def train_supervised(

0 commit comments

Comments
 (0)