Skip to content

Commit 4599e41

Browse files
committed
Refactor tests to use Mocking for TPU sample operations
1 parent 115ba78 commit 4599e41

File tree

1 file changed

+114
-37
lines changed

1 file changed

+114
-37
lines changed

tpu/test_tpu.py

Lines changed: 114 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,11 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from unittest.mock import MagicMock, patch
16+
1517
import uuid
1618

19+
from google.cloud.tpu_v2.services.tpu.pagers import ListNodesPager
1720
from google.cloud.tpu_v2.types import AcceleratorConfig, Node
1821

1922
import pytest
@@ -28,68 +31,142 @@
2831
import stop_tpu
2932

3033

31-
TPU_NAME = "test-tpu-" + uuid.uuid4().hex[:10]
34+
TPU_NAME = "test-tpu-" + uuid.uuid4().hex[:6]
3235
PROJECT_ID = os.getenv("GOOGLE_CLOUD_PROJECT")
3336
ZONE = "us-south1-a"
37+
FULL_TPU_NAME = f"projects/{PROJECT_ID}/locations/{ZONE}/nodes/{TPU_NAME}"
3438
TPU_TYPE = "v5litepod-1"
3539
TPU_VERSION = "tpu-vm-tf-2.17.0-pjrt"
40+
METADATA = {
41+
"startup-script": """#!/bin/bash
42+
echo "Hello World" > /var/log/hello.log
43+
sudo pip3 install --upgrade numpy >> /var/log/hello.log 2>&1
44+
"""
45+
}
46+
47+
48+
@pytest.fixture
49+
def mock_tpu_client() -> MagicMock:
50+
with patch("google.cloud.tpu_v2.TpuClient") as mock_client:
51+
yield mock_client.return_value
52+
53+
54+
@pytest.fixture
55+
def operation() -> MagicMock:
56+
yield MagicMock()
57+
58+
59+
def test_creating_tpu(mock_tpu_client: MagicMock, operation: MagicMock) -> None:
60+
mock_response = MagicMock(spec=Node)
61+
mock_response.state = Node.State.READY
62+
mock_response.name = FULL_TPU_NAME
63+
64+
mock_tpu_client.create_node.return_value = operation
65+
operation.result.return_value = mock_response
66+
67+
tpu = create_tpu.create_cloud_tpu(PROJECT_ID, ZONE, TPU_NAME, TPU_TYPE, TPU_VERSION)
68+
69+
assert tpu.name == FULL_TPU_NAME
70+
assert tpu.state == Node.State.READY
71+
mock_tpu_client.create_node.assert_called_once()
72+
operation.result.assert_called_once()
73+
74+
75+
def test_delete_tpu(mock_tpu_client: MagicMock) -> None:
76+
delete_tpu.delete_cloud_tpu(PROJECT_ID, ZONE, TPU_NAME)
77+
mock_tpu_client.delete_node.called_once()
3678

3779

38-
# Instance of TPU
39-
@pytest.fixture(scope="session")
40-
def tpu_instance() -> Node:
41-
yield create_tpu.create_cloud_tpu(PROJECT_ID, ZONE, TPU_NAME, TPU_TYPE, TPU_VERSION)
42-
try:
43-
delete_tpu.delete_cloud_tpu(PROJECT_ID, ZONE, TPU_NAME)
44-
except Exception as e:
45-
print(f"Error during cleanup: {e}")
80+
def test_creating_with_startup_script(
81+
mock_tpu_client: MagicMock, operation: MagicMock
82+
) -> None:
83+
mock_response = MagicMock(spec=Node)
84+
mock_response.metadata = METADATA
85+
mock_tpu_client.create_node.return_value = operation
86+
operation.result.return_value = mock_response
4687

88+
tpu_with_script = create_tpu_with_script.create_cloud_tpu_with_script(
89+
PROJECT_ID, ZONE, TPU_NAME, TPU_TYPE, TPU_VERSION
90+
)
4791

48-
def test_creating_tpu(tpu_instance: Node) -> None:
49-
assert tpu_instance.state == Node.State.READY
92+
mock_tpu_client.create_node.assert_called_once()
93+
operation.result.assert_called_once()
94+
assert "--upgrade numpy" in tpu_with_script.metadata["startup-script"]
5095

5196

