Skip to content

Commit 49e5f31

Browse files
authored
lazy more_elegant (#2451)
1 parent 215e3e4 commit 49e5f31

File tree

29 files changed

+1850
-209
lines changed

29 files changed

+1850
-209
lines changed

paddleformers/__init__.py

Lines changed: 48 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,19 @@
1414

1515
import os
1616
import sys
17+
from contextlib import suppress
1718
from datetime import datetime
19+
from typing import TYPE_CHECKING
20+
21+
from .utils.lazy_import import _LazyModule
1822

1923
PADDLEFORMERS_STABLE_VERSION = "PADDLEFORMERS_STABLE_VERSION"
2024

25+
with suppress(Exception):
26+
import paddle
27+
28+
paddle.disable_signal_handler()
29+
2130
# this version is used for develop and test.
2231
# release version will be added fixed version by setup.py.
2332
__version__ = "0.1.2.post"
@@ -38,20 +47,42 @@
3847
"This may cause PaddleFormers datasets to be unavailable in intranet. "
3948
"Please import paddleformers before datasets module to avoid download issues"
4049
)
41-
import paddle
42-
43-
from . import (
44-
data,
45-
datasets,
46-
mergekit,
47-
ops,
48-
peft,
49-
quantization,
50-
trainer,
51-
transformers,
52-
trl,
53-
utils,
54-
version,
55-
)
56-
57-
paddle.disable_signal_handler()
50+
51+
# module index
52+
modules = [
53+
"data",
54+
"datasets",
55+
"mergekit",
56+
"ops",
57+
"peft",
58+
"quantization",
59+
"trainer",
60+
"transformers",
61+
"trl",
62+
"utils",
63+
"version",
64+
]
65+
import_structure = {module: [] for module in modules}
66+
67+
if TYPE_CHECKING:
68+
from . import (
69+
data,
70+
datasets,
71+
mergekit,
72+
ops,
73+
peft,
74+
quantization,
75+
trainer,
76+
transformers,
77+
trl,
78+
utils,
79+
version,
80+
)
81+
else:
82+
sys.modules[__name__] = _LazyModule(
83+
__name__,
84+
globals()["__file__"],
85+
import_structure,
86+
module_spec=__spec__,
87+
extra_objects={"__version__": __version__},
88+
)

paddleformers/data/__init__.py

Lines changed: 93 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,96 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .blendable_dataset import *
16-
from .causal_dataset import *
17-
from .collate import *
18-
from .data_collator import *
19-
from .dist_dataloader import *
20-
from .sampler import *
21-
from .vocab import *
15+
16+
import sys
17+
from typing import TYPE_CHECKING
18+
19+
from ..utils.lazy_import import _LazyModule
20+
21+
import_structure = {
22+
"sampler": ["SamplerHelper"],
23+
"causal_dataset": [
24+
"check_data_split",
25+
"get_train_valid_test_split_",
26+
"get_datasets_weights_and_num_samples",
27+
"print_rank_0",
28+
"build_train_valid_test_datasets",
29+
"_build_train_valid_test_datasets",
30+
"get_indexed_dataset_",
31+
"GPTDataset",
32+
"_build_index_mappings",
33+
"_num_tokens",
34+
"_num_epochs",
35+
"_build_doc_idx",
36+
"_build_sample_idx",
37+
"_build_shuffle_idx",
38+
],
39+
"data_collator": [
40+
"DataCollatorForSeq2Seq",
41+
"default_data_collator",
42+
"DataCollator",
43+
"DataCollatorWithPadding",
44+
"InputDataClass",
45+
"DataCollatorMixin",
46+
"paddle_default_data_collator",
47+
"numpy_default_data_collator",
48+
"DefaultDataCollator",
49+
"DataCollatorForTokenClassification",
50+
"DataCollatorForEmbedding",
51+
"_paddle_collate_batch",
52+
"_numpy_collate_batch",
53+
"tolist",
54+
"DataCollatorForLanguageModeling",
55+
],
56+
"dist_dataloader": ["DummyDataset", "IterableDummyDataset", "DistDataLoader", "init_dataloader_comm_group"],
57+
"blendable_dataset": ["print_rank_0", "BlendableDataset"],
58+
"collate": ["Dict", "Pad", "Stack", "Tuple"],
59+
"vocab": ["Vocab"],
60+
"tokenizer": ["BaseTokenizer"],
61+
"indexed_dataset": [
62+
"print_rank_0",
63+
"get_available_dataset_impl",
64+
"make_dataset",
65+
"make_sft_dataset",
66+
"dataset_exists",
67+
"read_longs",
68+
"write_longs",
69+
"read_shorts",
70+
"write_shorts",
71+
"dtypes",
72+
"code",
73+
"index_file_path",
74+
"sft_index_file_path",
75+
"sft_data_file_path",
76+
"data_file_path",
77+
"loss_mask_file_path",
78+
"create_doc_idx",
79+
"IndexedDataset",
80+
"IndexedDatasetBuilder",
81+
"_warmup_mmap_file",
82+
"MMapIndexedDataset",
83+
"SFTMMapIndexedDataset",
84+
"make_builder",
85+
"SFTMMapIndexedDatasetBuilder",
86+
"MMapIndexedDatasetBuilder",
87+
"get_indexed_dataset_",
88+
"CompatibleIndexedDataset",
89+
],
90+
}
91+
92+
93+
if TYPE_CHECKING:
94+
from .blendable_dataset import *
95+
from .causal_dataset import *
96+
from .collate import *
97+
from .data_collator import *
98+
from .dist_dataloader import *
99+
from .sampler import *
100+
from .vocab import *
101+
else:
102+
sys.modules[__name__] = _LazyModule(
103+
__name__,
104+
globals()["__file__"],
105+
import_structure,
106+
module_spec=__spec__,
107+
)

