Skip to content

Commit 276b837

Browse files
author
Joanna Grycz
committed
feat: tpu_vm_create_startup_script
1 parent e0c0d8a commit 276b837

File tree

7 files changed

+380
-34
lines changed

7 files changed

+380
-34
lines changed
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
'use strict';
18+
19+
async function main(projectId, inputUri, outputUri, jobName) {
20+
// [START generativeaionvertexai_embedding_batch]
21+
// Imports the aiplatform library
22+
const aiplatformLib = require('@google-cloud/aiplatform');
23+
const aiplatform = aiplatformLib.protos.google.cloud.aiplatform.v1;
24+
25+
/**
26+
* TODO(developer): Uncomment/update these variables before running the sample.
27+
*/
28+
// projectId = 'YOUR_PROJECT_ID';
29+
30+
// Optional: URI of the input dataset.
31+
// Could be a BigQuery table or a Google Cloud Storage file.
32+
// E.g. "gs://[BUCKET]/[DATASET].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
33+
// inputUri =
34+
// 'gs://cloud-samples-data/generative-ai/embeddings/embeddings_input.jsonl';
35+
36+
// Optional: URI where the output will be stored.
37+
// Could be a BigQuery table or a Google Cloud Storage file.
38+
// E.g. "gs://[BUCKET]/[OUTPUT].jsonl" OR "bq://[PROJECT].[DATASET].[TABLE]"
39+
// outputUri = 'gs://your_backet/embedding_batch_output';
40+
41+
// The name of the job
42+
// jobName = `Batch embedding job: ${new Date().getMilliseconds()}`;
43+
44+
const textEmbeddingModel = 'text-embedding-005';
45+
const location = 'us-central1';
46+
47+
// Configure the parent resource
48+
const parent = `projects/${projectId}/locations/${location}`;
49+
const modelName = `projects/${projectId}/locations/${location}/publishers/google/models/${textEmbeddingModel}`;
50+
51+
// Specifies the location of the api endpoint
52+
const clientOptions = {
53+
apiEndpoint: `${location}-aiplatform.googleapis.com`,
54+
};
55+
56+
// Instantiates a client
57+
const jobServiceClient = new aiplatformLib.JobServiceClient(clientOptions);
58+
59+
// Generates embeddings from text using batch processing.
60+
// Read more: https://cloud.google.com/vertex-ai/generative-ai/docs/embeddings/batch-prediction-genai-embeddings
61+
async function callBatchEmbedding() {
62+
const gcsSource = new aiplatform.GcsSource({
63+
uris: [inputUri],
64+
});
65+
66+
const inputConfig = new aiplatform.BatchPredictionJob.InputConfig({
67+
gcsSource,
68+
instancesFormat: 'jsonl',
69+
});
70+
71+
const gcsDestination = new aiplatform.GcsDestination({
72+
outputUriPrefix: outputUri,
73+
});
74+
75+
const outputConfig = new aiplatform.BatchPredictionJob.OutputConfig({
76+
gcsDestination,
77+
predictionsFormat: 'jsonl',
78+
});
79+
80+
const batchPredictionJob = new aiplatform.BatchPredictionJob({
81+
displayName: jobName,
82+
model: modelName,
83+
inputConfig,
84+
outputConfig,
85+
});
86+
87+
const request = {
88+
parent,
89+
batchPredictionJob,
90+
};
91+
92+
// Create batch prediction job request
93+
const [response] = await jobServiceClient.createBatchPredictionJob(request);
94+
95+
console.log('Raw response: ', JSON.stringify(response, null, 2));
96+
}
97+
98+
await callBatchEmbedding();
99+
// [END generativeaionvertexai_embedding_batch]
100+
}
101+
102+
main(...process.argv.slice(2)).catch(err => {
103+
console.error(err.message);
104+
process.exitCode = 1;
105+
});
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
'use strict';
18+
19+
const {assert} = require('chai');
20+
const {after, describe, it} = require('mocha');
21+
const uuid = require('uuid').v4;
22+
const cp = require('child_process');
23+
const {JobServiceClient} = require('@google-cloud/aiplatform');
24+
25+
const execSync = cmd => cp.execSync(cmd, {encoding: 'utf-8'});
26+
27+
describe('Batch embedding', async () => {
28+
const displayName = `batch-embedding-job-${uuid()}`;
29+
const location = 'us-central1';
30+
const inputUri =
31+
'gs://ucaip-samples-test-output/inputs/batch_predict_TCN/tcn_inputs.jsonl';
32+
const outputUri = 'gs://ucaip-samples-test-output/';
33+
const jobServiceClient = new JobServiceClient({
34+
apiEndpoint: `${location}-aiplatform.googleapis.com`,
35+
});
36+
const projectId = process.env.CAIP_PROJECT_ID;
37+
let batchPredictionJobId;
38+
39+
after(async () => {
40+
const name = jobServiceClient.batchPredictionJobPath(
41+
projectId,
42+
location,
43+
batchPredictionJobId
44+
);
45+
46+
const cancelRequest = {
47+
name,
48+
};
49+
50+
jobServiceClient.cancelBatchPredictionJob(cancelRequest).then(() => {
51+
const deleteRequest = {
52+
name,
53+
};
54+
55+
return jobServiceClient.deleteBatchPredictionJob(deleteRequest);
56+
});
57+
});
58+
59+
it('should create job with text prediction', async () => {
60+
const response = execSync(
61+
`node ./create-batch-embedding.js ${projectId} ${inputUri} ${outputUri} ${displayName}`
62+
);
63+
64+
assert.match(response, new RegExp(displayName));
65+
batchPredictionJobId = response
66+
.split('/locations/us-central1/batchPredictionJobs/')[1]
67+
.split('\n')[0];
68+
});
69+
});

