1- from typing import Literal , Union , List , Dict , Optional
2- from pydantic import BaseModel , FilePath , validator , Field
3-
4- from huggingface_hub .utils import validate_repo_id
1+ from typing import List , Literal , Optional , Union
52
63import torch
4+ from pydantic import BaseModel , Field , FilePath , validator
5+
76
87# TODO: Refactor this into multiple files...
98HfModelPath = str
109
10+
1111class QaConfig (BaseModel ):
12- llm_tests : Optional [List [str ]] = Field ([], description = "list of tests that needs to be connected" )
13-
12+ llm_tests : Optional [List [str ]] = Field ([], description = "list of tests that needs to be connected" )
13+
1414
1515class DataConfig (BaseModel ):
16- file_type : Literal ["json" , "csv" , "huggingface" ] = Field (
17- None , description = "File type"
18- )
19- path : Union [FilePath , HfModelPath ] = Field (
20- None , description = "Path to the file or HuggingFace model"
21- )
22- prompt : str = Field (
23- None , description = "Prompt for the model. Use {} brackets for column name"
24- )
16+ file_type : Literal ["json" , "csv" , "huggingface" ] = Field (None , description = "File type" )
17+ path : Union [FilePath , HfModelPath ] = Field (None , description = "Path to the file or HuggingFace model" )
18+ prompt : str = Field (None , description = "Prompt for the model. Use {} brackets for column name" )
2519 prompt_stub : str = Field (
2620 None ,
2721 description = "Stub for the prompt; this is injected during training. Use {} brackets for column name" ,
@@ -48,9 +42,7 @@ class DataConfig(BaseModel):
4842
4943
5044class BitsAndBytesConfig (BaseModel ):
51- load_in_8bit : Optional [bool ] = Field (
52- False , description = "Enable 8-bit quantization with LLM.int8()"
53- )
45+ load_in_8bit : Optional [bool ] = Field (False , description = "Enable 8-bit quantization with LLM.int8()" )
5446 llm_int8_threshold : Optional [float ] = Field (
5547 6.0 , description = "Outlier threshold for outlier detection in 8-bit quantization"
5648 )
@@ -61,9 +53,7 @@ class BitsAndBytesConfig(BaseModel):
6153 False ,
6254 description = "Enable splitting model parts between int8 on GPU and fp32 on CPU" ,
6355 )
64- llm_int8_has_fp16_weight : Optional [bool ] = Field (
65- False , description = "Run LLM.int8() with 16-bit main weights"
66- )
56+ llm_int8_has_fp16_weight : Optional [bool ] = Field (False , description = "Run LLM.int8() with 16-bit main weights" )
6757
6858 load_in_4bit : Optional [bool ] = Field (
6959 True ,
@@ -86,14 +76,10 @@ class ModelConfig(BaseModel):
8676 "NousResearch/Llama-2-7b-hf" ,
8777 description = "Path to the model (huggingface repo or local path)" ,
8878 )
89- device_map : Optional [str ] = Field (
90- "auto" , description = "device onto which to load the model"
91- )
79+ device_map : Optional [str ] = Field ("auto" , description = "device onto which to load the model" )
9280
9381 quantize : Optional [bool ] = Field (False , description = "Flag to enable quantization" )
94- bitsandbytes : BitsAndBytesConfig = Field (
95- None , description = "Bits and Bytes configuration"
96- )
82+ bitsandbytes : BitsAndBytesConfig = Field (None , description = "Bits and Bytes configuration" )
9783
9884 # @validator("hf_model_ckpt")
9985 # def validate_model(cls, v, **kwargs):
@@ -116,22 +102,12 @@ def set_device_map_to_none(cls, v, values, **kwargs):
116102
117103class LoraConfig (BaseModel ):
118104 r : Optional [int ] = Field (8 , description = "Lora rank" )
119- task_type : Optional [str ] = Field (
120- "CAUSAL_LM" , description = "Base Model task type during training"
121- )
105+ task_type : Optional [str ] = Field ("CAUSAL_LM" , description = "Base Model task type during training" )
122106
123- lora_alpha : Optional [int ] = Field (
124- 16 , description = "The alpha parameter for Lora scaling"
125- )
126- bias : Optional [str ] = Field (
127- "none" , description = "Bias type for Lora. Can be 'none', 'all' or 'lora_only'"
128- )
129- lora_dropout : Optional [float ] = Field (
130- 0.1 , description = "The dropout probability for Lora layers"
131- )
132- target_modules : Optional [List [str ]] = Field (
133- None , description = "The names of the modules to apply Lora to"
134- )
107+ lora_alpha : Optional [int ] = Field (16 , description = "The alpha parameter for Lora scaling" )
108+ bias : Optional [str ] = Field ("none" , description = "Bias type for Lora. Can be 'none', 'all' or 'lora_only'" )
109+ lora_dropout : Optional [float ] = Field (0.1 , description = "The dropout probability for Lora layers" )
110+ target_modules : Optional [List [str ]] = Field (None , description = "The names of the modules to apply Lora to" )
135111 fan_in_fan_out : Optional [bool ] = Field (
136112 False ,
137113 description = "Flag to indicate if the layer to replace stores weight like (fan_in, fan_out)" ,
@@ -140,9 +116,7 @@ class LoraConfig(BaseModel):
140116 None ,
141117 description = "List of modules apart from LoRA layers to be set as trainable and saved in the final checkpoint" ,
142118 )
143- layers_to_transform : Optional [Union [List [int ], int ]] = Field (
144- None , description = "The layer indexes to transform"
145- )
119+ layers_to_transform : Optional [Union [List [int ], int ]] = Field (None , description = "The layer indexes to transform" )
146120 layers_pattern : Optional [str ] = Field (None , description = "The layer pattern name" )
147121 # rank_pattern: Optional[Dict[str, int]] = Field(
148122 # {}, description="The mapping from layer names or regexp expression to ranks"
@@ -155,15 +129,9 @@ class LoraConfig(BaseModel):
155129# TODO: Get comprehensive Args!
156130class TrainingArgs (BaseModel ):
157131 num_train_epochs : Optional [int ] = Field (1 , description = "Number of training epochs" )
158- per_device_train_batch_size : Optional [int ] = Field (
159- 1 , description = "Batch size per training device"
160- )
161- gradient_accumulation_steps : Optional [int ] = Field (
162- 1 , description = "Number of steps for gradient accumulation"
163- )
164- gradient_checkpointing : Optional [bool ] = Field (
165- True , description = "Flag to enable gradient checkpointing"
166- )
132+ per_device_train_batch_size : Optional [int ] = Field (1 , description = "Batch size per training device" )
133+ gradient_accumulation_steps : Optional [int ] = Field (1 , description = "Number of steps for gradient accumulation" )
134+ gradient_checkpointing : Optional [bool ] = Field (True , description = "Flag to enable gradient checkpointing" )
167135 optim : Optional [str ] = Field ("paged_adamw_32bit" , description = "Optimizer" )
168136 logging_steps : Optional [int ] = Field (100 , description = "Number of logging steps" )
169137 learning_rate : Optional [float ] = Field (2.0e-4 , description = "Learning rate" )
@@ -172,9 +140,7 @@ class TrainingArgs(BaseModel):
172140 fp16 : Optional [bool ] = Field (False , description = "Flag to enable fp16" )
173141 max_grad_norm : Optional [float ] = Field (0.3 , description = "Maximum gradient norm" )
174142 warmup_ratio : Optional [float ] = Field (0.03 , description = "Warmup ratio" )
175- lr_scheduler_type : Optional [str ] = Field (
176- "constant" , description = "Learning rate scheduler type"
177- )
143+ lr_scheduler_type : Optional [str ] = Field ("constant" , description = "Learning rate scheduler type" )
178144
179145
180146# TODO: Get comprehensive Args!
0 commit comments