1818
1919# Standard
2020from dataclasses import dataclass , field
21- from typing import List , Optional , Union , get_origin
21+ from typing import List , Optional , Union , get_args , get_origin
2222
2323# Third Party
2424import torch
@@ -32,7 +32,10 @@ def __post_init__(self):
3232 for name , field_type in self .__annotations__ .items ():
3333 val = self .__dict__ [name ]
3434 invalid_val = False
35- if not get_origin (field_type ) is list :
35+ if get_origin (field_type ) is Union :
36+ if not type (val ) in get_args (field_type ):
37+ invalid_val = True
38+ elif not get_origin (field_type ) is list :
3639 if not isinstance (val , field_type ):
3740 invalid_val = True
3841 else :
@@ -54,8 +57,8 @@ def __post_init__(self):
5457class ModelArguments (TypeChecker ):
5558 """Dataclass for model related arguments."""
5659
57- model_name_or_path : Optional [ str ] = field (default = "facebook/opt-125m" )
58- torch_dtype : Optional [ Union [torch .dtype , str ] ] = torch .bfloat16
60+ model_name_or_path : str = field (default = "facebook/opt-125m" )
61+ torch_dtype : Union [torch .dtype , str ] = torch .bfloat16
5962 use_fast_tokenizer : bool = field (
6063 default = True ,
6164 metadata = {
@@ -110,8 +113,8 @@ class DataArguments(TypeChecker):
110113 default = None ,
111114 metadata = {"help" : "Path to the test data in JSON/JSONL format" },
112115 )
113- max_seq_length : Optional [ int ] = field (default = 2048 )
114- num_calibration_samples : Optional [ int ] = field (default = 512 )
116+ max_seq_length : int = field (default = 2048 )
117+ num_calibration_samples : int = field (default = 512 )
115118
116119
117120@dataclass
0 commit comments