|  | 
| 19 | 19 | from typing import Any, Dict, List, Literal, Optional, Type, Union | 
| 20 | 20 | 
 | 
| 21 | 21 | import torch | 
|  | 22 | +from packaging.specifiers import InvalidSpecifier, SpecifierSet | 
| 22 | 23 | 
 | 
| 23 | 24 | from ..configuration_utils import ConfigMixin, FrozenDict | 
| 24 | 25 | from ..utils import is_torch_available, logging | 
|  | 26 | +from ..utils.import_utils import _is_package_available | 
| 25 | 27 | 
 | 
| 26 | 28 | 
 | 
| 27 | 29 | if is_torch_available(): | 
| @@ -670,3 +672,86 @@ def make_doc_string( | 
| 670 | 672 |     output += format_output_params(outputs, indent_level=2) | 
| 671 | 673 | 
 | 
| 672 | 674 |     return output | 
|  | 675 | + | 
|  | 676 | + | 
|  | 677 | +def _validate_requirements(reqs): | 
|  | 678 | +    if reqs is None: | 
|  | 679 | +        normalized_reqs = {} | 
|  | 680 | +    else: | 
|  | 681 | +        if not isinstance(reqs, dict): | 
|  | 682 | +            raise ValueError( | 
|  | 683 | +                "Requirements must be provided as a dictionary mapping package names to version specifiers." | 
|  | 684 | +            ) | 
|  | 685 | +        normalized_reqs = _normalize_requirements(reqs) | 
|  | 686 | + | 
|  | 687 | +    if not normalized_reqs: | 
|  | 688 | +        return {} | 
|  | 689 | + | 
|  | 690 | +    final: Dict[str, str] = {} | 
|  | 691 | +    for req, specified_ver in normalized_reqs.items(): | 
|  | 692 | +        req_available, req_actual_ver = _is_package_available(req) | 
|  | 693 | +        if not req_available: | 
|  | 694 | +            logger.warning(f"{req} was specified in the requirements but wasn't found in the current environment.") | 
|  | 695 | + | 
|  | 696 | +        if specified_ver: | 
|  | 697 | +            try: | 
|  | 698 | +                specifier = SpecifierSet(specified_ver) | 
|  | 699 | +            except InvalidSpecifier as err: | 
|  | 700 | +                raise ValueError(f"Requirement specifier '{specified_ver}' for {req} is invalid.") from err | 
|  | 701 | + | 
|  | 702 | +            if req_actual_ver == "N/A": | 
|  | 703 | +                logger.warning( | 
|  | 704 | +                    f"Version of {req} could not be determined to validate requirement '{specified_ver}'. Things might work unexpected." | 
|  | 705 | +                ) | 
|  | 706 | +            elif not specifier.contains(req_actual_ver, prereleases=True): | 
|  | 707 | +                logger.warning( | 
|  | 708 | +                    f"{req} requirement '{specified_ver}' is not satisfied by the installed version {req_actual_ver}. Things might work unexpected." | 
|  | 709 | +                ) | 
|  | 710 | + | 
|  | 711 | +        final[req] = specified_ver | 
|  | 712 | + | 
|  | 713 | +    return final | 
|  | 714 | + | 
|  | 715 | + | 
|  | 716 | +def _normalize_requirements(reqs): | 
|  | 717 | +    if not reqs: | 
|  | 718 | +        return {} | 
|  | 719 | + | 
|  | 720 | +    normalized: "OrderedDict[str, str]" = OrderedDict() | 
|  | 721 | + | 
|  | 722 | +    def _accumulate(mapping: Dict[str, Any]): | 
|  | 723 | +        for pkg, spec in mapping.items(): | 
|  | 724 | +            if isinstance(spec, dict): | 
|  | 725 | +                # This is recursive because blocks are composable. This way, we can merge requirements | 
|  | 726 | +                # from multiple blocks. | 
|  | 727 | +                _accumulate(spec) | 
|  | 728 | +                continue | 
|  | 729 | + | 
|  | 730 | +            pkg_name = str(pkg).strip() | 
|  | 731 | +            if not pkg_name: | 
|  | 732 | +                raise ValueError("Requirement package name cannot be empty.") | 
|  | 733 | + | 
|  | 734 | +            spec_str = "" if spec is None else str(spec).strip() | 
|  | 735 | +            if spec_str and not spec_str.startswith(("<", ">", "=", "!", "~")): | 
|  | 736 | +                spec_str = f"=={spec_str}" | 
|  | 737 | + | 
|  | 738 | +            existing_spec = normalized.get(pkg_name) | 
|  | 739 | +            if existing_spec is not None: | 
|  | 740 | +                if not existing_spec and spec_str: | 
|  | 741 | +                    normalized[pkg_name] = spec_str | 
|  | 742 | +                elif existing_spec and spec_str and existing_spec != spec_str: | 
|  | 743 | +                    try: | 
|  | 744 | +                        combined_spec = SpecifierSet(",".join(filter(None, [existing_spec, spec_str]))) | 
|  | 745 | +                    except InvalidSpecifier: | 
|  | 746 | +                        logger.warning( | 
|  | 747 | +                            f"Conflicting requirements for '{pkg_name}' detected: '{existing_spec}' vs '{spec_str}'. Keeping '{existing_spec}'." | 
|  | 748 | +                        ) | 
|  | 749 | +                    else: | 
|  | 750 | +                        normalized[pkg_name] = str(combined_spec) | 
|  | 751 | +                continue | 
|  | 752 | + | 
|  | 753 | +            normalized[pkg_name] = spec_str | 
|  | 754 | + | 
|  | 755 | +    _accumulate(reqs) | 
|  | 756 | + | 
|  | 757 | +    return normalized | 
0 commit comments