Skip to content

Commit f86e1e6

Browse files
committed
fix: code refactor
Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent d8f51e6 commit f86e1e6

File tree

3 files changed

+10
-12
lines changed

3 files changed

+10
-12
lines changed

tests/acceleration/test_acceleration_dataclasses.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
FastKernelsConfig,
3434
FusedLoraConfig,
3535
)
36-
from tuning.config.acceleration_configs.odm import ODM, ODMConfig
3736
from tuning.config.acceleration_configs.quantized_lora_config import (
3837
AutoGPTQLoraConfig,
3938
BNBQLoraConfig,
@@ -97,13 +96,6 @@ def test_dataclass_parse_successfully():
9796
)
9897
assert isinstance(cfg.fast_moe, FastMoe)
9998

100-
# 5. Specifing "--odm" will parse an ODM class
101-
parser = transformers.HfArgumentParser(dataclass_types=ODMConfig)
102-
(cfg,) = parser.parse_args_into_dataclasses(
103-
["--odm", "2", "1", "train_loss", "0.1", "0.2"],
104-
)
105-
assert isinstance(cfg.odm, ODM)
106-
10799

108100
def test_two_dataclasses_parse_successfully_together():
109101
"""Ensure that the two dataclasses can parse arguments successfully

tests/test_sft_trainer.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2151,6 +2151,10 @@ def test_handler(element, **kwargs):
21512151
_validate_training(tempdir)
21522152

21532153

2154+
@pytest.mark.skipif(
2155+
not is_fms_accelerate_available(plugins="odm"),
2156+
reason="Only runs if fms-accelerate is installed along with online-data-mixing plugin",
2157+
)
21542158
@pytest.mark.parametrize(
21552159
"datafiles, datasetconfigname, reward_type",
21562160
[
@@ -2258,12 +2262,16 @@ def test_online_data_mixing_plugin_sample_training(
22582262
)
22592263
assert len(output_inference) > 0
22602264
assert (
2261-
"It takes 10 days for digging a trench of 100 m long, 50 m broad and 10 m deep."
2265+
"It takes 10 days for digging a trench of 100 m long, 50 m broad and 10 m deep. "
22622266
"What length of trench,\n25 m broad and 15 m deep can be dug in 30 days ?"
22632267
in output_inference
22642268
), f"{output_inference} does not include the prompt"
22652269

22662270

2271+
@pytest.mark.skipif(
2272+
not is_fms_accelerate_available(plugins="odm"),
2273+
reason="Only runs if fms-accelerate is installed along with online-data-mixing plugin",
2274+
)
22672275
@pytest.mark.parametrize(
22682276
"datafiles, datasetconfigname, reward_type",
22692277
[
@@ -2333,7 +2341,7 @@ def test_online_data_mixing_plugin_sample_training_no_validation_split(
23332341
)
23342342
assert len(output_inference) > 0
23352343
assert (
2336-
"It takes 10 days for digging a trench of 100 m long, 50 m broad and 10 m deep."
2344+
"It takes 10 days for digging a trench of 100 m long, 50 m broad and 10 m deep. "
23372345
"What length of trench,\n25 m broad and 15 m deep can be dug in 30 days ?"
23382346
in output_inference
23392347
), f"{output_inference} does not include the prompt"

tuning/sft_trainer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -632,8 +632,6 @@ def parse_arguments(parser, json_config=None):
632632
Configuration for padding free and packing.
633633
FastMoeConfig
634634
Configuration for accelerated MoE.
635-
ODMConfig
636-
Configuration for online data mixing feature.
637635
TrackerConfigs
638636
Configuration for all trackers.
639637
dict[str, str]

0 commit comments

Comments
 (0)