Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/fast-mangos-chew.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@firebase/vertexai': patch
---

Pass `GenerativeModel`'s `BaseParams` to created chat sessions. This fixes an issue where `GenerationConfig` would not be inherited from `ChatSession`.
38 changes: 33 additions & 5 deletions packages/vertexai/src/models/generative-model.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -165,14 +165,33 @@ describe('GenerativeModel', () => {
);
restore();
});
it('overrides base model params with startChatParams', () => {
const genModel = new GenerativeModel(fakeVertexAI, {
model: 'my-model',
generationConfig: {
topK: 1
}
});
const chatSession = genModel.startChat({
generationConfig: {
topK: 2
}
});
expect(chatSession.params?.generationConfig).to.deep.equal({
topK: 2
});
});
it('passes params through to chat.sendMessage', async () => {
const genModel = new GenerativeModel(fakeVertexAI, {
model: 'my-model',
tools: [
{ functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] }
],
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }
systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] },
generationConfig: {
topK: 1
}
});
expect(genModel.tools?.length).to.equal(1);
expect(genModel.toolConfig?.functionCallingConfig?.mode).to.equal(
Expand All @@ -196,7 +215,8 @@ describe('GenerativeModel', () => {
return (
value.includes('myfunc') &&
value.includes(FunctionCallingMode.NONE) &&
value.includes('be friendly')
value.includes('be friendly') &&
value.includes('topK')
);
}),
{}
Expand Down Expand Up @@ -236,7 +256,10 @@ describe('GenerativeModel', () => {
{ functionDeclarations: [{ name: 'myfunc', description: 'mydesc' }] }
],
toolConfig: { functionCallingConfig: { mode: FunctionCallingMode.NONE } },
systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] }
systemInstruction: { role: 'system', parts: [{ text: 'be friendly' }] },
generationConfig: {
responseMimeType: 'image/jpeg'
}
});
expect(genModel.tools?.length).to.equal(1);
expect(genModel.toolConfig?.functionCallingConfig?.mode).to.equal(
Expand All @@ -262,7 +285,10 @@ describe('GenerativeModel', () => {
toolConfig: {
functionCallingConfig: { mode: FunctionCallingMode.AUTO }
},
systemInstruction: { role: 'system', parts: [{ text: 'be formal' }] }
systemInstruction: { role: 'system', parts: [{ text: 'be formal' }] },
generationConfig: {
responseMimeType: 'image/png'
}
})
.sendMessage('hello');
expect(makeRequestStub).to.be.calledWith(
Expand All @@ -274,7 +300,9 @@ describe('GenerativeModel', () => {
return (
value.includes('otherfunc') &&
value.includes(FunctionCallingMode.AUTO) &&
value.includes('be formal')
value.includes('be formal') &&
value.includes('image/png') &&
!value.includes('image/jpeg')
);
}),
{}
Expand Down
7 changes: 7 additions & 0 deletions packages/vertexai/src/models/generative-model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ export class GenerativeModel extends VertexAIModel {
tools: this.tools,
toolConfig: this.toolConfig,
systemInstruction: this.systemInstruction,
generationConfig: this.generationConfig,
safetySettings: this.safetySettings,
/**
* Overrides params inherited from GenerativeModel with those explicitly set in the
* StartChatParams. For example, if startChatParams.generationConfig is set, it'll override
* this.generationConfig.
*/
...startChatParams
},
this.requestOptions
Expand Down
Loading