Skip to content

Commit 68c11ef

Browse files
authored
feat: add tpu_vm_create_topology/startup_script (#3902)
* feat: add tpu_vm_create_topology * feat: tpu_vm_create_startup_script * Use mocked TPUClient
1 parent 87b18af commit 68c11ef

File tree

4 files changed

+341
-0
lines changed

4 files changed

+341
-0
lines changed

tpu/createStartupScriptVM.js

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
'use strict';
18+
19+
async function main(tpuClient) {
20+
// [START tpu_vm_create_startup_script]
21+
// Import the TPUClient
22+
// TODO(developer): Uncomment below line before running the sample.
23+
// const {TpuClient} = require('@google-cloud/tpu').v2;
24+
25+
const {Node, NetworkConfig} =
26+
require('@google-cloud/tpu').protos.google.cloud.tpu.v2;
27+
28+
// Instantiate a tpuClient
29+
// TODO(developer): Uncomment below line before running the sample.
30+
// tpuClient = new TpuClient();
31+
32+
// TODO(developer): Update these variables before running the sample.
33+
// Project ID or project number of the Google Cloud project you want to create a node.
34+
const projectId = await tpuClient.getProjectId();
35+
36+
// The name of the network you want the TPU node to connect to. The network should be assigned to your project.
37+
const networkName = 'compute-tpu-network';
38+
39+
// The region of the network, that you want the TPU node to connect to.
40+
const region = 'europe-west4';
41+
42+
// The name for your TPU.
43+
const nodeName = 'node-name-1';
44+
45+
// The zone in which to create the TPU.
46+
// For more information about supported TPU types for specific zones,
47+
// see https://cloud.google.com/tpu/docs/regions-zones
48+
const zone = 'europe-west4-a';
49+
50+
// The accelerator type that specifies the version and size of the Cloud TPU you want to create.
51+
// For more information about supported accelerator types for each TPU version,
52+
// see https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#versions.
53+
const tpuType = 'v2-8';
54+
55+
// Software version that specifies the version of the TPU runtime to install. For more information,
56+
// see https://cloud.google.com/tpu/docs/runtimes
57+
const tpuSoftwareVersion = 'tpu-vm-tf-2.17.0-pod-pjrt';
58+
59+
async function callCreateTpuVMStartupScript() {
60+
// Create a node
61+
const node = new Node({
62+
name: nodeName,
63+
zone,
64+
acceleratorType: tpuType,
65+
runtimeVersion: tpuSoftwareVersion,
66+
// Define network
67+
networkConfig: new NetworkConfig({
68+
enableExternalIps: true,
69+
network: `projects/${projectId}/global/networks/${networkName}`,
70+
subnetwork: `projects/${projectId}/regions/${region}/subnetworks/${networkName}`,
71+
}),
72+
metadata: {
73+
// The script updates numpy to the latest version and logs the output to a file.
74+
'startup-script': `#!/bin/bash
75+
echo "Hello World" > /var/log/hello.log
76+
sudo pip3 install --upgrade numpy >> /var/log/hello.log 2>&1`,
77+
},
78+
});
79+
80+
const parent = `projects/${projectId}/locations/${zone}`;
81+
const request = {parent, node, nodeId: nodeName};
82+
83+
const [operation] = await tpuClient.createNode(request);
84+
85+
// Wait for the create operation to complete.
86+
const [response] = await operation.promise();
87+
88+
console.log(JSON.stringify(response));
89+
return response;
90+
}
91+
return await callCreateTpuVMStartupScript();
92+
// [END tpu_vm_create_startup_script]
93+
}
94+
95+
module.exports = main;
96+
97+
// TODO(developer): Uncomment below lines before running the sample.
98+
// main(...process.argv.slice(2)).catch(err => {
99+
// console.error(err);
100+
// process.exitCode = 1;
101+
// });

tpu/createTopologyVM.js

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
'use strict';
18+
19+
async function main(tpuClient) {
20+
// [START tpu_vm_create_topology]
21+
// Import the TPUClient
22+
// TODO(developer): Uncomment below line before running the sample.
23+
// const {TpuClient} = require('@google-cloud/tpu').v2;
24+
25+
const {Node, NetworkConfig, AcceleratorConfig} =
26+
require('@google-cloud/tpu').protos.google.cloud.tpu.v2;
27+
28+
// Instantiate a tpuClient
29+
// TODO(developer): Uncomment below line before running the sample.
30+
// tpuClient = new TpuClient();
31+
32+
/**
33+
* TODO(developer): Update these variables before running the sample.
34+
*/
35+
// Project ID or project number of the Google Cloud project you want to create a node.
36+
const projectId = await tpuClient.getProjectId();
37+
38+
// The name of the network you want the TPU node to connect to. The network should be assigned to your project.
39+
const networkName = 'compute-tpu-network';
40+
41+
// The region of the network, that you want the TPU node to connect to.
42+
const region = 'europe-west4';
43+
44+
// The name for your TPU.
45+
const nodeName = 'node-name-1';
46+
47+
// The zone in which to create the TPU.
48+
// For more information about supported TPU types for specific zones,
49+
// see https://cloud.google.com/tpu/docs/regions-zones
50+
const zone = 'europe-west4-a';
51+
52+
// Software version that specifies the version of the TPU runtime to install. For more information,
53+
// see https://cloud.google.com/tpu/docs/runtimes
54+
const tpuSoftwareVersion = 'tpu-vm-tf-2.17.0-pod-pjrt';
55+
56+
// The version of the Cloud TPU you want to create.
57+
// Available options: TYPE_UNSPECIFIED = 0, V2 = 2, V3 = 4, V4 = 7
58+
const tpuVersion = AcceleratorConfig.Type.V2;
59+
60+
// The physical topology of your TPU slice.
61+
// For more information about topology for each TPU version,
62+
// see https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#versions.
63+
const topology = '2x2';
64+
65+
async function callCreateTpuVMTopology() {
66+
// Create a node
67+
const node = new Node({
68+
name: nodeName,
69+
zone,
70+
// acceleratorType: tpuType,
71+
runtimeVersion: tpuSoftwareVersion,
72+
// Define network
73+
networkConfig: new NetworkConfig({
74+
enableExternalIps: true,
75+
network: `projects/${projectId}/global/networks/${networkName}`,
76+
subnetwork: `projects/${projectId}/regions/${region}/subnetworks/${networkName}`,
77+
}),
78+
acceleratorConfig: new AcceleratorConfig({
79+
type: tpuVersion,
80+
topology,
81+
}),
82+
});
83+
84+
const parent = `projects/${projectId}/locations/${zone}`;
85+
const request = {parent, node, nodeId: nodeName};
86+
87+
const [operation] = await tpuClient.createNode(request);
88+
89+
// Wait for the create operation to complete.
90+
const [response] = await operation.promise();
91+
92+
console.log(JSON.stringify(response));
93+
return response;
94+
}
95+
return await callCreateTpuVMTopology();
96+
// [END tpu_vm_create_topology]
97+
}
98+
99+
module.exports = main;
100+
101+
// TODO(developer): Uncomment below lines before running the sample.
102+
// main(...process.argv.slice(2)).catch(err => {
103+
// console.error(err);
104+
// process.exitCode = 1;
105+
// });
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
'use strict';
18+
19+
const assert = require('node:assert/strict');
20+
const {beforeEach, afterEach, describe, it} = require('mocha');
21+
const sinon = require('sinon');
22+
const createStartupScriptVM = require('../createStartupScriptVM.js');
23+
24+
describe('Compute tpu', async () => {
25+
const nodeName = 'node-name-1';
26+
const zone = 'europe-west4-a';
27+
const projectId = 'project_id';
28+
let tpuClientMock;
29+
30+
beforeEach(() => {
31+
tpuClientMock = {
32+
getProjectId: sinon.stub().resolves(projectId),
33+
};
34+
});
35+
36+
afterEach(() => {
37+
sinon.restore();
38+
});
39+
40+
it('should create a new tpu with startup script', async () => {
41+
tpuClientMock.createNode = sinon.stub().resolves([
42+
{
43+
promise: sinon.stub().resolves([
44+
{
45+
name: nodeName,
46+
},
47+
]),
48+
},
49+
]);
50+
51+
const response = await createStartupScriptVM(tpuClientMock);
52+
53+
sinon.assert.calledWith(
54+
tpuClientMock.createNode,
55+
sinon.match({
56+
parent: `projects/${projectId}/locations/${zone}`,
57+
node: {
58+
name: nodeName,
59+
metadata: {
60+
'startup-script':
61+
'#!/bin/bash\n echo "Hello World" > /var/log/hello.log\n sudo pip3 install --upgrade numpy >> /var/log/hello.log 2>&1',
62+
},
63+
},
64+
nodeId: nodeName,
65+
})
66+
);
67+
assert(response.name.includes(nodeName));
68+
});
69+
});

tpu/test/createTopologyVM.test.js

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
'use strict';
18+
19+
const assert = require('node:assert/strict');
20+
const {beforeEach, afterEach, describe, it} = require('mocha');
21+
const sinon = require('sinon');
22+
const createTopologyVM = require('../createTopologyVM.js');
23+
24+
describe('Compute tpu with topology', async () => {
25+
const nodeName = 'node-name-1';
26+
const zone = 'europe-west4-a';
27+
const projectId = 'project_id';
28+
let tpuClientMock;
29+
30+
beforeEach(() => {
31+
tpuClientMock = {
32+
getProjectId: sinon.stub().resolves(projectId),
33+
};
34+
});
35+
36+
afterEach(() => {
37+
sinon.restore();
38+
});
39+
40+
it('should create a new tpu with topology', async () => {
41+
tpuClientMock.createNode = sinon.stub().resolves([
42+
{
43+
promise: sinon.stub().resolves([
44+
{
45+
name: nodeName,
46+
},
47+
]),
48+
},
49+
]);
50+
51+
const response = await createTopologyVM(tpuClientMock);
52+
53+
sinon.assert.calledWith(
54+
tpuClientMock.createNode,
55+
sinon.match({
56+
parent: `projects/${projectId}/locations/${zone}`,
57+
node: {
58+
name: nodeName,
59+
acceleratorConfig: {type: 2, topology: '2x2'},
60+
},
61+
nodeId: nodeName,
62+
})
63+
);
64+
assert(response.name.includes(nodeName));
65+
});
66+
});

0 commit comments

Comments
 (0)