|
162 | 162 | "\n", |
163 | 163 | "# Plotting the histogram.\n", |
164 | 164 | "plt.figure(figsize=(16, 10))\n", |
165 | | - "plt.hist(raw_data, density=True, bins=128, alpha=0.8, label=\"y\")\n", |
| 165 | + "plt.hist(raw_data, density=True, bins=128, alpha=0.8, label='y')\n", |
166 | 166 | "#plt.legend(loc='upper right')\n", |
167 | 167 | "plt.xlabel(\"Data\")\n", |
168 | 168 | "plt.ylabel(\"density\")\n", |
|
220 | 220 | "isClipped=np.logical_or(raw_data>clip_max, raw_data<clip_min)\n", |
221 | 221 | "idx_clipped_elements=np.where( isClipped )[0]\n", |
222 | 222 | "pd.DataFrame( \n", |
223 | | - " {\"idx\":idx_clipped_elements[:5], \n", |
224 | | - " \"raw\": raw_data[ idx_clipped_elements[:5] ],\n", |
225 | | - " \"clipped\": clipped_data[idx_clipped_elements[:5]] }\n", |
| 223 | + " {'idx':idx_clipped_elements[:5], \n", |
| 224 | + " 'raw': raw_data[ idx_clipped_elements[:5] ],\n", |
| 225 | + " 'clipped': clipped_data[idx_clipped_elements[:5]] }\n", |
226 | 226 | " )" |
227 | 227 | ] |
228 | 228 | }, |
|
236 | 236 | "# Plot the distribution and the clipped data to visualize\n", |
237 | 237 | "\n", |
238 | 238 | "plt.figure(figsize=(16, 10))\n", |
239 | | - "plt.hist(raw_data, density=True, bins=64, label=\"y (raw values)\", histtype=\"step\", linewidth=3.5),\n", |
240 | | - "plt.hist(clipped_data, density=True, bins=64, color=[\"#33b1ff\"], alpha=0.8,label=\"y_clamp (clipped edges)\"), \n", |
| 239 | + "plt.hist(raw_data, density=True, bins=64, label=\"y (raw values)\", histtype='step', linewidth=3.5),\n", |
| 240 | + "plt.hist(clipped_data, density=True, bins=64, color=['#33b1ff'], alpha=0.8,label=\"y_clamp (clipped edges)\"), \n", |
241 | 241 | "plt.legend(fancybox=True, ncol=2)\n", |
242 | 242 | "plt.xlabel(\"Data\")\n", |
243 | 243 | "plt.ylabel(\"density\")\n", |
|
294 | 294 | "outputs": [], |
295 | 295 | "source": [ |
296 | 296 | "plt.figure(figsize=(16, 10))\n", |
297 | | - "plt.hist(raw_data, density=True, bins=64, alpha=0.8,label=\"y (raw values)\", histtype=\"step\", linewidth=3.5)\n", |
298 | | - "plt.hist(y_scaled, density=True, bins=64, color=[\"#33b1ff\"], alpha=0.6,label=\"scale+shift\")\n", |
299 | | - "plt.hist(y_int, density=True, bins=64, color=[\"#007d79\"],alpha=0.8,label=\"quantize\")\n", |
300 | | - "plt.legend(loc=\"upper left\", fancybox=True, ncol=3)\n", |
| 297 | + "plt.hist(raw_data, density=True, bins=64, alpha=0.8,label=\"y (raw values)\", histtype='step', linewidth=3.5)\n", |
| 298 | + "plt.hist(y_scaled, density=True, bins=64, color=['#33b1ff'], alpha=0.6,label=\"scale+shift\")\n", |
| 299 | + "plt.hist(y_int, density=True, bins=64, color=['#007d79'],alpha=0.8,label=\"quantize\")\n", |
| 300 | + "plt.legend(loc='upper left', fancybox=True, ncol=3)\n", |
301 | 301 | "plt.xlabel(\"Data\")\n", |
302 | 302 | "plt.ylabel(\"density\")\n", |
303 | 303 | "#plt.yscale('log')\n", |
|
340 | 340 | "yq = y_int * stepsize + zp\n", |
341 | 341 | "\n", |
342 | 342 | "plt.figure(figsize=(16, 10))\n", |
343 | | - "plt.hist(raw_data, density=True, bins=64, label=\"original y\", histtype=\"step\", linewidth=2.5)#alpha=0.8,\n", |
344 | | - "plt.hist(yq, density=True, color=[\"#33b1ff\"], bins=64, label=\"quantized y\")#alpha=0.7,\n", |
| 343 | + "plt.hist(raw_data, density=True, bins=64, label=\"original y\", histtype='step', linewidth=2.5)#alpha=0.8,\n", |
| 344 | + "plt.hist(yq, density=True, color=['#33b1ff'], bins=64, label=\"quantized y\")#alpha=0.7,\n", |
345 | 345 | "plt.legend(fancybox=True, ncol=2)\n", |
346 | 346 | "plt.xlabel(\"Data\")\n", |
347 | 347 | "plt.ylabel(\"density\")\n", |
|
367 | 367 | "source": [ |
368 | 368 | "plt.subplots(3,1, figsize=(16, 12), sharex=True)\n", |
369 | 369 | "\n", |
370 | | - "arstyle=dict(facecolor=\"C1\",alpha=0.5, shrink=0.05)\n", |
| 370 | + "arstyle=dict(facecolor='C1',alpha=0.5, shrink=0.05)\n", |
371 | 371 | "\n", |
372 | 372 | "n_bit = 4\n", |
373 | 373 | "clip_min, clip_max = -2.5, 2.5\n", |
374 | 374 | "asym_raw_data = np.abs(raw_data)\n", |
375 | | - "for i, (raw_i, lbl_i) in enumerate([(raw_data, \"Case 1: sym data, sym Q\"), \n", |
376 | | - " (asym_raw_data, \"Case 2: asym data, asym Q\"),\n", |
377 | | - " (asym_raw_data, \"Case 3: asym data sym Q\") ]):\n", |
378 | | - " if \"asym Q\" in lbl_i:\n", |
| 375 | + "for i, (raw_i, lbl_i) in enumerate([(raw_data, 'Case 1: sym data, sym Q'), \n", |
| 376 | + " (asym_raw_data, 'Case 2: asym data, asym Q'),\n", |
| 377 | + " (asym_raw_data, 'Case 3: asym data sym Q') ]):\n", |
| 378 | + " if 'asym Q' in lbl_i:\n", |
379 | 379 | " # asym quantization for range [0, clip_max]\n", |
380 | 380 | " clip_min_i = np.min(raw_i)\n", |
381 | 381 | " nbins = 2**n_bit -1\n", |
|
396 | 396 | " max_bin_i = np.round( (clip_max-zp)/scale)*scale + zp\n", |
397 | 397 | "\n", |
398 | 398 | " plt.subplot(311+i)\n", |
399 | | - " plt.hist(raw_i, density=False, bins=64, label=\"original y\", histtype=\"step\", linewidth=2.5)\n", |
400 | | - " plt.hist(yq_i, density=False, color=[\"#33b1ff\"], bins=64, label=\"y_q\")\n", |
| 399 | + " plt.hist(raw_i, density=False, bins=64, label=\"original y\", histtype='step', linewidth=2.5)\n", |
| 400 | + " plt.hist(yq_i, density=False, color=['#33b1ff'], bins=64, label='y_q')\n", |
401 | 401 | " plt.legend(fancybox=True, ncol=2, fontsize=14)\n", |
402 | 402 | "\n", |
403 | 403 | " plt.ylabel(\"Count\")\n", |
404 | | - " plt.annotate(\"upper clip bound\", xy=(max_bin_i, 0), xytext=(max_bin_i, 1e5), arrowprops=arstyle) \n", |
405 | | - " plt.annotate(\"lower clip bound\", xy=(clip_min_i, 0), xytext=(clip_min_i, 1e5), arrowprops=arstyle) \n", |
| 404 | + " plt.annotate('upper clip bound', xy=(max_bin_i, 0), xytext=(max_bin_i, 1e5), arrowprops=arstyle) \n", |
| 405 | + " plt.annotate('lower clip bound', xy=(clip_min_i, 0), xytext=(clip_min_i, 1e5), arrowprops=arstyle) \n", |
406 | 406 | " plt.title(lbl_i)\n", |
407 | 407 | "\n", |
408 | 408 | "plt.tight_layout()\n", |
|
478 | 478 | "# Generate 1 sample\n", |
479 | 479 | "input = torch.randn(N,C,H,W)\n", |
480 | 480 | "\n", |
481 | | - "print(\"Input Shape: \", input.shape)\n", |
482 | | - "print(\"Number of unique input values: \", input.detach().unique().size()[0])\n", |
483 | | - "print(f\"Expected: {N * C * H * W} (Based on randomly generated values for shape {N} x {C} x {H} x {W})\")" |
| 481 | + "print('Input Shape: ', input.shape)\n", |
| 482 | + "print('Number of unique input values: ', input.detach().unique().size()[0])\n", |
| 483 | + "print(f'Expected: {N * C * H * W} (Based on randomly generated values for shape {N} x {C} x {H} x {W})')" |
484 | 484 | ] |
485 | 485 | }, |
486 | 486 | { |
|
508 | 508 | "# Quantize the input data\n", |
509 | 509 | "input_quant = simpleQuantizer(input, n_bit, clip_min, clip_max)\n", |
510 | 510 | "\n", |
511 | | - "print(\"Quantized input Shape: \", input_quant.shape)\n", |
512 | | - "print(\"Number of unique quantized input values: \", input_quant.detach().unique().size()[0])\n", |
513 | | - "print(f\"Expected: {2 ** n_bit} (Based on 2 ^ {n_bit})\")" |
| 511 | + "print('Quantized input Shape: ', input_quant.shape)\n", |
| 512 | + "print('Number of unique quantized input values: ', input_quant.detach().unique().size()[0])\n", |
| 513 | + "print(f'Expected: {2 ** n_bit} (Based on 2 ^ {n_bit})')" |
514 | 514 | ] |
515 | 515 | }, |
516 | 516 | { |
|
577 | 577 | "# ignore bias for now \n", |
578 | 578 | "net.conv.bias = torch.nn.Parameter(bias)\n", |
579 | 579 | "\n", |
580 | | - "print(\"Weight Shape: \", weight.shape)\n", |
581 | | - "print(\"Number of unique weight values: \", weight.detach().unique().size()[0])\n", |
582 | | - "print(f\"Expected: {weight.numel()} (Based on randomly generated values for shape {weight.shape[0]} x {weight.shape[1]} x {weight.shape[2]} x {weight.shape[3]})\")" |
| 580 | + "print('Weight Shape: ', weight.shape)\n", |
| 581 | + "print('Number of unique weight values: ', weight.detach().unique().size()[0])\n", |
| 582 | + "print(f'Expected: {weight.numel()} (Based on randomly generated values for shape {weight.shape[0]} x {weight.shape[1]} x {weight.shape[2]} x {weight.shape[3]})')" |
583 | 583 | ] |
584 | 584 | }, |
585 | 585 | { |
|
605 | 605 | "# Quantize the weights (similar to input)\n", |
606 | 606 | "weight_quant = simpleQuantizer(weight, n_bit, clip_min, clip_max)\n", |
607 | 607 | "\n", |
608 | | - "print(\"Quantized weight Shape: \", weight_quant.shape)\n", |
609 | | - "print(\"Number of unique quantized weight values: \", weight_quant.detach().unique().size()[0])\n", |
610 | | - "print(f\"Expected: {2 ** n_bit} (Based on 2 ^ {n_bit})\")\n", |
611 | | - "print(\"First Channel of Quantized Weight\", weight_quant[0])\n" |
| 608 | + "print('Quantized weight Shape: ', weight_quant.shape)\n", |
| 609 | + "print('Number of unique quantized weight values: ', weight_quant.detach().unique().size()[0])\n", |
| 610 | + "print(f'Expected: {2 ** n_bit} (Based on 2 ^ {n_bit})')\n", |
| 611 | + "print('First Channel of Quantized Weight', weight_quant[0])\n" |
612 | 612 | ] |
613 | 613 | }, |
614 | 614 | { |
|
635 | 635 | "# Generate quantized output y, NOTE, this net is currently using non-quantized weight \n", |
636 | 636 | "y_quant = net(input_quant)\n", |
637 | 637 | "\n", |
638 | | - "print(\"Number of unique output values: \", y.detach().unique().size()[0])\n", |
639 | | - "print(\"Expected maximum unique output values: \", y.flatten().size()[0])\n", |
640 | | - "print(\"Number of unique quantized output values: \", y_quant.detach().unique().size()[0])\n" |
| 638 | + "print('Number of unique output values: ', y.detach().unique().size()[0])\n", |
| 639 | + "print('Expected maximum unique output values: ', y.flatten().size()[0])\n", |
| 640 | + "print('Number of unique quantized output values: ', y_quant.detach().unique().size()[0])\n" |
641 | 641 | ] |
642 | 642 | }, |
643 | 643 | { |
|
662 | 662 | "outputs": [], |
663 | 663 | "source": [ |
664 | 664 | "def PlotAndCompare(d1, d2, labels, title):\n", |
665 | | - " mse = nn.functional.mse_loss(d1, d2, reduction=\"mean\" )\n", |
| 665 | + " mse = nn.functional.mse_loss(d1, d2, reduction='mean' )\n", |
666 | 666 | " plt.hist( d1.flatten().detach().numpy(), bins=64, alpha = 0.7, density=True, label=labels[0])\n", |
667 | | - " plt.hist( d2.flatten().detach().numpy(), bins=64, color=[\"#33b1ff\"], alpha = 0.8, density=True, label=labels[1], histtype=\"step\", linewidth=3.5)\n", |
668 | | - " plt.yscale(\"log\")\n", |
669 | | - " plt.legend(loc=\"upper center\", bbox_to_anchor=(0.5, -0.1), fancybox=True, ncol=2)\n", |
| 667 | + " plt.hist( d2.flatten().detach().numpy(), bins=64, color=['#33b1ff'], alpha = 0.8, density=True, label=labels[1], histtype='step', linewidth=3.5)\n", |
| 668 | + " plt.yscale('log')\n", |
| 669 | + " plt.legend(loc='upper center', bbox_to_anchor=(0.5, -0.1), fancybox=True, ncol=2)\n", |
670 | 670 | " plt.title(f\"{title}, MSE={mse:.3f}\")\n", |
671 | 671 | "\n", |
672 | 672 | "\n", |
673 | 673 | "\n", |
674 | | - "titles=[\"inputs\", \"weights\", \"outputs\"]\n", |
675 | | - "isQ = [\"not quantized\", \"quantized\"]\n", |
| 674 | + "titles=['inputs', 'weights', 'outputs']\n", |
| 675 | + "isQ = ['not quantized', 'quantized']\n", |
676 | 676 | "for i, inp in enumerate([input, input_quant]):\n", |
677 | 677 | " for j, W in enumerate([weight, weight_quant]):\n", |
678 | 678 | " plt.subplots(1,3,figsize=(18,5))\n", |
679 | | - " plt.suptitle(f\"Case {i*2+j+1}: Input {isQ[i]}, Weight {isQ[j]}\", fontsize=20, ha=\"center\", va=\"bottom\")\n", |
680 | | - " plt.subplot(131); PlotAndCompare(input, inp, [\"raw\", isQ[i]], f\"input, {isQ[i]}\")\n", |
681 | | - " plt.subplot(132); PlotAndCompare(weight, W, [\"raw\", isQ[j]], f\"weight, {isQ[j]}\")\n", |
| 679 | + " plt.suptitle(f'Case {i*2+j+1}: Input {isQ[i]}, Weight {isQ[j]}', fontsize=20, ha='center', va='bottom')\n", |
| 680 | + " plt.subplot(131); PlotAndCompare(input, inp, ['raw', isQ[i]], f\"input, {isQ[i]}\")\n", |
| 681 | + " plt.subplot(132); PlotAndCompare(weight, W, ['raw', isQ[j]], f\"weight, {isQ[j]}\")\n", |
682 | 682 | " net.conv.weight = torch.nn.Parameter(W)\n", |
683 | 683 | " y_quant = net(inp)\n", |
684 | | - " plt.subplot(133); PlotAndCompare(y, y_quant, [\"raw\", f\"A={isQ[j]}, W={isQ[i]}\"], \"conv output\")\n", |
| 684 | + " plt.subplot(133); PlotAndCompare(y, y_quant, ['raw', f'A={isQ[j]}, W={isQ[i]}'], \"conv output\")\n", |
685 | 685 | " plt.show()\n", |
686 | 686 | "\n", |
687 | 687 | "\n" |
|
737 | 737 | "qcfg = qconfig_init()\n", |
738 | 738 | "\n", |
739 | 739 | "# set bits for quantization (nbits_a needs to be set to quantize input regardless of bias)\n", |
740 | | - "qcfg[\"nbits_w\"] = 4\n", |
741 | | - "qcfg[\"nbits_a\"] = 4\n", |
| 740 | + "qcfg['nbits_w'] = 4\n", |
| 741 | + "qcfg['nbits_a'] = 4\n", |
742 | 742 | "\n", |
743 | 743 | "# just to be consistent with our \"simple Quantizer\" (normally align_zero is True)\n", |
744 | | - "qcfg[\"align_zero\"] = False\n", |
| 744 | + "qcfg['align_zero'] = False\n", |
745 | 745 | "\n", |
746 | 746 | "# Quantization Mode here means which quantizers we would like to use,\n", |
747 | 747 | "# There are many quantizers available in fms_mo, such as PArameterized Clipping acTivation (PACT),\n", |
748 | 748 | "# Statstics-Aware Weight Binning (SAWB).\n", |
749 | | - "qcfg[\"qw_mode\"] = \"pact\"\n", |
750 | | - "qcfg[\"qa_mode\"] = \"pact\"\n", |
| 749 | + "qcfg['qw_mode'] = 'pact'\n", |
| 750 | + "qcfg['qa_mode'] = 'pact'\n", |
751 | 751 | "\n", |
752 | 752 | "# Set weight and input (activation) clip vals\n", |
753 | | - "qcfg[\"w_clip_init_valn\"], qcfg[\"w_clip_init_val\"] = -2.5, 2.5\n", |
754 | | - "qcfg[\"act_clip_init_valn\"], qcfg[\"act_clip_init_val\"] = -2.5, 2.5\n", |
| 753 | + "qcfg['w_clip_init_valn'], qcfg['w_clip_init_val'] = -2.5, 2.5\n", |
| 754 | + "qcfg['act_clip_init_valn'], qcfg['act_clip_init_val'] = -2.5, 2.5\n", |
755 | 755 | "\n", |
756 | 756 | "\n", |
757 | 757 | "# This parameter is usually False, but for Demo purposes we quantize the first/only layer\n", |
758 | | - "qcfg[\"q1stlastconv\"] = True\n", |
| 758 | + "qcfg['q1stlastconv'] = True\n", |
759 | 759 | "\n", |
760 | 760 | "\n", |
761 | 761 | "if path.exists(\"results\"):\n", |
762 | 762 | " print(\"results folder exists!\")\n", |
763 | 763 | "else:\n", |
764 | | - " os.makedirs(\"results\")\n", |
| 764 | + " os.makedirs('results')\n", |
765 | 765 | " \n", |
766 | 766 | "# Step 2: Prepare the model to convert layer to add Quantizers\n", |
767 | | - "qmodel_prep(net_fms_mo, input, qcfg, save_fname=\"./results/temp.pt\")\n", |
| 767 | + "qmodel_prep(net_fms_mo, input, qcfg, save_fname='./results/temp.pt')\n", |
768 | 768 | "\n" |
769 | 769 | ] |
770 | 770 | }, |
|
780 | 780 | "y_quant = net(input_quant) \n", |
781 | 781 | "\n", |
782 | 782 | "plt.figure(figsize=(16, 10))\n", |
783 | | - "PlotAndCompare(y_quant_fms_mo, y_quant, [\"fms_mo\",\"manual\"],\"quantized Conv output by different methods\")\n", |
| 783 | + "PlotAndCompare(y_quant_fms_mo, y_quant, ['fms_mo','manual'],'quantized Conv output by different methods')\n", |
784 | 784 | "plt.show()\n" |
785 | 785 | ] |
786 | 786 | }, |
|
804 | 804 | "metadata": {}, |
805 | 805 | "outputs": [], |
806 | 806 | "source": [ |
807 | | - "import os\n", |
808 | | - "import wget\n", |
809 | | - "IMG_FILE_NAME = \"lion.png\"\n", |
810 | | - "url = \"https://raw.githubusercontent.com/foundation-model-stack/fms-model-optimizer/main/tutorials/images/\" + IMG_FILE_NAME\n", |
| 807 | + "import os, wget\n", |
| 808 | + "IMG_FILE_NAME = 'lion.png'\n", |
| 809 | + "url = 'https://raw.githubusercontent.com/foundation-model-stack/fms-model-optimizer/main/tutorials/images/' + IMG_FILE_NAME\n", |
811 | 810 | "\n", |
812 | 811 | "if not os.path.isfile(IMG_FILE_NAME):\n", |
813 | 812 | " wget.download(url, out=IMG_FILE_NAME)\n", |
|
864 | 863 | "\n", |
865 | 864 | "plt.subplots(3,1,figsize=(16,25))\n", |
866 | 865 | "plt.subplot(311)\n", |
867 | | - "plt.title(\"Output from non-quantized model\", fontsize=20)\n", |
868 | | - "plt.imshow(feature_map, cmap=\"RdBu\")\n", |
| 866 | + "plt.title('Output from non-quantized model', fontsize=20)\n", |
| 867 | + "plt.imshow(feature_map, cmap='RdBu')\n", |
869 | 868 | "plt.clim(0,255)\n", |
870 | 869 | "plt.colorbar()\n", |
871 | 870 | "\n", |
872 | 871 | "plt.subplot(312)\n", |
873 | | - "plt.title(\"Output from quantized model\", fontsize=20)\n", |
874 | | - "plt.imshow(feature_map_quant, cmap=\"RdBu\")\n", |
| 872 | + "plt.title('Output from quantized model', fontsize=20)\n", |
| 873 | + "plt.imshow(feature_map_quant, cmap='RdBu')\n", |
875 | 874 | "plt.clim(0,255)\n", |
876 | 875 | "plt.colorbar()\n", |
877 | 876 | "\n", |
878 | 877 | "plt.subplot(313)\n", |
879 | | - "PlotAndCompare(y_img_tensor, y_img_quant, [\"raw\",\"quantized\"],\"Conv output\")\n", |
| 878 | + "PlotAndCompare(y_img_tensor, y_img_quant, ['raw','quantized'],'Conv output')\n", |
880 | 879 | "\n", |
881 | 880 | "plt.tight_layout()\n", |
882 | 881 | "plt.show()\n" |
|
0 commit comments