|
38 | 38 |
|
39 | 39 | logger = logging.getLogger(__name__) |
40 | 40 |
|
41 | | - |
42 | 41 | if sys.version_info >= (3, 5): |
43 | 42 | from json.decoder import JSONDecodeError |
44 | 43 | else: |
45 | 44 | JSONDecodeError = ValueError |
46 | 45 |
|
47 | | - |
48 | 46 | DEPENDENCIES_PATTERN = re.compile( |
49 | 47 | r"^(?P<name>[\w\-]+)((?P<operation>==|>=|>)" |
50 | 48 | r"(?P<version>(\d+\.)?(\d+\.)?(\d+)?(dev)?[0-9]*))?$" |
51 | 49 | ) |
52 | 50 |
|
53 | | - |
54 | 51 | SIMPLE_NUMPY_TYPES = [ |
55 | 52 | nptype |
56 | 53 | for type_cat, nptypes in np.sctypes.items() |
@@ -580,15 +577,11 @@ def _is_cross_validator(self, o: Any) -> bool: |
580 | 577 |
|
581 | 578 | @classmethod |
582 | 579 | def _is_sklearn_flow(cls, flow: OpenMLFlow) -> bool: |
583 | | - if getattr(flow, "dependencies", None) is not None and "sklearn" in flow.dependencies: |
584 | | - return True |
585 | | - if flow.external_version is None: |
586 | | - return False |
587 | | - else: |
588 | | - return ( |
589 | | - flow.external_version.startswith("sklearn==") |
590 | | - or ",sklearn==" in flow.external_version |
591 | | - ) |
| 580 | + sklearn_dependency = isinstance(flow.dependencies, str) and "sklearn" in flow.dependencies |
| 581 | + sklearn_as_external = isinstance(flow.external_version, str) and ( |
| 582 | + flow.external_version.startswith("sklearn==") or ",sklearn==" in flow.external_version |
| 583 | + ) |
| 584 | + return sklearn_dependency or sklearn_as_external |
592 | 585 |
|
593 | 586 | def _get_sklearn_description(self, model: Any, char_lim: int = 1024) -> str: |
594 | 587 | """Fetches the sklearn function docstring for the flow description |
@@ -1867,24 +1860,22 @@ def is_subcomponent_specification(values): |
1867 | 1860 | # checks whether the current value can be a specification of |
1868 | 1861 | # subcomponents, as for example the value for steps parameter |
1869 | 1862 | # (in Pipeline) or transformers parameter (in |
1870 | | - # ColumnTransformer). These are always lists/tuples of lists/ |
1871 | | - # tuples, size bigger than 2 and an OpenMLFlow item involved. |
1872 | | - if not isinstance(values, (tuple, list)): |
1873 | | - return False |
1874 | | - for item in values: |
1875 | | - if not isinstance(item, (tuple, list)): |
1876 | | - return False |
1877 | | - if len(item) < 2: |
1878 | | - return False |
1879 | | - if not isinstance(item[1], (openml.flows.OpenMLFlow, str)): |
1880 | | - if ( |
| 1863 | + # ColumnTransformer). |
| 1864 | + return ( |
| 1865 | + # Specification requires list/tuple of list/tuple with |
| 1866 | + # at least length 2. |
| 1867 | + isinstance(values, (tuple, list)) |
| 1868 | + and all(isinstance(item, (tuple, list)) and len(item) > 1 for item in values) |
| 1869 | + # And each component needs to be a flow or interpretable string |
| 1870 | + and all( |
| 1871 | + isinstance(item[1], openml.flows.OpenMLFlow) |
| 1872 | + or ( |
1881 | 1873 | isinstance(item[1], str) |
1882 | 1874 | and item[1] in SKLEARN_PIPELINE_STRING_COMPONENTS |
1883 | | - ): |
1884 | | - pass |
1885 | | - else: |
1886 | | - return False |
1887 | | - return True |
| 1875 | + ) |
| 1876 | + for item in values |
| 1877 | + ) |
| 1878 | + ) |
1888 | 1879 |
|
1889 | 1880 | # _flow is openml flow object, _param dict maps from flow name to flow |
1890 | 1881 | # id for the main call, the param dict can be overridden (useful for |
|
0 commit comments