paddleformers/datasets/__init__.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,48 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import sys
16+
from typing import TYPE_CHECKING
1517

16-
from .dataset import *
17-
from .embedding_dataset import *
18-
from .zero_padding_dataset import *
18+
from ..utils.lazy_import import _LazyModule
19+
20+
import_structure = {
21+
"zero_padding_dataset": [
22+
"block_diag",
23+
"generate_greedy_packs",
24+
"ZeroPadding",
25+
"ZeroPaddingMapDataset",
26+
"ZeroPaddingIterableDataset",
27+
],
28+
"dataset": [
29+
"load_from_ppnlp",
30+
"DatasetTuple",
31+
"import_main_class",
32+
"load_from_hf",
33+
"load_dataset",
34+
"MapDataset",
35+
"IterDataset",
36+
"DatasetBuilder",
37+
"SimpleBuilder",
38+
],
39+
"embedding_dataset": [
40+
"Example",
41+
"Sequence",
42+
"Pair",
43+
"EmbeddingDatasetMixin",
44+
"EmbeddingDataset",
45+
"EmbeddingIterableDataset",
46+
],
47+
}
48+
49+
if TYPE_CHECKING:
50+
from .dataset import *
51+
from .embedding_dataset import *
52+
from .zero_padding_dataset import *
53+
else:
54+
sys.modules[__name__] = _LazyModule(
55+
__name__,
56+
globals()["__file__"],
57+
import_structure,
58+
module_spec=__spec__,
59+
)

paddleformers/generation/__init__.py

