Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/source/guides/3_pruning.rst
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ Following info will be printed before the pruning process is started:
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Constraint ┃ min ┃ centroid ┃ max ┃ max/min ratio ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ flops │ 274.34M1.28G4.59G │ 16.73 │
│ flops │ 548.68M2.56G9.18G │ 16.73 │
│ params │ 2.70M │ 9.75M │ 25.50M │ 9.43 │
└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘

Expand All @@ -199,7 +199,7 @@ Following info will be printed before the pruning process is started:
┃ ┃ ┃ Satisfiable ┃
┃ Constraint ┃ Upper Bound ┃ Upper Bound ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩
│ flops │ 2.75G │ True │
│ flops │ 5.50G │ True │
└──────────────┴──────────────┴──────────────┘


Expand Down
12 changes: 6 additions & 6 deletions docs/source/guides/7_nas.rst
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,8 @@ the search space together with your deployment constraints using

import torch

# Looking for a subnet with at most 2 GFLOPs
constraints = {"flops": 2.0e9}
# Looking for a subnet with at most 4 GFLOPs
constraints = {"flops": 4.0e9}

# Measure FLOPs against dummy_input
# Can be provided as a single tensor or tuple of input args to the model.
Expand All @@ -129,7 +129,7 @@ Following info will be printed:
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Constraint ┃ min ┃ centroid ┃ max ┃ max/min ratio ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ flops │ 487.92M1.84G4.59G │ 9.40 │
│ flops │ 975.84M3.68G9.18G │ 9.40 │
│ params │ 4.84M │ 12.33M │ 25.50M │ 5.27 │
└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘

Expand All @@ -138,7 +138,7 @@ Following info will be printed:
┃ ┃ ┃ Satisfiable ┃
┃ Constraint ┃ Upper Bound ┃ Upper Bound ┃
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩
│ flops │ 2.00G │ True │
│ flops │ 4.00G │ True │
└──────────────┴──────────────┴──────────────┘

Search Space Summary:
Expand Down Expand Up @@ -242,8 +242,8 @@ Below is an example of running search on an AutoNAS converted and trained model.
# Specify the sample input including target data shape for FLOPs calculation.
dummy_input = torch.randn(1, 3, 224, 224)

# Looking for a subnet with at most 2 GFLOPs
search_constraints = {"flops": 2.0e9}
# Looking for a subnet with at most 4 GFLOPs
search_constraints = {"flops": 4.0e9}

