Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions python/ray/_private/runtime_env/install_ray_or_pip_packages.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,20 @@ def find_first_matching_wheel(whl_dir: str, whl_file_name: str) -> str:
matches = list(dir_path.glob(whl_file_name))
return str(matches[0]) if matches else ""

# Uninstall to ensure the python environments are clean
def uninstall_ray():
pip_uninstall_command = [sys.executable, "-m", "pip", "uninstall", "-y", "ray", "ant-ray", "ant-ray-nightly"]
result = subprocess.run(pip_uninstall_command)
logger.info("Uninstalled ray: {}. Return code: {}".format(pip_uninstall_command, result.returncode))


def install_ray_package(ray_version, whl_dir):
pip_install_command = [sys.executable, "-m", "pip", "install", "-U"]

if whl_dir:
# generate whl file name
whl_file_name = (
f"ant_ray-*cp{sys.version_info.major}{sys.version_info.minor}*.whl"
f"ant_ray*cp{sys.version_info.major}{sys.version_info.minor}*.whl"
)

# got the first matched wheel file
Expand All @@ -34,7 +41,7 @@ def install_ray_package(ray_version, whl_dir):
pip_install_command.append(whl_file_path + "[default]")
else:
# generate ray package name
ray_package_name = f"ant_ray=={ray_version}"
ray_package_name = f"ant_ray[default]=={ray_version}"
pip_install_command.append(ray_package_name)

logger.info("Starting install ray: {}".format(pip_install_command))
Expand Down Expand Up @@ -97,9 +104,11 @@ def install_pip_package(pip_packages, isolate_pip_installation):
)

if args.ray_version or args.whl_dir:
uninstall_ray()
install_ray_package(args.ray_version, args.whl_dir)
if args.packages:
pip_packages = json.loads(
base64.b64decode(args.packages.encode("utf-8")).decode("utf-8")
)
install_pip_package(pip_packages, isolate_pip_installation)

114 changes: 114 additions & 0 deletions python/ray/tests/test_install_ray_or_pip_packages.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
Unit tests for uninstall_ray function in install_ray_or_pip_packages.py
"""
import os
import sys
import subprocess
import unittest
from unittest.mock import patch, MagicMock

# Add the project root to Python path to import ray modules
project_root = os.path.join(os.path.dirname(__file__), '..', '..')
sys.path.insert(0, project_root)

from ray._private.runtime_env.install_ray_or_pip_packages import uninstall_ray


class TestUninstallRay(unittest.TestCase):
"""Unit test class for uninstall_ray function"""

def setUp(self):
"""Set up test fixtures before each test method."""
pass

def tearDown(self):
"""Clean up after each test method."""
pass

@patch('ray._private.runtime_env.install_ray_or_pip_packages.subprocess.run')
@patch('ray._private.runtime_env.install_ray_or_pip_packages.logger')
def test_uninstall_ray_success(self, mock_logger, mock_subprocess_run):
"""Test successful uninstallation of ray packages"""
# Set up mock return value
mock_result = MagicMock()
mock_result.returncode = 0
mock_subprocess_run.return_value = mock_result

# Call the function under test
uninstall_ray()

# Verify subprocess.run was called correctly
expected_command = [sys.executable, '-m', 'pip', 'uninstall', '-y',
'ray', 'ant-ray', 'ant-ray-nightly']
mock_subprocess_run.assert_called_once_with(expected_command)

# Verify logger.info was called
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args[0][0]
self.assertIn("Uninstalled ray", call_args)
self.assertIn("Return code: 0", call_args)

@patch('ray._private.runtime_env.install_ray_or_pip_packages.subprocess.run')
@patch('ray._private.runtime_env.install_ray_or_pip_packages.logger')
def test_uninstall_ray_failure(self, mock_logger, mock_subprocess_run):
"""Test uninstallation with non-zero return code"""
# Set up mock return value
mock_result = MagicMock()
mock_result.returncode = 1
mock_subprocess_run.return_value = mock_result

# Call the function under test
uninstall_ray()

# Verify subprocess.run was called
expected_command = [sys.executable, '-m', 'pip', 'uninstall', '-y',
'ray', 'ant-ray', 'ant-ray-nightly']
mock_subprocess_run.assert_called_once_with(expected_command)

# Verify logger.info was called even with non-zero return code
mock_logger.info.assert_called_once()
call_args = mock_logger.info.call_args[0][0]
self.assertIn("Uninstalled ray", call_args)
self.assertIn("Return code: 1", call_args)

@patch('ray._private.runtime_env.install_ray_or_pip_packages.subprocess.run')
@patch('ray._private.runtime_env.install_ray_or_pip_packages.logger')
def test_uninstall_ray_command_structure(self, mock_logger, mock_subprocess_run):
"""Test the structure of the uninstall command"""
mock_result = MagicMock()
mock_result.returncode = 0
mock_subprocess_run.return_value = mock_result

uninstall_ray()

# Verify the command structure
call_args = mock_subprocess_run.call_args[0][0]
self.assertEqual(call_args[0], sys.executable)
self.assertEqual(call_args[1], '-m')
self.assertEqual(call_args[2], 'pip')
self.assertEqual(call_args[3], 'uninstall')
self.assertEqual(call_args[4], '-y')

# Verify all expected packages are included
packages = call_args[5:]
self.assertIn('ray', packages)
self.assertIn('ant-ray', packages)
self.assertIn('ant-ray-nightly', packages)
self.assertEqual(len(packages), 3)

@patch('ray._private.runtime_env.install_ray_or_pip_packages.subprocess.run')
@patch('ray._private.runtime_env.install_ray_or_pip_packages.logger')
def test_uninstall_ray_exception_propagation(self, mock_logger, mock_subprocess_run):
"""Test that exceptions from subprocess.run are propagated"""
# Mock subprocess.run to raise an exception
mock_subprocess_run.side_effect = Exception("Mocked subprocess error")

# The exception should be propagated
with self.assertRaises(Exception) as context:
uninstall_ray()

self.assertEqual(str(context.exception), "Mocked subprocess error")


if __name__ == '__main__':
unittest.main()
Loading