tpu/createStartupScriptVM.js

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
'use strict';
18+
19+
async function main(nodeName, zone, tpuType, tpuSoftwareVersion) {
20+
// [START tpu_vm_create_startup_script]
21+
// Import the TPU library
22+
const {TpuClient} = require('@google-cloud/tpu').v2;
23+
const {Node, NetworkConfig} =
24+
require('@google-cloud/tpu').protos.google.cloud.tpu.v2;
25+
26+
// Instantiate a tpuClient
27+
const tpuClient = new TpuClient();
28+
29+
/**
30+
* TODO(developer): Update/uncomment these variables before running the sample.
31+
*/
32+
// Project ID or project number of the Google Cloud project you want to create a node.
33+
const projectId = await tpuClient.getProjectId();
34+
35+
// The name of the network you want the TPU node to connect to. The network should be assigned to your project.
36+
const networkName = 'compute-tpu-network';
37+
38+
// The region of the network, that you want the TPU node to connect to.
39+
const region = 'europe-west4';
40+
41+
// The name for your TPU.
42+
// nodeName = 'node-name-1';
43+
44+
// The zone in which to create the TPU.
45+
// For more information about supported TPU types for specific zones,
46+
// see https://cloud.google.com/tpu/docs/regions-zones
47+
// zone = 'europe-west4-a';
48+
49+
// The accelerator type that specifies the version and size of the Cloud TPU you want to create.
50+
// For more information about supported accelerator types for each TPU version,
51+
// see https://cloud.google.com/tpu/docs/system-architecture-tpu-vm#versions.
52+
// tpuType = 'v2-8';
53+
54+
// Software version that specifies the version of the TPU runtime to install. For more information,
55+
// see https://cloud.google.com/tpu/docs/runtimes
56+
// tpuSoftwareVersion = 'tpu-vm-tf-2.17.0-pod-pjrt';
57+
58+
async function callCreateTpuVMStartupScript() {
59+
// Create a node
60+
const node = new Node({
61+
name: nodeName,
62+
zone,
63+
acceleratorType: tpuType,
64+
runtimeVersion: tpuSoftwareVersion,
65+
// Define network
66+
networkConfig: new NetworkConfig({
67+
enableExternalIps: true,
68+
network: `projects/${projectId}/global/networks/${networkName}`,
69+
subnetwork: `projects/${projectId}/regions/${region}/subnetworks/${networkName}`,
70+
}),
71+
metadata: {
72+
// The script updates numpy to the latest version and logs the output to a file.
73+
'startup-script': `#!/bin/bash
74+
echo "Hello World" > /var/log/hello.log
75+
sudo pip3 install --upgrade numpy >> /var/log/hello.log 2>&1`,
76+
},
77+
});
78+
79+
const parent = `projects/${projectId}/locations/${zone}`;
80+
const request = {parent, node, nodeId: nodeName};
81+
82+
const [operation] = await tpuClient.createNode(request);
83+
84+
// Wait for the create operation to complete.
85+
const [response] = await operation.promise();
86+
87+
console.log(JSON.stringify(response));
88+
}
89+
await callCreateTpuVMStartupScript();
90+
// [END tpu_vm_create_startup_script]
91+
}
92+
93+
main(...process.argv.slice(2)).catch(err => {
94+
console.error(err);
95+
process.exitCode = 1;
96+
});