Lines changed: 84 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,87 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
from .configuration_utils import GenerationConfig
15-
from .logits_process import (
16-
ForcedBOSTokenLogitsProcessor,
17-
ForcedEOSTokenLogitsProcessor,
18-
HammingDiversityLogitsProcessor,
19-
LogitsProcessor,
20-
LogitsProcessorList,
21-
MinLengthLogitsProcessor,
22-
RepetitionPenaltyLogitsProcessor,
23-
TopKProcess,
24-
TopPProcess,
25-
)
26-
from .stopping_criteria import (
27-
MaxLengthCriteria,
28-
MaxTimeCriteria,
29-
StoppingCriteria,
30-
StoppingCriteriaList,
31-
validate_stopping_criteria,
32-
)
33-
from .streamers import BaseStreamer, TextIteratorStreamer, TextStreamer
34-
from .utils import BeamSearchScorer, GenerationMixin, get_unfinished_flag
14+
15+
import sys
16+
from typing import TYPE_CHECKING
17+
18+
from ..utils.lazy_import import _LazyModule
19+
20+
import_structure = {
21+
"utils": [
22+
"GenerationMixin",
23+
"MinLengthLogitsProcessor",
24+
"convert_dtype",
25+
"get_unfinished_flag",
26+
"LogitsProcessor",
27+
"BeamHypotheses",
28+
"RepetitionPenaltyLogitsProcessor",
29+
"LogitsProcessorList",
30+
"TopKProcess",
31+
"map_structure",
32+
"BeamSearchScorer",
33+
"TopPProcess",
34+
"get_scale_by_dtype",
35+
"validate_stopping_criteria",
36+
],
37+
"model_outputs": ["ModelOutput"],
38+
"configuration_utils": ["GenerationConfig", "resolve_hf_generation_config_path"],
39+
"logits_process": [
40+
"MinLengthLogitsProcessor",
41+
"SequenceBiasLogitsProcessor",
42+
"NoRepeatNGramLogitsProcessor",
43+
"PrefixConstrainedLogitsProcessor",
44+
"TopPProcess",
45+
"LogitsWarper",
46+
"HammingDiversityLogitsProcessor",
47+
"ForcedEOSTokenLogitsProcessor",
48+
"ForcedBOSTokenLogitsProcessor",
49+
"LogitsProcessor",
50+
"RepetitionPenaltyLogitsProcessor",
51+
"TemperatureLogitsWarper",
52+
"TopKProcess",
53+
"_get_ngrams",
54+
"_get_generated_ngrams",
55+
"LogitsProcessorList",
56+
"NoBadWordsLogitsProcessor",
57+
"_calc_banned_ngram_tokens",
58+
],
59+
"stopping_criteria": [
60+
"validate_stopping_criteria",
61+
"StoppingCriteria",
62+
"MaxLengthCriteria",
63+
"StoppingCriteriaList",
64+
"MaxTimeCriteria",
65+
],
66+
"streamers": ["BaseStreamer", "TextIteratorStreamer", "TextStreamer"],
67+
}
68+
69+
if TYPE_CHECKING:
70+
from .configuration_utils import GenerationConfig
71+
from .logits_process import (
72+
ForcedBOSTokenLogitsProcessor,
73+
ForcedEOSTokenLogitsProcessor,
74+
HammingDiversityLogitsProcessor,
75+
LogitsProcessor,
76+
LogitsProcessorList,
77+
MinLengthLogitsProcessor,
78+
RepetitionPenaltyLogitsProcessor,
79+
TopKProcess,
80+
TopPProcess,
81+
)
82+
from .stopping_criteria import (
83+
MaxLengthCriteria,
84+
MaxTimeCriteria,
85+
StoppingCriteria,
86+
StoppingCriteriaList,
87+
validate_stopping_criteria,
88+
)
89+
from .streamers import BaseStreamer, TextIteratorStreamer, TextStreamer
90+
from .utils import BeamSearchScorer, GenerationMixin, get_unfinished_flag
91+
else:
92+
sys.modules[__name__] = _LazyModule(
93+
__name__,
94+
globals()["__file__"],
95+
import_structure,
96+
module_spec=__spec__,
97+
)

paddleformers/mergekit/__init__.py

Lines changed: 27 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,30 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .merge_config import *
16-
from .merge_method import *
17-
from .merge_model import *
18-
from .merge_utils import *
19-
from .sparsify_method import *
15+
16+
import sys
17+
from typing import TYPE_CHECKING
18+
19+
from ..utils.lazy_import import _LazyModule
20+
21+
import_structure = {
22+
"merge_model": ["save_file", "device_guard", "divide_lora_key_list", "divide_positions", "MergeModel"],
23+
"merge_method": ["MergeMethod"],
24+
"sparsify_method": ["SparsifyMethod"],
25+
"merge_utils": ["divide_positions", "divide_lora_key_list", "divide_safetensor_key_list"],
26+
"merge_config": ["MergeConfig"],
27+
}
28+
29+
if TYPE_CHECKING:
30+
from .merge_config import *
31+
from .merge_method import *
32+
from .merge_model import *
33+
from .merge_utils import *
34+
from .sparsify_method import *
35+
else:
36+
sys.modules[__name__] = _LazyModule(
37+
__name__,
38+
globals()["__file__"],
39+
import_structure,
40+
module_spec=__spec__,
41+
)

0 commit comments

Comments
 (0)