@@ -16,7 +16,7 @@ def test_cpu_platform():
1616 assert result == [packages ]
1717
1818
19- def test_cuda_platform (monkeypatch ):
19+ def test_cuda_platform_linux (monkeypatch ):
2020 monkeypatch .setattr ("torchruntime.installer.os_name" , "Linux" )
2121 packages = ["torch" , "torchvision" ]
2222 result = get_install_commands ("cu112" , packages )
@@ -32,7 +32,7 @@ def test_cuda_platform_windows_installs_triton(monkeypatch):
3232 assert result == [packages + ["--index-url" , expected_url ], ["triton-windows" ]]
3333
3434
35- def test_cuda_nightly_platform (monkeypatch ):
35+ def test_cuda_nightly_platform_linux (monkeypatch ):
3636 monkeypatch .setattr ("torchruntime.installer.os_name" , "Linux" )
3737 packages = ["torch" , "torchvision" ]
3838 result = get_install_commands ("nightly/cu112" , packages )
@@ -48,14 +48,15 @@ def test_cuda_nightly_platform_windows_installs_triton(monkeypatch):
4848 assert result == [packages + ["--index-url" , expected_url ], ["triton-windows" ]]
4949
5050
51- def test_rocm_platform ():
51+ def test_rocm_4_platform_does_not_install_triton (monkeypatch ):
52+ monkeypatch .setattr ("torchruntime.installer.os_name" , "Linux" )
5253 packages = ["torch" , "torchvision" ]
5354 result = get_install_commands ("rocm4.2" , packages )
5455 expected_url = "https://download.pytorch.org/whl/rocm4.2"
5556 assert result == [packages + ["--index-url" , expected_url ]]
5657
5758
58- def test_rocm_platform_linux_installs_triton (monkeypatch ):
59+ def test_rocm_6_platform_linux_installs_triton (monkeypatch ):
5960 monkeypatch .setattr ("torchruntime.installer.os_name" , "Linux" )
6061 packages = ["torch" , "torchvision" ]
6162 result = get_install_commands ("rocm6.2" , packages )
0 commit comments