Skip to content

Commit a5d0ca1

Browse files
author
Joanna Grycz
committed
Use mocked TPUClient
1 parent b575bdb commit a5d0ca1

File tree

6 files changed

+125
-174
lines changed

6 files changed

+125
-174
lines changed

tpu/createStartupScriptVM.js

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -16,19 +16,20 @@
1616

1717
'use strict';
1818

19-
async function main(nodeName, zone, tpuType, tpuSoftwareVersion) {
19+
async function main(tpuClient) {
2020
// [START tpu_vm_create_startup_script]
21-
// Import the TPU library
22-
const {TpuClient} = require('@google-cloud/tpu').v2;
21+
// Import the TPUClient
22+
// TODO(developer): Uncomment below line before running the sample.
23+
// const {TpuClient} = require('@google-cloud/tpu').v2;
24+
2325
const {Node, NetworkConfig} =
2426
require('@google-cloud/tpu').protos.google.cloud.tpu.v2;
2527

2628
// Instantiate a tpuClient
27-
const tpuClient = new TpuClient();
29+
// TODO(developer): Uncomment below line before running the sample.
30+
// tpuClient = new TpuClient();
2831

29-
/**
30-
* TODO(developer): Update/uncomment these variables before running the sample.
31-
*/
32+
// TODO(developer): Update these variables before running the sample.
3233
// Project ID or project number of the Google Cloud project you want to create a node.
3334
const projectId = await tpuClient.getProjectId();
3435

@@ -39,21 +40,21 @@ async function main(nodeName, zone, tpuType, tpuSoftwareVersion) {
3940
const region = 'europe-west4';
4041

4142
// The name for your TPU.
42-
// nodeName = 'node-name-1';
43+
const nodeName = 'node-name-1';
4344

4445
// The zone in which to create the TPU.
4546
// For more information about supported TPU types for specific zones,
4647
// see https://cloud.google.com/tpu/docs/regions-zones
47-
// zone = 'europe-west4-a';
48+
const zone = 'europe-west4-a';
4849

4950
// The accelerator type that specifies the version and size of the Cloud TPU you want to create.
5051
// For more information about supported accelerator types for each TPU version,
5152
// see https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#versions.
52-
// tpuType = 'v2-8';
53+
const tpuType = 'v2-8';
5354

5455
// Software version that specifies the version of the TPU runtime to install. For more information,
5556
// see https://cloud.google.com/tpu/docs/runtimes
56-
// tpuSoftwareVersion = 'tpu-vm-tf-2.17.0-pod-pjrt';
57+
const tpuSoftwareVersion = 'tpu-vm-tf-2.17.0-pod-pjrt';
5758

5859
async function callCreateTpuVMStartupScript() {
5960
// Create a node
@@ -85,12 +86,16 @@ async function main(nodeName, zone, tpuType, tpuSoftwareVersion) {
8586
const [response] = await operation.promise();
8687

8788
console.log(JSON.stringify(response));
89+
return response;
8890
}
89-
await callCreateTpuVMStartupScript();
91+
return await callCreateTpuVMStartupScript();
9092
// [END tpu_vm_create_startup_script]
9193
}
9294

93-
main(...process.argv.slice(2)).catch(err => {
94-
console.error(err);
95-
process.exitCode = 1;
96-
});
95+
module.exports = main;
96+
97+
// TODO(developer): Uncomment below lines before running the sample.
98+
// main(...process.argv.slice(2)).catch(err => {
99+
// console.error(err);
100+
// process.exitCode = 1;
101+
// });

tpu/createTopologyVM.js

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,21 @@
1616

1717
'use strict';
1818

19-
async function main(nodeName, zone, tpuSoftwareVersion) {
19+
async function main(tpuClient) {
2020
// [START tpu_vm_create_topology]
21-
// Import the TPU library
22-
const {TpuClient} = require('@google-cloud/tpu').v2;
21+
// Import the TPUClient
22+
// TODO(developer): Uncomment below line before running the sample.
23+
// const {TpuClient} = require('@google-cloud/tpu').v2;
24+
2325
const {Node, NetworkConfig, AcceleratorConfig} =
2426
require('@google-cloud/tpu').protos.google.cloud.tpu.v2;
2527

2628
// Instantiate a tpuClient
27-
const tpuClient = new TpuClient();
29+
// TODO(developer): Uncomment below line before running the sample.
30+
// tpuClient = new TpuClient();
2831

2932
/**
30-
* TODO(developer): Update/uncomment these variables before running the sample.
33+
* TODO(developer): Update these variables before running the sample.
3134
*/
3235
// Project ID or project number of the Google Cloud project you want to create a node.
3336
const projectId = await tpuClient.getProjectId();
@@ -39,16 +42,16 @@ async function main(nodeName, zone, tpuSoftwareVersion) {
3942
const region = 'europe-west4';
4043

4144
// The name for your TPU.
42-
// nodeName = 'node-name-1';
45+
const nodeName = 'node-name-1';
4346

4447
// The zone in which to create the TPU.
4548
// For more information about supported TPU types for specific zones,
4649
// see https://cloud.google.com/tpu/docs/regions-zones
47-
// zone = 'europe-west4-a';
50+
const zone = 'europe-west4-a';
4851

4952
// Software version that specifies the version of the TPU runtime to install. For more information,
5053
// see https://cloud.google.com/tpu/docs/runtimes
51-
// tpuSoftwareVersion = 'tpu-vm-tf-2.17.0-pod-pjrt';
54+
const tpuSoftwareVersion = 'tpu-vm-tf-2.17.0-pod-pjrt';
5255

5356
// The version of the Cloud TPU you want to create.
5457
// Available options: TYPE_UNSPECIFIED = 0, V2 = 2, V3 = 4, V4 = 7
@@ -74,7 +77,7 @@ async function main(nodeName, zone, tpuSoftwareVersion) {
7477
}),
7578
acceleratorConfig: new AcceleratorConfig({
7679
type: tpuVersion,
77-
topology: topology,
80+
topology,
7881
}),
7982
});
8083

@@ -87,12 +90,16 @@ async function main(nodeName, zone, tpuSoftwareVersion) {
8790
const [response] = await operation.promise();
8891

8992
console.log(JSON.stringify(response));
93+
return response;
9094
}
91-
await callCreateTpuVMTopology();
95+
return await callCreateTpuVMTopology();
9296
// [END tpu_vm_create_topology]
9397
}
9498

95-
main(...process.argv.slice(2)).catch(err => {
96-
console.error(err);
97-
process.exitCode = 1;
98-
});
99+
module.exports = main;
100+
101+
// TODO(developer): Uncomment below lines before running the sample.
102+
// main(...process.argv.slice(2)).catch(err => {
103+
// console.error(err);
104+
// process.exitCode = 1;
105+
// });

tpu/package.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
"test": "c8 mocha -p -j 2 test --timeout 1200000"
1515
},
1616
"dependencies": {
17-
"@google-cloud/tpu": "^3.5.0"
17+
"@google-cloud/tpu": "^3.5.0",
18+
"sinon": "^19.0.2"
1819
},
1920
"devDependencies": {
2021
"c8": "^10.0.0",

tpu/test/createStartupScriptVM.test.js

Lines changed: 42 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -16,43 +16,54 @@
1616

1717
'use strict';
1818

19-
const path = require('path');
2019
const assert = require('node:assert/strict');
21-
const {after, describe, it} = require('mocha');
22-
const cp = require('child_process');
23-
const {getStaleNodes, deleteNode} = require('./util');
24-
25-
const execSync = cmd => cp.execSync(cmd, {encoding: 'utf-8'});
26-
const cwd = path.join(__dirname, '..');
20+
const {beforeEach, afterEach, describe, it} = require('mocha');
21+
const sinon = require('sinon');
22+
const createStartupScriptVM = require('../createStartupScriptVM.js');
2723

2824
describe('Compute tpu', async () => {
29-
const nodePrefix = 'node-name-startup-script-2a2b3c';
30-
const nodeName = `${nodePrefix}${Math.floor(Math.random() * 1000 + 1)}`;
31-
const zone = 'us-east1-d';
32-
const tpuType = 'v3-32';
33-
const tpuSoftwareVersion = 'tpu-vm-base';
34-
35-
after(async () => {
36-
// Clean-up resources
37-
const nodes = await getStaleNodes(nodePrefix);
38-
await Promise.all(nodes.map(node => deleteNode(node.zone, node.nodeName)));
39-
});
25+
const nodeName = 'node-name-1';
26+
const zone = 'europe-west4-a';
27+
const projectId = 'project_id';
28+
let tpuClientMock;
4029

41-
it('should create a new tpu with startup script', () => {
42-
const metadata = {
43-
'startup-script':
44-
'#!/bin/bash\n echo "Hello World" > /var/log/hello.log\n sudo pip3 install --upgrade numpy >> /var/log/hello.log 2>&1',
30+
beforeEach(() => {
31+
tpuClientMock = {
32+
getProjectId: sinon.stub().resolves(projectId),
4533
};
34+
});
4635

47-
const response = JSON.parse(
48-
execSync(
49-
`node ./createStartupScriptVM.js ${nodeName} ${zone} ${tpuType} ${tpuSoftwareVersion}`,
50-
{
51-
cwd,
52-
}
53-
)
54-
);
36+
afterEach(() => {
37+
sinon.restore();
38+
});
39+
40+
it('should create a new tpu with startup script', async () => {
41+
tpuClientMock.createNode = sinon.stub().resolves([
42+
{
43+
promise: sinon.stub().resolves([
44+
{
45+
name: nodeName,
46+
},
47+
]),
48+
},
49+
]);
5550

56-
assert.deepEqual(response.metadata, metadata);
51+
const response = await createStartupScriptVM(tpuClientMock);
52+
53+
sinon.assert.calledWith(
54+
tpuClientMock.createNode,
55+
sinon.match({
56+
parent: `projects/${projectId}/locations/${zone}`,
57+
node: {
58+
name: nodeName,
59+
metadata: {
60+
'startup-script':
61+
'#!/bin/bash\n echo "Hello World" > /var/log/hello.log\n sudo pip3 install --upgrade numpy >> /var/log/hello.log 2>&1',
62+
},
63+
},
64+
nodeId: nodeName,
65+
})
66+
);
67+
assert(response.name.includes(nodeName));
5768
});
5869
});

tpu/test/createTopologyVM.test.js

Lines changed: 39 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -16,39 +16,51 @@
1616

1717
'use strict';
1818

19-
const path = require('path');
2019
const assert = require('node:assert/strict');
21-
const {after, describe, it} = require('mocha');
22-
const cp = require('child_process');
23-
const {getStaleNodes, deleteNode} = require('./util');
24-
25-
const execSync = cmd => cp.execSync(cmd, {encoding: 'utf-8'});
26-
const cwd = path.join(__dirname, '..');
20+
const {beforeEach, afterEach, describe, it} = require('mocha');
21+
const sinon = require('sinon');
22+
const createTopologyVM = require('../createTopologyVM.js');
2723

2824
describe('Compute tpu with topology', async () => {
29-
const nodePrefix = 'topology-node-name-2a2b3c';
30-
const nodeName = `${nodePrefix}${Math.floor(Math.random() * 1000 + 1)}`;
31-
const zone = 'asia-east1-c';
32-
const tpuSoftwareVersion = 'tpu-vm-tf-2.12.1';
33-
34-
after(async () => {
35-
// Clean-up resources
36-
const nodes = await getStaleNodes(nodePrefix);
37-
await Promise.all(nodes.map(node => deleteNode(node.zone, node.nodeName)));
25+
const nodeName = 'node-name-1';
26+
const zone = 'europe-west4-a';
27+
const projectId = 'project_id';
28+
let tpuClientMock;
29+
30+
beforeEach(() => {
31+
tpuClientMock = {
32+
getProjectId: sinon.stub().resolves(projectId),
33+
};
34+
});
35+
36+
afterEach(() => {
37+
sinon.restore();
3838
});
3939

40-
it('should create a new tpu with topology', () => {
41-
const acceleratorConfig = {type: 'V2', topology: '2x2'};
40+
it('should create a new tpu with topology', async () => {
41+
tpuClientMock.createNode = sinon.stub().resolves([
42+
{
43+
promise: sinon.stub().resolves([
44+
{
45+
name: nodeName,
46+
},
47+
]),
48+
},
49+
]);
4250

43-
const response = JSON.parse(
44-
execSync(
45-
`node ./createTopologyVM.js ${nodeName} ${zone} ${tpuSoftwareVersion}`,
46-
{
47-
cwd,
48-
}
49-
)
50-
);
51+
const response = await createTopologyVM(tpuClientMock);
5152

52-
assert.deepEqual(response.acceleratorConfig, acceleratorConfig);
53+
sinon.assert.calledWith(
54+
tpuClientMock.createNode,
55+
sinon.match({
56+
parent: `projects/${projectId}/locations/${zone}`,
57+
node: {
58+
name: nodeName,
59+
acceleratorConfig: {type: 2, topology: '2x2'},
60+
},
61+
nodeId: nodeName,
62+
})
63+
);
64+
assert(response.name.includes(nodeName));
5365
});
5466
});

0 commit comments

Comments
 (0)