Skip to content

Commit fbccdad

Browse files
authored
fix import download_utils & support ci set network proxy (#2477)
1 parent 2ad8cbd commit fbccdad

File tree

7 files changed

+80
-13
lines changed

7 files changed

+80
-13
lines changed

paddleformers/transformers/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@
1818
from typing import TYPE_CHECKING
1919
from ..utils.lazy_import import _LazyModule
2020

21+
from .download_utils import *
22+
2123
# from .auto.modeling import AutoModelForCausalLM
2224
import_structure = {
2325
"kto_criterion": [

paddleformers/utils/download/download.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@
2828
RepositoryNotFoundError,
2929
RevisionNotFoundError,
3030
)
31-
from paddle import __version__
31+
32+
try:
33+
from paddle import __version__
34+
except ImportError:
35+
__version__ = ""
36+
3237
from requests import HTTPError
3338

3439
from ..log import logger

tests/dataset/test_zero_padding.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -74,9 +74,7 @@ def preprocess_fn(
7474
return_attention_mask=True,
7575
):
7676
inputs = example["sentence"][:2]
77-
model_inputs = self.tokenizer(
78-
inputs, max_length=max_src_length, truncation=True, return_attention_mask=False, return_position_ids=False
79-
)
77+
model_inputs = self.tokenizer(inputs, max_length=max_src_length, truncation=True, return_attention_mask=False)
8078
labels_input_ids = model_inputs["input_ids"] + [self.tokenizer.eos_token_id]
8179
model_inputs["labels"] = [-100] * len(model_inputs["input_ids"]) + labels_input_ids
8280
model_inputs["input_ids"] = model_inputs["input_ids"] + labels_input_ids

tests/testing_utils.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
import yaml
3131

3232
from paddleformers.trainer.argparser import strtobool
33+
from paddleformers.utils.download import DownloadSource
3334
from paddleformers.utils.import_utils import is_package_available, is_paddle_available
3435

3536
__all__ = ["get_vocab_list", "stable_softmax", "cross_entropy"]
@@ -539,3 +540,53 @@ def init_dist_env(self, config: dict = {}):
539540

540541
fleet.init(is_collective=True, strategy=strategy)
541542
fleet.get_hybrid_communicate_group()
543+
544+
545+
def set_proxy(download_hub: DownloadSource = None):
546+
"""
547+
set network proxy for downloading model from aistudio/huggingface/modelscope
548+
"""
549+
550+
def decorator(func):
551+
def wrapper(*args, **kwargs):
552+
if download_hub is None:
553+
return func(*args, **kwargs)
554+
elif download_hub == DownloadSource.HUGGINGFACE:
555+
command = "source $work_dir/../../../proxy_hf && env"
556+
elif download_hub == DownloadSource.AISTUDIO:
557+
command = "source $work_dir/../../../proxy_aistudio && env"
558+
elif download_hub == DownloadSource.MODELSCOPE:
559+
command = "source $work_dir/../../../proxy_aistudio && env" # proxy_aistudio also suit for modelscope
560+
561+
proc = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True)
562+
out, _ = proc.communicate()
563+
564+
proxy_env = {}
565+
for line in out.decode().splitlines():
566+
if "=" not in line:
567+
continue
568+
key, _, value = line.partition("=")
569+
proxy_env[key] = value
570+
571+
ori_env = {}
572+
proxy_vars = ["HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY"]
573+
if download_hub == DownloadSource.AISTUDIO:
574+
proxy_vars.extend(["STUDIO_GIT_HOST", "STUDIO_CDN_HOST"])
575+
576+
for key in proxy_vars:
577+
if key in proxy_env:
578+
ori_env[key] = os.environ.get(key, "")
579+
os.environ[key] = proxy_env[key]
580+
581+
try:
582+
return func(*args, **kwargs)
583+
finally:
584+
for key, old_value in ori_env.items():
585+
if old_value is None:
586+
os.environ.pop(key, None)
587+
else:
588+
os.environ[key] = old_value
589+
590+
return wrapper
591+
592+
return decorator

tests/transformers/auto/test_confiugration.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@
2323
from paddleformers.transformers import AutoConfig
2424
from paddleformers.transformers.auto.configuration import CONFIG_MAPPING
2525
from paddleformers.transformers.bert.configuration import BertConfig
26+
from paddleformers.utils.download import DownloadSource
2627
from paddleformers.utils.env import CONFIG_NAME
28+
from tests.testing_utils import set_proxy
2729

2830
from ...utils.test_module.custom_configuration import CustomConfig
2931

@@ -66,17 +68,20 @@ def test_community_model_class(self):
6668
self.assertEqual(auto_config.hidden_size, number)
6769

6870
@unittest.skip("skipping due to connection error!")
71+
# @set_proxy(DownloadSource.HUGGINGFACE)
6972
def test_from_hf_hub(self):
70-
config = AutoConfig.from_pretrained("facebook/opt-66b", download_hub="huggingface")
71-
self.assertEqual(config.hidden_size, 9216)
73+
config = AutoConfig.from_pretrained("dfargveazd/tiny-random-llama-paddle", download_hub="huggingface")
74+
self.assertEqual(config.hidden_size, 192)
7275

73-
@unittest.skip("skipping due to connection error!")
76+
# @unittest.skip("skipping due to connection error!")
77+
@set_proxy(DownloadSource.AISTUDIO)
7478
def test_from_aistudio(self):
7579
config = AutoConfig.from_pretrained("test_paddleformers/tiny-random-llama", download_hub="aistudio")
76-
self.assertEqual(config.hidden_size, 768)
80+
self.assertEqual(config.hidden_size, 192)
7781

78-
@unittest.skip("skipping due to connection error!")
79-
def test_from_mdoelscope(self):
82+
# @unittest.skip("skipping due to connection error!")
83+
@set_proxy(DownloadSource.MODELSCOPE)
84+
def test_from_modelscope(self):
8085
config = AutoConfig.from_pretrained("sqlhuman/tiny-random-llama", download_hub="modelscope")
8186
self.assertEqual(config.hidden_size, 768)
8287

tests/transformers/auto/test_modeling.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@
3333
)
3434
from paddleformers.transformers.auto.configuration import CONFIG_MAPPING
3535
from paddleformers.transformers.auto.modeling import MODEL_MAPPING
36+
from paddleformers.utils.download import DownloadSource
3637
from paddleformers.utils.env import CONFIG_NAME, PADDLE_WEIGHTS_NAME
38+
from tests.testing_utils import set_proxy
3739

