Skip to content

Commit 4c663ea

Browse files
lerouxbgribnoysup
andauthored
feat(compass-assistant): add a system prompt to precede user prompts with the basic UI context COMPASS-10140 (#7629)
* simple assistant global state implementation * a basic system context prompt * remove the non-genuine connections hack for now * lint * move currentActiveConnections to ConnectionsComponent * use workspace-info, add tests * test that it added the sytem message * don't take context messages into account for confirmations * typeguard.. * more typeguards * clear the state when unmounting * check for existence of namespace rather * set collection tab related context in compass-collection rather * only send the context if it changed * TODO ticket number --------- Co-authored-by: Sergey Petushkov <[email protected]>
1 parent 4ba56a4 commit 4c663ea

File tree

16 files changed

+580
-33
lines changed

16 files changed

+580
-33
lines changed

package-lock.json

Lines changed: 4 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

packages/compass-assistant/package.json

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,10 @@
6161
"@mongodb-js/compass-logging": "^1.7.25",
6262
"@mongodb-js/compass-telemetry": "^1.19.5",
6363
"@mongodb-js/connection-info": "^0.24.0",
64+
"@mongodb-js/workspace-info": "^1.0.0",
6465
"ai": "^5.0.26",
6566
"compass-preferences-model": "^2.66.3",
67+
"mongodb-collection-model": "^5.37.0",
6668
"mongodb-connection-string-url": "^3.0.1",
6769
"react": "^17.0.2",
6870
"throttleit": "^2.1.0",
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import type { ConnectionInfo } from '@mongodb-js/connection-info';
2+
import type {
3+
WorkspaceTab,
4+
CollectionSubtab,
5+
} from '@mongodb-js/workspace-info';
6+
import type { CollectionMetadata } from 'mongodb-collection-model';
7+
import React, { useEffect } from 'react';
8+
9+
export type GlobalState = {
10+
activeConnections: ConnectionInfo[];
11+
activeWorkspace: WorkspaceTab | null;
12+
activeCollectionMetadata: CollectionMetadata | null;
13+
currentQuery: object | null;
14+
currentAggregation: object | null;
15+
activeCollectionSubTab: CollectionSubtab | null;
16+
};
17+
18+
const INITIAL_STATE: GlobalState = {
19+
activeConnections: [],
20+
activeWorkspace: null,
21+
activeCollectionMetadata: null,
22+
currentQuery: null,
23+
currentAggregation: null,
24+
activeCollectionSubTab: null,
25+
};
26+
27+
const AssistantGlobalStateContext = React.createContext<GlobalState>({
28+
...INITIAL_STATE,
29+
});
30+
31+
const AssistantGlobalSetStateContext = React.createContext<
32+
React.Dispatch<React.SetStateAction<GlobalState>>
33+
>(() => undefined);
34+
35+
export const AssistantGlobalStateProvider: React.FunctionComponent = ({
36+
children,
37+
}) => {
38+
const [globalState, setGlobalState] = React.useState({ ...INITIAL_STATE });
39+
return (
40+
<AssistantGlobalStateContext.Provider value={globalState}>
41+
<AssistantGlobalSetStateContext.Provider value={setGlobalState}>
42+
{children}
43+
</AssistantGlobalSetStateContext.Provider>
44+
</AssistantGlobalStateContext.Provider>
45+
);
46+
};
47+
48+
export function useSyncAssistantGlobalState<T extends keyof GlobalState>(
49+
stateKey: T,
50+
newState: GlobalState[T]
51+
) {
52+
const setState = React.useContext(AssistantGlobalSetStateContext);
53+
useEffect(() => {
54+
setState((prevState) => {
55+
const state = {
56+
...prevState,
57+
[stateKey]: newState,
58+
};
59+
60+
// Get rid of some non-sensical states incase the user switches away from
61+
// a collection tab to something that is not a collection tab.
62+
// activeConnections and activeWorkspace will get updated no matter
63+
// how/where the user navigates because those concepts are always
64+
// "present" in a way that an active collection is not.
65+
if (state.activeWorkspace?.type !== 'Collection') {
66+
state.activeCollectionMetadata = null;
67+
state.activeCollectionSubTab = null;
68+
}
69+
70+
return state;
71+
});
72+
}, [newState, setState, stateKey]);
73+
}
74+
75+
export function useAssistantGlobalState() {
76+
return React.useContext(AssistantGlobalStateContext);
77+
}

packages/compass-assistant/src/compass-assistant-provider.spec.tsx

Lines changed: 39 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -414,20 +414,48 @@ describe('CompassAssistantProvider', function () {
414414

415415
await renderOpenAssistantDrawer({ chat: mockChat });
416416

417-
userEvent.type(
418-
screen.getByPlaceholderText('Ask a question'),
419-
'Hello assistant!'
420-
);
421-
userEvent.click(screen.getByLabelText('Send message'));
417+
for (let i = 0; i < 2; i++) {
418+
userEvent.type(
419+
screen.getByPlaceholderText('Ask a question'),
420+
`Hello assistant! (${i})`
421+
);
422+
userEvent.click(screen.getByLabelText('Send message'));
422423

423-
await waitFor(() => {
424-
expect(sendMessageSpy.calledOnce).to.be.true;
425-
expect(sendMessageSpy.firstCall.args[0]).to.deep.include({
426-
text: 'Hello assistant!',
424+
await waitFor(() => {
425+
expect(sendMessageSpy.callCount).to.equal(i + 1);
426+
expect(sendMessageSpy.getCall(i).args[0]).to.deep.include({
427+
text: `Hello assistant! (${i})`,
428+
});
429+
430+
expect(screen.getByText(`Hello assistant! (${i})`)).to.exist;
427431
});
432+
}
428433

429-
expect(screen.getByText('Hello assistant!')).to.exist;
430-
});
434+
const contextMessages = mockChat.messages.filter(
435+
(message) => message.metadata?.isSystemContext
436+
);
437+
438+
for (const contextMessage of contextMessages) {
439+
// just clear it up so we can deep compare
440+
contextMessage.id = 'system-context';
441+
}
442+
443+
// it only sent one
444+
expect(contextMessages).to.deep.equal([
445+
{
446+
id: 'system-context',
447+
role: 'system',
448+
metadata: {
449+
isSystemContext: true,
450+
},
451+
parts: [
452+
{
453+
type: 'text',
454+
text: 'The user does not have any tabs open.',
455+
},
456+
],
457+
},
458+
]);
431459
});
432460

433461
it('will not send new messages if the user does not opt in', async function () {

packages/compass-assistant/src/compass-assistant-provider.tsx

Lines changed: 74 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,13 @@ import {
1313
} from '@mongodb-js/atlas-service/provider';
1414
import { DocsProviderTransport } from './docs-provider-transport';
1515
import {
16+
useCurrentValueRef,
1617
useDrawerActions,
1718
useInitialValue,
1819
} from '@mongodb-js/compass-components';
1920
import {
2021
buildConnectionErrorPrompt,
22+
buildContextPrompt,
2123
buildExplainPlanPrompt,
2224
buildProactiveInsightsPrompt,
2325
type EntryPointMessage,
@@ -31,7 +33,7 @@ import {
3133
createLoggerLocator,
3234
type Logger,
3335
} from '@mongodb-js/compass-logging/provider';
34-
import type { ConnectionInfo } from '@mongodb-js/connection-info';
36+
import { type ConnectionInfo } from '@mongodb-js/connection-info';
3537
import {
3638
telemetryLocator,
3739
type TrackFunction,
@@ -41,10 +43,15 @@ import type { AtlasAiService } from '@mongodb-js/compass-generative-ai/provider'
4143
import { atlasAiServiceLocator } from '@mongodb-js/compass-generative-ai/provider';
4244
import { buildConversationInstructionsPrompt } from './prompts';
4345
import { createOpenAI } from '@ai-sdk/openai';
46+
import {
47+
AssistantGlobalStateProvider,
48+
useAssistantGlobalState,
49+
} from './assistant-global-state';
4450

4551
export const ASSISTANT_DRAWER_ID = 'compass-assistant-drawer';
4652

4753
export type AssistantMessage = UIMessage & {
54+
role?: 'user' | 'assistant' | 'system';
4855
metadata?: {
4956
/** The text to display instead of the message text. */
5057
displayText?: string;
@@ -63,6 +70,11 @@ export type AssistantMessage = UIMessage & {
6370
instructions?: string;
6471
/** Excludes history if this message is the last message being sent */
6572
sendWithoutHistory?: boolean;
73+
/** Whether to send the current context along with the message if the context changed */
74+
sendContext?: boolean;
75+
76+
/** Whether this is a message to the model that we don't want to display to the user*/
77+
isSystemContext?: boolean;
6678
};
6779
};
6880

@@ -173,6 +185,16 @@ export type CompassAssistantService = {
173185
getIsAssistantEnabled: () => boolean;
174186
};
175187

188+
// Type guard to check if activeWorkspace has a connectionId property
189+
function hasConnectionId(obj: unknown): obj is { connectionId: string } {
190+
return (
191+
typeof obj === 'object' &&
192+
obj !== null &&
193+
'connectionId' in obj &&
194+
typeof (obj as any).connectionId === 'string'
195+
);
196+
}
197+
176198
export const AssistantProvider: React.FunctionComponent<
177199
PropsWithChildren<{
178200
appNameForPrompt: string;
@@ -183,12 +205,23 @@ export const AssistantProvider: React.FunctionComponent<
183205
const { openDrawer } = useDrawerActions();
184206
const track = useTelemetry();
185207

208+
const assistantGlobalStateRef = useCurrentValueRef(useAssistantGlobalState());
209+
210+
const lastContextPromptRef = useRef<string | null>(null);
211+
186212
const ensureOptInAndSend = useInitialValue(() => {
187213
return async function (
188214
message: SendMessage,
189215
options: SendOptions,
190216
callback: () => void
191217
) {
218+
const {
219+
activeWorkspace,
220+
activeConnections,
221+
activeCollectionMetadata,
222+
activeCollectionSubTab,
223+
} = assistantGlobalStateRef.current;
224+
192225
try {
193226
await atlasAiService.ensureAiFeatureAccess();
194227
} catch {
@@ -204,6 +237,36 @@ export const AssistantProvider: React.FunctionComponent<
204237
await chat.stop();
205238
}
206239

240+
const activeConnection =
241+
activeConnections.find((connInfo) => {
242+
return (
243+
hasConnectionId(activeWorkspace) &&
244+
connInfo.id === activeWorkspace.connectionId
245+
);
246+
}) ?? null;
247+
248+
const contextPrompt = buildContextPrompt({
249+
activeWorkspace,
250+
activeConnection,
251+
activeCollectionMetadata,
252+
activeCollectionSubTab,
253+
});
254+
255+
// use just the text so we have a stable reference to compare against
256+
const contextPromptText =
257+
contextPrompt.parts[0].type === 'text'
258+
? contextPrompt.parts[0].text
259+
: '';
260+
261+
const shouldSendContextPrompt =
262+
message?.metadata?.sendContext &&
263+
(!lastContextPromptRef.current ||
264+
lastContextPromptRef.current !== contextPromptText);
265+
if (shouldSendContextPrompt) {
266+
lastContextPromptRef.current = contextPromptText;
267+
chat.messages = [...chat.messages, contextPrompt];
268+
}
269+
207270
await chat.sendMessage(message, options);
208271
};
209272
});
@@ -224,6 +287,7 @@ export const AssistantProvider: React.FunctionComponent<
224287
metadata: {
225288
...metadata,
226289
source: entryPointName,
290+
sendContext: true,
227291
},
228292
},
229293
{},
@@ -285,13 +349,15 @@ export const CompassAssistantProvider = registerCompassPlugin(
285349
throw new Error('atlasAiService was not provided by the state');
286350
}
287351
return (
288-
<AssistantProvider
289-
appNameForPrompt={appNameForPrompt}
290-
chat={chat}
291-
atlasAiService={atlasAiService}
292-
>
293-
{children}
294-
</AssistantProvider>
352+
<AssistantGlobalStateProvider>
353+
<AssistantProvider
354+
appNameForPrompt={appNameForPrompt}
355+
chat={chat}
356+
atlasAiService={atlasAiService}
357+
>
358+
{children}
359+
</AssistantProvider>
360+
</AssistantGlobalStateProvider>
295361
);
296362
},
297363
activate: (

packages/compass-assistant/src/components/assistant-chat.tsx

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -263,11 +263,15 @@ export const AssistantChat: React.FunctionComponent<AssistantChatProps> = ({
263263
const trimmedMessageBody = messageBody.trim();
264264
if (trimmedMessageBody) {
265265
await chat.stop();
266-
void ensureOptInAndSend?.({ text: trimmedMessageBody }, {}, () => {
267-
track('Assistant Prompt Submitted', {
268-
user_input_length: trimmedMessageBody.length,
269-
});
270-
});
266+
void ensureOptInAndSend?.(
267+
{ text: trimmedMessageBody, metadata: { sendContext: true } },
268+
{},
269+
() => {
270+
track('Assistant Prompt Submitted', {
271+
user_input_length: trimmedMessageBody.length,
272+
});
273+
}
274+
);
271275
}
272276
},
273277
[track, ensureOptInAndSend, chat]
@@ -357,6 +361,10 @@ export const AssistantChat: React.FunctionComponent<AssistantChatProps> = ({
357361
[ensureOptInAndSend, setMessages, track]
358362
);
359363

364+
const visibleMessages = messages.filter(
365+
(message) => !message.metadata?.isSystemContext
366+
);
367+
360368
return (
361369
<div
362370
data-testid="assistant-chat"
@@ -374,7 +382,7 @@ export const AssistantChat: React.FunctionComponent<AssistantChatProps> = ({
374382
ref={messagesContainerRef}
375383
>
376384
<div className={messagesWrapStyles}>
377-
{messages.map((message, index) => {
385+
{visibleMessages.map((message, index) => {
378386
const { id, role, metadata, parts } = message;
379387
const seenTitles = new Set<string>();
380388
const sources = [];
@@ -395,7 +403,7 @@ export const AssistantChat: React.FunctionComponent<AssistantChatProps> = ({
395403
}
396404
if (metadata?.confirmation) {
397405
const { description, state } = metadata.confirmation;
398-
const isLastMessage = index === messages.length - 1;
406+
const isLastMessage = index === visibleMessages.length - 1;
399407

400408
return (
401409
<ConfirmationMessage

packages/compass-assistant/src/index.tsx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,4 @@ export {
77
export type { CompassAssistantService } from './compass-assistant-provider';
88
export type { ProactiveInsightsContext, EntryPointMessage } from './prompts';
99
export { APP_NAMES_FOR_PROMPT } from './prompts';
10+
export { useSyncAssistantGlobalState } from './assistant-global-state';

0 commit comments

Comments
 (0)