1010# limitations under the License.
1111
1212
13+ from collections import namedtuple
1314from dataclasses import dataclass
1415from typing import Any , Optional , Union
1516
@@ -92,16 +93,24 @@ def test_get_current_executed_op_name():
9293 assert hook_executor_mode .get_current_executed_op_name ("foo" ) == "conv/post_hook__conv-conv2d-0__0[0]/foo/0"
9394
9495
95- @pytest .fixture (params = ["tensor" , "list" , "torch_return_type" ])
96- def example_outputs (request : FixtureRequest ) -> Union [torch .Tensor , list [torch .Tensor ], torch .return_types .max ]:
96+ NamedTuple = namedtuple ("NamedTuple" , ["output1" , "output2" ])
97+
98+
99+ @pytest .fixture (params = ["tensor" , "list" , "torch_return_type" , "named_tuple" ])
100+ def example_outputs (
101+ request : FixtureRequest ,
102+ ) -> Union [torch .Tensor , list [torch .Tensor ], torch .return_types .max , NamedTuple ]:
97103 return {
98104 "tensor" : torch .tensor (1 ),
99105 "list" : [torch .tensor (1 ), torch .tensor ([2 ])],
100106 "torch_return_type" : torch .return_types .max ((torch .tensor (1 ), torch .tensor ([2 ]))),
107+ "named_tuple" : NamedTuple (torch .tensor (1 ), torch .tensor ([2 ])),
101108 }.get (request .param )
102109
103110
104- def test_execute_post_hooks (example_outputs : Union [torch .Tensor , list [torch .Tensor ], torch .return_types .max ]):
111+ def test_execute_post_hooks (
112+ example_outputs : Union [torch .Tensor , list [torch .Tensor ], torch .return_types .max , NamedTuple ],
113+ ):
105114 op_name = "/relu/0"
106115 hook_storage = HookStorage ()
107116 hook_port_0 = CallCount ()
@@ -112,6 +121,7 @@ def test_execute_post_hooks(example_outputs: Union[torch.Tensor, list[torch.Tens
112121 op_meta = OpMeta ("/relu/0" , torch .relu )
113122 ret_val = ctx .execute_post_hooks (example_outputs , op_meta )
114123 assert type (example_outputs ) is type (ret_val )
124+ assert example_outputs == ret_val
115125
116126 assert hook_port_0 .call_count == 1
117127 if isinstance (example_outputs , torch .Tensor ):
@@ -120,6 +130,28 @@ def test_execute_post_hooks(example_outputs: Union[torch.Tensor, list[torch.Tens
120130 assert hook_port_1 .call_count == 1
121131
122132
133+ def test_process_model_output (
134+ example_outputs : Union [torch .Tensor , list [torch .Tensor ], torch .return_types .max , NamedTuple ],
135+ ):
136+ hook_storage = HookStorage ()
137+ hook_output = CallCount ()
138+ hook_output_0 = CallCount ()
139+ hook_output_1 = CallCount ()
140+ hook_storage .register_pre_function_hook ("output" , 0 , hook_output )
141+ hook_storage .register_pre_function_hook ("output_0" , 0 , hook_output_0 )
142+ hook_storage .register_pre_function_hook ("output_1" , 0 , hook_output_1 )
143+
144+ ctx = FunctionHookMode (nn .Identity (), hook_storage )
145+ ret_val = ctx .process_model_outputs (example_outputs )
146+ assert type (example_outputs ) is type (ret_val )
147+ assert example_outputs == ret_val
148+ if isinstance (example_outputs , torch .Tensor ):
149+ assert hook_output .call_count == 1
150+ else :
151+ assert hook_output_0 .call_count == 1
152+ assert hook_output_1 .call_count == 1
153+
154+
123155class ConcatModel (nn .Module ):
124156 def forward (self , x : torch .Tensor ) -> torch .Tensor :
125157 return torch .cat ([x , x ], dim = 0 )
0 commit comments