diff --git a/src/models/generative-model.test.ts b/src/models/generative-model.test.ts index e6c6fd85d..75c382bcb 100644 --- a/src/models/generative-model.test.ts +++ b/src/models/generative-model.test.ts @@ -28,6 +28,7 @@ import { import { getMockResponse } from "../../test-utils/mock-response"; import { match, restore, stub } from "sinon"; import * as request from "../requests/request"; +import { Task } from '../requests/request'; use(sinonChai); @@ -463,3 +464,120 @@ describe("GenerativeModel", () => { restore(); }); }); + +describe('GenerativeModel - generateAnswer()', () => { + const API_KEY = 'test-api-key'; + const MODEL_NAME = 'models/aqa'; + const TEST_REQUEST = { + input: "What's the capital of France?", + sources: [{ + title: "World Facts", + url: "https://example.com", + content: "Paris is the capital of France." + }] + }; + + afterEach(() => { + restore(); + }); + + it('should make correct API request', async () => { + const mockResponse = { + answer: "Paris", + attributedPassages: [{ + text: "Paris is the capital of France.", + source: { + title: "World Facts", + url: "https://example.com", + content: "Paris is the capital of France." + } + }], + confidenceScore: 0.95 + }; + + const makeRequestStub = stub(request, 'makeModelRequest') + .resolves({ + json: () => Promise.resolve(mockResponse) + } as Response); + + const model = new GenerativeModel(API_KEY, { model: MODEL_NAME }); + const result = await model.generateAnswer(TEST_REQUEST); + + expect(makeRequestStub).calledWith( + 'models/aqa', + Task.GENERATE_ANSWER, + API_KEY, + false, + JSON.stringify(TEST_REQUEST), + {} + ); + + expect(result).to.deep.equal({ + answer: "Paris", + attributedPassages: [{ + text: "Paris is the capital of France.", + source: { + title: "World Facts", + url: "https://example.com", + content: "Paris is the capital of France." + } + }], + confidenceScore: 0.95 + }); + }); + + it('should handle API errors', async () => { + const errorResponse = { + error: { + message: "Invalid input format", + code: 400 + } + }; + + stub(request, 'makeModelRequest').resolves({ + json: () => Promise.resolve(errorResponse) + } as Response); + + const model = new GenerativeModel(API_KEY, { model: MODEL_NAME }); + + try { + await model.generateAnswer(TEST_REQUEST); + expect.fail('Should have thrown error'); + } catch (e) { + expect(e.message).to.include('AQA API Error'); + expect(e.response).to.deep.equal(errorResponse); + } + }); + + it('should validate required input field', async () => { + const model = new GenerativeModel(API_KEY, { model: MODEL_NAME }); + + try { + // @ts-expect-error Testing invalid input + await model.generateAnswer({ sources: [] }); + expect.fail('Should have thrown error'); + } catch (e) { + expect(e.message).to.include("must contain 'input' field"); + } + }); + + it('should handle partial response data', async () => { + const partialResponse = { + answer: "Paris", + // Missing attributedPassages and confidenceScore + }; + + stub(request, 'makeModelRequest').resolves({ + json: () => Promise.resolve(partialResponse) + } as Response); + + const model = new GenerativeModel(API_KEY, { model: MODEL_NAME }); + const result = await model.generateAnswer(TEST_REQUEST); + + expect(result).to.deep.equal({ + answer: "Paris", + attributedPassages: [], + confidenceScore: 0 + }); + }); +}); diff --git a/src/models/generative-model.ts b/src/models/generative-model.ts index 7cd3fe622..77af364da 100644 --- a/src/models/generative-model.ts +++ b/src/models/generative-model.ts @@ -28,6 +28,8 @@ import { CountTokensResponse, EmbedContentRequest, EmbedContentResponse, + GenerateAnswerRequest, + GenerateAnswerResponse, GenerateContentRequest, GenerateContentResult, GenerateContentStreamResult, @@ -50,6 +52,9 @@ import { formatGenerateContentInput, formatSystemInstruction, } from "../requests/request-helpers"; +import { Task, makeModelRequest } from "../requests/request"; +import { processAqaResponse } from "../requests/response-helpers"; +import { GoogleGenerativeAIRequestInputError, GoogleGenerativeAIResponseError } from "../errors"; /** * Class for generative model APIs. @@ -175,6 +180,85 @@ export class GenerativeModel { ); } + /** + * Generates an attributed answer based on provided sources and input question. + * + * @public + * + * @param request - The request parameters containing: + * - `input`: Required. The question to be answered + * - `sources`: Array of attributed sources to reference + * - `temperature`: Optional. Controls randomness (0.0-1.0) + * + * @param requestOptions - Optional. Overrides for request configuration + * + * @returns Promise resolving to {@link GenerateAnswerResponse} containing: + * - `answer`: The generated answer text + * - `attributedPassages`: Array of source passages used + * - `confidenceScore`: Model's confidence in the answer (0.0-1.0) + * + * @throws {GoogleGenerativeAIRequestInputError} If invalid request format + * @throws {GoogleGenerativeAIResponseError} If API returns error + * + * @example + * ``` + * const model = genAI.getGenerativeModel({ model: 'models/aqa' }); + * const request = { + * input: "What's the capital of France?", + * sources: [{ + * title: "World Factbook", + * url: "https://example.com/factbook", + * content: "Paris is the administrative and cultural center of France." + * }] + * }; + * + * try { + * const result = await model.generateAnswer(request); + * console.log(result.answer); // "Paris" + * console.log(result.attributedPassages[0].source.title); // "World Factbook" + * } catch (err) { + * console.error(err); + * } + * ``` + */ + async generateAnswer( + request: GenerateAnswerRequest, + requestOptions: SingleRequestOptions = {}, + ): Promise { + const generativeModelRequestOptions: SingleRequestOptions = { + ...this._requestOptions, + ...requestOptions, + }; + + if (!request.input) { + throw new GoogleGenerativeAIRequestInputError( + "GenerateAnswerRequest must contain 'input' field" + ); + } + + try { + const response = await makeModelRequest( + 'models/aqa', // Hardcode AQA model path + Task.GENERATE_ANSWER, + this.apiKey, + false, + JSON.stringify(request), + generativeModelRequestOptions + ); + + const responseJson = await response.json(); + return processAqaResponse(responseJson); + } catch (e) { + if (e instanceof GoogleGenerativeAIResponseError) { + throw new GoogleGenerativeAIResponseError( + `AQA API Error: ${e.message}`, + e.response + ); + } + throw e; + } + } + /** * Counts the tokens in the provided request. * diff --git a/src/requests/request.ts b/src/requests/request.ts index 64c3703f9..649bd34db 100644 --- a/src/requests/request.ts +++ b/src/requests/request.ts @@ -37,6 +37,7 @@ const PACKAGE_LOG_HEADER = "genai-js"; export enum Task { GENERATE_CONTENT = "generateContent", STREAM_GENERATE_CONTENT = "streamGenerateContent", + GENERATE_ANSWER = "generateAnswer", COUNT_TOKENS = "countTokens", EMBED_CONTENT = "embedContent", BATCH_EMBED_CONTENTS = "batchEmbedContents", diff --git a/src/requests/response-helpers.ts b/src/requests/response-helpers.ts index 2554d001a..4524aac09 100644 --- a/src/requests/response-helpers.ts +++ b/src/requests/response-helpers.ts @@ -16,9 +16,11 @@ */ import { + AttributedPassage, EnhancedGenerateContentResponse, FinishReason, FunctionCall, + GenerateAnswerResponse, GenerateContentCandidate, GenerateContentResponse, } from "../../types"; @@ -206,3 +208,25 @@ export function formatBlockErrorMessage( } return message; } + +export function processAqaResponse(response: GenerateAnswerResponse): GenerateAnswerResponse { + if (!response.answer) { + throw new GoogleGenerativeAIResponseError( + "Invalid AQA response format - missing 'answer' field", + response + ); + } + + return { + answer: response.answer, + attributedPassages: (response.attributedPassages || []).map((p: AttributedPassage) => ({ + text: p.text, + source: { + title: p.source?.title || "Unknown", + url: p.source?.url || "", + content: p.source?.content || "" + } + })), + confidenceScore: response.confidenceScore ?? 0 + }; +} \ No newline at end of file diff --git a/types/requests.ts b/types/requests.ts index 81285bc20..a66378e99 100644 --- a/types/requests.ts +++ b/types/requests.ts @@ -245,3 +245,30 @@ export interface CodeExecutionTool { */ codeExecution: {}; } + +/** + * Request sent to `generateAnswer` endpoint. + * @public + */ +export interface GenerateAnswerRequest { + input: string; + sources?: AttributedSource[]; + temperature?: number; +} + +export interface AttributedSource { + title: string; + url: string; + content: string; +} + +export interface GenerateAnswerResponse { + answer: string; + attributedPassages: AttributedPassage[]; + confidenceScore: number; +} + +export interface AttributedPassage { + text: string; + source: AttributedSource; +} \ No newline at end of file