tpu/vmCreateTopology.js renamed to tpu/createTopologyVM.js

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ async function main(nodeName, zone, tpuSoftwareVersion) {
8787
const [response] = await operation.promise();
8888

8989
console.log(JSON.stringify(response));
90-
console.log(`TPU VM: ${nodeName} created.`);
9190
}
9291
await callCreateTpuVMTopology();
9392
// [END tpu_vm_create_topology]
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright 2024 Google LLC
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
'use strict';
18+
19+
const path = require('path');
20+
const assert = require('node:assert/strict');
21+
const {after, describe, it} = require('mocha');
22+
const cp = require('child_process');
23+
const {getStaleNodes, deleteNode} = require('./util');
24+
25+
const execSync = cmd => cp.execSync(cmd, {encoding: 'utf-8'});
26+
const cwd = path.join(__dirname, '..');
27+
28+
describe('Compute tpu', async () => {
29+
const nodePrefix = 'node-name-startup-script-2a2b3c';
30+
const nodeName = `${nodePrefix}${Math.floor(Math.random() * 1000 + 1)}`;
31+
const zone = 'us-east1-d';
32+
const tpuType = 'v3-32';
33+
const tpuSoftwareVersion = 'tpu-vm-base';
34+
35+
after(async () => {
36+
// Clean-up resources
37+
const nodes = await getStaleNodes(nodePrefix);
38+
await Promise.all(nodes.map(node => deleteNode(node.zone, node.nodeName)));
39+
});
40+
41+
it('should create a new tpu with startup script', () => {
42+
const metadata = {
43+
'startup-script':
44+
'#!/bin/bash\n echo "Hello World" > /var/log/hello.log\n sudo pip3 install --upgrade numpy >> /var/log/hello.log 2>&1',
45+
};
46+
47+
const response = JSON.parse(
48+
execSync(
49+
`node ./createStartupScriptVM.js ${nodeName} ${zone} ${tpuType} ${tpuSoftwareVersion}`,
50+
{
51+
cwd,
52+
}
53+
)
54+
);
55+
56+
assert.deepEqual(response.metadata, metadata);
57+
});
58+
});

tpu/test/vmTopology.test.js renamed to tpu/test/createTopologyVM.test.js

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
const path = require('path');
2020
const assert = require('node:assert/strict');
21-
const {before, after, describe, it} = require('mocha');
21+
const {after, describe, it} = require('mocha');
2222
const cp = require('child_process');
2323
const {getStaleNodes, deleteNode} = require('./util');
2424

@@ -28,27 +28,27 @@ const cwd = path.join(__dirname, '..');
2828
describe('Compute tpu with topology', async () => {
2929
const nodePrefix = 'topology-node-name-2a2b3c';
3030
const nodeName = `${nodePrefix}${Math.floor(Math.random() * 1000 + 1)}`;
31-
const zone = 'europe-west4-a';
32-
const tpuSoftwareVersion = 'tpu-vm-tf-2.17.0-pod-pjrt';
33-
34-
before(async () => {
35-
// Cleanup resources
36-
const nodes = await getStaleNodes(nodePrefix, zone);
37-
await Promise.all(nodes.map(node => deleteNode(zone, node.nodeName)));
38-
});
31+
const zone = 'asia-east1-c';
32+
const tpuSoftwareVersion = 'tpu-vm-tf-2.12.1';
3933

4034
after(async () => {
41-
// Delete node
42-
await deleteNode(zone, nodeName);
35+
// Clean-up resources
36+
const nodes = await getStaleNodes(nodePrefix);
37+
await Promise.all(nodes.map(node => deleteNode(node.zone, node.nodeName)));
4338
});
4439

45-
it('should create a new tpu', () => {
46-
const response = execSync(
47-
`node ./vmCreateTopology.js ${nodeName} ${zone} ${tpuSoftwareVersion}`,
48-
{
49-
cwd,
50-
}
40+
it('should create a new tpu with topology', () => {
41+
const acceleratorConfig = {type: 'V2', topology: '2x2'};
42+
43+
const response = JSON.parse(
44+
execSync(
45+
`node ./createTopologyVM.js ${nodeName} ${zone} ${tpuSoftwareVersion}`,
46+
{
47+
cwd,
48+
}
49+
)
5150
);
52-
assert(response.includes(`TPU VM: ${nodeName} created.`));
51+
52+
assert.deepEqual(response.acceleratorConfig, acceleratorConfig);
5353
});
5454
});

0 commit comments

Comments
 (0)