diff --git a/CHANGELOG.md b/CHANGELOG.md index f38732f7d8..7d9b648bea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -81,6 +81,8 @@ Unreleased changes template. {#v0-0-0-fixed} ### Fixed +* (pypi) Platform specific extras are now correctly handled when using + universal lock files with environment markers. Fixes [#2690](https://github.com/bazel-contrib/rules_python/pull/2690). * (runfiles) ({obj}`--bootstrap_impl=script`) Follow symlinks when searching for runfiles. * (toolchains) Do not try to run `chmod` when downloading non-windows hermetic toolchain repositories on Windows. Fixes diff --git a/python/private/pypi/parse_requirements.bzl b/python/private/pypi/parse_requirements.bzl index d2014a7eb9..1cbf094f5c 100644 --- a/python/private/pypi/parse_requirements.bzl +++ b/python/private/pypi/parse_requirements.bzl @@ -30,22 +30,9 @@ load("//python/private:normalize_name.bzl", "normalize_name") load("//python/private:repo_utils.bzl", "repo_utils") load(":index_sources.bzl", "index_sources") load(":parse_requirements_txt.bzl", "parse_requirements_txt") +load(":pep508_requirement.bzl", "requirement") load(":whl_target_platforms.bzl", "select_whls") -def _extract_version(entry): - """Extract the version part from the requirement string. - - - Args: - entry: {type}`str` The requirement string. - """ - version_start = entry.find("==") - if version_start != -1: - # Extract everything after '==' until the next space or end of the string - version, _, _ = entry[version_start + 2:].partition(" ") - return version - return None - def parse_requirements( ctx, *, @@ -111,19 +98,20 @@ def parse_requirements( # The requirement lines might have duplicate names because lines for extras # are returned as just the base package name. e.g., `foo[bar]` results # in an entry like `("foo", "foo[bar] == 1.0 ...")`. - requirements_dict = { - (normalize_name(entry[0]), _extract_version(entry[1])): entry - for entry in sorted( - parse_result.requirements, - # Get the longest match and fallback to original WORKSPACE sorting, - # which should get us the entry with most extras. - # - # FIXME @aignas 2024-05-13: The correct behaviour might be to get an - # entry with all aggregated extras, but it is unclear if we - # should do this now. - key = lambda x: (len(x[1].partition("==")[0]), x), - ) - }.values() + # Lines with different markers are not condidered duplicates. + requirements_dict = {} + for entry in sorted( + parse_result.requirements, + # Get the longest match and fallback to original WORKSPACE sorting, + # which should get us the entry with most extras. + # + # FIXME @aignas 2024-05-13: The correct behaviour might be to get an + # entry with all aggregated extras, but it is unclear if we + # should do this now. + key = lambda x: (len(x[1].partition("==")[0]), x), + ): + req = requirement(entry[1]) + requirements_dict[(req.name, req.version, req.marker)] = entry tokenized_options = [] for opt in parse_result.options: @@ -132,7 +120,7 @@ def parse_requirements( pip_args = tokenized_options + extra_pip_args for plat in plats: - requirements[plat] = requirements_dict + requirements[plat] = requirements_dict.values() options[plat] = pip_args requirements_by_platform = {} diff --git a/python/private/pypi/pep508_requirement.bzl b/python/private/pypi/pep508_requirement.bzl index 11f2b3e8fa..ee7b5dfc35 100644 --- a/python/private/pypi/pep508_requirement.bzl +++ b/python/private/pypi/pep508_requirement.bzl @@ -30,6 +30,16 @@ def requirement(spec): """ spec = spec.strip() requires, _, maybe_hashes = spec.partition(";") + + version_start = requires.find("==") + version = None + if version_start != -1: + # Extract everything after '==' until the next space or end of the string + version, _, _ = requires[version_start + 2:].partition(" ") + + # Remove any trailing characters from the version string + version = version.strip(" ") + marker, _, _ = maybe_hashes.partition("--hash") requires, _, extras_unparsed = requires.partition("[") extras_unparsed, _, _ = extras_unparsed.partition("]") @@ -42,4 +52,5 @@ def requirement(spec): name = normalize_name(name).replace("_", "-"), marker = marker.strip(" "), extras = extras, + version = version, ) diff --git a/tests/pypi/extension/extension_tests.bzl b/tests/pypi/extension/extension_tests.bzl index 1652e76156..66c9e0549e 100644 --- a/tests/pypi/extension/extension_tests.bzl +++ b/tests/pypi/extension/extension_tests.bzl @@ -856,6 +856,84 @@ git_dep @ git+https://git.server/repo/project@deadbeefdeadbeef _tests.append(_test_simple_get_index) +def _test_optimum_sys_platform_extra(env): + pypi = _parse_modules( + env, + module_ctx = _mock_mctx( + _mod( + name = "rules_python", + parse = [ + _parse( + hub_name = "pypi", + python_version = "3.15", + requirements_lock = "universal.txt", + ), + ], + ), + read = lambda x: { + "universal.txt": """\ +optimum[onnxruntime]==1.17.1 ; sys_platform == 'darwin' +optimum[onnxruntime-gpu]==1.17.1 ; sys_platform == 'linux' +""", + }[x], + ), + available_interpreters = { + "python_3_15_host": "unit_test_interpreter_target", + }, + ) + + pypi.exposed_packages().contains_exactly({"pypi": []}) + pypi.hub_group_map().contains_exactly({"pypi": {}}) + pypi.hub_whl_map().contains_exactly({ + "pypi": { + "optimum": { + "pypi_315_optimum_linux_aarch64_linux_arm_linux_ppc_linux_s390x_linux_x86_64": [ + whl_config_setting( + version = "3.15", + target_platforms = [ + "cp315_linux_aarch64", + "cp315_linux_arm", + "cp315_linux_ppc", + "cp315_linux_s390x", + "cp315_linux_x86_64", + ], + config_setting = None, + filename = None, + ), + ], + "pypi_315_optimum_osx_aarch64_osx_x86_64": [ + whl_config_setting( + version = "3.15", + target_platforms = [ + "cp315_osx_aarch64", + "cp315_osx_x86_64", + ], + config_setting = None, + filename = None, + ), + ], + }, + }, + }) + + pypi.whl_libraries().contains_exactly({ + "pypi_315_optimum_linux_aarch64_linux_arm_linux_ppc_linux_s390x_linux_x86_64": { + "dep_template": "@pypi//{name}:{target}", + "python_interpreter_target": "unit_test_interpreter_target", + "repo": "pypi_315", + "requirement": "optimum[onnxruntime-gpu]==1.17.1", + }, + "pypi_315_optimum_osx_aarch64_osx_x86_64": { + "dep_template": "@pypi//{name}:{target}", + "python_interpreter_target": "unit_test_interpreter_target", + "repo": "pypi_315", + "requirement": "optimum[onnxruntime]==1.17.1", + }, + }) + pypi.whl_mods().contains_exactly({}) + +_tests.append(_test_optimum_sys_platform_extra) + def extension_test_suite(name): """Create the test suite. diff --git a/tests/pypi/pep508/requirement_tests.bzl b/tests/pypi/pep508/requirement_tests.bzl index 7c81ea50fc..9afb43a437 100644 --- a/tests/pypi/pep508/requirement_tests.bzl +++ b/tests/pypi/pep508/requirement_tests.bzl @@ -20,20 +20,21 @@ _tests = [] def _test_requirement_line_parsing(env): want = { - " name1[ foo ] ": ("name1", ["foo"]), - "Name[foo]": ("name", ["foo"]), - "name [fred,bar] @ http://foo.com ; python_version=='2.7'": ("name", ["fred", "bar"]), - "name; (os_name=='a' or os_name=='b') and os_name=='c'": ("name", [""]), - "name@http://foo.com": ("name", [""]), - "name[ Foo123 ]": ("name", ["Foo123"]), - "name[extra]@http://foo.com": ("name", ["extra"]), - "name[foo]": ("name", ["foo"]), - "name[quux, strange];python_version<'2.7' and platform_version=='2'": ("name", ["quux", "strange"]), - "name_foo[bar]": ("name-foo", ["bar"]), + " name1[ foo ] ": ("name1", ["foo"], None, ""), + "Name[foo]": ("name", ["foo"], None, ""), + "name [fred,bar] @ http://foo.com ; python_version=='2.7'": ("name", ["fred", "bar"], None, "python_version=='2.7'"), + "name; (os_name=='a' or os_name=='b') and os_name=='c'": ("name", [""], None, "(os_name=='a' or os_name=='b') and os_name=='c'"), + "name@http://foo.com": ("name", [""], None, ""), + "name[ Foo123 ]": ("name", ["Foo123"], None, ""), + "name[extra]@http://foo.com": ("name", ["extra"], None, ""), + "name[foo]": ("name", ["foo"], None, ""), + "name[quux, strange];python_version<'2.7' and platform_version=='2'": ("name", ["quux", "strange"], None, "python_version<'2.7' and platform_version=='2'"), + "name_foo[bar]": ("name-foo", ["bar"], None, ""), + "name_foo[bar]==0.25": ("name-foo", ["bar"], "0.25", ""), } got = { - i: (parsed.name, parsed.extras) + i: (parsed.name, parsed.extras, parsed.version, parsed.marker) for i, parsed in {case: requirement(case) for case in want}.items() } env.expect.that_dict(got).contains_exactly(want)