Skip to content

Commit 0eaf455

Browse files
Fix test on py39
Signed-off-by: Thara Palanivel <[email protected]>
1 parent 060bdb0 commit 0eaf455

File tree

1 file changed

+9
-6
lines changed

1 file changed

+9
-6
lines changed

fms_mo/training_args.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
# Standard
2020
from dataclasses import dataclass, field
21-
from typing import List, Optional, Union, get_origin
21+
from typing import List, Optional, Union, get_args, get_origin
2222

2323
# Third Party
2424
import torch
@@ -32,7 +32,10 @@ def __post_init__(self):
3232
for name, field_type in self.__annotations__.items():
3333
val = self.__dict__[name]
3434
invalid_val = False
35-
if not get_origin(field_type) is list:
35+
if get_origin(field_type) is Union:
36+
if not type(val) in get_args(field_type):
37+
invalid_val = True
38+
elif not get_origin(field_type) is list:
3639
if not isinstance(val, field_type):
3740
invalid_val = True
3841
else:
@@ -54,8 +57,8 @@ def __post_init__(self):
5457
class ModelArguments(TypeChecker):
5558
"""Dataclass for model related arguments."""
5659

57-
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
58-
torch_dtype: Optional[Union[torch.dtype, str]] = torch.bfloat16
60+
model_name_or_path: str = field(default="facebook/opt-125m")
61+
torch_dtype: Union[torch.dtype, str] = torch.bfloat16
5962
use_fast_tokenizer: bool = field(
6063
default=True,
6164
metadata={
@@ -110,8 +113,8 @@ class DataArguments(TypeChecker):
110113
default=None,
111114
metadata={"help": "Path to the test data in JSON/JSONL format"},
112115
)
113-
max_seq_length: Optional[int] = field(default=2048)
114-
num_calibration_samples: Optional[int] = field(default=512)
116+
max_seq_length: int = field(default=2048)
117+
num_calibration_samples: int = field(default=512)
115118

116119

117120
@dataclass

0 commit comments

Comments
 (0)