52-
def test_creating_with_startup_script() -> None:
53-
tpu_name_with_script = "script-tpu-" + uuid.uuid4().hex[:5]
54-
try:
55-
tpu_with_script = create_tpu_with_script.create_cloud_tpu_with_script(
56-
PROJECT_ID, ZONE, tpu_name_with_script, TPU_TYPE, TPU_VERSION
57-
)
58-
assert "--upgrade numpy" in tpu_with_script.metadata["startup-script"]
59-
finally:
60-
print(f"\n\n ------------ Deleting TPU {TPU_NAME}\n ------------")
61-
delete_tpu.delete_cloud_tpu(PROJECT_ID, ZONE, tpu_name_with_script)
97+
def test_get_tpu(mock_tpu_client: MagicMock) -> None:
98+
mock_response = MagicMock(spec=Node)
99+
mock_response.name = FULL_TPU_NAME
100+
mock_response.state = Node.State.READY
62101

102+
mock_tpu_client.get_node.return_value = mock_response
63103

64-
def test_get_tpu() -> None:
65104
tpu = get_tpu.get_cloud_tpu(PROJECT_ID, ZONE, TPU_NAME)
105+
66106
assert tpu.state == Node.State.READY
67-
assert tpu.name == f"projects/{PROJECT_ID}/locations/{ZONE}/nodes/{TPU_NAME}"
107+
assert tpu.name == FULL_TPU_NAME
108+
mock_tpu_client.get_node.assert_called_once()
109+
110+
111+
def test_list_tpu(mock_tpu_client: MagicMock) -> None:
112+
mock_pager = MagicMock(spec=ListNodesPager)
113+
nodes = [
114+
Node(name="Node1", state=Node.State.READY),
115+
Node(name="Node2", state=Node.State.CREATING),
116+
]
117+
mock_pager.__iter__.return_value = nodes
68118

119+
mock_tpu_client.list_nodes.return_value = mock_pager
69120

70-
def test_list_tpu() -> None:
71121
nodes = list_tpu.list_cloud_tpu(PROJECT_ID, ZONE)
72122
assert len(list(nodes)) > 0
123+
mock_tpu_client.list_nodes.assert_called_once()
73124

74125

75-
def test_stop_tpu() -> None:
126+
def test_stop_tpu(mock_tpu_client: MagicMock, operation: MagicMock) -> None:
127+
mock_response = MagicMock(spec=Node)
128+
mock_response.state = Node.State.STOPPED
129+
130+
mock_tpu_client.stop_node.return_value = operation
131+
operation.result.return_value = mock_response
132+
76133
node = stop_tpu.stop_cloud_tpu(PROJECT_ID, ZONE, TPU_NAME)
134+
135+
mock_tpu_client.stop_node.assert_called_once()
136+
operation.result.assert_called_once()
77137
assert node.state == Node.State.STOPPED
78138

79139

80-
def test_start_tpu() -> None:
140+
def test_start_tpu(mock_tpu_client: MagicMock, operation: MagicMock) -> None:
141+
mock_response = MagicMock(spec=Node)
142+
mock_response.state = Node.State.READY
143+
144+
mock_tpu_client.start_node.return_value = operation
145+
operation.result.return_value = mock_response
146+
81147
node = start_tpu.start_cloud_tpu(PROJECT_ID, ZONE, TPU_NAME)
148+
149+
mock_tpu_client.start_node.assert_called_once()
150+
operation.result.assert_called_once()
82151
assert node.state == Node.State.READY
83152

84153

85-
def test_with_topology() -> None:
86-
topology_tpu_name = "topology-tpu-" + uuid.uuid4().hex[:5]
87-
topology_zone = "us-central1-a"
88-
try:
89-
topology_tpu = create_tpu_topology.create_cloud_tpu_with_topology(
90-
PROJECT_ID, topology_zone, topology_tpu_name, TPU_VERSION
91-
)
92-
assert topology_tpu.accelerator_config.type_ == AcceleratorConfig.Type.V3
93-
assert topology_tpu.accelerator_config.topology == "2x2"
94-
finally:
95-
delete_tpu.delete_cloud_tpu(PROJECT_ID, topology_zone, topology_tpu_name)
154+
def test_with_topology(mock_tpu_client: MagicMock, operation: MagicMock) -> None:
155+
from google.cloud import tpu_v2
156+
157+
mock_response = MagicMock(spec=Node)
158+
mock_response.accelerator_config = tpu_v2.AcceleratorConfig(
159+
type_=tpu_v2.AcceleratorConfig.Type.V3,
160+
topology="2x2",
161+
)
162+
163+
mock_tpu_client.create_node.return_value = operation
164+
operation.result.return_value = mock_response
165+
166+
topology_tpu = create_tpu_topology.create_cloud_tpu_with_topology(
167+
PROJECT_ID, ZONE, TPU_NAME, TPU_VERSION
168+
)
169+
assert topology_tpu.accelerator_config.type_ == AcceleratorConfig.Type.V3
170+
assert topology_tpu.accelerator_config.topology == "2x2"
171+
mock_tpu_client.create_node.assert_called_once()
172+
operation.result.assert_called_once()

0 commit comments

Comments
 (0)