|
1 |
| -import { describeAccuracyTests } from "./sdk/describe-accuracy-tests.js"; |
| 1 | +import { describeAccuracyTests, describeSuite } from "./sdk/describe-accuracy-tests.js"; |
2 | 2 | import { getAvailableModels } from "./sdk/models.js";
|
3 | 3 | import { AccuracyTestConfig } from "./sdk/describe-accuracy-tests.js";
|
4 |
| -import { findResponse } from "../../src/tools/mongodb/read/find.js"; |
5 |
| -import { MockedTools } from "./sdk/test-tools.js"; |
6 |
| -import { collectionSchemaResponse } from "../../src/tools/mongodb/metadata/collectionSchema.js"; |
7 |
| -import { getSimplifiedSchema } from "mongodb-schema"; |
8 | 4 |
|
9 |
| -const documents = [ |
10 |
| - { |
11 |
| - title: "book1", |
12 |
| - author: "author1", |
13 |
| - date_of_publish: "01.01.1990", |
14 |
| - }, |
15 |
| - { |
16 |
| - title: "book2", |
17 |
| - author: "author1", |
18 |
| - date_of_publish: "01.01.1992", |
19 |
| - }, |
20 |
| - { |
21 |
| - title: "book3", |
22 |
| - author: "author2", |
23 |
| - date_of_publish: "01.01.1990", |
24 |
| - }, |
25 |
| -]; |
26 |
| - |
27 |
| -function callsFindNoFilter(prompt: string): AccuracyTestConfig { |
| 5 | +function callsFindNoFilter(prompt: string, database = "mflix", collection = "movies"): AccuracyTestConfig { |
28 | 6 | return {
|
29 | 7 | injectConnectedAssumption: true,
|
30 | 8 | prompt: prompt,
|
31 |
| - mockedTools: { |
32 |
| - "collection-schema": async () => |
33 |
| - collectionSchemaResponse("db1", "coll1", await getSimplifiedSchema(documents)), |
34 |
| - find: () => findResponse("coll1", documents), |
35 |
| - }, |
| 9 | + mockedTools: {}, |
36 | 10 | expectedToolCalls: [
|
37 | 11 | {
|
38 | 12 | toolName: "find",
|
39 | 13 | parameters: {
|
40 |
| - database: "db1", |
41 |
| - collection: "coll1", |
| 14 | + database, |
| 15 | + collection, |
42 | 16 | },
|
43 | 17 | },
|
44 | 18 | ],
|
45 | 19 | };
|
46 | 20 | }
|
47 | 21 |
|
48 |
| -function callsFindWithFilter(prompt: string): AccuracyTestConfig { |
| 22 | +function callsFindWithFilter(prompt: string, filter: Record<string, unknown>): AccuracyTestConfig { |
49 | 23 | return {
|
50 | 24 | injectConnectedAssumption: true,
|
51 | 25 | prompt: prompt,
|
52 |
| - mockedTools: { |
53 |
| - "collection-schema": async () => |
54 |
| - collectionSchemaResponse("db1", "coll1", await getSimplifiedSchema(documents)), |
55 |
| - find: () => |
56 |
| - findResponse( |
57 |
| - "coll1", |
58 |
| - documents.filter((doc) => doc.author === "author1") |
59 |
| - ), |
60 |
| - }, |
| 26 | + mockedTools: {}, |
61 | 27 | expectedToolCalls: [
|
62 | 28 | {
|
63 | 29 | toolName: "find",
|
64 | 30 | parameters: {
|
65 |
| - database: "db1", |
66 |
| - collection: "coll1", |
67 |
| - filter: { author: "author1" }, |
| 31 | + database: "mflix", |
| 32 | + collection: "movies", |
| 33 | + filter: filter, |
68 | 34 | },
|
69 | 35 | },
|
70 | 36 | ],
|
71 | 37 | };
|
72 | 38 | }
|
73 | 39 |
|
74 |
| -function callsFindWithProjection(prompt: string): AccuracyTestConfig { |
| 40 | +function callsFindWithProjection(prompt: string, projection: Record<string, number>): AccuracyTestConfig { |
75 | 41 | return {
|
76 | 42 | injectConnectedAssumption: true,
|
77 | 43 | prompt: prompt,
|
78 |
| - mockedTools: { |
79 |
| - "collection-schema": async () => |
80 |
| - collectionSchemaResponse("db1", "coll1", await getSimplifiedSchema(documents)), |
81 |
| - find: () => findResponse("coll1", documents), |
82 |
| - }, |
| 44 | + mockedTools: {}, |
83 | 45 | expectedToolCalls: [
|
84 | 46 | {
|
85 | 47 | toolName: "find",
|
86 | 48 | parameters: {
|
87 |
| - database: "db1", |
88 |
| - collection: "coll1", |
89 |
| - projection: { title: 1 }, |
| 49 | + database: "mflix", |
| 50 | + collection: "movies", |
| 51 | + projection, |
90 | 52 | },
|
91 | 53 | },
|
92 | 54 | ],
|
93 | 55 | };
|
94 | 56 | }
|
95 | 57 |
|
96 |
| -function callsFindWithProjectionAndFilters(prompt: string): AccuracyTestConfig { |
| 58 | +function callsFindWithProjectionAndFilters( |
| 59 | + prompt: string, |
| 60 | + filter: Record<string, unknown>, |
| 61 | + projection: Record<string, number> |
| 62 | +): AccuracyTestConfig { |
97 | 63 | return {
|
98 | 64 | injectConnectedAssumption: true,
|
99 | 65 | prompt: prompt,
|
100 |
| - mockedTools: { |
101 |
| - "collection-schema": async () => |
102 |
| - collectionSchemaResponse("db1", "coll1", await getSimplifiedSchema(documents)), |
103 |
| - find: () => |
104 |
| - findResponse( |
105 |
| - "coll1", |
106 |
| - documents.filter((doc) => doc.date_of_publish === "01.01.1992") |
107 |
| - ), |
108 |
| - }, |
| 66 | + mockedTools: {}, |
109 | 67 | expectedToolCalls: [
|
110 | 68 | {
|
111 | 69 | toolName: "find",
|
112 | 70 | parameters: {
|
113 |
| - database: "db1", |
114 |
| - collection: "coll1", |
115 |
| - filter: { date_of_publish: "01.01.1992" }, |
116 |
| - projection: { title: 1 }, |
| 71 | + database: "mflix", |
| 72 | + collection: "movies", |
| 73 | + filter, |
| 74 | + projection, |
117 | 75 | },
|
118 | 76 | },
|
119 | 77 | ],
|
120 | 78 | };
|
121 | 79 | }
|
122 | 80 |
|
123 |
| -function callsFindWithSortAndLimit(prompt: string): AccuracyTestConfig { |
| 81 | +function callsFindWithFilterSortAndLimit( |
| 82 | + prompt: string, |
| 83 | + filter: Record<string, unknown>, |
| 84 | + sort: Record<string, number>, |
| 85 | + limit: number |
| 86 | +): AccuracyTestConfig { |
124 | 87 | return {
|
125 | 88 | injectConnectedAssumption: true,
|
126 | 89 | prompt: prompt,
|
127 |
| - mockedTools: { |
128 |
| - "collection-schema": async () => |
129 |
| - collectionSchemaResponse("db1", "coll1", await getSimplifiedSchema(documents)), |
130 |
| - find: () => findResponse("coll1", [documents[0], documents[1]]), |
131 |
| - }, |
| 90 | + mockedTools: {}, |
132 | 91 | expectedToolCalls: [
|
133 | 92 | {
|
134 | 93 | toolName: "find",
|
135 | 94 | parameters: {
|
136 |
| - database: "db1", |
137 |
| - collection: "coll1", |
138 |
| - sort: { date_of_publish: 1 }, |
139 |
| - limit: 2, |
| 95 | + database: "mflix", |
| 96 | + collection: "movies", |
| 97 | + filter, |
| 98 | + sort, |
| 99 | + limit, |
140 | 100 | },
|
141 | 101 | },
|
142 | 102 | ],
|
143 | 103 | };
|
144 | 104 | }
|
145 | 105 |
|
146 |
| -describeAccuracyTests("find", getAvailableModels(), [ |
147 |
| - callsFindNoFilter("List all the documents in 'db1.coll1' namespace"), |
148 |
| - callsFindNoFilter("Find all the documents from collection coll1 in database db1"), |
149 |
| - callsFindWithFilter("Find all the books published by author name 'author1' in db1.coll1 namespace"), |
150 |
| - callsFindWithFilter("Find all the documents in coll1 collection and db1 database where author is 'author1'"), |
151 |
| - callsFindWithProjection("Give me all the title of the books available in 'db1.coll1' namespace"), |
152 |
| - callsFindWithProjection("Give me all the title of the books published in available in 'db1.coll1' namespace"), |
153 |
| - callsFindWithProjectionAndFilters( |
154 |
| - "Find all the book titles from 'db1.coll1' namespace where date_of_publish is '01.01.1992'" |
155 |
| - ), |
156 |
| - callsFindWithSortAndLimit("List first two books sorted by the field date_of_publish in namespace db1.coll1"), |
157 |
| -]); |
| 106 | +describeAccuracyTests(getAvailableModels(), { |
| 107 | + ...describeSuite("should only call find tool", [ |
| 108 | + callsFindNoFilter("List all the movies in 'mflix.movies' namespace."), |
| 109 | + callsFindNoFilter("List all the documents in 'comics.books' namespace.", "comics", "books"), |
| 110 | + callsFindWithFilter("Find all the movies in 'mflix.movies' namespace with runtime less than 100.", { |
| 111 | + runtime: { $lt: 100 }, |
| 112 | + }), |
| 113 | + callsFindWithFilter("Find all movies in 'mflix.movies' collection where director is 'Christina Collins'", { |
| 114 | + director: "Christina Collins", |
| 115 | + }), |
| 116 | + callsFindWithProjection("Give me all the movie titles available in 'mflix.movies' namespace", { title: 1 }), |
| 117 | + callsFindWithProjectionAndFilters( |
| 118 | + "Use 'mflix.movies' namespace to answer who were casted in the movie 'Certain Fish'", |
| 119 | + { title: "Certain Fish" }, |
| 120 | + { cast: 1 } |
| 121 | + ), |
| 122 | + callsFindWithFilterSortAndLimit( |
| 123 | + "From the mflix.movies namespace, give me first 2 movies of Horror genre sorted ascending by their runtime", |
| 124 | + { genres: "Horror" }, |
| 125 | + { runtime: 1 }, |
| 126 | + 2 |
| 127 | + ), |
| 128 | + ]), |
| 129 | +}); |
0 commit comments