|
489 | 489 | "* prune the model;\n", |
490 | 490 | "* obtain a valid pytorch model that can be used for fine-tuning.\n", |
491 | 491 | "\n", |
492 | | - "Let's say you have the ResNet20 model as our base model to prune from and we are looking for a model with at most 30M FLOPs. We can provide search constraints for `flops` and/or `params` by an upper bound. The values can either be absolute numbers (e.g. `30e6`) or a string percentage (e.g. `\"75%\"`). In addition, we should also provide our training data loader to [mtp.prune](../reference/generated/modelopt.torch.prune.pruning.rst#modelopt.torch.prune.pruning.prune). The training data loader will be used to calibrate the normalization layers in the model. Finally, we will also specify a custom config for configuring the pruning search space to get a more fine-grained selection of pruned nets.\n", |
| 492 | + "Let's say you have the ResNet20 model as our base model to prune from and we are looking for a model with at most 60M FLOPs. We can provide search constraints for `flops` and/or `params` by an upper bound. The values can either be absolute numbers (e.g. `60e6`) or a string percentage (e.g. `\"75%\"`). In addition, we should also provide our training data loader to [mtp.prune](../reference/generated/modelopt.torch.prune.pruning.rst#modelopt.torch.prune.pruning.prune). The training data loader will be used to calibrate the normalization layers in the model. Finally, we will also specify a custom config for configuring the pruning search space to get a more fine-grained selection of pruned nets.\n", |
493 | 493 | "\n", |
494 | 494 | "Finally, we can store the pruned architecture and weights using [mto.save](../reference/generated/modelopt.torch.opt.conversion.rst#modelopt.torch.opt.conversion.save).\n", |
495 | 495 | "\n", |
|
529 | 529 | "┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓\n", |
530 | 530 | "┃\u001b[1m \u001b[0m\u001b[1mConstraint \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmin \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mcentroid \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmax \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mmax/min ratio\u001b[0m\u001b[1m \u001b[0m┃\n", |
531 | 531 | "┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩\n", |
532 | | - "│ flops │ 24.33M │ 27.57M │ 40.55M │ 1.67 │\n", |
| 532 | + "│ flops │ 48.66M │ 55.14M │ 81.10M │ 1.67 │\n", |
533 | 533 | "│ params │ 90.94K │ 141.63K │ 268.35K │ 2.95 │\n", |
534 | 534 | "└──────────────┴──────────────┴──────────────┴──────────────┴───────────────┘\n", |
535 | 535 | "\u001b[3m \u001b[0m\n", |
|
538 | 538 | "┃\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mSatisfiable \u001b[0m\u001b[1m \u001b[0m┃\n", |
539 | 539 | "┃\u001b[1m \u001b[0m\u001b[1mConstraint \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mUpper Bound \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mUpper Bound \u001b[0m\u001b[1m \u001b[0m┃\n", |
540 | 540 | "┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━┩\n", |
541 | | - "│ flops │ 30.00M │ True │\n", |
| 541 | + "│ flops │ 60.00M │ True │\n", |
542 | 542 | "└──────────────┴──────────────┴──────────────┘\n", |
543 | 543 | "\n", |
544 | 544 | "\n", |
|
618 | 618 | "name": "stdout", |
619 | 619 | "output_type": "stream", |
620 | 620 | "text": [ |
621 | | - "[best_subnet_constraints] = {'params': '173.88K', 'flops': '29.64M'}\n" |
| 621 | + "[best_subnet_constraints] = {'params': '173.88K', 'flops': '59.28M'}\n" |
622 | 622 | ] |
623 | 623 | }, |
624 | 624 | { |
|
656 | 656 | "pruned_model, _ = mtp.prune(\n", |
657 | 657 | " model=resnet20(ckpt=\"resnet20.pth\"),\n", |
658 | 658 | " mode=[(\"fastnas\", config)],\n", |
659 | | - " constraints={\"flops\": 30e6},\n", |
| 659 | + " constraints={\"flops\": 60e6},\n", |
660 | 660 | " dummy_input=dummy_input,\n", |
661 | 661 | " config={\n", |
662 | 662 | " \"data_loader\": train_loader,\n", |
|
676 | 676 | "cell_type": "markdown", |
677 | 677 | "metadata": {}, |
678 | 678 | "source": [ |
679 | | - "As we can see, the best subnet (29.6M FLOPs) has fitted our constraint of 30M FLOPs. We can also see a drop in validation accuracy for the searched model. This is very common after pruning and fine-tuning is necessary for this model.\n", |
| 679 | + "As we can see, the best subnet (59.3M FLOPs) has fitted our constraint of 60M FLOPs. We can also see a drop in validation accuracy for the searched model. This is very common after pruning and fine-tuning is necessary for this model.\n", |
680 | 680 | "\n", |
681 | 681 | "#### Restore the pruned subnet using [mto.restore](../reference/generated/modelopt.torch.opt.conversion.rst#modelopt.torch.opt.conversion.restore)" |
682 | 682 | ] |
|
795 | 795 | "\n", |
796 | 796 | "| Model | FLOPs | Params | Test Accuracy |\n", |
797 | 797 | "| --------------- | ---------- | ---------- | ----------------- |\n", |
798 | | - "| ResNet20 | 40.6M | 268k | 90.9% |\n", |
799 | | - "| FastNAS subnet | 29.6M | 174k | 90.3% |\n", |
| 798 | + "| ResNet20 | 81.2M | 268k | 90.9% |\n", |
| 799 | + "| FastNAS subnet | 59.2M | 174k | 90.3% |\n", |
800 | 800 | "\n", |
801 | 801 | "As we see here, we have reduced the FLOPs and number of parameters which would also result in a improvement in latency with very little loss in accuracy. Good job!\n", |
802 | 802 | "\n", |
|
0 commit comments