Skip to content

Commit ebb2966

Browse files
authored
Arm backend: Use _pytree tree_flatten in model evaluator (#11133)
### Summary Updates model evaluator to use tree_flatten for output tensors when checking int8 deviation. * Removes in script function which flattens the tensors.
1 parent e833b5e commit ebb2966

File tree

1 file changed

+3
-16
lines changed

1 file changed

+3
-16
lines changed

backends/arm/util/arm_model_evaluator.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import torch
2020
from torch.nn.modules import Module
21+
from torch.utils._pytree import tree_flatten
2122
from torch.utils.data import DataLoader
2223
from torchvision import datasets, transforms # type: ignore[import-untyped]
2324

@@ -28,20 +29,6 @@
2829
logger.setLevel(logging.INFO)
2930

3031

31-
def flatten_args(args) -> tuple | list:
32-
flattened_args: list = []
33-
if isinstance(args, torch.Tensor):
34-
return [args]
35-
36-
for arg in args:
37-
if isinstance(arg, (tuple, list)):
38-
flattened_args.extend(arg)
39-
else:
40-
flattened_args.append(arg)
41-
42-
return tuple(flattened_args)
43-
44-
4532
class GenericModelEvaluator:
4633
REQUIRES_CONFIG = False
4734

@@ -72,8 +59,8 @@ def get_model_error(self) -> defaultdict:
7259
- Maximum percentage error
7360
- Mean absolute error
7461
"""
75-
fp32_outputs = flatten_args(self.fp32_model(*self.example_input))
76-
int8_outputs = flatten_args(self.int8_model(*self.example_input))
62+
fp32_outputs, _ = tree_flatten(self.fp32_model(*self.example_input))
63+
int8_outputs, _ = tree_flatten(self.int8_model(*self.example_input))
7764

7865
model_error_dict = defaultdict(list)
7966

0 commit comments

Comments
 (0)