Skip to content

Commit dbeb760

Browse files
Fix test
Signed-off-by: Thara Palanivel <[email protected]>
1 parent 8556ae9 commit dbeb760

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

fms_mo/training_args.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
# Standard
2020
from dataclasses import dataclass, field
21-
from typing import Optional, Union
21+
from typing import List, Optional, Union, get_origin
2222

2323
# Third Party
2424
import torch
@@ -32,12 +32,14 @@ def __post_init__(self):
3232
for name, field_type in self.__annotations__.items():
3333
val = self.__dict__[name]
3434
invalid_val = False
35-
if not field_type is list:
35+
if not get_origin(field_type) is list:
3636
if not isinstance(val, field_type):
3737
invalid_val = True
3838
else:
39-
if not isinstance(val, list) or not all(
40-
isinstance(item, int) for item in val
39+
if not (
40+
get_origin(val) is list
41+
or type(val) is list # pylint: disable=unidiomatic-typecheck
42+
or all(isinstance(item, int) for item in val)
4143
):
4244
invalid_val = True
4345

@@ -189,4 +191,4 @@ class FP8Arguments(TypeChecker):
189191

190192
targets: str = field(default="Linear")
191193
scheme: str = field(default="FP8_DYNAMIC")
192-
ignore: list[str] = field(default_factory=lambda: ["lm_head"])
194+
ignore: List[str] = field(default_factory=lambda: ["lm_head"])

tox.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ deps = {[testenv:ruff]deps}
5959
commands =
6060
ruff check {posargs:--fix} .
6161
ruff format .
62+
isort .
6263
isort --check .
6364

6465
[testenv:spellcheck]

0 commit comments

Comments
 (0)