# search_res (dict) contains state_dict / stats of the searcher
searched_model, search_res = mtn.search(
Expand Down
16 changes: 8 additions & 8 deletions examples/pruning/cifar_resnet.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@
"* prune the model;\n",
"* obtain a valid pytorch model that can be used for fine-tuning.\n",
"\n",
"Let's say you have the ResNet20 model as our base model to prune from and we are looking for a model with at most 30M FLOPs. We can provide search constraints for `flops` and/or `params` by an upper bound. The values can either be absolute numbers (e.g. `30e6`) or a string percentage (e.g. `\"75%\"`). In addition, we should also provide our training data loader to [mtp.prune](../reference/generated/modelopt.torch.prune.pruning.rst#modelopt.torch.prune.pruning.prune). The training data loader will be used to calibrate the normalization layers in the model. Finally, we will also specify a custom config for configuring the pruning search space to get a more fine-grained selection of pruned nets.\n",
"Let's say you have the ResNet20 model as our base model to prune from and we are looking for a model with at most 60M FLOPs. We can provide search constraints for `flops` and/or `params` by an upper bound. The values can either be absolute numbers (e.g. `60e6`) or a string percentage (e.g. `\"75%\"`). In addition, we should also provide our training data loader to [mtp.prune](../reference/generated/modelopt.torch.prune.pruning.rst#modelopt.torch.prune.pruning.prune). The training data loader will be used to calibrate the normalization layers in the model. Finally, we will also specify a custom config for configuring the pruning search space to get a more fine-grained selection of pruned nets.\n",
"\n",
"Finally, we can store the pruned architecture and weights using [mto.save](../reference/generated/modelopt.torch.opt.conversion.rst#modelopt.torch.opt.conversion.save).\n",
"\n",
Expand Down Expand Up @@ -529,7 +529,7 @@
"┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
"┃\u001b[1m \u001b[0m\u001b[1mConstraint \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmin \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mcentroid \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmax \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmax/min ratio\u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
"│ flops │ 24.33M27.57M40.55M │ 1.67 │\n",
"│ flops │ 48.66M55.14M81.10M │ 1.67 │\n",
"│ params │ 90.94K │ 141.63K │ 268.35K │ 2.95 │\n",
"└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘\n",
"\u001b[3m \u001b[0m\n",
Expand All @@ -538,7 +538,7 @@
"┃\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mSatisfiable \u001b[0m\u001b[1m \u001b[0m┃\n",
"┃\u001b[1m \u001b[0m\u001b[1mConstraint \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mUpper Bound \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mUpper Bound \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩\n",
"│ flops │ 30.00M │ True │\n",
"│ flops │ 60.00M │ True │\n",
"└──────────────┴──────────────┴──────────────┘\n",
"\n",
"\n",
Expand Down Expand Up @@ -618,7 +618,7 @@
"name": "stdout",
"output_type": "stream",
"text": [
"[best_subnet_constraints] = {'params': '173.88K', 'flops': '29.64M'}\n"
"[best_subnet_constraints] = {'params': '173.88K', 'flops': '59.28M'}\n"
]
},
{
Expand Down Expand Up @@ -656,7 +656,7 @@
"pruned_model, _ = mtp.prune(\n",
" model=resnet20(ckpt=\"resnet20.pth\"),\n",
" mode=[(\"fastnas\", config)],\n",
" constraints={\"flops\": 30e6},\n",
" constraints={\"flops\": 60e6},\n",
" dummy_input=dummy_input,\n",
" config={\n",
" \"data_loader\": train_loader,\n",
Expand All @@ -676,7 +676,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"As we can see, the best subnet (29.6M FLOPs) has fitted our constraint of 30M FLOPs. We can also see a drop in validation accuracy for the searched model. This is very common after pruning and fine-tuning is necessary for this model.\n",
"As we can see, the best subnet (59.3M FLOPs) has fitted our constraint of 60M FLOPs. We can also see a drop in validation accuracy for the searched model. This is very common after pruning and fine-tuning is necessary for this model.\n",
"\n",
"#### Restore the pruned subnet using [mto.restore](../reference/generated/modelopt.torch.opt.conversion.rst#modelopt.torch.opt.conversion.restore)"
]
Expand Down Expand Up @@ -795,8 +795,8 @@
"\n",
"| Model | FLOPs | Params | Test Accuracy |\n",
"| --------------- | ---------- | ---------- | ----------------- |\n",
"| ResNet20 | 40.6M | 268k | 90.9% |\n",
"| FastNAS subnet | 29.6M | 174k | 90.3% |\n",
"| ResNet20 | 81.2M | 268k | 90.9% |\n",
"| FastNAS subnet | 59.2M | 174k | 90.3% |\n",
"\n",
"As we see here, we have reduced the FLOPs and number of parameters which would also result in a improvement in latency with very little loss in accuracy. Good job!\n",
"\n",
Expand Down
2 changes: 1 addition & 1 deletion modelopt/torch/nas/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def inference_flops(
with warnings.catch_warnings():
warnings.simplefilter("ignore")
with batch_norm_ignored_flops():
flops = profile.profile_macs(network, args=dummy_input)
flops = 2 * profile.profile_macs(network, args=dummy_input)
network.train(is_training)
if return_str:
return num2hrb(flops)
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/torch/nas/test_evaluate_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_evaluate_constraints(model_and_input_func):
# NOTE: using param_num here instead of param_num_from_forward to check
# correctness of the function.
"params": param_num(model, unit=1.0),
"flops": profile_macs(model, args) / 1.0,
"flops": 2 * profile_macs(model, args) / 1.0,
}

assert actual_results == expected_results
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_percent_limits():
cf = ConstraintsFunc(model, constraints=constraints, dummy_input=args)

remove_bn(model)
max_flops = profile_macs(model, args)
max_flops = 2 * profile_macs(model, args)
max_params = param_num(model, unit=1.0)
expected_results = {
# NOTE: using trainable_param_num here instead of trainable_param_num_from_forward to check
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/torch/nas/test_nas_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,5 @@ def test_flops(in_channel, out_channel, kernel_size, groups, data_x, data_y) ->
input_data_shape = (1, in_channel, data_x, data_y)
out_elements = out_channel * data_x * data_y
per_element_filters = in_channel * kernel_size * kernel_size // groups
desired_output = out_elements * per_element_filters
desired_output = 2 * out_elements * per_element_filters
assert inference_flops(conv_module, data_shape=input_data_shape, unit=1) == desired_output