Skip to content

Commit 34dd207

Browse files
chore: simplified the update pipeline and added tool call serialization
1 parent 3530453 commit 34dd207

File tree

1 file changed

+94
-107
lines changed

1 file changed

+94
-107
lines changed

tests/accuracy/sdk/accuracy-result-storage/mongodb-storage.ts

Lines changed: 94 additions & 107 deletions
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,38 @@ import {
55
AccuracyRunStatus,
66
AccuracyRunStatuses,
77
ExpectedToolCall,
8+
LLMToolCall,
89
ModelResponse,
10+
PromptResult,
911
} from "./result-storage.js";
1012

1113
// Omitting these as they might contain large chunk of texts
1214
const OMITTED_MODEL_RESPONSE_FIELDS: (keyof ModelResponse)[] = ["messages", "text"];
1315

16+
// The LLMToolCalls and ExpectedToolCalls are expected to have mongodb operators
17+
// nested in the objects. This interferes with the update operation that we do
18+
// on the accuracy result document to save the model responses which is why we
19+
// serialize them before saving and deserialize them on fetch.
20+
type SavedAccuracyResult = Omit<AccuracyResult, "promptResults"> & {
21+
promptResults: SavedPromptResult[];
22+
};
23+
24+
type SavedPromptResult = Omit<PromptResult, "expectedToolCalls" | "modelResponses"> & {
25+
expectedToolCalls: string;
26+
modelResponses: SavedModelResponse[];
27+
};
28+
29+
type SavedModelResponse = Omit<ModelResponse, "llmToolCalls"> & {
30+
llmToolCalls: string;
31+
};
32+
1433
export class MongoDBBasedResultStorage implements AccuracyResultStorage {
1534
private client: MongoClient;
16-
private resultCollection: Collection<AccuracyResult>;
35+
private resultCollection: Collection<SavedAccuracyResult>;
1736

1837
constructor(connectionString: string, database: string, collection: string) {
1938
this.client = new MongoClient(connectionString);
20-
this.resultCollection = this.client.db(database).collection<AccuracyResult>(collection);
39+
this.resultCollection = this.client.db(database).collection<SavedAccuracyResult>(collection);
2140
}
2241

2342
async getAccuracyResult(commitSHA: string, runId?: string): Promise<AccuracyResult | null> {
@@ -28,11 +47,14 @@ export class MongoDBBasedResultStorage implements AccuracyResultStorage {
2847
// for commit is when you want the last successful run of that
2948
// particular commit.
3049
{ commitSHA, runStatus: AccuracyRunStatus.Done };
31-
return await this.resultCollection.findOne(filters, {
50+
51+
const result = await this.resultCollection.findOne(filters, {
3252
sort: {
3353
createdOn: -1,
3454
},
3555
});
56+
57+
return result ? this.deserializeSavedResult(result) : result;
3658
}
3759

3860
async updateRunStatus(commitSHA: string, runId: string, status: AccuracyRunStatuses): Promise<void> {
@@ -59,130 +81,77 @@ export class MongoDBBasedResultStorage implements AccuracyResultStorage {
5981
expectedToolCalls: ExpectedToolCall[];
6082
modelResponse: ModelResponse;
6183
}): Promise<void> {
62-
const savedModelResponse: ModelResponse = { ...modelResponse };
63-
for (const field of OMITTED_MODEL_RESPONSE_FIELDS) {
64-
delete savedModelResponse[field];
65-
}
66-
67-
await this.resultCollection.updateOne(
68-
{ commitSHA, runId },
69-
{
70-
$setOnInsert: {
71-
runStatus: AccuracyRunStatus.InProgress,
72-
createdOn: Date.now(),
73-
commitSHA,
74-
runId,
75-
promptResults: [],
76-
},
77-
},
78-
{ upsert: true }
79-
);
80-
81-
await this.resultCollection.updateOne(
82-
{
83-
commitSHA,
84-
runId,
85-
"promptResults.prompt": { $ne: prompt },
86-
},
87-
{
88-
$push: {
89-
promptResults: { prompt, expectedToolCalls, modelResponses: [] },
90-
},
91-
}
92-
);
84+
const expectedToolCallsToSave = JSON.stringify(expectedToolCalls);
85+
const modelResponseToSave: SavedModelResponse = {
86+
...modelResponse,
87+
llmToolCalls: JSON.stringify(modelResponse.llmToolCalls),
88+
};
9389

94-
await this.resultCollection.updateOne(
95-
{ commitSHA, runId },
96-
{
97-
$push: {
98-
"promptResults.$[promptElement].modelResponses": savedModelResponse,
99-
},
100-
},
101-
{
102-
arrayFilters: [{ "promptElement.prompt": prompt }],
103-
}
104-
);
105-
}
106-
107-
async saveModelResponseForPromptAtomic({
108-
commitSHA,
109-
runId,
110-
prompt,
111-
expectedToolCalls,
112-
modelResponse,
113-
}: {
114-
commitSHA: string;
115-
runId: string;
116-
prompt: string;
117-
expectedToolCalls: ExpectedToolCall[];
118-
modelResponse: ModelResponse;
119-
}): Promise<void> {
120-
const savedModelResponse: ModelResponse = { ...modelResponse };
12190
for (const field of OMITTED_MODEL_RESPONSE_FIELDS) {
122-
delete savedModelResponse[field];
91+
delete modelResponseToSave[field];
12392
}
12493

12594
await this.resultCollection.updateOne(
12695
{ commitSHA, runId },
12796
[
12897
{
12998
$set: {
130-
runStatus: {
131-
$ifNull: ["$runStatus", AccuracyRunStatus.InProgress],
132-
},
133-
createdOn: {
134-
$ifNull: ["$createdOn", Date.now()],
99+
runStatus: { $ifNull: ["$runStatus", AccuracyRunStatus.InProgress] },
100+
createdOn: { $ifNull: ["$createdOn", Date.now()] },
101+
commitSHA: { $ifNull: ["$commitSHA", commitSHA] },
102+
runId: { $ifNull: ["$runId", runId] },
103+
promptResults: {
104+
$ifNull: ["$promptResults", []],
135105
},
136-
commitSHA: commitSHA,
137-
runId: runId,
106+
},
107+
},
108+
{
109+
$set: {
138110
promptResults: {
139111
$let: {
140112
vars: {
141-
existingPrompts: { $ifNull: ["$promptResults", []] },
142-
promptExists: {
143-
$in: [
144-
prompt,
145-
{
146-
$ifNull: [
147-
{ $map: { input: "$promptResults", as: "pr", in: "$$pr.prompt" } },
148-
[],
149-
],
150-
},
151-
],
113+
existingPromptIndex: {
114+
$indexOfArray: ["$promptResults.prompt", prompt],
152115
},
153116
},
154117
in: {
155-
$map: {
156-
input: {
157-
$cond: {
158-
if: "$$promptExists",
159-
then: "$$existingPrompts",
160-
else: {
161-
$concatArrays: [
162-
"$$existingPrompts",
163-
[{ prompt, expectedToolCalls, modelResponses: [] }],
164-
],
165-
},
166-
},
167-
},
168-
as: "promptResult",
169-
in: {
170-
$cond: {
171-
if: { $eq: ["$$promptResult.prompt", prompt] },
172-
then: {
173-
prompt: "$$promptResult.prompt",
174-
expectedToolCalls: "$$promptResult.expectedToolCalls",
175-
modelResponses: {
176-
$concatArrays: [
177-
"$$promptResult.modelResponses",
178-
[savedModelResponse],
179-
],
118+
$cond: [
119+
{ $eq: ["$$existingPromptIndex", -1] },
120+
{
121+
$concatArrays: [
122+
"$promptResults",
123+
[
124+
{
125+
prompt,
126+
expectedToolCalls: expectedToolCallsToSave,
127+
modelResponses: [modelResponseToSave],
180128
},
129+
],
130+
],
131+
},
132+
{
133+
$map: {
134+
input: "$promptResults",
135+
as: "promptResult",
136+
in: {
137+
$cond: [
138+
{ $eq: ["$$promptResult.prompt", prompt] },
139+
{
140+
prompt: "$$promptResult.prompt",
141+
expectedToolCalls: expectedToolCallsToSave,
142+
modelResponses: {
143+
$concatArrays: [
144+
"$$promptResult.modelResponses",
145+
[modelResponseToSave],
146+
],
147+
},
148+
},
149+
"$$promptResult",
150+
],
181151
},
182-
else: "$$promptResult",
183152
},
184153
},
185-
},
154+
],
186155
},
187156
},
188157
},
@@ -193,6 +162,24 @@ export class MongoDBBasedResultStorage implements AccuracyResultStorage {
193162
);
194163
}
195164

165+
private deserializeSavedResult(result: SavedAccuracyResult): AccuracyResult {
166+
return {
167+
...result,
168+
promptResults: result.promptResults.map<PromptResult>((result) => {
169+
return {
170+
...result,
171+
expectedToolCalls: JSON.parse(result.expectedToolCalls) as ExpectedToolCall[],
172+
modelResponses: result.modelResponses.map<ModelResponse>((response) => {
173+
return {
174+
...response,
175+
llmToolCalls: JSON.parse(response.llmToolCalls) as LLMToolCall[],
176+
};
177+
}),
178+
};
179+
}),
180+
};
181+
}
182+
196183
async close(): Promise<void> {
197184
await this.client.close();
198185
}

0 commit comments

Comments
 (0)