Skip to content

Commit ef28179

Browse files
[pt] fix processing named tuple (#3712)
### Changes Support using named tuples as outputs of model and functions.
1 parent e33872a commit ef28179

File tree

2 files changed

+43
-5
lines changed

2 files changed

+43
-5
lines changed

src/nncf/torch/function_hook/hook_executor_mode.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,10 @@ def execute_post_hooks(self, output: Any, op_meta: OpMeta) -> Any:
452452
for idx, value in enumerate(output):
453453
output[idx] = self.process_post_function_hooks_for_value(value, op_meta, idx)
454454
if cls_tuple is not None:
455-
output = cls_tuple(output)
455+
if hasattr(cls_tuple, "_fields"): # likely a namedtuple
456+
output = cls_tuple(*output)
457+
else:
458+
output = cls_tuple(output)
456459
else:
457460
output = self.process_post_function_hooks_for_value(output, op_meta, 0)
458461
return output
@@ -505,7 +508,10 @@ def process_model_outputs(self, outputs: Any) -> Any:
505508
if isinstance(val, Tensor):
506509
outputs[idx] = self.execute_hooks_for_model_output(f"output_{idx}", val)
507510
if cls_tuple is not None:
508-
outputs = cls_tuple(outputs)
511+
if hasattr(cls_tuple, "_fields"): # likely a namedtuple
512+
outputs = cls_tuple(*outputs)
513+
else:
514+
outputs = cls_tuple(outputs)
509515
return outputs
510516

511517
def execute_hooks_for_model_output(self, name: str, value: Any) -> Any:

tests/torch2/function_hook/test_function_hook_mode.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# limitations under the License.
1111

1212

13+
from collections import namedtuple
1314
from dataclasses import dataclass
1415
from typing import Any, Optional, Union
1516

@@ -92,16 +93,24 @@ def test_get_current_executed_op_name():
9293
assert hook_executor_mode.get_current_executed_op_name("foo") == "conv/post_hook__conv-conv2d-0__0[0]/foo/0"
9394

9495

95-
@pytest.fixture(params=["tensor", "list", "torch_return_type"])
96-
def example_outputs(request: FixtureRequest) -> Union[torch.Tensor, list[torch.Tensor], torch.return_types.max]:
96+
NamedTuple = namedtuple("NamedTuple", ["output1", "output2"])
97+
98+
99+
@pytest.fixture(params=["tensor", "list", "torch_return_type", "named_tuple"])
100+
def example_outputs(
101+
request: FixtureRequest,
102+
) -> Union[torch.Tensor, list[torch.Tensor], torch.return_types.max, NamedTuple]:
97103
return {
98104
"tensor": torch.tensor(1),
99105
"list": [torch.tensor(1), torch.tensor([2])],
100106
"torch_return_type": torch.return_types.max((torch.tensor(1), torch.tensor([2]))),
107+
"named_tuple": NamedTuple(torch.tensor(1), torch.tensor([2])),
101108
}.get(request.param)
102109

103110

104-
def test_execute_post_hooks(example_outputs: Union[torch.Tensor, list[torch.Tensor], torch.return_types.max]):
111+
def test_execute_post_hooks(
112+
example_outputs: Union[torch.Tensor, list[torch.Tensor], torch.return_types.max, NamedTuple],
113+
):
105114
op_name = "/relu/0"
106115
hook_storage = HookStorage()
107116
hook_port_0 = CallCount()
@@ -112,6 +121,7 @@ def test_execute_post_hooks(example_outputs: Union[torch.Tensor, list[torch.Tens
112121
op_meta = OpMeta("/relu/0", torch.relu)
113122
ret_val = ctx.execute_post_hooks(example_outputs, op_meta)
114123
assert type(example_outputs) is type(ret_val)
124+
assert example_outputs == ret_val
115125

116126
assert hook_port_0.call_count == 1
117127
if isinstance(example_outputs, torch.Tensor):
@@ -120,6 +130,28 @@ def test_execute_post_hooks(example_outputs: Union[torch.Tensor, list[torch.Tens
120130
assert hook_port_1.call_count == 1
121131

122132

133+
def test_process_model_output(
134+
example_outputs: Union[torch.Tensor, list[torch.Tensor], torch.return_types.max, NamedTuple],
135+
):
136+
hook_storage = HookStorage()
137+
hook_output = CallCount()
138+
hook_output_0 = CallCount()
139+
hook_output_1 = CallCount()
140+
hook_storage.register_pre_function_hook("output", 0, hook_output)
141+
hook_storage.register_pre_function_hook("output_0", 0, hook_output_0)
142+
hook_storage.register_pre_function_hook("output_1", 0, hook_output_1)
143+
144+
ctx = FunctionHookMode(nn.Identity(), hook_storage)
145+
ret_val = ctx.process_model_outputs(example_outputs)
146+
assert type(example_outputs) is type(ret_val)
147+
assert example_outputs == ret_val
148+
if isinstance(example_outputs, torch.Tensor):
149+
assert hook_output.call_count == 1
150+
else:
151+
assert hook_output_0.call_count == 1
152+
assert hook_output_1.call_count == 1
153+
154+
123155
class ConcatModel(nn.Module):
124156
def forward(self, x: torch.Tensor) -> torch.Tensor:
125157
return torch.cat([x, x], dim=0)

0 commit comments

Comments
 (0)