11from os .path import join
22
3- import bitsandbytes as bnb
4- import torch
53from datasets import Dataset
64from peft import (
75 LoraConfig ,
@@ -38,21 +36,10 @@ def __init__(self, config: Config, directory_helper: DirectoryHelper):
3836 self .model = None
3937 self .tokenizer = None
4038
41- """ TODO: Figure out how to handle multi-gpu
42- if config.accelerate:
43- self.accelerator = Accelerator()
44- self.accelerator.state.deepspeed_plugin.deepspeed_config[
45- "train_micro_batch_size_per_gpu"
46- ] = self.config.training.training_args.per_device_train_batch_size
47-
48- if config.accelerate:
49- # device_index = Accelerator().process_index
50- self.device_map = None #{"": device_index}
51- else:
52- """
5339 self .device_map = self ._model_config .device_map
5440
5541 self ._load_model_and_tokenizer ()
42+ self ._inject_lora ()
5643
5744 def _load_model_and_tokenizer (self ):
5845 ckpt = self ._model_config .hf_model_ckpt
@@ -67,11 +54,7 @@ def _load_model_and_tokenizer(self):
6754 def _get_model (self ):
6855 model = AutoModelForCausalLM .from_pretrained (
6956 self ._model_config .hf_model_ckpt ,
70- quantization_config = (
71- BitsAndBytesConfig (** self ._model_config .bitsandbytes .model_dump ())
72- if not self .config .accelerate
73- else None
74- ),
57+ quantization_config = BitsAndBytesConfig (** self ._model_config .bitsandbytes .model_dump ()),
7558 use_cache = False ,
7659 device_map = self .device_map ,
7760 torch_dtype = self ._model_config .casted_torch_dtype ,
@@ -90,19 +73,10 @@ def _get_tokenizer(self):
9073 return tokenizer
9174
9275 def _inject_lora (self ):
93- if not self .config .accelerate :
94- self .model .gradient_checkpointing_enable ()
95- self .model = prepare_model_for_kbit_training (self .model )
76+ self .model .gradient_checkpointing_enable ()
77+ self .model = prepare_model_for_kbit_training (self .model )
9678 self .model = get_peft_model (self .model , self ._lora_config )
9779
98- if not self .config .accelerate :
99- self .optimizer = bnb .optim .Adam8bit (self .model .parameters (), lr = self ._training_args .learning_rate )
100- self .lr_scheduler = torch .optim .lr_scheduler .ConstantLR (self .optimizer )
101- if self .config .accelerate :
102- self .model , self .optimizer , self .lr_scheduler = self .accelerator .prepare (
103- self .model , self .optimizer , self .lr_scheduler
104- )
105-
10680 def finetune (self , train_dataset : Dataset ):
10781 logging_dir = join (self ._weights_path , "/logs" )
10882 training_args = TrainingArguments (
@@ -123,7 +97,6 @@ def finetune(self, train_dataset: Dataset):
12397 args = training_args ,
12498 dataset_text_field = "formatted_prompt" , # TODO: maybe move consts to a dedicated folder
12599 callbacks = [progress_callback ],
126- # optimizers=[self.optimizer, self.lr_scheduler],
127100 ** self ._sft_args .model_dump (),
128101 )
129102
0 commit comments