Skip to content

Commit dc66831

Browse files
Fixed tests
1 parent aa522cf commit dc66831

File tree

5 files changed

+69
-85
lines changed

5 files changed

+69
-85
lines changed

tpu/src/main/java/tpu/CreateQueuedResourceWithStartupScript.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public static void main(String[] args)
3838
// The zone in which to create the TPU.
3939
// For more information about supported TPU types for specific zones,
4040
// see https://cloud.google.com/tpu/docs/regions-zones
41-
String zone = "europe-west4-a";
41+
String zone = "us-central1-a";
4242
// The name for your TPU.
4343
String nodeName = "YOUR_TPU_NAME";
4444
// The accelerator type that specifies the version and size of the Cloud TPU you want to create.

tpu/src/main/java/tpu/DeleteForceQueuedResource.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ public static void deleteForceQueuedResource(
6969
tpuClient.deleteQueuedResourceAsync(request).get();
7070

7171
} catch (UnknownException | InterruptedException | ExecutionException | IOException e) {
72-
System.out.println(e.getMessage());
72+
System.err.printf("Error deleting resource: %s%n", e.getMessage());
73+
System.out.printf("Deleted Queued Resource: %s\n", name);
7374
}
7475
System.out.printf("Deleted Queued Resource: %s\n", name);
7576
}

tpu/src/test/java/tpu/CreateTpuIT.java

Lines changed: 0 additions & 70 deletions
This file was deleted.

tpu/src/test/java/tpu/QueuedResourceIT.java

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,15 @@
3535
import java.io.ByteArrayOutputStream;
3636
import java.io.IOException;
3737
import java.io.PrintStream;
38-
import org.junit.Before;
39-
import org.junit.Test;
38+
import org.junit.jupiter.api.BeforeAll;
39+
import org.junit.jupiter.api.Test;
4040
import org.junit.jupiter.api.Timeout;
4141
import org.junit.runner.RunWith;
4242
import org.junit.runners.JUnit4;
4343
import org.mockito.MockedStatic;
4444

4545
@RunWith(JUnit4.class)
46-
@Timeout(value = 3)
46+
@Timeout(value = 10)
4747
public class QueuedResourceIT {
4848
private static final String PROJECT_ID = "project-id";
4949
private static final String ZONE = "europe-west4-a";
@@ -52,10 +52,10 @@ public class QueuedResourceIT {
5252
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.14.1";
5353
private static final String QUEUED_RESOURCE_NAME = "queued-resource";
5454
private static final String NETWORK_NAME = "default";
55-
private ByteArrayOutputStream bout;
55+
private static ByteArrayOutputStream bout;
5656

57-
@Before
58-
public void setUp() {
57+
@BeforeAll
58+
public static void setUp() {
5959
bout = new ByteArrayOutputStream();
6060
System.setOut(new PrintStream(bout));
6161
}
@@ -75,8 +75,8 @@ public void testCreateQueuedResourceWithSpecifiedNetwork() throws Exception {
7575

7676
QueuedResource returnedQueuedResource =
7777
CreateQueuedResourceWithNetwork.createQueuedResourceWithNetwork(
78-
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
79-
TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME);
78+
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
79+
TPU_TYPE, TPU_SOFTWARE_VERSION, NETWORK_NAME);
8080

8181
verify(mockTpuClient, times(1))
8282
.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class));
@@ -106,7 +106,7 @@ public void testGetQueuedResource() throws IOException {
106106
}
107107

108108
@Test
109-
public void testDeleteTpuVm() {
109+
public void testDeleteForceQueuedResource() {
110110
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
111111
TpuClient mockTpuClient = mock(TpuClient.class);
112112
OperationFuture mockFuture = mock(OperationFuture.class);
@@ -124,4 +124,29 @@ public void testDeleteTpuVm() {
124124
.deleteQueuedResourceAsync(any(DeleteQueuedResourceRequest.class));
125125
}
126126
}
127+
128+
@Test
129+
public void testCreateQueuedResourceWithStartupScript() throws Exception {
130+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
131+
QueuedResource mockQueuedResource = mock(QueuedResource.class);
132+
TpuClient mockTpuClient = mock(TpuClient.class);
133+
OperationFuture mockFuture = mock(OperationFuture.class);
134+
135+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
136+
.thenReturn(mockTpuClient);
137+
when(mockTpuClient.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class)))
138+
.thenReturn(mockFuture);
139+
when(mockFuture.get()).thenReturn(mockQueuedResource);
140+
141+
QueuedResource returnedQueuedResource =
142+
CreateQueuedResourceWithStartupScript.createQueuedResource(
143+
PROJECT_ID, ZONE, QUEUED_RESOURCE_NAME, NODE_NAME,
144+
TPU_TYPE, TPU_SOFTWARE_VERSION);
145+
146+
verify(mockTpuClient, times(1))
147+
.createQueuedResourceAsync(any(CreateQueuedResourceRequest.class));
148+
verify(mockFuture, times(1)).get();
149+
assertEquals(returnedQueuedResource, mockQueuedResource);
150+
}
151+
}
127152
}