3840
from ...utils.test_module.custom_configuration import CustomConfig
3941
from ...utils.test_module.custom_model import CustomModel
@@ -74,16 +76,19 @@ def test_model_from_pretrained_cache_dir(self):
7476
self.assertFalse(os.path.exists(os.path.join(tempdir, model_name, model_name)))
7577

7678
@unittest.skip("skipping due to connection error!")
79+
# @set_proxy(DownloadSource.HUGGINGFACE)
7780
def test_from_hf_hub(self):
78-
model = AutoModel.from_pretrained("dfargveazd/tiny-random-llama", download_hub="huggingface")
81+
model = AutoModel.from_pretrained("dfargveazd/tiny-random-llama-paddle", download_hub="huggingface")
7982
self.assertIsInstance(model, LlamaModel)
8083

81-
@unittest.skip("skipping due to connection error!")
84+
# @unittest.skip("skipping due to connection error!")
85+
@set_proxy(DownloadSource.AISTUDIO)
8286
def test_from_aistudio(self):
8387
model = AutoModel.from_pretrained("test_paddleformers/tiny-random-llama", download_hub="aistudio")
8488
self.assertIsInstance(model, LlamaModel)
8589

86-
@unittest.skip("skipping due to connection error!")
90+
# @unittest.skip("skipping due to connection error!")
91+
@set_proxy(DownloadSource.MODELSCOPE)
8792
def test_from_modelscope(self):
8893
model = AutoModel.from_pretrained("sqlhuman/tiny-random-llama", download_hub="modelscope")
8994
self.assertIsInstance(model, LlamaModel)

tests/transformers/test_configuration_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ def test_from_pretrained_cache_dir(self):
153153
self.assertFalse(os.path.exists(os.path.join(tempdir, model_id, model_id)))
154154

155155
@unittest.skip("skipping due to connection error!")
156+
# @set_proxy(DownloadSource.HUGGINGFACE)
156157
def test_load_from_hf(self):
157158
"""test load config from hf"""
158159
config = BertConfig.from_pretrained("hf-internal-testing/tiny-random-BertModel", download_hub="huggingface")

0 commit comments

Comments
 (0)