|
27 | 27 | import sagemaker.utils |
28 | 28 | from sagemaker.workflow import is_pipeline_variable |
29 | 29 |
|
30 | | -from sagemaker.deprecations import renamed_warning |
| 30 | +from sagemaker.deprecations import renamed_warning, renamed_kwargs |
31 | 31 |
|
32 | 32 | logger = logging.getLogger(__name__) |
33 | 33 |
|
@@ -600,6 +600,106 @@ def _validate_smdataparallel_args( |
600 | 600 | raise ValueError(err_msg) |
601 | 601 |
|
602 | 602 |
|
| 603 | +def validate_distribution( |
| 604 | + distribution, instance_groups, framework_name, framework_version, py_version, image_uri, kwargs |
| 605 | +): |
| 606 | + """Check if distribution strategy is correctly invoked by the user. |
| 607 | +
|
| 608 | + Currently, check for `dataparallel`, `modelparallel` and heterogeneous cluster set up. |
| 609 | + Validate if the user requested strategy is supported. |
| 610 | +
|
| 611 | + Args: |
| 612 | + distribution (dict): A dictionary with information to enable distributed training. |
| 613 | + (Defaults to None if distributed training is not enabled.) For example: |
| 614 | +
|
| 615 | + .. code:: python |
| 616 | +
|
| 617 | + { |
| 618 | + "smdistributed": { |
| 619 | + "dataparallel": { |
| 620 | + "enabled": True |
| 621 | + } |
| 622 | + } |
| 623 | + } |
| 624 | + instance_groups ([InstanceGroup]): A list contains instance groups used for training. |
| 625 | + framework_name (str): A string representing the name of framework selected. |
| 626 | + framework_version (str): A string representing the framework version selected. |
| 627 | + py_version (str): A string representing the python version selected. |
| 628 | + image_uri (str): A string representing a Docker image URI. |
| 629 | + kwargs(dict): Additional kwargs passed to this function |
| 630 | +
|
| 631 | + Returns: |
| 632 | + distribution(dict): updated dictionary with validated information |
| 633 | + to enable distributed training. |
| 634 | +
|
| 635 | + Raises: |
| 636 | + ValueError: if distribution dictionary isn't correctly formatted or |
| 637 | + multiple strategies are requested simultaneously or |
| 638 | + an unsupported strategy is requested or |
| 639 | + strategy-specific inputs are incorrect/unsupported or |
| 640 | + heterogeneous cluster set up is incorrect |
| 641 | + """ |
| 642 | + train_instance_groups = distribution.get("instance_groups", []) |
| 643 | + if instance_groups is None: |
| 644 | + if len(train_instance_groups) >= 1: |
| 645 | + # if estimator's instance_groups is not defined but |
| 646 | + # train_instance_groups are specified in distribution |
| 647 | + raise ValueError("Instance groups not specified in the estimator !") |
| 648 | + else: |
| 649 | + if len(train_instance_groups) > len(instance_groups): |
| 650 | + # if train_instance_groups in distribution are more than estimator's instance_groups |
| 651 | + raise ValueError("Train instance groups oversubscribed !") |
| 652 | + if len(instance_groups) == 1 and len(train_instance_groups) == 0: |
| 653 | + # if just one instance_group but it is not specified in distribution, we set it for user |
| 654 | + train_instance_groups = instance_groups |
| 655 | + elif len(instance_groups) > 1 and len(train_instance_groups) != 1: |
| 656 | + # currently we just support one train instance group |
| 657 | + raise ValueError("Distribution should only contain one instance group name !") |
| 658 | + |
| 659 | + if len(train_instance_groups) != 0: |
| 660 | + # in this case, we are handling a heterogeneous cluster training job |
| 661 | + instance_group_names = [] |
| 662 | + for train_instance_group in train_instance_groups: |
| 663 | + # in future version we will support multiple train_instance_groups, so use loop here |
| 664 | + if train_instance_group not in instance_groups: |
| 665 | + # check if train instance groups belongs to what user defined in estimator set up |
| 666 | + raise ValueError( |
| 667 | + f"Invalid training instance group {train_instance_group.instance_group_name} !" |
| 668 | + ) |
| 669 | + instance_type = train_instance_group.instance_type |
| 670 | + validate_smdistributed( |
| 671 | + instance_type=instance_type, |
| 672 | + framework_name=framework_name, |
| 673 | + framework_version=framework_version, |
| 674 | + py_version=py_version, |
| 675 | + distribution=distribution, |
| 676 | + image_uri=image_uri, |
| 677 | + ) |
| 678 | + warn_if_parameter_server_with_multi_gpu( |
| 679 | + training_instance_type=instance_type, distribution=distribution |
| 680 | + ) |
| 681 | + # get instance group names |
| 682 | + instance_group_names.append(train_instance_group.instance_group_name) |
| 683 | + distribution["instance_groups"] = instance_group_names |
| 684 | + else: |
| 685 | + # in this case, we are handling a normal training job (without heterogeneous cluster) |
| 686 | + instance_type = renamed_kwargs( |
| 687 | + "train_instance_type", "instance_type", kwargs.get("instance_type"), kwargs |
| 688 | + ) |
| 689 | + validate_smdistributed( |
| 690 | + instance_type=instance_type, |
| 691 | + framework_name=framework_name, |
| 692 | + framework_version=framework_version, |
| 693 | + py_version=py_version, |
| 694 | + distribution=distribution, |
| 695 | + image_uri=image_uri, |
| 696 | + ) |
| 697 | + warn_if_parameter_server_with_multi_gpu( |
| 698 | + training_instance_type=instance_type, distribution=distribution |
| 699 | + ) |
| 700 | + return distribution |
| 701 | + |
| 702 | + |
603 | 703 | def python_deprecation_warning(framework, latest_supported_version): |
604 | 704 | """Placeholder docstring""" |
605 | 705 | return PYTHON_2_DEPRECATION_WARNING.format( |
|
0 commit comments