Skip to content

Commit 3491a3b

Browse files
authored
fix: tweak the arg shapes to improve tool accuracy (#381)
1 parent 743cbfa commit 3491a3b

File tree

4 files changed

+13
-7
lines changed

4 files changed

+13
-7
lines changed

src/tools/mongodb/metadata/explain.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ export class ExplainTool extends MongoDBToolBase {
1616
...DbOperationArgs,
1717
method: z
1818
.array(
19-
z.union([
19+
z.discriminatedUnion("name", [
2020
z.object({
2121
name: z.literal("aggregate"),
2222
arguments: z.object(AggregateArgs),

src/tools/mongodb/read/aggregate.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import { EJSON } from "bson";
66
import { checkIndexUsage } from "../../../helpers/indexCheck.js";
77

88
export const AggregateArgs = {
9-
pipeline: z.array(z.record(z.string(), z.unknown())).describe("An array of aggregation stages to execute"),
9+
pipeline: z.array(z.object({}).passthrough()).describe("An array of aggregation stages to execute"),
1010
};
1111

1212
export class AggregateTool extends MongoDBToolBase {

src/tools/mongodb/read/count.ts

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ import { checkIndexUsage } from "../../../helpers/indexCheck.js";
66

77
export const CountArgs = {
88
query: z
9-
.record(z.string(), z.unknown())
9+
.object({})
10+
.passthrough()
1011
.optional()
1112
.describe(
1213
"A filter/query parameter. Allows users to filter the documents to count. Matches the syntax of the filter argument of db.collection.count()."

src/tools/mongodb/read/find.ts

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,23 @@ import { checkIndexUsage } from "../../../helpers/indexCheck.js";
88

99
export const FindArgs = {
1010
filter: z
11-
.record(z.string(), z.unknown())
11+
.object({})
12+
.passthrough()
1213
.optional()
1314
.describe("The query filter, matching the syntax of the query argument of db.collection.find()"),
1415
projection: z
15-
.record(z.string(), z.unknown())
16+
.object({})
17+
.passthrough()
1618
.optional()
1719
.describe("The projection, matching the syntax of the projection argument of db.collection.find()"),
1820
limit: z.number().optional().default(10).describe("The maximum number of documents to return"),
1921
sort: z
20-
.record(z.string(), z.custom<SortDirection>())
22+
.object({})
23+
.catchall(z.custom<SortDirection>())
2124
.optional()
22-
.describe("A document, describing the sort order, matching the syntax of the sort argument of cursor.sort()"),
25+
.describe(
26+
"A document, describing the sort order, matching the syntax of the sort argument of cursor.sort(). The keys of the object are the fields to sort on, while the values are the sort directions (1 for ascending, -1 for descending)."
27+
),
2328
};
2429

2530
export class FindTool extends MongoDBToolBase {

0 commit comments

Comments
 (0)