1818
1919# Standard
2020from dataclasses import dataclass , field
21- from typing import List , Optional , Union
21+ from typing import Optional , Union
2222
2323# Third Party
2424import torch
2525
2626
2727@dataclass
28- class ModelArguments :
28+ class TypeChecker :
29+ def __post_init__ (self ):
30+ for name , field_type in self .__annotations__ .items ():
31+ val = self .__dict__ [name ]
32+ invalid_val = False
33+ if not field_type is list :
34+ if not isinstance (val , field_type ):
35+ invalid_val = True
36+ else :
37+ if not isinstance (val , list ) or not all (isinstance (item , int ) for item in val ):
38+ invalid_val = True
39+
40+ if invalid_val :
41+ current_type = type (val )
42+ raise TypeError (
43+ f"The field `{ name } ` was assigned by `{ current_type } ` instead of `{ field_type } `"
44+ )
45+
46+
47+ @dataclass
48+ class ModelArguments (TypeChecker ):
2949 """Dataclass for model related arguments."""
3050
3151 model_name_or_path : Optional [str ] = field (default = "facebook/opt-125m" )
@@ -60,24 +80,26 @@ class ModelArguments:
6080 )
6181 },
6282 )
63- device : str = field (
83+ device : Optional [ str ] = field (
6484 default = None ,
6585 metadata = {
66- "help" : ("`torch.device`: The device on which the module is (assuming that all the module parameters are on the same device)." )
67- }
86+ "help" : (
87+ "`torch.device`: The device on which the module is (assuming that all the module parameters are on the same device)."
88+ )
89+ },
6890 )
6991
7092
7193@dataclass
72- class DataArguments :
94+ class DataArguments ( TypeChecker ) :
7395 """Dataclass for data related arguments."""
7496
75- training_data_path : str = field (
97+ training_data_path : Optional [ str ] = field (
7698 default = None ,
7799 metadata = {"help" : "Path to the training data in JSON/JSONL format" },
78100 )
79- training_data_config : str = field (default = None )
80- test_data_path : str = field (
101+ training_data_config : Optional [ str ] = field (default = None )
102+ test_data_path : Optional [ str ] = field (
81103 default = None ,
82104 metadata = {"help" : "Path to the test data in JSON/JSONL format" },
83105 )
@@ -86,7 +108,7 @@ class DataArguments:
86108
87109
88110@dataclass
89- class OptArguments :
111+ class OptArguments ( TypeChecker ) :
90112 """Dataclass for optimization related arguments."""
91113
92114 quant_method : str = field (
@@ -104,7 +126,7 @@ class OptArguments:
104126
105127
106128@dataclass
107- class FMSMOArguments :
129+ class FMSMOArguments ( TypeChecker ) :
108130 """Dataclass arguments used by fms_mo native quantization functions."""
109131
110132 nbits_w : int = field (default = 32 , metadata = {"help" : ("weight precision" )})
@@ -139,7 +161,7 @@ class FMSMOArguments:
139161
140162
141163@dataclass
142- class GPTQArguments :
164+ class GPTQArguments ( TypeChecker ) :
143165 """Dataclass for GPTQ related arguments that will be used by auto-gptq."""
144166
145167 bits : int = field (default = 4 , metadata = {"choices" : [2 , 3 , 4 , 8 ]})
@@ -157,9 +179,9 @@ class GPTQArguments:
157179
158180
159181@dataclass
160- class FP8Arguments :
182+ class FP8Arguments ( TypeChecker ) :
161183 """Dataclass for FP8 related arguments that will be used by llm-compressor."""
162184
163185 targets : str = field (default = "Linear" )
164186 scheme : str = field (default = "FP8_DYNAMIC" )
165- ignore : List [str ] = field (default_factory = lambda : ["lm_head" ])
187+ ignore : list [str ] = field (default_factory = lambda : ["lm_head" ])
0 commit comments