diff --git a/src/methods/chat-session.test.ts b/src/methods/chat-session.test.ts index c09821030..ed36c67f7 100644 --- a/src/methods/chat-session.test.ts +++ b/src/methods/chat-session.test.ts @@ -46,6 +46,15 @@ describe("ChatSession", () => { match.any, ); }); + it("sendMessage() should reset messageInProgress flag after resolving promise", async () => { + const mockResponse = getMockResponse( + "unary-success-basic-reply-short.json", + ); + stub(request, "makeModelRequest").resolves(mockResponse as Response); + const chatSession = new ChatSession("MY_API_KEY", "a-model"); + await chatSession.sendMessage("hello"); + expect(chatSession["_messageInProgress"]).to.be.false; + }); }); describe("sendMessageRecitationErrorNotAddingResponseToHistory()", () => { it("generateContent errors should be catchable", async () => { @@ -75,6 +84,22 @@ describe("ChatSession", () => { expect(consoleStub).to.not.be.called; clock.restore(); }); + it("sendMessageStream() should reset messageInProgress flag after resolving promise", async () => { + const consoleStub = stub(console, "error"); + const generateContentStreamStub = stub( + generateContentMethods, + "generateContentStream", + ).resolves(); + const chatSession = new ChatSession("MY_API_KEY", "a-model"); + await expect(chatSession.sendMessageStream("hello")).to.be.fulfilled; + expect(generateContentStreamStub).to.be.calledWith( + "MY_API_KEY", + "a-model", + match.any, + ); + expect(consoleStub).to.not.be.called; + expect(chatSession["_messageInProgress"]).to.be.false; + }); it("downstream sendPromise errors should log but not throw", async () => { const clock = useFakeTimers(); const consoleStub = stub(console, "error"); diff --git a/src/methods/chat-session.ts b/src/methods/chat-session.ts index 4022e2e39..9dfc9b78f 100644 --- a/src/methods/chat-session.ts +++ b/src/methods/chat-session.ts @@ -45,6 +45,7 @@ export class ChatSession { private _apiKey: string; private _history: Content[] = []; private _sendPromise: Promise = Promise.resolve(); + private _messageInProgress: boolean = false; constructor( apiKey: string, @@ -81,6 +82,12 @@ export class ChatSession { request: string | Array, requestOptions: SingleRequestOptions = {}, ): Promise { + if (this._messageInProgress) { + throw new Error( + "sendMessage() was called while another message was in progress, this may lead to unexpected behavior.", + ); + } + this._messageInProgress = true; await this._sendPromise; const newContent = formatNewContent(request); const generateContentRequest: GenerateContentRequest = { @@ -126,6 +133,9 @@ export class ChatSession { } } finalResult = result; + }) + .finally(() => { + this._messageInProgress = false; }); await this._sendPromise; return finalResult; @@ -144,6 +154,12 @@ export class ChatSession { request: string | Array, requestOptions: SingleRequestOptions = {}, ): Promise { + if (this._messageInProgress) { + throw new Error( + "sendMessageStream() was called while another message was in progress, this may lead to unexpected behavior.", + ); + } + this._messageInProgress = true; await this._sendPromise; const newContent = formatNewContent(request); const generateContentRequest: GenerateContentRequest = { @@ -203,6 +219,7 @@ export class ChatSession { console.error(e); } }); + this._messageInProgress = false; return streamPromise; } }