Skip to content

Commit cc3f730

Browse files
Fix invalid type checking test
Signed-off-by: Thara Palanivel <[email protected]>
1 parent e6df893 commit cc3f730

File tree

2 files changed

+40
-14
lines changed

2 files changed

+40
-14
lines changed

fms_mo/training_args.py

Lines changed: 36 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,34 @@
1818

1919
# Standard
2020
from dataclasses import dataclass, field
21-
from typing import List, Optional, Union
21+
from typing import Optional, Union
2222

2323
# Third Party
2424
import torch
2525

2626

2727
@dataclass
28-
class ModelArguments:
28+
class TypeChecker:
29+
def __post_init__(self):
30+
for name, field_type in self.__annotations__.items():
31+
val = self.__dict__[name]
32+
invalid_val = False
33+
if not field_type is list:
34+
if not isinstance(val, field_type):
35+
invalid_val = True
36+
else:
37+
if not isinstance(val, list) or not all(isinstance(item, int) for item in val):
38+
invalid_val = True
39+
40+
if invalid_val:
41+
current_type = type(val)
42+
raise TypeError(
43+
f"The field `{name}` was assigned by `{current_type}` instead of `{field_type}`"
44+
)
45+
46+
47+
@dataclass
48+
class ModelArguments(TypeChecker):
2949
"""Dataclass for model related arguments."""
3050

3151
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
@@ -60,24 +80,26 @@ class ModelArguments:
6080
)
6181
},
6282
)
63-
device: str = field(
83+
device: Optional[str] = field(
6484
default=None,
6585
metadata={
66-
"help": ("`torch.device`: The device on which the module is (assuming that all the module parameters are on the same device).")
67-
}
86+
"help": (
87+
"`torch.device`: The device on which the module is (assuming that all the module parameters are on the same device)."
88+
)
89+
},
6890
)
6991

7092

7193
@dataclass
72-
class DataArguments:
94+
class DataArguments(TypeChecker):
7395
"""Dataclass for data related arguments."""
7496

75-
training_data_path: str = field(
97+
training_data_path: Optional[str] = field(
7698
default=None,
7799
metadata={"help": "Path to the training data in JSON/JSONL format"},
78100
)
79-
training_data_config: str = field(default=None)
80-
test_data_path: str = field(
101+
training_data_config: Optional[str] = field(default=None)
102+
test_data_path: Optional[str] = field(
81103
default=None,
82104
metadata={"help": "Path to the test data in JSON/JSONL format"},
83105
)
@@ -86,7 +108,7 @@ class DataArguments:
86108

87109

88110
@dataclass
89-
class OptArguments:
111+
class OptArguments(TypeChecker):
90112
"""Dataclass for optimization related arguments."""
91113

92114
quant_method: str = field(
@@ -104,7 +126,7 @@ class OptArguments:
104126

105127

106128
@dataclass
107-
class FMSMOArguments:
129+
class FMSMOArguments(TypeChecker):
108130
"""Dataclass arguments used by fms_mo native quantization functions."""
109131

110132
nbits_w: int = field(default=32, metadata={"help": ("weight precision")})
@@ -139,7 +161,7 @@ class FMSMOArguments:
139161

140162

141163
@dataclass
142-
class GPTQArguments:
164+
class GPTQArguments(TypeChecker):
143165
"""Dataclass for GPTQ related arguments that will be used by auto-gptq."""
144166

145167
bits: int = field(default=4, metadata={"choices": [2, 3, 4, 8]})
@@ -157,9 +179,9 @@ class GPTQArguments:
157179

158180

159181
@dataclass
160-
class FP8Arguments:
182+
class FP8Arguments(TypeChecker):
161183
"""Dataclass for FP8 related arguments that will be used by llm-compressor."""
162184

163185
targets: str = field(default="Linear")
164186
scheme: str = field(default="FP8_DYNAMIC")
165-
ignore: List[str] = field(default_factory=lambda: ["lm_head"])
187+
ignore: list[str] = field(default_factory=lambda: ["lm_head"])

tests/build/test_launch_script.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,10 @@ def test_config_parsing_error():
209209
assert pytest_wrapped_e.value.code == USER_ERROR_EXIT_CODE
210210
assert os.stat(tempdir + "/termination-log").st_size > 0
211211

212+
with open(tempdir + "/termination-log", "r", encoding="utf-8") as f:
213+
contents = f.read()
214+
assert contents == "Exception raised during optimization. This may be a problem with your input: The field `nbits_w` was assigned by `<class 'str'>` instead of `<class 'int'>`"
215+
212216

213217
def _validate_termination_files_when_quantization_succeeds(base_dir):
214218
# Check termination log and .complete files exist

0 commit comments

Comments
 (0)