Skip to content

Commit 8a275ff

Browse files
Fix import (#266)
1 parent 5c1d0a7 commit 8a275ff

File tree

3 files changed

+23
-43
lines changed

3 files changed

+23
-43
lines changed

swift/llm/__init__.py

Lines changed: 7 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,35 +1,9 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2-
from typing import TYPE_CHECKING
3-
4-
from swift.utils.import_utils import _LazyModule
2+
from .app_ui import gradio_chat_demo, gradio_generation_demo, llm_app_ui
3+
from .infer import llm_infer, merge_lora, prepare_model_template
4+
from .rome import rome_infer
5+
# Recommend using `xxx_main`
6+
from .run import (app_ui_main, dpo_main, infer_main, merge_lora_main,
7+
rome_main, sft_main)
8+
from .sft import llm_sft
59
from .utils import *
6-
7-
if TYPE_CHECKING:
8-
from .app_ui import gradio_chat_demo, gradio_generation_demo, llm_app_ui
9-
from .infer import llm_infer, merge_lora, prepare_model_template
10-
from .rome import rome_infer
11-
# Recommend using `xxx_main`
12-
from .run import (app_ui_main, dpo_main, infer_main, merge_lora_main,
13-
rome_main, sft_main)
14-
from .sft import llm_sft
15-
else:
16-
_import_structure = {
17-
'app_ui': ['gradio_chat_demo', 'gradio_generation_demo', 'llm_app_ui'],
18-
'infer': ['llm_infer', 'merge_lora', 'prepare_model_template'],
19-
'rome': ['rome_infer'],
20-
'run': [
21-
'app_ui_main', 'dpo_main', 'infer_main', 'merge_lora_main',
22-
'rome_main', 'sft_main'
23-
],
24-
'sft': ['llm_sft'],
25-
}
26-
27-
import sys
28-
29-
sys.modules[__name__] = _LazyModule(
30-
__name__,
31-
globals()['__file__'],
32-
_import_structure,
33-
module_spec=__spec__,
34-
extra_objects={},
35-
)

swift/trainers/__init__.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
22
from typing import TYPE_CHECKING
33

4-
from transformers.trainer_utils import (EvaluationStrategy, FSDPOption,
5-
HPSearchBackend, HubStrategy,
6-
IntervalStrategy, SchedulerType)
7-
84
from swift.utils.import_utils import _LazyModule
95

106
if TYPE_CHECKING:
117
from .arguments import Seq2SeqTrainingArguments, TrainingArguments
128
from .dpo_trainers import DPOTrainer
139
from .trainers import Seq2SeqTrainer, Trainer
10+
from .utils import EvaluationStrategy, FSDPOption, HPSearchBackend, HubStrategy, \
11+
IntervalStrategy, SchedulerType, ShardedDDPOption
1412
else:
1513
_import_structure = {
1614
'arguments': ['Seq2SeqTrainingArguments', 'TrainingArguments'],
1715
'dpo_trainers': ['DPOTrainer'],
1816
'trainers': ['Seq2SeqTrainer', 'Trainer'],
17+
'utils': [
18+
'EvaluationStrategy', 'FSDPOption', 'HPSearchBackend',
19+
'HubStrategy', 'IntervalStrategy', 'SchedulerType',
20+
'ShardedDDPOption'
21+
]
1922
}
2023

2124
import sys
@@ -27,9 +30,3 @@
2730
module_spec=__spec__,
2831
extra_objects={},
2932
)
30-
31-
try:
32-
# https://github.com/huggingface/transformers/pull/25702
33-
from transformers.trainer_utils import ShardedDDPOption
34-
except ImportError:
35-
pass

swift/trainers/utils.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,15 @@
66
from typing import List, Union
77

88
from torch.nn import Module
9+
from transformers.trainer_utils import (EvaluationStrategy, FSDPOption,
10+
HPSearchBackend, HubStrategy,
11+
IntervalStrategy, SchedulerType)
12+
13+
try:
14+
# https://github.com/huggingface/transformers/pull/25702
15+
from transformers.trainer_utils import ShardedDDPOption
16+
except ImportError:
17+
ShardedDDPOption = None
918

1019

1120
def can_return_loss(model: Module) -> List[str]:

0 commit comments

Comments
 (0)