Skip to content

Commit 911b091

Browse files
ppwwyyxxfacebook-github-bot
authored andcommitted
Add TracingAdapter module to flatten module outputs automatically
Summary: This adapter allows tracing models with rich inputs/outputs formats, as long as the formats are recognized (dict/list/tuple, and d2 builtin structures.) It simplifies code for tracing export, tracing evaluation and flop counting Reviewed By: theschnitz Differential Revision: D26298375 fbshipit-source-id: d0b20c26c13f69c80752caa921efc6011c2651bc
1 parent 42a2473 commit 911b091

File tree

9 files changed

+206
-123
lines changed

9 files changed

+206
-123
lines changed

detectron2/evaluation/evaluator.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
import logging
44
import time
55
from collections import OrderedDict
6-
from contextlib import contextmanager
6+
from contextlib import ExitStack, contextmanager
77
import torch
8+
from torch import nn
89

910
from detectron2.utils.comm import get_world_size, is_main_process
1011
from detectron2.utils.logger import log_every_n_seconds
@@ -101,13 +102,14 @@ def evaluate(self):
101102
def inference_on_dataset(model, data_loader, evaluator):
102103
"""
103104
Run model on the data_loader and evaluate the metrics with evaluator.
104-
Also benchmark the inference speed of `model.forward` accurately.
105+
Also benchmark the inference speed of `model.__call__` accurately.
105106
The model will be used in eval mode.
106107
107108
Args:
108-
model (nn.Module): a module which accepts an object from
109-
`data_loader` and returns some outputs. It will be temporarily set to `eval` mode.
109+
model (callable): a callable which takes an object from
110+
`data_loader` and returns some outputs.
110111
112+
If it's an nn.Module, it will be temporarily set to `eval` mode.
111113
If you wish to evaluate a model in `training` mode instead, you can
112114
wrap the given model and override its behavior of `.eval()` and `.train()`.
113115
data_loader: an iterable object with a length.
@@ -131,7 +133,11 @@ def inference_on_dataset(model, data_loader, evaluator):
131133
num_warmup = min(5, total - 1)
132134
start_time = time.perf_counter()
133135
total_compute_time = 0
134-
with inference_context(model), torch.no_grad():
136+
with ExitStack() as stack:
137+
if isinstance(model, nn.Module):
138+
stack.enter_context(inference_context(model))
139+
stack.enter_context(torch.no_grad())
140+
135141
for idx, inputs in enumerate(data_loader):
136142
if idx == num_warmup:
137143
start_time = time.perf_counter()

detectron2/export/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# -*- coding: utf-8 -*-
22

33
from .api import *
4+
from .flatten import TracingAdapter
45

56
__all__ = [k for k in globals().keys() if not k.startswith("_")]

detectron2/export/flatten.py

Lines changed: 98 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import collections
22
from dataclasses import dataclass
3-
from typing import List
3+
from typing import Callable, List, Optional, Tuple
44
import torch
5+
from torch import nn
56

67
from detectron2.structures import Boxes, Instances
78

@@ -171,3 +172,99 @@ def flatten_to_tuple(obj):
171172
F = IdentitySchema
172173

173174
return F.flatten(obj)
175+
176+
177+
class TracingAdapter(nn.Module):
178+
"""
179+
A model may take rich input/output format (e.g. dict or custom classes).
180+
This adapter flattens input/output format of a model so it becomes traceable.
181+
182+
It also records the necessary schema to rebuild model's inputs/outputs from flattened
183+
inputs/outputs.
184+
185+
Example:
186+
187+
::
188+
outputs = model(inputs) # inputs/outputs may be rich structure
189+
adapter = TracingAdapter(model, inputs)
190+
191+
# can now trace the model, with adapter.flattened_inputs, or another
192+
# tuple of tensors with the same length and meaning
193+
traced = torch.jit.trace(adapter, adapter.flattened_inputs)
194+
195+
# traced model can only produce flattened outputs (tuple of tensors)
196+
flattened_outputs = traced(*adapter.flattened_inputs)
197+
# adapter knows the schema to convert it back (new_outputs == outputs)
198+
new_outputs = adapter.outputs_schema(flattened_outputs)
199+
"""
200+
201+
flattened_inputs: Tuple[torch.Tensor] = None
202+
"""
203+
Flattened version of inputs given to this class's constructor.
204+
"""
205+
206+
inputs_schema: Schema = None
207+
"""
208+
Schema of the inputs given to this class's constructor.
209+
"""
210+
211+
outputs_schema: Schema = None
212+
"""
213+
Schema of the output produced by calling the given model with inputs.
214+
"""
215+
216+
def __init__(self, model: nn.Module, inputs, inference_func: Optional[Callable] = None):
217+
"""
218+
Args:
219+
model: an nn.Module
220+
inputs: An input argument or a tuple of input arguments used to call model.
221+
After flattening, it has to only consist of tensors.
222+
inference_func: a callable that takes (model, *inputs), calls the
223+
model with inputs, and return outputs. By default it
224+
is ``lambda model, *inputs: model(*inputs)``. Can be override
225+
if you need to call the model differently.
226+
"""
227+
super().__init__()
228+
if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)):
229+
model = model.module
230+
self.model = model
231+
if not isinstance(inputs, tuple):
232+
inputs = (inputs,)
233+
self.inputs = inputs
234+
235+
if inference_func is None:
236+
inference_func = lambda model, *inputs: model(*inputs) # noqa
237+
self.inference_func = inference_func
238+
239+
self.flattened_inputs, self.inputs_schema = flatten_to_tuple(inputs)
240+
for input in self.flattened_inputs:
241+
if not isinstance(input, torch.Tensor):
242+
raise ValueError(
243+
f"Inputs for tracing must only contain tensors. Got a {type(input)} instead."
244+
)
245+
246+
def forward(self, *args: torch.Tensor):
247+
with torch.no_grad():
248+
inputs_orig_format = self.inputs_schema(args)
249+
outputs = self.inference_func(self.model, *inputs_orig_format)
250+
flattened_outputs, schema = flatten_to_tuple(outputs)
251+
if self.outputs_schema is None:
252+
self.outputs_schema = schema
253+
else:
254+
assert (
255+
self.outputs_schema == schema
256+
), "Model should always return outputs with the same structure so it can be traced!"
257+
return flattened_outputs
258+
259+
def _create_wrapper(self, traced_model):
260+
"""
261+
Return a function that has an input/output interface the same as the
262+
original model, but it calls the given traced model under the hood.
263+
"""
264+
265+
def forward(*args):
266+
flattened_inputs, _ = flatten_to_tuple(args)
267+
flattened_outputs = traced_model(*flattened_inputs)
268+
return self.outputs_schema(flattened_outputs)
269+
270+
return forward

detectron2/utils/analysis.py

Lines changed: 16 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,11 @@
11
# Copyright (c) Facebook, Inc. and its affiliates.
22
# -*- coding: utf-8 -*-
33

4-
import logging
54
import typing
6-
import torch
75
from fvcore.nn import activation_count, flop_count, parameter_count, parameter_count_table
86
from torch import nn
97

10-
from detectron2.structures import BitMasks, Boxes, ImageList, Instances
11-
12-
from .logger import log_first_n
8+
from detectron2.export import TracingAdapter
139

1410
__all__ = [
1511
"activation_count_operators",
@@ -64,11 +60,13 @@ def flop_count_operators(
6460
the flops of box & mask head depends on the number of proposals &
6561
the number of detected objects.
6662
Therefore, the flops counting using a single input may not accurately
67-
reflect the computation cost of a model.
63+
reflect the computation cost of a model. It's recommended to average
64+
across a number of inputs.
6865
6966
Args:
7067
model: a detectron2 model that takes `list[dict]` as input.
7168
inputs (list[dict]): inputs to model, in detectron2's standard format.
69+
Only "image" key will be used.
7270
"""
7371
return _wrapper_count_operators(model=model, inputs=inputs, mode=FLOPS_MODE, **kwargs)
7472

