File tree Expand file tree Collapse file tree 1 file changed +5
-1
lines changed
Expand file tree Collapse file tree 1 file changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -511,8 +511,9 @@ def test_model_forward_intermediates(model_name, batch_size):
511511 spatial_axis = get_spatial_dim (output_fmt )
512512 import math
513513
514+ inpt = torch .randn ((batch_size , * input_size ))
514515 output , intermediates = model .forward_intermediates (
515- torch . randn (( batch_size , * input_size )) ,
516+ inpt ,
516517 output_fmt = output_fmt ,
517518 )
518519 assert len (expected_channels ) == len (intermediates )
@@ -524,6 +525,9 @@ def test_model_forward_intermediates(model_name, batch_size):
524525 assert o .shape [0 ] == batch_size
525526 assert not torch .isnan (o ).any ()
526527
528+ output2 = model .forward_features (inpt )
529+ assert torch .allclose (output , output2 )
530+
527531
528532def _create_fx_model (model , train = False ):
529533 # This block of code does a bit of juggling to handle any case where there are multiple outputs in train mode
You can’t perform that action at this time.
0 commit comments