3
3
use crate :: errors:: McpError ;
4
4
use crate :: schema_from_type;
5
5
use crate :: schema_tree_shake:: { DepthLimit , SchemaTreeShaker } ;
6
- use apollo_compiler:: Schema ;
7
- use apollo_compiler:: ast:: OperationType as AstOperationType ;
6
+ use apollo_compiler:: ast:: { Field , OperationType as AstOperationType , Selection } ;
8
7
use apollo_compiler:: validation:: Valid ;
8
+ use apollo_compiler:: { Name , Node , Schema } ;
9
9
use apollo_schema_index:: { OperationType , Options , SchemaIndex } ;
10
10
use rmcp:: model:: { CallToolResult , Content , ErrorCode , Tool } ;
11
11
use rmcp:: schemars:: JsonSchema ;
@@ -20,15 +20,16 @@ use tracing::debug;
20
20
/// The name of the tool to search a GraphQL schema.
21
21
pub const SEARCH_TOOL_NAME : & str = "search" ;
22
22
23
- /// The depth of nested types to include for leaf nodes on matching root paths .
24
- pub const LEAF_DEPTH : DepthLimit = DepthLimit :: Limited ( 1 ) ;
23
+ /// The maximum number of search results to consider .
24
+ const MAX_SEARCH_RESULTS : usize = 5 ;
25
25
26
26
/// A tool to search a GraphQL schema.
27
27
#[ derive( Clone ) ]
28
28
pub struct Search {
29
29
schema : Arc < Mutex < Valid < Schema > > > ,
30
30
index : SchemaIndex ,
31
31
allow_mutations : bool ,
32
+ leaf_depth : usize ,
32
33
pub tool : Tool ,
33
34
}
34
35
@@ -53,6 +54,8 @@ impl Search {
53
54
pub fn new (
54
55
schema : Arc < Mutex < Valid < Schema > > > ,
55
56
allow_mutations : bool ,
57
+ leaf_depth : usize ,
58
+ index_memory_bytes : usize ,
56
59
) -> Result < Self , IndexingError > {
57
60
let root_types = if allow_mutations {
58
61
OperationType :: Query | OperationType :: Mutation
@@ -62,8 +65,9 @@ impl Search {
62
65
let locked = & schema. try_lock ( ) ?;
63
66
Ok ( Self {
64
67
schema : schema. clone ( ) ,
65
- index : SchemaIndex :: new ( locked, root_types) ?,
68
+ index : SchemaIndex :: new ( locked, root_types, index_memory_bytes ) ?,
66
69
allow_mutations,
70
+ leaf_depth,
67
71
tool : Tool :: new (
68
72
SEARCH_TOOL_NAME ,
69
73
"Search a GraphQL schema" ,
@@ -84,7 +88,7 @@ impl Search {
84
88
)
85
89
} ) ?;
86
90
87
- root_paths. truncate ( 5 ) ;
91
+ root_paths. truncate ( MAX_SEARCH_RESULTS ) ;
88
92
debug ! (
89
93
"Root paths for search terms: {}\n {}" ,
90
94
input. terms. join( ", " ) ,
@@ -98,16 +102,32 @@ impl Search {
98
102
let schema = self . schema . lock ( ) . await ;
99
103
let mut tree_shaker = SchemaTreeShaker :: new ( & schema) ;
100
104
for root_path in root_paths {
101
- let types = root_path. inner . types . clone ( ) ;
102
- let path_len = types. len ( ) ;
103
- for ( i, type_name) in types. into_iter ( ) . enumerate ( ) {
104
- if let Some ( extended_type) = schema. types . get ( type_name. as_ref ( ) ) {
105
- let depth = if i == path_len - 1 {
106
- LEAF_DEPTH
105
+ let path_len = root_path. inner . len ( ) ;
106
+ for ( i, path_node) in root_path. inner . into_iter ( ) . enumerate ( ) {
107
+ if let Some ( extended_type) = schema. types . get ( path_node. node_type . as_str ( ) ) {
108
+ let ( selection_set, depth) = if i == path_len - 1 {
109
+ ( None , DepthLimit :: Limited ( self . leaf_depth ) )
107
110
} else {
108
- DepthLimit :: Limited ( 1 )
111
+ (
112
+ path_node. field_name . as_ref ( ) . map ( |field_name| {
113
+ vec ! [ Selection :: Field ( Node :: from( Field {
114
+ alias: Default :: default ( ) ,
115
+ name: Name :: new_unchecked( field_name) ,
116
+ arguments: Default :: default ( ) ,
117
+ selection_set: Default :: default ( ) ,
118
+ directives: Default :: default ( ) ,
119
+ } ) ) ]
120
+ } ) ,
121
+ DepthLimit :: Limited ( 1 ) ,
122
+ )
109
123
} ;
110
- tree_shaker. retain_type ( extended_type, depth)
124
+ tree_shaker. retain_type ( extended_type, selection_set. as_ref ( ) , depth)
125
+ }
126
+ for field_arg in path_node. field_args {
127
+ if let Some ( extended_type) = schema. types . get ( field_arg. as_str ( ) ) {
128
+ // Retain input types with unlimited depth because all input must be given
129
+ tree_shaker. retain_type ( extended_type, None , DepthLimit :: Unlimited ) ;
130
+ }
111
131
}
112
132
}
113
133
}
@@ -171,7 +191,8 @@ mod tests {
171
191
#[ tokio:: test]
172
192
async fn test_search_tool ( schema : Valid < Schema > ) {
173
193
let schema = Arc :: new ( Mutex :: new ( schema) ) ;
174
- let search = Search :: new ( schema. clone ( ) , false ) . expect ( "Failed to create search tool" ) ;
194
+ let search = Search :: new ( schema. clone ( ) , false , 1 , 15_000_000 )
195
+ . expect ( "Failed to create search tool" ) ;
175
196
176
197
let result = search
177
198
. execute ( Input {
@@ -188,7 +209,8 @@ mod tests {
188
209
#[ tokio:: test]
189
210
async fn test_referencing_types_are_collected ( schema : Valid < Schema > ) {
190
211
let schema = Arc :: new ( Mutex :: new ( schema) ) ;
191
- let search = Search :: new ( schema. clone ( ) , true ) . expect ( "Failed to create search tool" ) ;
212
+ let search =
213
+ Search :: new ( schema. clone ( ) , true , 1 , 15_000_000 ) . expect ( "Failed to create search tool" ) ;
192
214
193
215
// Search for a type that should have references
194
216
let result = search
0 commit comments