11# Copyright (c) Facebook, Inc. and its affiliates.
22# -*- coding: utf-8 -*-
33
4- import logging
54import typing
6- import torch
75from fvcore .nn import activation_count , flop_count , parameter_count , parameter_count_table
86from torch import nn
97
10- from detectron2 .structures import BitMasks , Boxes , ImageList , Instances
11-
12- from .logger import log_first_n
8+ from detectron2 .export import TracingAdapter
139
1410__all__ = [
1511 "activation_count_operators" ,
@@ -64,11 +60,13 @@ def flop_count_operators(
6460 the flops of box & mask head depends on the number of proposals &
6561 the number of detected objects.
6662 Therefore, the flops counting using a single input may not accurately
67- reflect the computation cost of a model.
63+ reflect the computation cost of a model. It's recommended to average
64+ across a number of inputs.
6865
6966 Args:
7067 model: a detectron2 model that takes `list[dict]` as input.
7168 inputs (list[dict]): inputs to model, in detectron2's standard format.
69+ Only "image" key will be used.
7270 """
7371 return _wrapper_count_operators (model = model , inputs = inputs , mode = FLOPS_MODE , ** kwargs )
7472
@@ -90,71 +88,34 @@ def activation_count_operators(
9088 Args:
9189 model: a detectron2 model that takes `list[dict]` as input.
9290 inputs (list[dict]): inputs to model, in detectron2's standard format.
91+ Only "image" key will be used.
9392 """
9493 return _wrapper_count_operators (model = model , inputs = inputs , mode = ACTIVATIONS_MODE , ** kwargs )
9594
9695
97- def _flatten_to_tuple (outputs ):
98- result = []
99- if isinstance (outputs , torch .Tensor ):
100- result .append (outputs )
101- elif isinstance (outputs , (list , tuple )):
102- for v in outputs :
103- result .extend (_flatten_to_tuple (v ))
104- elif isinstance (outputs , dict ):
105- for _ , v in outputs .items ():
106- result .extend (_flatten_to_tuple (v ))
107- elif isinstance (outputs , Instances ):
108- result .extend (_flatten_to_tuple (outputs .get_fields ()))
109- elif isinstance (outputs , (Boxes , BitMasks , ImageList )):
110- result .append (outputs .tensor )
111- else :
112- log_first_n (
113- logging .WARN ,
114- f"Output of type { type (outputs )} not included in flops/activations count." ,
115- n = 10 ,
116- )
117- return tuple (result )
118-
119-
12096def _wrapper_count_operators (
12197 model : nn .Module , inputs : list , mode : str , ** kwargs
12298) -> typing .DefaultDict [str , float ]:
123-
12499 # ignore some ops
125100 supported_ops = {k : lambda * args , ** kwargs : {} for k in _IGNORED_OPS }
126101 supported_ops .update (kwargs .pop ("supported_ops" , {}))
127102 kwargs ["supported_ops" ] = supported_ops
128103
129104 assert len (inputs ) == 1 , "Please use batch size=1"
130105 tensor_input = inputs [0 ]["image" ]
131-
132- class WrapModel (nn .Module ):
133- def __init__ (self , model ):
134- super ().__init__ ()
135- if isinstance (
136- model , (nn .parallel .distributed .DistributedDataParallel , nn .DataParallel )
137- ):
138- self .model = model .module
139- else :
140- self .model = model
141-
142- def forward (self , image ):
143- # jit requires the input/output to be Tensors
144- inputs = [{"image" : image }]
145- outputs = self .model .forward (inputs )
146- # Only the subgraph that computes the returned tuple of tensor will be
147- # counted. So we flatten everything we found to tuple of tensors.
148- return _flatten_to_tuple (outputs )
106+ inputs = [{"image" : tensor_input }] # remove other keys, in case there are any
149107
150108 old_train = model .training
151- with torch .no_grad ():
152- if mode == FLOPS_MODE :
153- ret = flop_count (WrapModel (model ).train (False ), (tensor_input ,), ** kwargs )
154- elif mode == ACTIVATIONS_MODE :
155- ret = activation_count (WrapModel (model ).train (False ), (tensor_input ,), ** kwargs )
156- else :
157- raise NotImplementedError ("Count for mode {} is not supported yet." .format (mode ))
109+ if isinstance (model , (nn .parallel .distributed .DistributedDataParallel , nn .DataParallel )):
110+ model = model .module
111+ wrapper = TracingAdapter (model , inputs )
112+ wrapper .eval ()
113+ if mode == FLOPS_MODE :
114+ ret = flop_count (wrapper , (tensor_input ,), ** kwargs )
115+ elif mode == ACTIVATIONS_MODE :
116+ ret = activation_count (wrapper , (tensor_input ,), ** kwargs )
117+ else :
118+ raise NotImplementedError ("Count for mode {} is not supported yet." .format (mode ))
158119 # compatible with change in fvcore
159120 if isinstance (ret , tuple ):
160121 ret = ret [0 ]
0 commit comments