Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 21 additions & 11 deletions ai-platform/snippets/predict-text-embeddings.js
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
// [START generativeaionvertexai_sdk_embedding]
async function main(
project,
model = 'text-embedding-005',
model = 'gemini-embedding-001',
texts = 'banana bread?;banana muffins?',
task = 'QUESTION_ANSWERING',
dimensionality = 0,
Expand All @@ -37,19 +37,29 @@ async function main(
const instances = texts
.split(';')
.map(e => helpers.toValue({content: e, task_type: task}));

const client = new PredictionServiceClient(clientOptions);
const parameters = helpers.toValue(
dimensionality > 0 ? {outputDimensionality: parseInt(dimensionality)} : {}
);
const request = {endpoint, instances, parameters};
const client = new PredictionServiceClient(clientOptions);
const [response] = await client.predict(request);
const predictions = response.predictions;
const embeddings = predictions.map(p => {
const embeddingsProto = p.structValue.fields.embeddings;
const valuesProto = embeddingsProto.structValue.fields.values;
return valuesProto.listValue.values.map(v => v.numberValue);
});
console.log('Got embeddings: \n' + JSON.stringify(embeddings));
const allEmbeddings = []
// gemini-embedding-001 takes one input at a time.
for (const instance of instances) {
const request = {endpoint, instances: [instance], parameters};
const [response] = await client.predict(request);
const predictions = response.predictions;

const embeddings = predictions.map(p => {
const embeddingsProto = p.structValue.fields.embeddings;
const valuesProto = embeddingsProto.structValue.fields.values;
return valuesProto.listValue.values.map(v => v.numberValue);
});

allEmbeddings.push(embeddings[0])
}
Comment on lines +47 to +59
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The current implementation iterates through each text instance to get its embedding, making an API call for each. If one of these client.predict() calls fails, the entire callPredict function will reject, and the loop will terminate, preventing subsequent instances from being processed.

Could you clarify if this fail-fast behavior is intended for this snippet? While it might be acceptable, especially for a debugging/testing branch, if more resilience is desired (e.g., to collect embeddings for successful instances even if some fail), we might consider adding a try...catch block inside the for...of loop.

For example, you could adapt the loop like this:

    const allEmbeddings = [];
    // gemini-embedding-001 takes one input at a time.
    for (const instance of instances) {
      try {
        const request = {endpoint, instances: [instance], parameters};
        const [response] = await client.predict(request);
        const predictions = response.predictions;

        // Assuming predictions[0] contains the embedding structure for the single instance
        const embeddingData = predictions[0].structValue.fields.embeddings;
        const valuesProto = embeddingData.structValue.fields.values;
        const singleEmbedding = valuesProto.listValue.values.map(v => v.numberValue);

        allEmbeddings.push(singleEmbedding);
      } catch (error) {
        // Safely access instance content for logging, if available
        const contentValue = instance.structValue?.fields?.content?.stringValue;
        const instanceIdentifier = contentValue ? `"${contentValue}"` : 'current instance';
        console.error(`Failed to get embedding for ${instanceIdentifier}. Error: ${error.message}`);
        // Optional: Push a placeholder (e.g., null) or skip, or collect errors
      }
    }

What are your thoughts on enhancing error handling for individual instances in this context?



console.log('Got embeddings: \n' + JSON.stringify(allEmbeddings));
}

callPredict();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ const texts = [
describe('predict text embeddings', () => {
it('should get text embeddings using the latest model', async () => {
const stdout = execSync(
`node ./predict-text-embeddings.js ${project} text-embedding-004 '${texts.join(';')}' QUESTION_ANSWERING ${dimensionality}`,
`node ./predict-text-embeddings.js ${project} gemini-embedding-001 '${texts.join(';')}' QUESTION_ANSWERING ${dimensionality}`,
{cwd}
);
const embeddings = JSON.parse(stdout.trimEnd().split('\n').at(-1));
Expand Down
Loading