Skip to content

Commit d03f487

Browse files
committed
feat: add the 4-bit quantisation option and remove unnecessary base model copying
1 parent d1ff2fb commit d03f487

File tree

10 files changed

+72
-42
lines changed

10 files changed

+72
-42
lines changed

.github/workflows/main.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ jobs:
2424
python-version: ${{ matrix.python-version }}
2525
- name: Install dependencies
2626
run: |
27-
uv sync --group dev --group docs --group vllm
27+
uv sync --group dev --group docs
2828
- name: Check types
2929
run: |
3030
uv run mypy app

app/cli/cli.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ def serve_model(
6767
streamable: bool = typer.Option(False, help="Serve the streamable endpoints only"),
6868
device: Device = typer.Option(Device.DEFAULT.value, help="The device to serve the model on"),
6969
llm_engine: Optional[LlmEngine] = typer.Option(LlmEngine.CMS.value, help="The engine to use for text generation"),
70+
load_in_4bit: Optional[bool] = typer.Option(False, help="Load the model in 4-bit precision, used by 'huggingface_llm' models"),
7071
debug: Optional[bool] = typer.Option(None, help="Run in the debug mode"),
7172
) -> None:
7273
"""
@@ -84,6 +85,7 @@ def serve_model(
8485
streamable (bool): Serve the streamable endpoints only. Defaults to False.
8586
device (Device): The device to serve the model on. Defaults to Device.DEFAULT.
8687
llm_engine (LlmEngine): The inference engine to use. Defaults to LlmEngine.CMS.
88+
load_in_4bit (bool): Load the model in 4-bit precision, used by 'huggingface_llm' models. Defaults to False.
8789
debug (Optional[bool]): Run in debug mode if set to True.
8890
"""
8991

