|
489 | 489 | "* prune the model;\n",
|
490 | 490 | "* obtain a valid pytorch model that can be used for fine-tuning.\n",
|
491 | 491 | "\n",
|
492 |
| - "Let's say you have the ResNet20 model as our base model to prune from and we are looking for a model with at most 30M FLOPs. We can provide search constraints for `flops` and/or `params` by an upper bound. The values can either be absolute numbers (e.g. `30e6`) or a string percentage (e.g. `\"75%\"`). In addition, we should also provide our training data loader to [mtp.prune](../reference/generated/modelopt.torch.prune.pruning.rst#modelopt.torch.prune.pruning.prune). The training data loader will be used to calibrate the normalization layers in the model. Finally, we will also specify a custom config for configuring the pruning search space to get a more fine-grained selection of pruned nets.\n", |
| 492 | + "Let's say you have the ResNet20 model as our base model to prune from and we are looking for a model with at most 60M FLOPs. We can provide search constraints for `flops` and/or `params` by an upper bound. The values can either be absolute numbers (e.g. `60e6`) or a string percentage (e.g. `\"75%\"`). In addition, we should also provide our training data loader to [mtp.prune](../reference/generated/modelopt.torch.prune.pruning.rst#modelopt.torch.prune.pruning.prune). The training data loader will be used to calibrate the normalization layers in the model. Finally, we will also specify a custom config for configuring the pruning search space to get a more fine-grained selection of pruned nets.\n", |
493 | 493 | "\n",
|
494 | 494 | "Finally, we can store the pruned architecture and weights using [mto.save](../reference/generated/modelopt.torch.opt.conversion.rst#modelopt.torch.opt.conversion.save).\n",
|
495 | 495 | "\n",
|
|
529 | 529 | "┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n",
|
530 | 530 | "┃\u001b[1m \u001b[0m\u001b[1mConstraint \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmin \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mcentroid \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmax \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmax/min ratio\u001b[0m\u001b[1m \u001b[0m┃\n",
|
531 | 531 | "┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n",
|
532 |
| - "│ flops │ 24.33M │ 27.57M │ 40.55M │ 1.67 │\n", |
| 532 | + "│ flops │ 48.66M │ 55.14M │ 81.10M │ 1.67 │\n", |
533 | 533 | "│ params │ 90.94K │ 141.63K │ 268.35K │ 2.95 │\n",
|
534 | 534 | "└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘\n",
|
535 | 535 | "\u001b[3m \u001b[0m\n",
|
|
538 | 538 | "┃\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mSatisfiable \u001b[0m\u001b[1m \u001b[0m┃\n",
|
539 | 539 | "┃\u001b[1m \u001b[0m\u001b[1mConstraint \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mUpper Bound \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mUpper Bound \u001b[0m\u001b[1m \u001b[0m┃\n",
|
540 | 540 | "┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩\n",
|
541 |
| - "│ flops │ 30.00M │ True │\n", |
| 541 | + "│ flops │ 60.00M │ True │\n", |
542 | 542 | "└──────────────┴──────────────┴──────────────┘\n",
|
543 | 543 | "\n",
|
544 | 544 | "\n",
|
|
618 | 618 | "name": "stdout",
|
619 | 619 | "output_type": "stream",
|
620 | 620 | "text": [
|
621 |
| - "[best_subnet_constraints] = {'params': '173.88K', 'flops': '29.64M'}\n" |
| 621 | + "[best_subnet_constraints] = {'params': '173.88K', 'flops': '59.28M'}\n" |
622 | 622 | ]
|
623 | 623 | },
|
624 | 624 | {
|
|
656 | 656 | "pruned_model, _ = mtp.prune(\n",
|
657 | 657 | " model=resnet20(ckpt=\"resnet20.pth\"),\n",
|
658 | 658 | " mode=[(\"fastnas\", config)],\n",
|
659 |
| - " constraints={\"flops\": 30e6},\n", |
| 659 | + " constraints={\"flops\": 60e6},\n", |
660 | 660 | " dummy_input=dummy_input,\n",
|
661 | 661 | " config={\n",
|
662 | 662 | " \"data_loader\": train_loader,\n",
|
|
676 | 676 | "cell_type": "markdown",
|
677 | 677 | "metadata": {},
|
678 | 678 | "source": [
|
679 |
| - "As we can see, the best subnet (29.6M FLOPs) has fitted our constraint of 30M FLOPs. We can also see a drop in validation accuracy for the searched model. This is very common after pruning and fine-tuning is necessary for this model.\n", |
| 679 | + "As we can see, the best subnet (59.3M FLOPs) has fitted our constraint of 60M FLOPs. We can also see a drop in validation accuracy for the searched model. This is very common after pruning and fine-tuning is necessary for this model.\n", |
680 | 680 | "\n",
|
681 | 681 | "#### Restore the pruned subnet using [mto.restore](../reference/generated/modelopt.torch.opt.conversion.rst#modelopt.torch.opt.conversion.restore)"
|
682 | 682 | ]
|
|
795 | 795 | "\n",
|
796 | 796 | "| Model | FLOPs | Params | Test Accuracy |\n",
|
797 | 797 | "| --------------- | ---------- | ---------- | ----------------- |\n",
|
798 |
| - "| ResNet20 | 40.6M | 268k | 90.9% |\n", |
799 |
| - "| FastNAS subnet | 29.6M | 174k | 90.3% |\n", |
| 798 | + "| ResNet20 | 81.2M | 268k | 90.9% |\n", |
| 799 | + "| FastNAS subnet | 59.2M | 174k | 90.3% |\n", |
800 | 800 | "\n",
|
801 | 801 | "As we see here, we have reduced the FLOPs and number of parameters which would also result in a improvement in latency with very little loss in accuracy. Good job!\n",
|
802 | 802 | "\n",
|
|
0 commit comments