diff --git a/docs/source/guides/3_pruning.rst b/docs/source/guides/3_pruning.rst index 238727384..786c286da 100644 --- a/docs/source/guides/3_pruning.rst +++ b/docs/source/guides/3_pruning.rst @@ -190,7 +190,7 @@ Following info will be printed before the pruning process is started: ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Constraint ┃ min ┃ centroid ┃ max ┃ max/min ratio ┃ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ - │ flops │ 274.34M │ 1.28G │ 4.59G │ 16.73 │ + │ flops │ 548.68M │ 2.56G │ 9.18G │ 16.73 │ │ params │ 2.70M │ 9.75M │ 25.50M │ 9.43 │ └──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘ @@ -199,7 +199,7 @@ Following info will be printed before the pruning process is started: ┃ ┃ ┃ Satisfiable ┃ ┃ Constraint ┃ Upper Bound ┃ Upper Bound ┃ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩ - │ flops │ 2.75G │ True │ + │ flops │ 5.50G │ True │ └──────────────┴──────────────┴──────────────┘ diff --git a/docs/source/guides/7_nas.rst b/docs/source/guides/7_nas.rst index bf8198b80..888039fcd 100644 --- a/docs/source/guides/7_nas.rst +++ b/docs/source/guides/7_nas.rst @@ -109,8 +109,8 @@ the search space together with your deployment constraints using import torch - # Looking for a subnet with at most 2 GFLOPs - constraints = {"flops": 2.0e9} + # Looking for a subnet with at most 4 GFLOPs + constraints = {"flops": 4.0e9} # Measure FLOPs against dummy_input # Can be provided as a single tensor or tuple of input args to the model. @@ -129,7 +129,7 @@ Following info will be printed: ┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓ ┃ Constraint ┃ min ┃ centroid ┃ max ┃ max/min ratio ┃ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩ - │ flops │ 487.92M │ 1.84G │ 4.59G │ 9.40 │ + │ flops │ 975.84M │ 3.68G │ 9.18G │ 9.40 │ │ params │ 4.84M │ 12.33M │ 25.50M │ 5.27 │ └──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘ @@ -138,7 +138,7 @@ Following info will be printed: ┃ ┃ ┃ Satisfiable ┃ ┃ Constraint ┃ Upper Bound ┃ Upper Bound ┃ ┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩ - │ flops │ 2.00G │ True │ + │ flops │ 4.00G │ True │ └──────────────┴──────────────┴──────────────┘ Search Space Summary: @@ -242,8 +242,8 @@ Below is an example of running search on an AutoNAS converted and trained model. # Specify the sample input including target data shape for FLOPs calculation. dummy_input = torch.randn(1, 3, 224, 224) - # Looking for a subnet with at most 2 GFLOPs - search_constraints = {"flops": 2.0e9} + # Looking for a subnet with at most 4 GFLOPs + search_constraints = {"flops": 4.0e9} # search_res (dict) contains state_dict / stats of the searcher searched_model, search_res = mtn.search( diff --git a/examples/pruning/cifar_resnet.ipynb b/examples/pruning/cifar_resnet.ipynb index 62d297ba6..1c6f10852 100644 --- a/examples/pruning/cifar_resnet.ipynb +++ b/examples/pruning/cifar_resnet.ipynb @@ -489,7 +489,7 @@ "* prune the model;\n", "* obtain a valid pytorch model that can be used for fine-tuning.\n", "\n", - "Let's say you have the ResNet20 model as our base model to prune from and we are looking for a model with at most 30M FLOPs. We can provide search constraints for `flops` and/or `params` by an upper bound. The values can either be absolute numbers (e.g. `30e6`) or a string percentage (e.g. `\"75%\"`). In addition, we should also provide our training data loader to [mtp.prune](../reference/generated/modelopt.torch.prune.pruning.rst#modelopt.torch.prune.pruning.prune). The training data loader will be used to calibrate the normalization layers in the model. Finally, we will also specify a custom config for configuring the pruning search space to get a more fine-grained selection of pruned nets.\n", + "Let's say you have the ResNet20 model as our base model to prune from and we are looking for a model with at most 60M FLOPs. We can provide search constraints for `flops` and/or `params` by an upper bound. The values can either be absolute numbers (e.g. `60e6`) or a string percentage (e.g. `\"75%\"`). In addition, we should also provide our training data loader to [mtp.prune](../reference/generated/modelopt.torch.prune.pruning.rst#modelopt.torch.prune.pruning.prune). The training data loader will be used to calibrate the normalization layers in the model. Finally, we will also specify a custom config for configuring the pruning search space to get a more fine-grained selection of pruned nets.\n", "\n", "Finally, we can store the pruned architecture and weights using [mto.save](../reference/generated/modelopt.torch.opt.conversion.rst#modelopt.torch.opt.conversion.save).\n", "\n", @@ -529,7 +529,7 @@ "┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", "┃\u001b[1m \u001b[0m\u001b[1mConstraint \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmin \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mcentroid \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmax \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmax/min ratio\u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", - "│ flops │ 24.33M │ 27.57M │ 40.55M │ 1.67 │\n", + "│ flops │ 48.66M │ 55.14M │ 81.10M │ 1.67 │\n", "│ params │ 90.94K │ 141.63K │ 268.35K │ 2.95 │\n", "└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘\n", "\u001b[3m \u001b[0m\n", @@ -538,7 +538,7 @@ "┃\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mSatisfiable \u001b[0m\u001b[1m \u001b[0m┃\n", "┃\u001b[1m \u001b[0m\u001b[1mConstraint \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mUpper Bound \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mUpper Bound \u001b[0m\u001b[1m \u001b[0m┃\n", "┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩\n", - "│ flops │ 30.00M │ True │\n", + "│ flops │ 60.00M │ True │\n", "└──────────────┴──────────────┴──────────────┘\n", "\n", "\n", @@ -618,7 +618,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "[best_subnet_constraints] = {'params': '173.88K', 'flops': '29.64M'}\n" + "[best_subnet_constraints] = {'params': '173.88K', 'flops': '59.28M'}\n" ] }, { @@ -656,7 +656,7 @@ "pruned_model, _ = mtp.prune(\n", " model=resnet20(ckpt=\"resnet20.pth\"),\n", " mode=[(\"fastnas\", config)],\n", - " constraints={\"flops\": 30e6},\n", + " constraints={\"flops\": 60e6},\n", " dummy_input=dummy_input,\n", " config={\n", " \"data_loader\": train_loader,\n", @@ -676,7 +676,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "As we can see, the best subnet (29.6M FLOPs) has fitted our constraint of 30M FLOPs. We can also see a drop in validation accuracy for the searched model. This is very common after pruning and fine-tuning is necessary for this model.\n", + "As we can see, the best subnet (59.3M FLOPs) has fitted our constraint of 60M FLOPs. We can also see a drop in validation accuracy for the searched model. This is very common after pruning and fine-tuning is necessary for this model.\n", "\n", "#### Restore the pruned subnet using [mto.restore](../reference/generated/modelopt.torch.opt.conversion.rst#modelopt.torch.opt.conversion.restore)" ] @@ -795,8 +795,8 @@ "\n", "| Model | FLOPs | Params | Test Accuracy |\n", "| --------------- | ---------- | ---------- | ----------------- |\n", - "| ResNet20 | 40.6M | 268k | 90.9% |\n", - "| FastNAS subnet | 29.6M | 174k | 90.3% |\n", + "| ResNet20 | 81.2M | 268k | 90.9% |\n", + "| FastNAS subnet | 59.2M | 174k | 90.3% |\n", "\n", "As we see here, we have reduced the FLOPs and number of parameters which would also result in a improvement in latency with very little loss in accuracy. Good job!\n", "\n", diff --git a/modelopt/torch/nas/utils.py b/modelopt/torch/nas/utils.py index 56b103fec..51cb6456e 100644 --- a/modelopt/torch/nas/utils.py +++ b/modelopt/torch/nas/utils.py @@ -116,7 +116,7 @@ def inference_flops( with warnings.catch_warnings(): warnings.simplefilter("ignore") with batch_norm_ignored_flops(): - flops = profile.profile_macs(network, args=dummy_input) + flops = 2 * profile.profile_macs(network, args=dummy_input) network.train(is_training) if return_str: return num2hrb(flops) diff --git a/tests/unit/torch/nas/test_evaluate_constraints.py b/tests/unit/torch/nas/test_evaluate_constraints.py index fec02b696..4521e3a0f 100644 --- a/tests/unit/torch/nas/test_evaluate_constraints.py +++ b/tests/unit/torch/nas/test_evaluate_constraints.py @@ -47,7 +47,7 @@ def test_evaluate_constraints(model_and_input_func): # NOTE: using param_num here instead of param_num_from_forward to check # correctness of the function. "params": param_num(model, unit=1.0), - "flops": profile_macs(model, args) / 1.0, + "flops": 2 * profile_macs(model, args) / 1.0, } assert actual_results == expected_results @@ -83,7 +83,7 @@ def test_percent_limits(): cf = ConstraintsFunc(model, constraints=constraints, dummy_input=args) remove_bn(model) - max_flops = profile_macs(model, args) + max_flops = 2 * profile_macs(model, args) max_params = param_num(model, unit=1.0) expected_results = { # NOTE: using trainable_param_num here instead of trainable_param_num_from_forward to check diff --git a/tests/unit/torch/nas/test_nas_utils.py b/tests/unit/torch/nas/test_nas_utils.py index b16f4b315..3186912ec 100644 --- a/tests/unit/torch/nas/test_nas_utils.py +++ b/tests/unit/torch/nas/test_nas_utils.py @@ -36,5 +36,5 @@ def test_flops(in_channel, out_channel, kernel_size, groups, data_x, data_y) -> input_data_shape = (1, in_channel, data_x, data_y) out_elements = out_channel * data_x * data_y per_element_filters = in_channel * kernel_size * kernel_size // groups - desired_output = out_elements * per_element_filters + desired_output = 2 * out_elements * per_element_filters assert inference_flops(conv_module, data_shape=input_data_shape, unit=1) == desired_output