@@ -135,7 +137,7 @@ def serve_model(
135137
if model_path:
136138
model_service = model_service_dep()
137139
model_service.model_name = model_name
138-
model_service.init_model()
140+
model_service.init_model(load_in_4bit=load_in_4bit)
139141
cms_globals.model_manager_dep = ModelManagerDep(model_service)
140142
elif mlflow_model_uri:
141143
model_service = ModelManager.retrieve_model_service_from_uri(mlflow_model_uri, config, dst_model_path)
@@ -187,6 +189,7 @@ def train_model(
187189
description: Optional[str] = typer.Option(None, help="The description of the training or change logs"),
188190
model_name: Optional[str] = typer.Option(None, help="The string representation of the model name"),
189191
device: Device = typer.Option(Device.DEFAULT.value, help="The device to train the model on"),
192+
load_in_4bit: Optional[bool] = typer.Option(False, help="Load the model in 4-bit precision, used by 'huggingface_llm' models"),
190193
debug: Optional[bool] = typer.Option(None, help="Run in the debug mode"),
191194
) -> None:
192195
"""
@@ -206,6 +209,7 @@ def train_model(
206209
description (Optional[str]): The optional description of the training or change logs.
207210
model_name (Optional[str]): The optional string representation of the model name.
208211
device (Device): The device to train the model on. Defaults to Device.DEFAULT.
212+
load_in_4bit (bool): Load the model in 4-bit precision, used by 'huggingface_llm' models. Defaults to False.
209213
debug (Optional[bool]): Run in debug mode if set to True.
210214
"""
211215

@@ -229,7 +233,7 @@ def train_model(
229233
pass
230234
model_service = model_service_dep()
231235
model_service.model_name = model_name if model_name is not None else "CMS model"
232-
model_service.init_model()
236+
model_service.init_model(load_in_4bit=load_in_4bit)
233237
elif mlflow_model_uri:
234238
model_service = ModelManager.retrieve_model_service_from_uri(mlflow_model_uri, config, dst_model_path)
235239
model_service.model_name = model_name if model_name is not None else "CMS model"

app/model_services/base.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,14 @@ def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]:
154154
raise NotImplementedError
155155

156156
@abstractmethod
157-
def init_model(self) -> None:
157+
def init_model(self, *args: Any, **kwargs: Any) -> None:
158158
"""
159159
Initialises the model and auxiliary resources.
160160
161+
Args:
162+
*args (Any): Additional positional arguments to be passed to this method.
163+
**kwargs (Any): Additional keyword arguments to be passed to this method.
164+
161165
Raises:
162166
NotImplementedError: If the method is not implemented by the subclass.
163167
"""

app/model_services/huggingface_llm_model.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,14 @@ def load_model(
174174
else:
175175
raise ConfigurationException(f"Model package archive format is not supported: {model_file_path}")
176176

177-
def init_model(self) -> None:
178-
"""Initialises the HuggingFace model and its tokenizer based on the configuration."""
177+
def init_model(self, load_in_4bit: bool = False, *args: Any, **kwargs: Any) -> None:
178+
"""Initialises the HuggingFace model and its tokenizer based on the configuration.
179+
180+
Args:
181+
load_in_4bit (bool): Whether to load the model in 4-bit precision. Defaults to False.
182+
*args (Any): Additional positional arguments to be passed to this method.
183+
**kwargs (Any): Additional keyword arguments to be passed to this method.
184+
"""
179185

180186
if all([
181187
hasattr(self, "_model"),
@@ -185,7 +191,7 @@ def init_model(self) -> None:
185191
]):
186192
logger.warning("Model service is already initialised and can be initialised only once")
187193
else:
188-
self._model, self._tokenizer = self.load_model(self._model_pack_path)
194+
self._model, self._tokenizer = self.load_model(self._model_pack_path, load_in_4bit=load_in_4bit)
189195
if non_default_device_is_available(get_settings().DEVICE):
190196
self._model.to(get_settings().DEVICE)
191197
if self._enable_trainer:

app/model_services/huggingface_ner_model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -175,8 +175,13 @@ def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) ->
175175
else:
176176
raise ConfigurationException(f"Model package archive format is not supported: {model_file_path}")
177177

178-
def init_model(self) -> None:
179-
"""Initialises the HuggingFace model, its tokenizer and a NER pipeline based on the configuration."""
178+
def init_model(self, *args: Any, **kwargs: Any) -> None:
179+
"""Initialises the HuggingFace model, its tokenizer and a NER pipeline based on the configuration.
180+
181+
Args:
182+
*args (Any): Additional positional arguments to be passed to this method.
183+
**kwargs (Any): Additional keyword arguments to be passed to this method.
184+
"""
180185

181186
if all([
182187
hasattr(self, "_model"),

app/model_services/medcat_model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,13 @@ def load_model(model_file_path: str, *args: Tuple, **kwargs: Dict[str, Any]) ->
119119
else:
120120
raise ConfigurationException("Model package archive format is not supported")
121121

122-
def init_model(self) -> None:
123-
"""Initializes the MedCAT model based on the configuration."""
122+
def init_model(self, *args: Any, **kwargs: Any) -> None:
123+
"""Initializes the MedCAT model based on the configuration.
124+
125+
Args:
126+
*args (Any): Additional positional arguments to be passed to this method.
127+
**kwargs (Any): Additional keyword arguments to be passed to this method.
128+
"""
124129

125130
if hasattr(self, "_model") and isinstance(self._model, CAT):
126131
logger.warning("Model service is already initialised and can be initialised only once")

app/model_services/medcat_model_deid.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,13 @@ def batch_annotate(self, texts: List[str]) -> List[List[Annotation]]:
178178

179179
return annotations_list
180180

181-
def init_model(self) -> None:
182-
"""Initializes the MedCAT De-Identification (AnonCAT) model based on the configuration."""
181+
def init_model(self, *args: Any, **kwargs: Any) -> None:
182+
"""Initializes the MedCAT De-Identification (AnonCAT) model based on the configuration.
183+
184+
Args:
185+
*args (Any): Additional positional arguments to be passed to this method.
186+
**kwargs (Any): Additional keyword arguments to be passed to this method.
187+
"""
183188

184189
if hasattr(self, "_model") and isinstance(self._model, CAT):
185190
logger.warning("Model service is already initialised and can be initialised only once")

app/model_services/trf_model_deid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def load_model(
8686
logger.info("Model loaded from %s", unpacked_model_dir)
8787
return tokenizer, model
8888

89-
def init_model(self) -> None:
89+
def init_model(self, *args: Any, **kwargs: Any) -> None:
9090
if hasattr(self, "_model") and isinstance(self._model, PreTrainedModel):
9191
logger.warning("Model service is already initialised and can be initialised only once")
9292
else:

app/trainers/huggingface_llm_trainer.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,11 @@ def __init__(self, model_service: "HuggingFaceLlmModel") -> None:
8888
self._model_service = model_service
8989
self._model_name = model_service.model_name
9090
self._model_pack_path = model_service._model_pack_path
91-
self._retrained_models_dir = os.path.join(model_service._model_parent_dir, "retrained",
92-
self._model_name.replace(" ", "_"))
91+
self._retrained_models_dir = os.path.join(
92+
model_service._model_parent_dir,
93+
"retrained",
94+
self._model_name.replace(" ", "_"),
95+
)
9396
self._model_manager = ModelManager(type(model_service), model_service._config)
9497
self._max_length = model_service.model.config.max_position_embeddings
9598
os.makedirs(self._retrained_models_dir, exist_ok=True)
@@ -306,7 +309,7 @@ def run(
306309
logger.error("Cannot import the GRPO Trainer. Please install it with `pip install cms[vllm]`.")
307310
raise ExtraDependencyRequiredException("Cannot import the GRPO Trainer. Please install it with `pip install cms[vllm]`.")
308311

309-
copied_model_pack_path = None
312+
trained_model_pack_path = None
310313
redeploy = self._config.REDEPLOY_TRAINED_MODEL == "true"
311314
skip_save_model = self._config.SKIP_SAVE_MODEL == "true"
312315
results_path = os.path.abspath(os.path.join(self._config.TRAINING_CACHE_DIR, "results"))
@@ -319,15 +322,16 @@ def run(
319322

320323
if not eval_mode:
321324
try:
322-
logger.info("Loading a new model copy for training...")
323-
copied_model_pack_path = self._make_model_file_copy(self._model_pack_path, run_id)
324-
model, tokenizer = self._model_service.load_model(
325-
copied_model_pack_path,
326-
load_in_4bit=True, # for memory efficient training
325+
logger.info("Loading a PEFT model for training...")
326+
model_pack_file_ext = get_model_data_package_extension(self._model_pack_path)
327+
trained_model_pack_path = self._model_pack_path.replace(
328+
model_pack_file_ext,
329+
f"_trained_{run_id}{model_pack_file_ext}",
327330
)
328-
copied_model_directory = os.path.join(
329-
os.path.dirname(copied_model_pack_path),
330-
get_model_data_package_base_name(copied_model_pack_path),
331+
model, tokenizer = self._model_service.model, self._model_service.tokenizer
332+
trained_model_directory = os.path.join(
333+
os.path.dirname(trained_model_pack_path),
334+
get_model_data_package_base_name(trained_model_pack_path),
331335
)
332336

333337
if non_default_device_is_available(self._config.DEVICE):
@@ -355,7 +359,7 @@ def run(
355359
],
356360
)
357361

358-
model = get_peft_model(model, lora_config)
362+
peft_model = get_peft_model(model, lora_config)
359363

360364
mlflow_logging_callback = MLflowLoggingCallback(self._tracker_client)
361365
cancel_event_check_callback = CancelEventCheckCallback(self._cancel_event)
@@ -378,27 +382,26 @@ def run(
378382
training_args = GRPOConfig(
379383
output_dir=results_path,
380384
logging_dir=logs_path,
385+
logging_steps=log_frequency,
381386
learning_rate=5e-6,
382387
adam_beta1=0.9,
383388
adam_beta2=0.99,
384389
weight_decay=0.1,
385390
warmup_ratio=0.1,
386391
lr_scheduler_type="cosine",
387392
optim="paged_adamw_8bit",
388-
logging_steps=1,
389393
per_device_train_batch_size=6, # This global batch size must be divisible by the number of generations
390394
gradient_accumulation_steps=1,
391395
num_generations=6,
392396
max_prompt_length=max_prompt_length,
393397
max_completion_length=max_seq_length - max_prompt_length,
394398
num_train_epochs = training_params["nepochs"],
395-
max_steps=250,
396399
save_steps=250,
397400
max_grad_norm=0.1,
398401
report_to="none",
399402
)
400403
trainer = GRPOTrainer(
401-
model=model,
404+
model=peft_model,
402405
processing_class=tokenizer,
403406
reward_funcs=self._get_reward_functions(),
404407
args=training_args,
@@ -409,7 +412,7 @@ def run(
409412
else:
410413
raise ConfigurationException(f"Unsupported trainer type: {trainer_type}")
411414

412-
self._tracker_client.log_model_config(model.config.to_dict())
415+
self._tracker_client.log_model_config({**model.config.to_dict(), **peft_model.peft_config})
413416
self._tracker_client.log_trainer_version(TrainerBackend.TRANSFORMERS, transformers_version)
414417

415418
logger.info(f"Performing {trainer_type.upper()} training...")
@@ -422,11 +425,13 @@ def run(
422425
model_pack_file_ext = get_model_data_package_extension(self._config.BASE_MODEL_FILE)
423426
model_pack_file_name = f"{ModelType.HUGGINGFACE_LLM.value}_{run_id}{model_pack_file_ext}"
424427
retrained_model_pack_path = os.path.join(self._retrained_models_dir, model_pack_file_name)
428+
model = peft_model.merge_and_unload()
425429
model.save_pretrained(
426-
copied_model_directory,
430+
trained_model_directory,
427431
safe_serialization=(self._config.TRAINING_SAFE_MODEL_SERIALISATION == "true"),
428432
)
429-
create_model_data_package(copied_model_directory, retrained_model_pack_path)
433+
tokenizer.save_pretrained(trained_model_directory)
434+
create_model_data_package(trained_model_directory, retrained_model_pack_path)
430435
model_uri = self._tracker_client.save_model(
431436
retrained_model_pack_path,
432437
self._model_name,
@@ -475,7 +480,7 @@ def run(
475480
with self._training_lock:
476481
self._training_in_progress = False
477482
self._clean_up_training_cache()
478-
self._housekeep_file(copied_model_pack_path)
483+
self._housekeep_file(trained_model_pack_path)
479484
if trainer is not None:
480485
del trainer
481486
gc.collect()
@@ -505,6 +510,7 @@ def run(
505510
training_args = GRPOConfig(
506511
output_dir=results_path,
507512
logging_dir=logs_path,
513+
logging_steps=log_frequency,
508514
per_device_eval_batch_size=6,
509515
num_generations=2,
510516
max_prompt_length=max_prompt_length,
@@ -607,19 +613,19 @@ def correctness_reward_func(
607613
)
608614
return [2.0 if r == a else 0.0 for r, a in zip(extracted_responses, answer)]
609615

610-
def int_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]:
616+
def int_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]:
611617
responses = [completion[0]["content"] for completion in completions]
612618
extracted_responses = [extract_xml_answer(r) for r in responses]
613619
return [0.5 if r.isdigit() else 0.0 for r in extracted_responses]
614620

615-
def strict_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]:
621+
def strict_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]:
616622
"""Reward function that checks if the completion has a specific format."""
617623
pattern = r"^<reasoning>\n.*?\n</reasoning>\n<answer>\n.*?\n</answer>\n$"
618624
responses = [completion[0]["content"] for completion in completions]
619625
matches = [re.match(pattern, r) for r in responses]
620626
return [0.5 if match else 0.0 for match in matches]
621627

622-
def soft_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]:
628+
def soft_format_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]:
623629
"""Reward function that checks if the completion has a specific format."""
624630
pattern = r"<reasoning>.*?</reasoning>\s*<answer>.*?</answer>"
625631
responses = [completion[0]["content"] for completion in completions]
@@ -640,7 +646,7 @@ def count_xml(text: str) -> float:
640646
count -= (len(text.split("\n</answer>")[-1]) - 1) * 0.001
641647
return count
642648

643-
def xmlcount_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> list[float]:
649+
def xmlcount_reward_func(completions: Tuple[Any], **kwargs: Dict[str, Any]) -> List[float]:
644650
contents = [completion[0]["content"] for completion in completions]
645651
return [count_xml(c) for c in contents]
646652

app/utils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -547,11 +547,6 @@ def unpack_model_data_package(model_data_file_path: str, model_data_folder_path:
547547
elif model_data_file_path.endswith(".tar.gz"):
548548
with tarfile.open(model_data_file_path, "r:gz") as f:
549549
for member in f.getmembers():
550-
path_parts = member.name.split(os.sep)
551-
stripped_path = os.sep.join(path_parts[1:])
552-
if not stripped_path:
553-
continue
554-
member.name = stripped_path
555550
f.extract(member, path=model_data_folder_path)
556551
return True
557552
else:

0 commit comments

Comments
 (0)