Skip to content

Commit 1e8ca06

Browse files
committed
Add _test_simple_multiple_platforms_with_extras
Run with: ```console bazel test //tests/pypi/extension:test_simple_multiple_platforms_with_extras ```
1 parent 5fd3c86 commit 1e8ca06

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

tests/pypi/extension/extension_tests.bzl

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,96 @@ new-package==0.0.1 --hash=sha256:deadb00f2
401401

402402
_tests.append(_test_simple_multiple_python_versions)
403403

404+
def _test_simple_multiple_platforms_with_extras(env):
405+
"""TODO(hartikainen): Test that reproduces a multi-platform-with-extras issue."""
406+
# This test case is based on my issue where different requirement strings for the same package
407+
# (`jax` vs `jax[cuda12]`) for multiple platforms caused a "duplicate library" error (for details,
408+
# see https://github.com/bazel-contrib/rules_python/issues/2797#issuecomment-3143914644).
409+
pypi = _parse_modules(
410+
env,
411+
module_ctx = _mock_mctx(
412+
_mod(
413+
name = "rules_python",
414+
parse = [
415+
_parse(
416+
hub_name = "pypi",
417+
python_version = "3.12",
418+
download_only = True,
419+
requirements_by_platform = {
420+
"requirements.linux_arm64.txt": "linux_aarch64",
421+
"requirements.linux_x86_64.txt": "linux_x86_64",
422+
"requirements.macos_arm64.txt": "osx_aarch64",
423+
},
424+
experimental_index_url = "pypi.org",
425+
),
426+
],
427+
),
428+
read = lambda x: {
429+
"requirements.linux_arm64.txt": """\
430+
jax==0.7.0 \
431+
--hash=sha256:4dd8924f171ed73a4f1a6191e2f800ae1745069989b69fabc45593d6b6504003 \
432+
--hash=sha256:62833036cbaf4641d66ae94c61c0446890a91b2c0d153946583a0ebe04877a76
433+
""",
434+
"requirements.linux_x86_64.txt": """\
435+
jax[cuda12]==0.7.0 \
436+
--hash=sha256:62833036cbaf4641d66ae94c61c0446890a91b2c0d153946583a0ebe04877a76
437+
""",
438+
"requirements.macos_arm64.txt": """\
439+
jax==0.7.0 \
440+
--hash=sha256:4dd8924f171ed73a4f1a6191e2f800ae1745069989b69fabc45593d6b6504003 \
441+
--hash=sha256:62833036cbaf4641d66ae94c61c0446890a91b2c0d153946583a0ebe04877a76
442+
""",
443+
}[x],
444+
),
445+
available_interpreters = {
446+
"python_3_12_host": "unit_test_interpreter_target",
447+
},
448+
minor_mapping = {"3.12": "3.12.11"},
449+
simpleapi_download = lambda *_, **__: {
450+
"jax": parse_simpleapi_html(
451+
url = "https://example.com/jax",
452+
content = """
453+
<a href="jax-0.7.0.tar.gz#sha256=4dd8924f171ed73a4f1a6191e2f800ae1745069989b69fabc45593d6b6504003" data-requires-python="&gt;=3.11">jax-0.7.0.tar.gz</a>
454+
<a href="jax-0.7.0-py3-none-any.whl#sha256=62833036cbaf4641d66ae94c61c0446890a91b2c0d153946583a0ebe04877a76" data-requires-python="&gt;=3.11" data-dist-info-metadata="sha256=99d99c9ac3b0b8273e2c248da18a3b73dce72cc178336881324e9ecf8da36d0a" data-core-metadata="sha256=99d99c9ac3b0b8273e2c248da18a3b73dce72cc178336881324e9ecf8da36d0a">jax-0.7.0-py3-none-any.whl</a>
455+
""",
456+
),
457+
},
458+
)
459+
460+
pypi.exposed_packages().contains_exactly({"pypi": ["jax"]})
461+
# TODO(hartikainen): Check these expectations.
462+
pypi.hub_whl_map().contains_exactly({"pypi": {
463+
"jax": {
464+
"pypi_312_jax_py3_none_any_62833036": [
465+
whl_config_setting(
466+
# TODO(hartikainen): I think all these platforms use the same `.whl`
467+
# and thus all three platforms should be included in the same
468+
# `target_platforms` here?
469+
target_platforms = ["cp312_linux_arm64", "cp312_linux_x86_64", "cp312_osx_aarch64"],
470+
version = "3.12",
471+
),
472+
],
473+
},
474+
}})
475+
pypi.whl_libraries().contains_exactly({
476+
"pypi_312_jax_py3_none_any_62833036": {
477+
"dep_template": "@pypi//{name}:{target}",
478+
"download_only": True,
479+
"experimental_target_platforms": ["linux_arm64", "linux_x86_64", "osx_aarch64"],
480+
"filename": "jax-0.7.0-py3-none-any.whl",
481+
"python_interpreter_target": "unit_test_interpreter_target",
482+
# NOTE(hartikainen): Perhaps this is part of the problem?
483+
# This should say `jax[cuda12]==0.7.0` for `linux_x86_64` platform and
484+
# `jax==0.7.0` for `linux_arm64` and `osx_aarch64`.
485+
"requirement": "jax[cuda12]==0.7.0",
486+
"sha256": "62833036cbaf4641d66ae94c61c0446890a91b2c0d153946583a0ebe04877a76",
487+
"urls": ["https://example.com/jax-0.7.0-py3-none-any.whl"],
488+
},
489+
})
490+
pypi.whl_mods().contains_exactly({})
491+
492+
_tests.append(_test_simple_multiple_platforms_with_extras)
493+
404494
def _test_simple_with_markers(env):
405495
pypi = _parse_modules(
406496
env,

0 commit comments

Comments
 (0)