Skip to content

Commit 70abfb4

Browse files
authored
Fix FLOPs calculation (#388)
Signed-off-by: Mohammed Yasin <[email protected]> Signed-off-by: Y-T-G <[email protected]>
1 parent 8d924d3 commit 70abfb4

File tree

6 files changed

+20
-20
lines changed

6 files changed

+20
-20
lines changed

docs/source/guides/3_pruning.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ Following info will be printed before the pruning process is started:
190190
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
191191
┃ Constraint ┃ min ┃ centroid ┃ max ┃ max/min ratio ┃
192192
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
193-
│ flops │ 274.34M1.28G4.59G │ 16.73 │
193+
│ flops │ 548.68M2.56G9.18G │ 16.73 │
194194
│ params │ 2.70M │ 9.75M │ 25.50M │ 9.43 │
195195
└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘
196196
@@ -199,7 +199,7 @@ Following info will be printed before the pruning process is started:
199199
┃ ┃ ┃ Satisfiable ┃
200200
┃ Constraint ┃ Upper Bound ┃ Upper Bound ┃
201201
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩
202-
│ flops │ 2.75G │ True │
202+
│ flops │ 5.50G │ True │
203203
└──────────────┴──────────────┴──────────────┘
204204
205205

docs/source/guides/7_nas.rst

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -109,8 +109,8 @@ the search space together with your deployment constraints using
109109
110110
import torch
111111
112-
# Looking for a subnet with at most 2 GFLOPs
113-
constraints = {"flops": 2.0e9}
112+
# Looking for a subnet with at most 4 GFLOPs
113+
constraints = {"flops": 4.0e9}
114114
115115
# Measure FLOPs against dummy_input
116116
# Can be provided as a single tensor or tuple of input args to the model.
@@ -129,7 +129,7 @@ Following info will be printed:
129129
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
130130
┃ Constraint ┃ min ┃ centroid ┃ max ┃ max/min ratio ┃
131131
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
132-
│ flops │ 487.92M1.84G4.59G │ 9.40 │
132+
│ flops │ 975.84M3.68G9.18G │ 9.40 │
133133
│ params │ 4.84M │ 12.33M │ 25.50M │ 5.27 │
134134
└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘
135135
@@ -138,7 +138,7 @@ Following info will be printed:
138138
┃ ┃ ┃ Satisfiable ┃
139139
┃ Constraint ┃ Upper Bound ┃ Upper Bound ┃
140140
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩
141-
│ flops │ 2.00G │ True │
141+
│ flops │ 4.00G │ True │
142142
└──────────────┴──────────────┴──────────────┘
143143
144144
Search Space Summary:
@@ -242,8 +242,8 @@ Below is an example of running search on an AutoNAS converted and trained model.
242242
# Specify the sample input including target data shape for FLOPs calculation.
243243
dummy_input = torch.randn(1, 3, 224, 224)
244244
245-
# Looking for a subnet with at most 2 GFLOPs
246-
search_constraints = {"flops": 2.0e9}
245+
# Looking for a subnet with at most 4 GFLOPs
246+
search_constraints = {"flops": 4.0e9}
247247
248248
# search_res (dict) contains state_dict / stats of the searcher
249249
searched_model, search_res = mtn.search(

examples/pruning/cifar_resnet.ipynb

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -489,7 +489,7 @@
489489
"* prune the model;\n",
490490
"* obtain a valid pytorch model that can be used for fine-tuning.\n",
491491
"\n",
492-
"Let's say you have the ResNet20 model as our base model to prune from and we are looking for a model with at most 30M FLOPs. We can provide search constraints for `flops` and/or `params` by an upper bound. The values can either be absolute numbers (e.g. `30e6`) or a string percentage (e.g. `\"75%\"`). In addition, we should also provide our training data loader to [mtp.prune](../reference/generated/modelopt.torch.prune.pruning.rst#modelopt.torch.prune.pruning.prune). The training data loader will be used to calibrate the normalization layers in the model. Finally, we will also specify a custom config for configuring the pruning search space to get a more fine-grained selection of pruned nets.\n",
492+
"Let's say you have the ResNet20 model as our base model to prune from and we are looking for a model with at most 60M FLOPs. We can provide search constraints for `flops` and/or `params` by an upper bound. The values can either be absolute numbers (e.g. `60e6`) or a string percentage (e.g. `\"75%\"`). In addition, we should also provide our training data loader to [mtp.prune](../reference/generated/modelopt.torch.prune.pruning.rst#modelopt.torch.prune.pruning.prune). The training data loader will be used to calibrate the normalization layers in the model. Finally, we will also specify a custom config for configuring the pruning search space to get a more fine-grained selection of pruned nets.\n",
493493
"\n",
494494
"Finally, we can store the pruned architecture and weights using [mto.save](../reference/generated/modelopt.torch.opt.conversion.rst#modelopt.torch.opt.conversion.save).\n",
495495
"\n",
@@ -529,7 +529,7 @@
529529
"┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
530530
"\u001b[1m \u001b[0m\u001b[1mConstraint \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmin \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mcentroid \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmax \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmax/min ratio\u001b[0m\u001b[1m \u001b[0m┃\n",
531531
"┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
532-
"│ flops │ 24.33M27.57M40.55M │ 1.67 │\n",
532+
"│ flops │ 48.66M55.14M81.10M │ 1.67 │\n",
533533
"│ params │ 90.94K │ 141.63K │ 268.35K │ 2.95 │\n",
534534
"└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘\n",
535535
"\u001b[3m \u001b[0m\n",
@@ -538,7 +538,7 @@
538538
"\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mSatisfiable \u001b[0m\u001b[1m \u001b[0m┃\n",
539539
"\u001b[1m \u001b[0m\u001b[1mConstraint \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mUpper Bound \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mUpper Bound \u001b[0m\u001b[1m \u001b[0m┃\n",
540540
"┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩\n",
541-
"│ flops │ 30.00M │ True │\n",
541+
"│ flops │ 60.00M │ True │\n",
542542
"└──────────────┴──────────────┴──────────────┘\n",
543543
"\n",
544544
"\n",
@@ -618,7 +618,7 @@
618618
"name": "stdout",
619619
"output_type": "stream",
620620
"text": [
621-
"[best_subnet_constraints] = {'params': '173.88K', 'flops': '29.64M'}\n"
621+
"[best_subnet_constraints] = {'params': '173.88K', 'flops': '59.28M'}\n"
622622
]
623623
},
624624
{
@@ -656,7 +656,7 @@
656656
"pruned_model, _ = mtp.prune(\n",
657657
" model=resnet20(ckpt=\"resnet20.pth\"),\n",
658658
" mode=[(\"fastnas\", config)],\n",
659-
" constraints={\"flops\": 30e6},\n",
659+
" constraints={\"flops\": 60e6},\n",
660660
" dummy_input=dummy_input,\n",
661661
" config={\n",
662662
" \"data_loader\": train_loader,\n",
@@ -676,7 +676,7 @@
676676
"cell_type": "markdown",
677677
"metadata": {},
678678
"source": [
679-
"As we can see, the best subnet (29.6M FLOPs) has fitted our constraint of 30M FLOPs. We can also see a drop in validation accuracy for the searched model. This is very common after pruning and fine-tuning is necessary for this model.\n",
679+
"As we can see, the best subnet (59.3M FLOPs) has fitted our constraint of 60M FLOPs. We can also see a drop in validation accuracy for the searched model. This is very common after pruning and fine-tuning is necessary for this model.\n",
680680
"\n",
681681
"#### Restore the pruned subnet using [mto.restore](../reference/generated/modelopt.torch.opt.conversion.rst#modelopt.torch.opt.conversion.restore)"
682682
]
@@ -795,8 +795,8 @@
795795
"\n",
796796
"| Model | FLOPs | Params | Test Accuracy |\n",
797797
"| --------------- | ---------- | ---------- | ----------------- |\n",
798-
"| ResNet20 | 40.6M | 268k | 90.9% |\n",
799-
"| FastNAS subnet | 29.6M | 174k | 90.3% |\n",
798+
"| ResNet20 | 81.2M | 268k | 90.9% |\n",
799+
"| FastNAS subnet | 59.2M | 174k | 90.3% |\n",
800800
"\n",
801801
"As we see here, we have reduced the FLOPs and number of parameters which would also result in a improvement in latency with very little loss in accuracy. Good job!\n",
802802
"\n",

modelopt/torch/nas/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ def inference_flops(
116116
with warnings.catch_warnings():
117117
warnings.simplefilter("ignore")
118118
with batch_norm_ignored_flops():
119-
flops = profile.profile_macs(network, args=dummy_input)
119+
flops = 2 * profile.profile_macs(network, args=dummy_input)
120120
network.train(is_training)
121121
if return_str:
122122
return num2hrb(flops)

tests/unit/torch/nas/test_evaluate_constraints.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def test_evaluate_constraints(model_and_input_func):
4747
# NOTE: using param_num here instead of param_num_from_forward to check
4848
# correctness of the function.
4949
"params": param_num(model, unit=1.0),
50-
"flops": profile_macs(model, args) / 1.0,
50+
"flops": 2 * profile_macs(model, args) / 1.0,
5151
}
5252

5353
assert actual_results == expected_results
@@ -83,7 +83,7 @@ def test_percent_limits():
8383
cf = ConstraintsFunc(model, constraints=constraints, dummy_input=args)
8484

8585
remove_bn(model)
86-
max_flops = profile_macs(model, args)
86+
max_flops = 2 * profile_macs(model, args)
8787
max_params = param_num(model, unit=1.0)
8888
expected_results = {
8989
# NOTE: using trainable_param_num here instead of trainable_param_num_from_forward to check

tests/unit/torch/nas/test_nas_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,5 +36,5 @@ def test_flops(in_channel, out_channel, kernel_size, groups, data_x, data_y) ->
3636
input_data_shape = (1, in_channel, data_x, data_y)
3737
out_elements = out_channel * data_x * data_y
3838
per_element_filters = in_channel * kernel_size * kernel_size // groups
39-
desired_output = out_elements * per_element_filters
39+
desired_output = 2 * out_elements * per_element_filters
4040
assert inference_flops(conv_module, data_shape=input_data_shape, unit=1) == desired_output

0 commit comments

Comments
 (0)