Skip to content

Commit 6a94b26

Browse files
committed
Add _test_simple_multiple_platforms_with_extras
Run with: ```console bazel test //tests/pypi/extension:test_simple_multiple_platforms_with_extras ```
1 parent 5e75007 commit 6a94b26

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

tests/pypi/extension/extension_tests.bzl

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,96 @@ new-package==0.0.1 --hash=sha256:deadb00f2
403403

404404
_tests.append(_test_simple_multiple_python_versions)
405405

406+
def _test_simple_multiple_platforms_with_extras(env):
407+
"""TODO(hartikainen): Test that reproduces a multi-platform-with-extras issue."""
408+
# This test case is based on my issue where different requirement strings for the same package
409+
# (`jax` vs `jax[cuda12]`) for multiple platforms caused a "duplicate library" error (for details,
410+
# see https://github.com/bazel-contrib/rules_python/issues/2797#issuecomment-3143914644).
411+
pypi = _parse_modules(
412+
env,
413+
module_ctx = _mock_mctx(
414+
_mod(
415+
name = "rules_python",
416+
parse = [
417+
_parse(
418+
hub_name = "pypi",
419+
python_version = "3.12",
420+
download_only = True,
421+
requirements_by_platform = {
422+
"requirements.linux_arm64.txt": "linux_aarch64",
423+
"requirements.linux_x86_64.txt": "linux_x86_64",
424+
"requirements.macos_arm64.txt": "osx_aarch64",
425+
},
426+
experimental_index_url = "pypi.org",
427+
),
428+
],
429+
),
430+
read = lambda x: {
431+
"requirements.linux_arm64.txt": """\
432+
jax==0.7.0 \
433+
--hash=sha256:4dd8924f171ed73a4f1a6191e2f800ae1745069989b69fabc45593d6b6504003 \
434+
--hash=sha256:62833036cbaf4641d66ae94c61c0446890a91b2c0d153946583a0ebe04877a76
435+
""",
436+
"requirements.linux_x86_64.txt": """\
437+
jax[cuda12]==0.7.0 \
438+
--hash=sha256:62833036cbaf4641d66ae94c61c0446890a91b2c0d153946583a0ebe04877a76
439+
""",
440+
"requirements.macos_arm64.txt": """\
441+
jax==0.7.0 \
442+
--hash=sha256:4dd8924f171ed73a4f1a6191e2f800ae1745069989b69fabc45593d6b6504003 \
443+
--hash=sha256:62833036cbaf4641d66ae94c61c0446890a91b2c0d153946583a0ebe04877a76
444+
""",
445+
}[x],
446+
),
447+
available_interpreters = {
448+
"python_3_12_host": "unit_test_interpreter_target",
449+
},
450+
minor_mapping = {"3.12": "3.12.11"},
451+
simpleapi_download = lambda *_, **__: {
452+
"jax": parse_simpleapi_html(
453+
url = "https://example.com/jax",
454+
content = """
455+
<a href="jax-0.7.0.tar.gz#sha256=4dd8924f171ed73a4f1a6191e2f800ae1745069989b69fabc45593d6b6504003" data-requires-python="&gt;=3.11">jax-0.7.0.tar.gz</a>
456+
<a href="jax-0.7.0-py3-none-any.whl#sha256=62833036cbaf4641d66ae94c61c0446890a91b2c0d153946583a0ebe04877a76" data-requires-python="&gt;=3.11" data-dist-info-metadata="sha256=99d99c9ac3b0b8273e2c248da18a3b73dce72cc178336881324e9ecf8da36d0a" data-core-metadata="sha256=99d99c9ac3b0b8273e2c248da18a3b73dce72cc178336881324e9ecf8da36d0a">jax-0.7.0-py3-none-any.whl</a>
457+
""",
458+
),
459+
},
460+
)
461+
462+
pypi.exposed_packages().contains_exactly({"pypi": ["jax"]})
463+
# TODO(hartikainen): Check these expectations.
464+
pypi.hub_whl_map().contains_exactly({"pypi": {
465+
"jax": {
466+
"pypi_312_jax_py3_none_any_62833036": [
467+
whl_config_setting(
468+
# TODO(hartikainen): I think all these platforms use the same `.whl`
469+
# and thus all three platforms should be included in the same
470+
# `target_platforms` here?
471+
target_platforms = ["cp312_linux_arm64", "cp312_linux_x86_64", "cp312_osx_aarch64"],
472+
version = "3.12",
473+
),
474+
],
475+
},
476+
}})
477+
pypi.whl_libraries().contains_exactly({
478+
"pypi_312_jax_py3_none_any_62833036": {
479+
"dep_template": "@pypi//{name}:{target}",
480+
"download_only": True,
481+
"experimental_target_platforms": ["linux_arm64", "linux_x86_64", "osx_aarch64"],
482+
"filename": "jax-0.7.0-py3-none-any.whl",
483+
"python_interpreter_target": "unit_test_interpreter_target",
484+
# NOTE(hartikainen): Perhaps this is part of the problem?
485+
# This should say `jax[cuda12]==0.7.0` for `linux_x86_64` platform and
486+
# `jax==0.7.0` for `linux_arm64` and `osx_aarch64`.
487+
"requirement": "jax[cuda12]==0.7.0",
488+
"sha256": "62833036cbaf4641d66ae94c61c0446890a91b2c0d153946583a0ebe04877a76",
489+
"urls": ["https://example.com/jax-0.7.0-py3-none-any.whl"],
490+
},
491+
})
492+
pypi.whl_mods().contains_exactly({})
493+
494+
_tests.append(_test_simple_multiple_platforms_with_extras)
495+
406496
def _test_simple_with_markers(env):
407497
pypi = _parse_modules(
408498
env,

0 commit comments

Comments
 (0)