@@ -90,71 +88,34 @@ def activation_count_operators(
9088
Args:
9189
model: a detectron2 model that takes `list[dict]` as input.
9290
inputs (list[dict]): inputs to model, in detectron2's standard format.
91+
Only "image" key will be used.
9392
"""
9493
return _wrapper_count_operators(model=model, inputs=inputs, mode=ACTIVATIONS_MODE, **kwargs)
9594

9695

97-
def _flatten_to_tuple(outputs):
98-
result = []
99-
if isinstance(outputs, torch.Tensor):
100-
result.append(outputs)
101-
elif isinstance(outputs, (list, tuple)):
102-
for v in outputs:
103-
result.extend(_flatten_to_tuple(v))
104-
elif isinstance(outputs, dict):
105-
for _, v in outputs.items():
106-
result.extend(_flatten_to_tuple(v))
107-
elif isinstance(outputs, Instances):
108-
result.extend(_flatten_to_tuple(outputs.get_fields()))
109-
elif isinstance(outputs, (Boxes, BitMasks, ImageList)):
110-
result.append(outputs.tensor)
111-
else:
112-
log_first_n(
113-
logging.WARN,
114-
f"Output of type {type(outputs)} not included in flops/activations count.",
115-
n=10,
116-
)
117-
return tuple(result)
118-
119-
12096
def _wrapper_count_operators(
12197
model: nn.Module, inputs: list, mode: str, **kwargs
12298
) -> typing.DefaultDict[str, float]:
123-
12499
# ignore some ops
125100
supported_ops = {k: lambda *args, **kwargs: {} for k in _IGNORED_OPS}
126101
supported_ops.update(kwargs.pop("supported_ops", {}))
127102
kwargs["supported_ops"] = supported_ops
128103

129104
assert len(inputs) == 1, "Please use batch size=1"
130105
tensor_input = inputs[0]["image"]
131-
132-
class WrapModel(nn.Module):
133-
def __init__(self, model):
134-
super().__init__()
135-
if isinstance(
136-
model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)
137-
):
138-
self.model = model.module
139-
else:
140-
self.model = model
141-
142-
def forward(self, image):
143-
# jit requires the input/output to be Tensors
144-
inputs = [{"image": image}]
145-
outputs = self.model.forward(inputs)
146-
# Only the subgraph that computes the returned tuple of tensor will be
147-
# counted. So we flatten everything we found to tuple of tensors.
148-
return _flatten_to_tuple(outputs)
106+
inputs = [{"image": tensor_input}] # remove other keys, in case there are any
149107

150108
old_train = model.training
151-
with torch.no_grad():
152-
if mode == FLOPS_MODE:
153-
ret = flop_count(WrapModel(model).train(False), (tensor_input,), **kwargs)
154-
elif mode == ACTIVATIONS_MODE:
155-
ret = activation_count(WrapModel(model).train(False), (tensor_input,), **kwargs)
156-
else:
157-
raise NotImplementedError("Count for mode {} is not supported yet.".format(mode))
109+
if isinstance(model, (nn.parallel.distributed.DistributedDataParallel, nn.DataParallel)):
110+
model = model.module
111+
wrapper = TracingAdapter(model, inputs)
112+
wrapper.eval()
113+
if mode == FLOPS_MODE:
114+
ret = flop_count(wrapper, (tensor_input,), **kwargs)
115+
elif mode == ACTIVATIONS_MODE:
116+
ret = activation_count(wrapper, (tensor_input,), **kwargs)
117+
else:
118+
raise NotImplementedError("Count for mode {} is not supported yet.".format(mode))
158119
# compatible with change in fvcore
159120
if isinstance(ret, tuple):
160121
ret = ret[0]

tests/test_export_torchscript.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from detectron2 import model_zoo
1010
from detectron2.config import get_cfg
11-
from detectron2.export.flatten import flatten_to_tuple
11+
from detectron2.export.flatten import TracingAdapter, flatten_to_tuple
1212
from detectron2.export.torchscript import dump_torchscript_IR, export_torchscript_with_instances
1313
from detectron2.export.torchscript_patch import patch_builtin_len
1414
from detectron2.layers import ShapeSpec
@@ -86,8 +86,7 @@ class TestTracing(unittest.TestCase):
8686
def testMaskRCNN(self):
8787
def inference_func(model, image):
8888
inputs = [{"image": image}]
89-
outputs = model.inference(inputs, do_postprocess=False)[0]
90-
return outputs
89+
return model.inference(inputs, do_postprocess=False)[0]
9190

9291
self._test_model("COCO-InstanceSegmentation/mask_rcnn_R_50_FPN_3x.yaml", inference_func)
9392

@@ -102,26 +101,15 @@ def _test_model(self, config_path, inference_func):
102101
model = model_zoo.get(config_path, trained=True)
103102
image = get_sample_coco_image()
104103

105-
class Wrapper(nn.ModuleList): # a wrapper to make the model traceable
106-
def forward(self, image):
107-
outputs = inference_func(self[0], image)
108-
flattened_outputs, schema = flatten_to_tuple(outputs)
109-
if not hasattr(self, "schema"):
110-
self.schema = schema
111-
return flattened_outputs
112-
113-
def rebuild(self, flattened_outputs):
114-
return self.schema(flattened_outputs)
115-
116-
wrapper = Wrapper([model])
104+
wrapper = TracingAdapter(model, image, inference_func)
117105
wrapper.eval()
118106
with torch.no_grad(), patch_builtin_len():
119107
small_image = nn.functional.interpolate(image, scale_factor=0.5)
120108
# trace with a different image, and the trace must still work
121109
traced_model = torch.jit.trace(wrapper, (small_image,))
122110

123111
output = inference_func(model, image)
124-
traced_output = wrapper.rebuild(traced_model(image))
112+
traced_output = wrapper.outputs_schema(traced_model(image))
125113
assert_instances_allclose(output, traced_output, size_as_tensor=True)
126114

127115
def testKeypointHead(self):
@@ -191,6 +179,9 @@ def test_flatten_basic(self):
191179
new_obj = schema(res)
192180
self.assertEqual(new_obj, obj)
193181

182+
_, new_schema = flatten_to_tuple(new_obj)
183+
self.assertEqual(schema, new_schema) # test __eq__
184+
194185
def test_flatten_instances_boxes(self):
195186
inst = Instances(
196187
torch.tensor([5, 8]), pred_masks=torch.tensor([3]), pred_boxes=Boxes(torch.ones((1, 4)))

tests/test_model_analysis.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def setUp(self):
1414

1515
def test_flop(self):
1616
# RetinaNet supports flop-counting with random inputs
17-
inputs = [{"image": torch.rand(3, 800, 800)}]
17+
inputs = [{"image": torch.rand(3, 800, 800), "test_unused": "abcd"}]
1818
res = flop_count_operators(self.model, inputs)
1919
self.assertTrue(int(res["conv"]), 146) # 146B flops
2020

tools/deploy/README.md

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,12 @@ We show a few example commands to export and execute a Mask R-CNN model in C++.
6565
## Notes:
6666

6767
1. Tracing/Caffe2-tracing requires valid weights & sample inputs.
68-
Therefore the above commands require [setting up the COCO dataset](https://detectron2.readthedocs.io/tutorials/builtin_datasets.html).
68+
Therefore the above commands require pre-trained models and [COCO dataset](https://detectron2.readthedocs.io/tutorials/builtin_datasets.html).
6969
You can modify the script to obtain sample inputs in other ways instead of from COCO.
7070

71-
2. `--run-eval` flag is supported with caffe2 format.
72-
This flag will evaluate the converted models to verify its accuracy.
73-
The accuracy is typically slightly different (within 0.1 AP) from original model due to
74-
numerical precisions between different implementations.
71+
2. `--run-eval` flag can be used under certain modes
72+
(caffe2_tracing with caffe2 format, or tracing with torchscript format)
73+
to evaluate the exported model using the dataset in the config.
7574
It's recommended to always verify the accuracy in case the conversion is not successful.
76-
Evaluation can be slow if model is exported to CPU.
75+
Evaluation can be slow if model is exported to CPU or dataset is too large ("coco_2017_val_100" is a small subset of COCO useful for evaluation).
76+
Caffe2 accuracy may be slightly different (within 0.1 AP) from original model due to numerical precisions between different runtime.

0 commit comments

Comments
 (0)