|
30 | 30 | import yaml
|
31 | 31 |
|
32 | 32 | from paddleformers.trainer.argparser import strtobool
|
| 33 | +from paddleformers.utils.download import DownloadSource |
33 | 34 | from paddleformers.utils.import_utils import is_package_available, is_paddle_available
|
34 | 35 |
|
35 | 36 | __all__ = ["get_vocab_list", "stable_softmax", "cross_entropy"]
|
@@ -539,3 +540,53 @@ def init_dist_env(self, config: dict = {}):
|
539 | 540 |
|
540 | 541 | fleet.init(is_collective=True, strategy=strategy)
|
541 | 542 | fleet.get_hybrid_communicate_group()
|
| 543 | + |
| 544 | + |
| 545 | +def set_proxy(download_hub: DownloadSource = None): |
| 546 | + """ |
| 547 | + set network proxy for downloading model from aistudio/huggingface/modelscope |
| 548 | + """ |
| 549 | + |
| 550 | + def decorator(func): |
| 551 | + def wrapper(*args, **kwargs): |
| 552 | + if download_hub is None: |
| 553 | + return func(*args, **kwargs) |
| 554 | + elif download_hub == DownloadSource.HUGGINGFACE: |
| 555 | + command = "source $work_dir/../../../proxy_hf && env" |
| 556 | + elif download_hub == DownloadSource.AISTUDIO: |
| 557 | + command = "source $work_dir/../../../proxy_aistudio && env" |
| 558 | + elif download_hub == DownloadSource.MODELSCOPE: |
| 559 | + command = "source $work_dir/../../../proxy_aistudio && env" # proxy_aistudio also suit for modelscope |
| 560 | + |
| 561 | + proc = subprocess.Popen(command, stdout=subprocess.PIPE, shell=True) |
| 562 | + out, _ = proc.communicate() |
| 563 | + |
| 564 | + proxy_env = {} |
| 565 | + for line in out.decode().splitlines(): |
| 566 | + if "=" not in line: |
| 567 | + continue |
| 568 | + key, _, value = line.partition("=") |
| 569 | + proxy_env[key] = value |
| 570 | + |
| 571 | + ori_env = {} |
| 572 | + proxy_vars = ["HTTP_PROXY", "HTTPS_PROXY", "NO_PROXY"] |
| 573 | + if download_hub == DownloadSource.AISTUDIO: |
| 574 | + proxy_vars.extend(["STUDIO_GIT_HOST", "STUDIO_CDN_HOST"]) |
| 575 | + |
| 576 | + for key in proxy_vars: |
| 577 | + if key in proxy_env: |
| 578 | + ori_env[key] = os.environ.get(key, "") |
| 579 | + os.environ[key] = proxy_env[key] |
| 580 | + |
| 581 | + try: |
| 582 | + return func(*args, **kwargs) |
| 583 | + finally: |
| 584 | + for key, old_value in ori_env.items(): |
| 585 | + if old_value is None: |
| 586 | + os.environ.pop(key, None) |
| 587 | + else: |
| 588 | + os.environ[key] = old_value |
| 589 | + |
| 590 | + return wrapper |
| 591 | + |
| 592 | + return decorator |
0 commit comments