|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | import os |
| 15 | +from unittest.mock import MagicMock, patch |
| 16 | + |
15 | 17 | import uuid |
16 | 18 |
|
| 19 | +from google.cloud.tpu_v2.services.tpu.pagers import ListNodesPager |
17 | 20 | from google.cloud.tpu_v2.types import AcceleratorConfig, Node |
18 | 21 |
|
19 | 22 | import pytest |
|
28 | 31 | import stop_tpu |
29 | 32 |
|
30 | 33 |
|
31 | | -TPU_NAME = "test-tpu-" + uuid.uuid4().hex[:10] |
| 34 | +TPU_NAME = "test-tpu-" + uuid.uuid4().hex[:6] |
32 | 35 | PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT") |
33 | 36 | ZONE = "us-south1-a" |
| 37 | +FULL_TPU_NAME = f"projects/{PROJECT_ID}/locations/{ZONE}/nodes/{TPU_NAME}" |
34 | 38 | TPU_TYPE = "v5litepod-1" |
35 | 39 | TPU_VERSION = "tpu-vm-tf-2.17.0-pjrt" |
| 40 | +METADATA = { |
| 41 | + "startup-script": """#!/bin/bash |
| 42 | + echo "Hello World" > /var/log/hello.log |
| 43 | + sudo pip3 install --upgrade numpy >> /var/log/hello.log 2>&1 |
| 44 | + """ |
| 45 | +} |
| 46 | + |
| 47 | + |
| 48 | +@pytest.fixture |
| 49 | +def mock_tpu_client() -> MagicMock: |
| 50 | + with patch("google.cloud.tpu_v2.TpuClient") as mock_client: |
| 51 | + yield mock_client.return_value |
| 52 | + |
| 53 | + |
| 54 | +@pytest.fixture |
| 55 | +def operation() -> MagicMock: |
| 56 | + yield MagicMock() |
| 57 | + |
| 58 | + |
| 59 | +def test_creating_tpu(mock_tpu_client: MagicMock, operation: MagicMock) -> None: |
| 60 | + mock_response = MagicMock(spec=Node) |
| 61 | + mock_response.state = Node.State.READY |
| 62 | + mock_response.name = FULL_TPU_NAME |
| 63 | + |
| 64 | + mock_tpu_client.create_node.return_value = operation |
| 65 | + operation.result.return_value = mock_response |
| 66 | + |
| 67 | + tpu = create_tpu.create_cloud_tpu(PROJECT_ID, ZONE, TPU_NAME, TPU_TYPE, TPU_VERSION) |
| 68 | + |
| 69 | + assert tpu.name == FULL_TPU_NAME |
| 70 | + assert tpu.state == Node.State.READY |
| 71 | + mock_tpu_client.create_node.assert_called_once() |
| 72 | + operation.result.assert_called_once() |
| 73 | + |
| 74 | + |
| 75 | +def test_delete_tpu(mock_tpu_client: MagicMock) -> None: |
| 76 | + delete_tpu.delete_cloud_tpu(PROJECT_ID, ZONE, TPU_NAME) |
| 77 | + mock_tpu_client.delete_node.called_once() |
36 | 78 |
|
37 | 79 |
|
38 | | -# Instance of TPU |
39 | | -@pytest.fixture(scope="session") |
40 | | -def tpu_instance() -> Node: |
41 | | - yield create_tpu.create_cloud_tpu(PROJECT_ID, ZONE, TPU_NAME, TPU_TYPE, TPU_VERSION) |
42 | | - try: |
43 | | - delete_tpu.delete_cloud_tpu(PROJECT_ID, ZONE, TPU_NAME) |
44 | | - except Exception as e: |
45 | | - print(f"Error during cleanup: {e}") |
| 80 | +def test_creating_with_startup_script( |
| 81 | + mock_tpu_client: MagicMock, operation: MagicMock |
| 82 | +) -> None: |
| 83 | + mock_response = MagicMock(spec=Node) |
| 84 | + mock_response.metadata = METADATA |
| 85 | + mock_tpu_client.create_node.return_value = operation |
| 86 | + operation.result.return_value = mock_response |
46 | 87 |
|
| 88 | + tpu_with_script = create_tpu_with_script.create_cloud_tpu_with_script( |
| 89 | + PROJECT_ID, ZONE, TPU_NAME, TPU_TYPE, TPU_VERSION |
| 90 | + ) |
47 | 91 |
|
48 | | -def test_creating_tpu(tpu_instance: Node) -> None: |
49 | | - assert tpu_instance.state == Node.State.READY |
| 92 | + mock_tpu_client.create_node.assert_called_once() |
| 93 | + operation.result.assert_called_once() |
| 94 | + assert "--upgrade numpy" in tpu_with_script.metadata["startup-script"] |
50 | 95 |
|
51 | 96 |
|
52 | | -def test_creating_with_startup_script() -> None: |
53 | | - tpu_name_with_script = "script-tpu-" + uuid.uuid4().hex[:5] |
54 | | - try: |
55 | | - tpu_with_script = create_tpu_with_script.create_cloud_tpu_with_script( |
56 | | - PROJECT_ID, ZONE, tpu_name_with_script, TPU_TYPE, TPU_VERSION |
57 | | - ) |
58 | | - assert "--upgrade numpy" in tpu_with_script.metadata["startup-script"] |
59 | | - finally: |
60 | | - print(f"\n\n ------------ Deleting TPU {TPU_NAME}\n ------------") |
61 | | - delete_tpu.delete_cloud_tpu(PROJECT_ID, ZONE, tpu_name_with_script) |
| 97 | +def test_get_tpu(mock_tpu_client: MagicMock) -> None: |
| 98 | + mock_response = MagicMock(spec=Node) |
| 99 | + mock_response.name = FULL_TPU_NAME |
| 100 | + mock_response.state = Node.State.READY |
62 | 101 |
|
| 102 | + mock_tpu_client.get_node.return_value = mock_response |
63 | 103 |
|
64 | | -def test_get_tpu() -> None: |
65 | 104 | tpu = get_tpu.get_cloud_tpu(PROJECT_ID, ZONE, TPU_NAME) |
| 105 | + |
66 | 106 | assert tpu.state == Node.State.READY |
67 | | - assert tpu.name == f"projects/{PROJECT_ID}/locations/{ZONE}/nodes/{TPU_NAME}" |
| 107 | + assert tpu.name == FULL_TPU_NAME |
| 108 | + mock_tpu_client.get_node.assert_called_once() |
| 109 | + |
| 110 | + |
| 111 | +def test_list_tpu(mock_tpu_client: MagicMock) -> None: |
| 112 | + mock_pager = MagicMock(spec=ListNodesPager) |
| 113 | + nodes = [ |
| 114 | + Node(name="Node1", state=Node.State.READY), |
| 115 | + Node(name="Node2", state=Node.State.CREATING), |
| 116 | + ] |
| 117 | + mock_pager.__iter__.return_value = nodes |
68 | 118 |
|
| 119 | + mock_tpu_client.list_nodes.return_value = mock_pager |
69 | 120 |
|
70 | | -def test_list_tpu() -> None: |
71 | 121 | nodes = list_tpu.list_cloud_tpu(PROJECT_ID, ZONE) |
72 | 122 | assert len(list(nodes)) > 0 |
| 123 | + mock_tpu_client.list_nodes.assert_called_once() |
73 | 124 |
|
74 | 125 |
|
75 | | -def test_stop_tpu() -> None: |
| 126 | +def test_stop_tpu(mock_tpu_client: MagicMock, operation: MagicMock) -> None: |
| 127 | + mock_response = MagicMock(spec=Node) |
| 128 | + mock_response.state = Node.State.STOPPED |
| 129 | + |
| 130 | + mock_tpu_client.stop_node.return_value = operation |
| 131 | + operation.result.return_value = mock_response |
| 132 | + |
76 | 133 | node = stop_tpu.stop_cloud_tpu(PROJECT_ID, ZONE, TPU_NAME) |
| 134 | + |
| 135 | + mock_tpu_client.stop_node.assert_called_once() |
| 136 | + operation.result.assert_called_once() |
77 | 137 | assert node.state == Node.State.STOPPED |
78 | 138 |
|
79 | 139 |
|
80 | | -def test_start_tpu() -> None: |
| 140 | +def test_start_tpu(mock_tpu_client: MagicMock, operation: MagicMock) -> None: |
| 141 | + mock_response = MagicMock(spec=Node) |
| 142 | + mock_response.state = Node.State.READY |
| 143 | + |
| 144 | + mock_tpu_client.start_node.return_value = operation |
| 145 | + operation.result.return_value = mock_response |
| 146 | + |
81 | 147 | node = start_tpu.start_cloud_tpu(PROJECT_ID, ZONE, TPU_NAME) |
| 148 | + |
| 149 | + mock_tpu_client.start_node.assert_called_once() |
| 150 | + operation.result.assert_called_once() |
82 | 151 | assert node.state == Node.State.READY |
83 | 152 |
|
84 | 153 |
|
85 | | -def test_with_topology() -> None: |
86 | | - topology_tpu_name = "topology-tpu-" + uuid.uuid4().hex[:5] |
87 | | - topology_zone = "us-central1-a" |
88 | | - try: |
89 | | - topology_tpu = create_tpu_topology.create_cloud_tpu_with_topology( |
90 | | - PROJECT_ID, topology_zone, topology_tpu_name, TPU_VERSION |
91 | | - ) |
92 | | - assert topology_tpu.accelerator_config.type_ == AcceleratorConfig.Type.V3 |
93 | | - assert topology_tpu.accelerator_config.topology == "2x2" |
94 | | - finally: |
95 | | - delete_tpu.delete_cloud_tpu(PROJECT_ID, topology_zone, topology_tpu_name) |
| 154 | +def test_with_topology(mock_tpu_client: MagicMock, operation: MagicMock) -> None: |
| 155 | + from google.cloud import tpu_v2 |
| 156 | + |
| 157 | + mock_response = MagicMock(spec=Node) |
| 158 | + mock_response.accelerator_config = tpu_v2.AcceleratorConfig( |
| 159 | + type_=tpu_v2.AcceleratorConfig.Type.V3, |
| 160 | + topology="2x2", |
| 161 | + ) |
| 162 | + |
| 163 | + mock_tpu_client.create_node.return_value = operation |
| 164 | + operation.result.return_value = mock_response |
| 165 | + |
| 166 | + topology_tpu = create_tpu_topology.create_cloud_tpu_with_topology( |
| 167 | + PROJECT_ID, ZONE, TPU_NAME, TPU_VERSION |
| 168 | + ) |
| 169 | + assert topology_tpu.accelerator_config.type_ == AcceleratorConfig.Type.V3 |
| 170 | + assert topology_tpu.accelerator_config.topology == "2x2" |
| 171 | + mock_tpu_client.create_node.assert_called_once() |
| 172 | + operation.result.assert_called_once() |
0 commit comments