@@ -252,14 +252,15 @@ def forward_intermediates(
252252 """
253253 assert output_fmt in ('NCHW' ,), 'Output shape must be NCHW.'
254254 intermediates = []
255- take_indices , max_index = feature_take_indices (len (self .body ) - 1 , indices )
256-
255+ stage_ends = [1 , 2 , 3 , 4 , 5 ]
256+ take_indices , max_index = feature_take_indices (len (stage_ends ), indices )
257+ take_indices = [stage_ends [i ] for i in take_indices ]
258+ max_index = stage_ends [max_index ]
257259 # forward pass
258- x = self .body [0 ](x ) # s2d
259260 if torch .jit .is_scripting () or not stop_early : # can't slice blocks in torchscript
260- stages = [ self .body [ 1 ], self . body [ 2 ], self . body [ 3 ], self . body [ 4 ], self . body [ 5 ]]
261+ stages = self .body
261262 else :
262- stages = self .body [1 :max_index + 2 ]
263+ stages = self .body [:max_index + 1 ]
263264
264265 for feat_idx , stage in enumerate (stages ):
265266 x = stage (x )
@@ -279,8 +280,10 @@ def prune_intermediate_layers(
279280 ):
280281 """ Prune layers not required for specified intermediates.
281282 """
282- take_indices , max_index = feature_take_indices (len (self .body ) - 1 , indices )
283- self .body = self .body [1 :max_index + 2 ] # truncate blocks w/ stem as idx 0
283+ stage_ends = [1 , 2 , 3 , 4 , 5 ]
284+ take_indices , max_index = feature_take_indices (len (stage_ends ), indices )
285+ max_index = stage_ends [max_index ]
286+ self .body = self .body [:max_index + 1 ] # truncate blocks w/ stem as idx 0
284287 if prune_head :
285288 self .reset_classifier (0 , '' )
286289 return take_indices
0 commit comments