tpu/src/test/java/tpu/TpuVmIT.java

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
package tpu;
1818

1919
import static com.google.common.truth.Truth.assertThat;
20+
import static org.junit.Assert.assertEquals;
2021
import static org.mockito.Mockito.any;
2122
import static org.mockito.Mockito.mock;
2223
import static org.mockito.Mockito.mockStatic;
@@ -25,6 +26,7 @@
2526
import static org.mockito.Mockito.when;
2627

2728
import com.google.api.gax.longrunning.OperationFuture;
29+
import com.google.cloud.tpu.v2.CreateNodeRequest;
2830
import com.google.cloud.tpu.v2.DeleteNodeRequest;
2931
import com.google.cloud.tpu.v2.GetNodeRequest;
3032
import com.google.cloud.tpu.v2.Node;
@@ -42,11 +44,13 @@
4244
import org.mockito.MockedStatic;
4345

4446
@RunWith(JUnit4.class)
45-
@Timeout(value = 3)
47+
@Timeout(value = 10)
4648
public class TpuVmIT {
4749
private static final String PROJECT_ID = "project-id";
4850
private static final String ZONE = "asia-east1-c";
4951
private static final String NODE_NAME = "test-tpu";
52+
private static final String TPU_TYPE = "v2-8";
53+
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.12.1";
5054
private static ByteArrayOutputStream bout;
5155

5256
@BeforeAll
@@ -55,21 +59,45 @@ public static void setUp() {
5559
System.setOut(new PrintStream(bout));
5660
}
5761

62+
@Test
63+
public void testCreateTpuVm() throws Exception {
64+
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
65+
Node mockNode = mock(Node.class);
66+
TpuClient mockTpuClient = mock(TpuClient.class);
67+
OperationFuture mockFuture = mock(OperationFuture.class);
68+
69+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
70+
.thenReturn(mockTpuClient);
71+
when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class)))
72+
.thenReturn(mockFuture);
73+
when(mockFuture.get()).thenReturn(mockNode);
74+
75+
Node returnedNode = CreateTpuVm.createTpuVm(
76+
PROJECT_ID, ZONE, NODE_NAME,
77+
TPU_TYPE, TPU_SOFTWARE_VERSION);
78+
79+
verify(mockTpuClient, times(1))
80+
.createNodeAsync(any(CreateNodeRequest.class));
81+
verify(mockFuture, times(1)).get();
82+
assertEquals(returnedNode, mockNode);
83+
}
84+
}
85+
5886
@Test
5987
public void testGetTpuVm() throws IOException {
6088
try (MockedStatic<TpuClient> mockedTpuClient = mockStatic(TpuClient.class)) {
6189
Node mockNode = mock(Node.class);
6290
TpuClient mockClient = mock(TpuClient.class);
63-
GetTpuVm mockGetTpuVm = mock(GetTpuVm.class);
6491

6592
mockedTpuClient.when(TpuClient::create).thenReturn(mockClient);
6693
when(mockClient.getNode(any(GetNodeRequest.class))).thenReturn(mockNode);
6794

6895
Node returnedNode = GetTpuVm.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
6996

70-
verify(mockGetTpuVm, times(1))
71-
.getTpuVm(PROJECT_ID, ZONE, NODE_NAME);
97+
verify(mockClient, times(1))
98+
.getNode(any(GetNodeRequest.class));
7299
assertThat(returnedNode).isEqualTo(mockNode);
100+
verify(mockClient, times(1)).close();
73101
}
74102
}
75103

0 commit comments

Comments
 (0)