Skip to content

Commit 01c4e12

Browse files
committed
feat: add the trainer for HF LLMs
1 parent fab113b commit 01c4e12

File tree

14 files changed

+1120
-26
lines changed

14 files changed

+1120
-26
lines changed

app/api/api.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import os.path
55
import app.api.globals as cms_globals
66

7-
from typing import Dict, Any, Optional
7+
from typing import Dict, Any, Optional, Union, Type
88
from concurrent.futures import ThreadPoolExecutor
99
from anyio.lowlevel import RunVar
1010
from anyio import CapacityLimiter
@@ -20,7 +20,7 @@
2020
from app.api.dependencies import ModelServiceDep
2121
from app.api.utils import add_exception_handlers, add_rate_limiter, init_vllm_engine
2222
from app.config import Settings
23-
from app.domain import Tags, TagsStreamable
23+
from app.domain import Tags, TagsStreamable, TagsGenerative
2424
from app.management.tracker_client import TrackerClient
2525
from app.utils import get_settings, unpack_model_data_package, get_model_data_package_base_name
2626
from app.exception import ConfigurationException
@@ -131,6 +131,11 @@ def get_generative_server(config: Settings, msd_overwritten: Optional[ModelServi
131131
app = _load_health_check_router(app)
132132
logger.debug("Health check router loaded")
133133

134+
if config.ENABLE_TRAINING_APIS == "true":
135+
app = _load_supervised_training_router(app)
136+
logger.debug("Supervised training router loaded")
137+
app = _load_training_operations(app)
138+
134139
if config.AUTH_USER_ENABLED == "true":
135140
app = _load_auth_router(app)
136141
logger.debug("Auth router loaded")
@@ -198,11 +203,18 @@ def _get_app(
198203
streamable: bool = False,
199204
generative: bool = False,
200205
) -> FastAPI:
201-
tags_metadata = [{ # type: ignore
202-
"name": tag.name,
203-
"description": tag.value
204-
} for tag in (Tags if not streamable else TagsStreamable)]
205206
config = get_settings()
207+
tags: Union[Type[Tags], Type[TagsStreamable], Type[TagsGenerative]]
208+
if generative:
209+
tags = TagsGenerative
210+
elif streamable:
211+
tags = TagsStreamable
212+
else:
213+
tags = Tags
214+
tags_metadata = [{ # type: ignore
215+
"name": tag.name, # type: ignore
216+
"description": tag.value # type: ignore
217+
} for tag in tags]
206218
app = FastAPI(
207219
title="CogStack ModelServe",
208220
summary="A model serving and governance system for CogStack NLP solutions",

app/api/routers/generative.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from starlette.status import HTTP_200_OK, HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR
1414
from app.domain import (
1515
Tags,
16+
TagsGenerative,
1617
OpenAIChatRequest,
1718
OpenAIChatResponse,
1819
OpenAIEmbeddingsRequest,
@@ -41,7 +42,7 @@
4142

4243
@router.post(
4344
PATH_GENERATE,
44-
tags=[Tags.Generative.name],
45+
tags=[TagsGenerative.Generative],
4546
response_class=PlainTextResponse,
4647
dependencies=[Depends(cms_globals.props.current_active_user)],
4748
description="Generate text",
@@ -91,7 +92,7 @@ def generate_text(
9192

9293
@router.post(
9394
PATH_GENERATE_ASYNC,
94-
tags=[Tags.Generative.name],
95+
tags=[TagsGenerative.Generative],
9596
response_class=StreamingResponse,
9697
dependencies=[Depends(cms_globals.props.current_active_user)],
9798
description="Generate a stream of texts",

app/api/routers/supervised_training.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,9 @@
1212

1313
import app.api.globals as cms_globals
1414
from app.api.dependencies import validate_tracking_id
15-
from app.domain import Tags
15+
from app.domain import Tags, ModelType
1616
from app.model_services.base import AbstractModelService
17-
from app.processors.metrics_collector import concat_trainer_exports
17+
from app.processors.metrics_collector import concat_json_lists, concat_trainer_exports
1818
from app.utils import filter_by_concept_ids
1919

2020
router = APIRouter()
@@ -72,12 +72,19 @@ async def train_supervised(
7272
files.append(temp_te)
7373
file_names.append("" if te.filename is None else te.filename)
7474

75-
concatenated = concat_trainer_exports([file.name for file in files], allow_recurring_doc_ids=False)
76-
logger.debug("Training exports concatenated")
77-
data_file = tempfile.NamedTemporaryFile(mode="w")
78-
concatenated = filter_by_concept_ids(cast(Dict[str, Any], concatenated), model_service.info().model_type)
79-
logger.debug("Training exports filtered by concept IDs")
80-
json.dump(concatenated, data_file)
75+
if model_service.info().model_type is not ModelType.HUGGINGFACE_LLM:
76+
concatenated_te = concat_trainer_exports([file.name for file in files], allow_recurring_doc_ids=False)
77+
logger.debug("Training exports concatenated")
78+
data_file = tempfile.NamedTemporaryFile(mode="w+")
79+
concatenated_te = filter_by_concept_ids(cast(Dict[str, Any], concatenated_te), model_service.info().model_type)
80+
logger.debug("Training exports filtered by concept IDs")
81+
json.dump(concatenated_te, data_file)
82+
else:
83+
concatenated = concat_json_lists([file.name for file in files])
84+
logger.debug("Training exports concatenated")
85+
data_file = tempfile.NamedTemporaryFile(mode="w+")
86+
json.dump(concatenated, data_file)
87+
8188
data_file.flush()
8289
data_file.seek(0)
8390
training_id = tracking_id or str(uuid.uuid4())
@@ -102,6 +109,7 @@ async def train_supervised(
102109
return _get_training_response(training_response, training_id)
103110

104111

112+
105113
def _get_training_response(training_response: Tuple[bool, str, str], training_id: str) -> JSONResponse:
106114
training_accepted, experiment_id, run_id = training_response
107115
if training_accepted:

app/api/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from slowapi.errors import RateLimitExceeded
2727
from fastapi_users.jwt import decode_jwt
2828
from app.config import Settings
29-
from app.domain import Tags
29+
from app.domain import TagsGenerative
3030
from app.exception import StartTrainingException, AnnotationException, ConfigurationException, ClientException
3131

3232
logger = logging.getLogger("cms")
@@ -376,7 +376,7 @@ async def _stream() -> AsyncGenerator[bytes, None]:
376376
endpoint=endpoint,
377377
methods=methods,
378378
include_in_schema=True,
379-
tags=[Tags.Generative],
379+
tags=[TagsGenerative.Generative.name],
380380
)
381381
app.include_router(router)
382382

app/domain.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,15 @@ class Tags(str, Enum):
3131

3232

3333
class TagsStreamable(str, Enum):
34+
Metadata = "Get the model card"
3435
Streaming = "Retrieve NER entities as a stream by running the model"
3536

3637

38+
class TagsGenerative(str, Enum):
39+
Metadata = "Get the model card"
40+
Generative = "Generate text based on the input prompt"
41+
42+
3743
class CodeType(str, Enum):
3844
SNOMED = "SNOMED"
3945
UMLS = "UMLS"
@@ -104,6 +110,19 @@ class LlmEngine(Enum):
104110
CMS = "CMS"
105111
VLLM = "vLLM"
106112

113+
class LlmRole(Enum):
114+
SYSTEM = "system"
115+
USER = "user"
116+
ASSISTANT = "assistant"
117+
TOOL = "tool"
118+
119+
class LlmTrainerType(Enum):
120+
GRPO = "grpo"
121+
PPO = "ppo"
122+
123+
class LlmDatasetType(Enum):
124+
JSON = "json"
125+
CSV = "csv"
107126

108127
class Annotation(BaseModel):
109128
doc_name: Optional[str] = Field(default=None, description="The name of the document to which the annotation belongs")

app/exception.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,8 @@ class ClientException(Exception):
2727

2828

2929
class DatasetException(Exception):
30-
""" An exception raised due to dataset errors"""
30+
"""An exception raised due to dataset errors"""
31+
32+
33+
class DeviceNotAvailableError(RuntimeError):
34+
"""An exception raised when a specificy device is required but not available."""

app/model_services/huggingface_llm_model.py

Lines changed: 67 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,19 @@
33
import asyncio
44
import torch
55
from concurrent.futures import ThreadPoolExecutor
6-
from typing import Dict, List, Optional, Tuple, Any, AsyncIterable, Callable, Union
6+
from typing import Dict, List, Optional, Tuple, Any, AsyncIterable, TextIO, Callable, Union
77
from transformers import (
88
AutoModelForCausalLM,
99
AutoTokenizer,
1010
PreTrainedModel,
1111
PreTrainedTokenizerBase,
1212
TextIteratorStreamer,
13+
BitsAndBytesConfig,
1314
)
1415
from app import __version__ as app_version
1516
from app.exception import ConfigurationException
1617
from app.model_services.base import AbstractModelService
18+
from app.trainers.huggingface_llm_trainer import HuggingFaceLlmSupervisedTrainer
1719
from app.domain import ModelCard, ModelType, Annotation
1820
from app.config import Settings
1921
from app.utils import (
@@ -123,13 +125,19 @@ def from_model(cls, model: PreTrainedModel, tokenizer: PreTrainedTokenizerBase)
123125
return model_service
124126

125127
@staticmethod
126-
def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
128+
def load_model(
129+
model_file_path: str,
130+
*args: Tuple,
131+
load_in_4bit: bool = False,
132+
**kwargs: Dict[str, Any]
133+
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
127134
"""
128135
Loads a pre-trained model and its tokenizer from a model package file.
129136
130137
Args:
131138
model_file_path (str): The path to the model package file.
132139
*args (Tuple): Additional positional arguments.
140+
load_in_4bit (bool): Whether to load the model in 4-bit precision. Defaults to False.
133141
**kwargs (Dict[str, Any]): Additional keyword arguments.
134142
135143
Returns:
@@ -142,7 +150,16 @@ def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) ->
142150
model_path = os.path.join(os.path.dirname(model_file_path), get_model_data_package_base_name(model_file_path))
143151
if unpack_model_data_package(model_file_path, model_path):
144152
try:
145-
model = AutoModelForCausalLM.from_pretrained(model_path)
153+
if load_in_4bit:
154+
bnb_config = BitsAndBytesConfig(
155+
load_in_4bit=True,
156+
bnb_4bit_quant_type="nf4",
157+
bnb_4bit_compute_dtype=torch.bfloat16,
158+
bnb_4bit_use_double_quant=True,
159+
)
160+
model = AutoModelForCausalLM.from_pretrained(model_path, quantization_config=bnb_config)
161+
else:
162+
model = AutoModelForCausalLM.from_pretrained(model_path)
146163
ensure_tensor_contiguity(model)
147164
tokenizer = AutoTokenizer.from_pretrained(
148165
model_path,
@@ -172,7 +189,7 @@ def init_model(self) -> None:
172189
if non_default_device_is_available(get_settings().DEVICE):
173190
self._model.to(get_settings().DEVICE)
174191
if self._enable_trainer:
175-
logger.error("Trainers are not yet implemented for HuggingFace Generative models")
192+
self._supervised_trainer = HuggingFaceLlmSupervisedTrainer(self)
176193

177194
def info(self) -> ModelCard:
178195
"""
@@ -355,3 +372,49 @@ def create_embeddings(
355372

356373
results = embeddings.cpu().numpy().tolist()
357374
return results[0] if isinstance(text, str) else results
375+
376+
def train_supervised(
377+
self,
378+
data_file: TextIO,
379+
epochs: int,
380+
log_frequency: int,
381+
training_id: str,
382+
input_file_name: str,
383+
raw_data_files: Optional[List[TextIO]] = None,
384+
description: Optional[str] = None,
385+
synchronised: bool = False,
386+
**hyperparams: Dict[str, Any],
387+
) -> Tuple[bool, str, str]:
388+
"""
389+
Initiates supervised training on the model.
390+
391+
Args:
392+
data_file (TextIO): The file containing the trainer export data.
393+
epochs (int): The number of training epochs.
394+
log_frequency (int): The number of epochs after which training metrics will be logged.
395+
training_id (str): A unique identifier for the training process.
396+
input_file_name (str): The name of the input file to be logged.
397+
raw_data_files (Optional[List[TextIO]]): Additional raw data files to be logged. Defaults to None.
398+
description (Optional[str]): The description of the training or change logs. Defaults to empty.
399+
synchronised (bool): Whether to wait for the training to complete.
400+
**hyperparams (Dict[str, Any]): Additional hyperparameters for training.
401+
402+
Returns:
403+
Tuple[bool, str, str]: A tuple with the first element indicating success or failure.
404+
405+
Raises:
406+
ConfigurationException: If the supervised trainer is not enabled.
407+
"""
408+
if self._supervised_trainer is None:
409+
raise ConfigurationException("The supervised trainer is not enabled")
410+
return self._supervised_trainer.train(
411+
data_file,
412+
epochs,
413+
log_frequency,
414+
training_id,
415+
input_file_name,
416+
raw_data_files,
417+
description,
418+
synchronised,
419+
**hyperparams,
420+
)

app/processors/metrics_collector.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,36 @@ def concat_trainer_exports(
194194
return combined
195195

196196

197+
def concat_json_lists(
198+
data_file_paths: List[str],
199+
combined_data_file_path: Optional[str] = None,
200+
) -> Union[List[Dict[str, Any]], str]:
201+
"""
202+
Concatenates multiple json list files into a single combined file.
203+
204+
Args:
205+
data_file_paths (List[str]): List of paths to files each containing a json list.
206+
combined_data_file_path (Optional[str]): The file path where the combined data will be saved. If None, the combined data will be returned as a list.
207+
208+
209+
Returns:
210+
Union[List[Dict[str, Any]], str]: The path to the combined data file if `combined_data_file_path` is provided, or the combined data as a list otherwise.
211+
"""
212+
combined: List = []
213+
for path in data_file_paths:
214+
with open(path, "r") as f:
215+
data = json.load(f)
216+
combined.extend(data)
217+
218+
if isinstance(combined_data_file_path, str):
219+
with open(combined_data_file_path, "w") as f:
220+
json.dump(combined, f)
221+
222+
return combined_data_file_path
223+
else:
224+
return combined
225+
226+
197227
def get_stats_from_trainer_export(
198228
trainer_export: Union[str, IO, Dict],
199229
return_df: bool = False,

0 commit comments

Comments
 (0)