Skip to content

Commit a54fcc8

Browse files
Created separated test for CreateTpuVm
1 parent 6390ada commit a54fcc8

File tree

2 files changed

+65
-24
lines changed

2 files changed

+65
-24
lines changed
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package tpu;
18+
19+
import static org.mockito.Mockito.any;
20+
import static org.mockito.Mockito.mock;
21+
import static org.mockito.Mockito.times;
22+
import static org.mockito.Mockito.verify;
23+
import static org.mockito.Mockito.when;
24+
25+
import com.google.api.gax.longrunning.OperationFuture;
26+
import com.google.cloud.tpu.v2.CreateNodeRequest;
27+
import com.google.cloud.tpu.v2.TpuClient;
28+
import com.google.cloud.tpu.v2.TpuSettings;
29+
import java.util.concurrent.TimeUnit;
30+
import org.junit.Test;
31+
import org.junit.jupiter.api.Timeout;
32+
import org.junit.runner.RunWith;
33+
import org.junit.runners.JUnit4;
34+
import org.mockito.MockedStatic;
35+
import org.mockito.Mockito;
36+
37+
@RunWith(JUnit4.class)
38+
@Timeout(value = 3, unit = TimeUnit.MINUTES)
39+
public class TpuCreateIT {
40+
private static final String PROJECT_ID = "project-id";
41+
private static final String ZONE = "asia-east1-c";
42+
private static final String NODE_NAME = "test-tpu";
43+
private static final String TPU_TYPE = "v2-8";
44+
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.12.1";
45+
46+
@Test
47+
public void testCreateTpuVm() throws Exception {
48+
TpuClient mockTpuClient = mock(TpuClient.class);
49+
try (MockedStatic<TpuClient> mockedTpuClient = Mockito.mockStatic(TpuClient.class)) {
50+
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
51+
.thenReturn(mockTpuClient);
52+
53+
OperationFuture mockFuture = mock(OperationFuture.class);
54+
when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class)))
55+
.thenReturn(mockFuture);
56+
CreateTpuVm.createTpuVm(
57+
PROJECT_ID, ZONE, NODE_NAME,
58+
TPU_TYPE, TPU_SOFTWARE_VERSION);
59+
60+
verify(mockTpuClient, times(1))
61+
.createNodeAsync(any(CreateNodeRequest.class));
62+
verify(mockTpuClient, times(1)).close();
63+
}
64+
}
65+
}

tpu/src/test/java/tpu/TpuVmIT.java

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
import static org.mockito.Mockito.when;
2424

2525
import com.google.api.gax.longrunning.OperationFuture;
26-
import com.google.cloud.tpu.v2.CreateNodeRequest;
2726
import com.google.cloud.tpu.v2.DeleteNodeRequest;
2827
import com.google.cloud.tpu.v2.Node;
2928
import com.google.cloud.tpu.v2.NodeName;
@@ -32,10 +31,7 @@
3231
import java.io.IOException;
3332
import java.util.concurrent.ExecutionException;
3433
import java.util.concurrent.TimeUnit;
35-
import org.junit.jupiter.api.MethodOrderer;
36-
import org.junit.jupiter.api.Order;
3734
import org.junit.jupiter.api.Test;
38-
import org.junit.jupiter.api.TestMethodOrder;
3935
import org.junit.jupiter.api.Timeout;
4036
import org.junit.runner.RunWith;
4137
import org.mockito.MockedStatic;
@@ -44,30 +40,10 @@
4440

4541
@RunWith(PowerMockRunner.class)
4642
@Timeout(value = 3, unit = TimeUnit.MINUTES)
47-
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
4843
public class TpuVmIT {
4944
private static final String PROJECT_ID = "project-id";
5045
private static final String ZONE = "asia-east1-c";
5146
private static final String NODE_NAME = "test-tpu";
52-
private static final String TPU_TYPE = "v2-8";
53-
private static final String TPU_SOFTWARE_VERSION = "tpu-vm-tf-2.12.1";
54-
55-
@Test
56-
@Order(1)
57-
public void testCreateTpuVm() throws IOException, ExecutionException, InterruptedException {
58-
TpuClient mockTpuClient = mock(TpuClient.class);
59-
try (MockedStatic<TpuClient> mockedTpuClient = Mockito.mockStatic(TpuClient.class)) {
60-
mockedTpuClient.when(() -> TpuClient.create(any(TpuSettings.class)))
61-
.thenReturn(mockTpuClient);
62-
63-
OperationFuture mockFuture = mock(OperationFuture.class);
64-
when(mockTpuClient.createNodeAsync(any(CreateNodeRequest.class)))
65-
.thenReturn(mockFuture);
66-
CreateTpuVm.createTpuVm(PROJECT_ID, ZONE, NODE_NAME, TPU_TYPE, TPU_SOFTWARE_VERSION);
67-
68-
verify(mockTpuClient, times(1)).createNodeAsync(any(CreateNodeRequest.class));
69-
}
70-
}
7147

7248
@Test
7349
public void testGetTpuVm() throws IOException {

0 commit comments

Comments
 (0)