Skip to content

Commit 5fe2068

Browse files
viduni94cqliu1
authored andcommitted
[Obs AI Assistant] Fix contextual insights scoring (elastic#214259)
Closes elastic#209572 ### Summary Scoring in contextual insights is broken because the `get_contextual_insight_instructions` tool call is not followed by the tool response. This happens because we replace the last user message (in this case tool response) with the user message related to scoring. ### Solution We should include the tool call name when replacing this message, so that it gets converted to inference messages correctly here: https://github.com/elastic/kibana/blob/07012811b29b487a3b4a664469c7a198355e44bf/x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts#L60-L81 ### Checklist - [x] The PR description includes the appropriate Release Notes section, and the correct `release_note:*` label is applied per the [guidelines](https://www.elastic.co/guide/en/kibana/master/contributing.html#kibana-release-notes-process)
1 parent 463d2c2 commit 5fe2068

File tree

28 files changed

+948
-214
lines changed

28 files changed

+948
-214
lines changed

x-pack/platform/plugins/shared/observability_ai_assistant/common/convert_messages_for_inference.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ export function convertMessagesForInference(
6666
msg.role === InferenceMessageRole.Assistant &&
6767
msg.toolCalls?.[0]?.function.name === message.message.name
6868
) as AssistantMessage | undefined;
69+
6970
if (!toolCallRequest) {
7071
throw new Error(`Could not find tool call request for ${message.message.name}`);
7172
}

x-pack/platform/plugins/shared/observability_ai_assistant/server/functions/context.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,14 @@ export function registerContextFunction({
6363
);
6464

6565
const userPrompt = userMessage?.message.content!;
66+
const userMessageFunctionName = userMessage?.message.name;
6667

6768
const { scores, relevantDocuments, suggestions } = await recallAndScore({
6869
recall: client.recall,
6970
chat,
7071
logger: resources.logger,
7172
userPrompt,
73+
userMessageFunctionName,
7274
context: screenDescription,
7375
messages,
7476
signal,
Lines changed: 265 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,265 @@
1+
/*
2+
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
3+
* or more contributor license agreements. Licensed under the Elastic License
4+
* 2.0; you may not use this file except in compliance with the Elastic License
5+
* 2.0.
6+
*/
7+
8+
import { RecalledSuggestion, recallAndScore } from './recall_and_score';
9+
import { scoreSuggestions } from './score_suggestions';
10+
import { MessageRole, type Message } from '../../../common';
11+
import type { FunctionCallChatFunction } from '../../service/types';
12+
import { AnalyticsServiceStart } from '@kbn/core/server';
13+
import { Logger } from '@kbn/logging';
14+
import { recallRankingEventType } from '../../analytics/recall_ranking';
15+
16+
jest.mock('./score_suggestions', () => ({
17+
scoreSuggestions: jest.fn(),
18+
}));
19+
20+
export const sampleMessages: Message[] = [
21+
{
22+
'@timestamp': '2025-03-13T14:53:11.240Z',
23+
message: { role: MessageRole.User, content: 'test' },
24+
},
25+
];
26+
27+
export const normalConversationMessages: Message[] = [
28+
{
29+
'@timestamp': '2025-03-12T21:00:13.980Z',
30+
message: { role: MessageRole.User, content: 'What is my favourite color?' },
31+
},
32+
{
33+
'@timestamp': '2025-03-12T21:00:14.920Z',
34+
message: {
35+
function_call: { name: 'context', trigger: MessageRole.Assistant },
36+
role: MessageRole.Assistant,
37+
content: '',
38+
},
39+
},
40+
];
41+
42+
export const contextualInsightsMessages: Message[] = [
43+
{
44+
'@timestamp': '2025-03-12T21:01:21.111Z',
45+
message: {
46+
role: MessageRole.User,
47+
content: "I'm looking at an alert and trying to understand why it was triggered",
48+
},
49+
},
50+
{
51+
'@timestamp': '2025-03-12T21:01:21.111Z',
52+
message: {
53+
role: MessageRole.Assistant,
54+
function_call: {
55+
name: 'get_contextual_insight_instructions',
56+
trigger: MessageRole.Assistant,
57+
arguments: '{}',
58+
},
59+
},
60+
},
61+
{
62+
'@timestamp': '2025-03-12T21:01:21.111Z',
63+
message: {
64+
role: MessageRole.User,
65+
content:
66+
'{"instructions":"I\'m an SRE. I am looking at an alert that was triggered. I want to understand why it was triggered......}',
67+
name: 'get_contextual_insight_instructions',
68+
},
69+
},
70+
{
71+
'@timestamp': '2025-03-12T21:01:21.984Z',
72+
message: {
73+
function_call: { name: 'context', trigger: MessageRole.Assistant },
74+
role: MessageRole.Assistant,
75+
content: '',
76+
},
77+
},
78+
];
79+
80+
describe('recallAndScore', () => {
81+
const mockRecall = jest.fn();
82+
const mockChat = jest.fn() as unknown as FunctionCallChatFunction;
83+
const mockLogger = { error: jest.fn(), debug: jest.fn() } as unknown as Logger;
84+
const mockAnalytics = { reportEvent: jest.fn() } as unknown as AnalyticsServiceStart;
85+
const signal = new AbortController().signal;
86+
87+
beforeEach(() => {
88+
jest.clearAllMocks();
89+
});
90+
91+
describe('when no documents are recalled', () => {
92+
let result: {
93+
relevantDocuments?: RecalledSuggestion[];
94+
scores?: Array<{ id: string; score: number }>;
95+
suggestions: RecalledSuggestion[];
96+
};
97+
98+
beforeEach(async () => {
99+
mockRecall.mockResolvedValue([]);
100+
101+
result = await recallAndScore({
102+
recall: mockRecall,
103+
chat: mockChat,
104+
analytics: mockAnalytics,
105+
userPrompt: 'What is my favorite color?',
106+
context: 'Some context',
107+
messages: sampleMessages,
108+
logger: mockLogger,
109+
signal,
110+
});
111+
});
112+
113+
it('returns empty suggestions', async () => {
114+
expect(result).toEqual({ relevantDocuments: [], scores: [], suggestions: [] });
115+
});
116+
117+
it('invokes recall with user prompt and screen context', async () => {
118+
expect(mockRecall).toHaveBeenCalledWith({
119+
queries: [
120+
{ text: 'What is my favorite color?', boost: 3 },
121+
{ text: 'Some context', boost: 1 },
122+
],
123+
});
124+
});
125+
126+
it('does not score the suggestions', async () => {
127+
expect(scoreSuggestions).not.toHaveBeenCalled();
128+
});
129+
});
130+
131+
it('handles errors when scoring fails', async () => {
132+
mockRecall.mockResolvedValue([{ id: 'doc1', text: 'Hello world', score: 0.5 }]);
133+
(scoreSuggestions as jest.Mock).mockRejectedValue(new Error('Scoring failed'));
134+
135+
const result = await recallAndScore({
136+
recall: mockRecall,
137+
chat: mockChat,
138+
analytics: mockAnalytics,
139+
userPrompt: 'test',
140+
context: 'context',
141+
messages: sampleMessages,
142+
logger: mockLogger,
143+
signal,
144+
});
145+
146+
expect(mockLogger.error).toHaveBeenCalledWith(
147+
expect.stringContaining('Error scoring documents: Scoring failed'),
148+
expect.any(Object)
149+
);
150+
expect(result.suggestions.length).toBe(1);
151+
expect(result.suggestions[0].id).toBe('doc1');
152+
});
153+
154+
it('calls scoreSuggestions with correct arguments', async () => {
155+
const recalledDocs = [{ id: 'doc1', text: 'Hello world', score: 0.8 }];
156+
mockRecall.mockResolvedValue(recalledDocs);
157+
(scoreSuggestions as jest.Mock).mockResolvedValue({
158+
scores: [{ id: 'doc1', score: 7 }],
159+
relevantDocuments: recalledDocs,
160+
});
161+
162+
await recallAndScore({
163+
recall: mockRecall,
164+
chat: mockChat,
165+
analytics: mockAnalytics,
166+
userPrompt: 'test',
167+
context: 'context',
168+
messages: sampleMessages,
169+
logger: mockLogger,
170+
signal,
171+
});
172+
173+
expect(scoreSuggestions).toHaveBeenCalledWith({
174+
suggestions: recalledDocs,
175+
logger: mockLogger,
176+
messages: sampleMessages,
177+
userPrompt: 'test',
178+
userMessageFunctionName: undefined,
179+
context: 'context',
180+
signal,
181+
chat: mockChat,
182+
});
183+
});
184+
185+
it('handles the normal conversation flow correctly', async () => {
186+
mockRecall.mockResolvedValue([
187+
{ id: 'fav_color', text: 'My favourite color is blue.', score: 0.9 },
188+
]);
189+
(scoreSuggestions as jest.Mock).mockResolvedValue({
190+
scores: [{ id: 'fav_color', score: 7 }],
191+
relevantDocuments: [{ id: 'fav_color', text: 'My favourite color is blue.' }],
192+
});
193+
194+
const result = await recallAndScore({
195+
recall: mockRecall,
196+
chat: mockChat,
197+
analytics: mockAnalytics,
198+
userPrompt: "What's my favourite color?",
199+
context: '',
200+
messages: normalConversationMessages,
201+
logger: mockLogger,
202+
signal,
203+
});
204+
205+
expect(result.relevantDocuments).toEqual([
206+
{ id: 'fav_color', text: 'My favourite color is blue.' },
207+
]);
208+
expect(mockRecall).toHaveBeenCalled();
209+
expect(scoreSuggestions).toHaveBeenCalled();
210+
});
211+
212+
it('handles contextual insights conversation flow correctly', async () => {
213+
mockRecall.mockResolvedValue([
214+
{ id: 'alert_cause', text: 'The alert was triggered due to high CPU usage.', score: 0.85 },
215+
]);
216+
(scoreSuggestions as jest.Mock).mockResolvedValue({
217+
scores: [{ id: 'alert_cause', score: 6 }],
218+
relevantDocuments: [
219+
{ id: 'alert_cause', text: 'The alert was triggered due to high CPU usage.' },
220+
],
221+
});
222+
223+
const result = await recallAndScore({
224+
recall: mockRecall,
225+
chat: mockChat,
226+
analytics: mockAnalytics,
227+
userPrompt: "I'm looking at an alert and trying to understand why it was triggered",
228+
context: 'User is analyzing an alert',
229+
messages: contextualInsightsMessages,
230+
logger: mockLogger,
231+
signal,
232+
});
233+
234+
expect(result.relevantDocuments).toEqual([
235+
{ id: 'alert_cause', text: 'The alert was triggered due to high CPU usage.' },
236+
]);
237+
expect(mockRecall).toHaveBeenCalled();
238+
expect(scoreSuggestions).toHaveBeenCalled();
239+
});
240+
241+
it('reports analytics with the correct structure', async () => {
242+
const recalledDocs = [{ id: 'doc1', text: 'Hello world', score: 0.8 }];
243+
mockRecall.mockResolvedValue(recalledDocs);
244+
(scoreSuggestions as jest.Mock).mockResolvedValue({
245+
scores: [{ id: 'doc1', score: 7 }],
246+
relevantDocuments: recalledDocs,
247+
});
248+
249+
await recallAndScore({
250+
recall: mockRecall,
251+
chat: mockChat,
252+
analytics: mockAnalytics,
253+
userPrompt: 'test',
254+
context: 'context',
255+
messages: sampleMessages,
256+
logger: mockLogger,
257+
signal,
258+
});
259+
260+
expect(mockAnalytics.reportEvent).toHaveBeenCalledWith(
261+
recallRankingEventType,
262+
expect.objectContaining({ scoredDocuments: [{ elserScore: 0.8, llmScore: 7 }] })
263+
);
264+
});
265+
});

x-pack/platform/plugins/shared/observability_ai_assistant/server/utils/recall/recall_and_score.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ export async function recallAndScore({
2121
chat,
2222
analytics,
2323
userPrompt,
24+
userMessageFunctionName,
2425
context,
2526
messages,
2627
logger,
@@ -30,6 +31,7 @@ export async function recallAndScore({
3031
chat: FunctionCallChatFunction;
3132
analytics: AnalyticsServiceStart;
3233
userPrompt: string;
34+
userMessageFunctionName?: string;
3335
context: string;
3436
messages: Message[];
3537
logger: Logger;
@@ -62,6 +64,7 @@ export async function recallAndScore({
6264
logger,
6365
messages,
6466
userPrompt,
67+
userMessageFunctionName,
6568
context,
6669
signal,
6770
chat,

0 commit comments

Comments
 (0)