@@ -200,30 +200,20 @@ impl<'schema> SchemaTreeShaker<'schema> {
200200
201201 /// Return the set of types retained after tree shaking.
202202 pub fn shaken ( & mut self ) -> Result < Schema , Box < WithErrors < Schema > > > {
203- let mut filtered_root_operations = self
204- . schema
205- . schema_definition
206- . query
207- . clone ( )
208- . map ( |query_name| vec ! [ Node :: new( ( OperationType :: Query , query_name. name) ) ] )
209- . unwrap_or_default ( ) ;
210- if self . operation_types . contains ( & OperationType :: Mutation ) {
211- if let Some ( mutation_name) = self . schema . schema_definition . mutation . clone ( ) {
212- filtered_root_operations
213- . push ( Node :: new ( ( OperationType :: Mutation , mutation_name. name ) ) ) ;
214- }
215- }
216- if self . operation_types . contains ( & OperationType :: Subscription ) {
217- if let Some ( subscription_name) = self . schema . schema_definition . subscription . clone ( ) {
218- filtered_root_operations. push ( Node :: new ( (
219- OperationType :: Subscription ,
220- subscription_name. name ,
221- ) ) ) ;
222- }
223- }
203+ let root_operations = self
204+ . operation_types
205+ . iter ( )
206+ . filter_map ( |operation_type| {
207+ self . schema
208+ . root_operation ( * operation_type)
209+ . cloned ( )
210+ . map ( |operation_name| Node :: new ( ( * operation_type, operation_name) ) )
211+ } )
212+ . collect ( ) ;
213+
224214 let schema_definition =
225215 Definition :: SchemaDefinition ( apollo_compiler:: Node :: new ( SchemaDefinition {
226- root_operations : filtered_root_operations ,
216+ root_operations,
227217 description : self . schema . schema_definition . description . clone ( ) ,
228218 directives : DirectiveList (
229219 self . schema
@@ -927,8 +917,8 @@ fn retain_directive(
927917
928918#[ cfg( test) ]
929919mod test {
930-
931920 use apollo_compiler:: { ast:: OperationType , parser:: Parser } ;
921+ use rstest:: { fixture, rstest} ;
932922
933923 use crate :: {
934924 operations:: { MutationMode , operation_defs} ,
@@ -1070,4 +1060,110 @@ mod test {
10701060 "type Query {\n id: UsedInQuery\n }\n \n scalar UsedInQuery\n "
10711061 ) ;
10721062 }
1063+
1064+ #[ fixture]
1065+ fn nested_schema ( ) -> apollo_compiler:: Schema {
1066+ Parser :: new ( )
1067+ . parse_ast (
1068+ r#"
1069+ type Query { level1: Level1 }
1070+ type Level1 { level2: Level2 }
1071+ type Level2 { level3: Level3 }
1072+ type Level3 { level4: Level4 }
1073+ type Level4 { id: String }
1074+ "# ,
1075+ "schema.graphql" ,
1076+ )
1077+ . unwrap ( )
1078+ . to_schema_validate ( )
1079+ . unwrap ( )
1080+ . into_inner ( )
1081+ }
1082+
1083+ #[ rstest]
1084+ fn should_respect_depth_limit ( nested_schema : apollo_compiler:: Schema ) {
1085+ let mut shaker = SchemaTreeShaker :: new ( & nested_schema) ;
1086+
1087+ // Get the Query type to start from
1088+ let query_type = nested_schema. types . get ( "Query" ) . unwrap ( ) ;
1089+
1090+ // Test with depth limit of 1
1091+ shaker. retain_type ( query_type, DepthLimit :: Limited ( 1 ) ) ;
1092+ let shaken_schema = shaker. shaken ( ) . unwrap ( ) ;
1093+
1094+ // Should retain only Query, not Level1, Level2, Level3, or Level4
1095+ assert ! ( shaken_schema. types. contains_key( "Query" ) ) ;
1096+ assert ! ( !shaken_schema. types. contains_key( "Level1" ) ) ;
1097+ assert ! ( !shaken_schema. types. contains_key( "Level2" ) ) ;
1098+ assert ! ( !shaken_schema. types. contains_key( "Level3" ) ) ;
1099+ assert ! ( !shaken_schema. types. contains_key( "Level4" ) ) ;
1100+
1101+ // Test with depth limit of 2
1102+ let mut shaker = SchemaTreeShaker :: new ( & nested_schema) ;
1103+ shaker. retain_type ( query_type, DepthLimit :: Limited ( 2 ) ) ;
1104+ let shaken_schema = shaker. shaken ( ) . unwrap ( ) ;
1105+
1106+ // Should retain Query and Level1, but not deeper levels
1107+ assert ! ( shaken_schema. types. contains_key( "Query" ) ) ;
1108+ assert ! ( shaken_schema. types. contains_key( "Level1" ) ) ;
1109+ assert ! ( !shaken_schema. types. contains_key( "Level2" ) ) ;
1110+ assert ! ( !shaken_schema. types. contains_key( "Level3" ) ) ;
1111+ assert ! ( !shaken_schema. types. contains_key( "Level4" ) ) ;
1112+
1113+ // Test with depth limit of 1 starting from Level2
1114+ let mut shaker = SchemaTreeShaker :: new ( & nested_schema) ;
1115+ let level2_type = nested_schema. types . get ( "Level2" ) . unwrap ( ) ;
1116+ shaker. retain_type ( level2_type, DepthLimit :: Limited ( 1 ) ) ;
1117+ let shaken_schema = shaker. shaken ( ) . unwrap ( ) ;
1118+
1119+ // Should retain only Level2 - note that a stub Query is always added so the schema is valid
1120+ assert ! ( shaken_schema. types. contains_key( "Query" ) ) ;
1121+ assert ! ( !shaken_schema. types. contains_key( "Level1" ) ) ;
1122+ assert ! ( shaken_schema. types. contains_key( "Level2" ) ) ;
1123+ assert ! ( !shaken_schema. types. contains_key( "Level3" ) ) ;
1124+ assert ! ( !shaken_schema. types. contains_key( "Level4" ) ) ;
1125+
1126+ // Test with depth limit of 2 starting from Level2
1127+ let mut shaker = SchemaTreeShaker :: new ( & nested_schema) ;
1128+ shaker. retain_type ( level2_type, DepthLimit :: Limited ( 2 ) ) ;
1129+ let shaken_schema = shaker. shaken ( ) . unwrap ( ) ;
1130+
1131+ // Should retain Level2 and Level3 - note that a stub Query is always added so the schema is valid
1132+ assert ! ( shaken_schema. types. contains_key( "Query" ) ) ;
1133+ assert ! ( !shaken_schema. types. contains_key( "Level1" ) ) ;
1134+ assert ! ( shaken_schema. types. contains_key( "Level2" ) ) ;
1135+ assert ! ( shaken_schema. types. contains_key( "Level3" ) ) ;
1136+ assert ! ( !shaken_schema. types. contains_key( "Level4" ) ) ;
1137+
1138+ // Test with depth limit of 5 starting from Level2
1139+ let mut shaker = SchemaTreeShaker :: new ( & nested_schema) ;
1140+ shaker. retain_type ( level2_type, DepthLimit :: Limited ( 5 ) ) ;
1141+ let shaken_schema = shaker. shaken ( ) . unwrap ( ) ;
1142+
1143+ // Should retain Level2 and deeper types - note that a stub Query is always added so the schema is valid
1144+ assert ! ( shaken_schema. types. contains_key( "Query" ) ) ;
1145+ assert ! ( !shaken_schema. types. contains_key( "Level1" ) ) ;
1146+ assert ! ( shaken_schema. types. contains_key( "Level2" ) ) ;
1147+ assert ! ( shaken_schema. types. contains_key( "Level3" ) ) ;
1148+ assert ! ( shaken_schema. types. contains_key( "Level4" ) ) ;
1149+ }
1150+
1151+ #[ rstest]
1152+ fn should_retain_all_types_with_unlimited_depth ( nested_schema : apollo_compiler:: Schema ) {
1153+ let mut shaker = SchemaTreeShaker :: new ( & nested_schema) ;
1154+
1155+ // Get the Query type to start from
1156+ let query_type = nested_schema. types . get ( "Query" ) . unwrap ( ) ;
1157+
1158+ // Test with unlimited depth
1159+ shaker. retain_type ( query_type, DepthLimit :: Unlimited ) ;
1160+ let shaken_schema = shaker. shaken ( ) . unwrap ( ) ;
1161+
1162+ // Should retain all types
1163+ assert ! ( shaken_schema. types. contains_key( "Query" ) ) ;
1164+ assert ! ( shaken_schema. types. contains_key( "Level1" ) ) ;
1165+ assert ! ( shaken_schema. types. contains_key( "Level2" ) ) ;
1166+ assert ! ( shaken_schema. types. contains_key( "Level3" ) ) ;
1167+ assert ! ( shaken_schema. types. contains_key( "Level4" ) ) ;
1168+ }
10731169}
0 commit comments