|
4 | 4 | "metadata": { |
5 | 5 | "colab": { |
6 | 6 | "provenance": [], |
7 | | - "gpuType": "T4" |
| 7 | + "gpuType": "T4", |
| 8 | + "include_colab_link": true |
8 | 9 | }, |
9 | 10 | "kernelspec": { |
10 | 11 | "name": "python3", |
|
2072 | 2073 | } |
2073 | 2074 | }, |
2074 | 2075 | "cells": [ |
| 2076 | + { |
| 2077 | + "cell_type": "markdown", |
| 2078 | + "metadata": { |
| 2079 | + "id": "view-in-github", |
| 2080 | + "colab_type": "text" |
| 2081 | + }, |
| 2082 | + "source": [ |
| 2083 | + "<a href=\"https://colab.research.google.com/github/mryab/efficient-dl-systems/blob/main/week05_large_models/practice_part2.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" |
| 2084 | + ] |
| 2085 | + }, |
2075 | 2086 | { |
2076 | 2087 | "cell_type": "markdown", |
2077 | 2088 | "source": [ |
|
2101 | 2112 | }, |
2102 | 2113 | "outputId": "02202d9e-12d4-4341-b31f-006c356bd2b7" |
2103 | 2114 | }, |
2104 | | - "execution_count": 1, |
| 2115 | + "execution_count": null, |
2105 | 2116 | "outputs": [ |
2106 | 2117 | { |
2107 | 2118 | "output_type": "stream", |
|
2706 | 2717 | " for name, param in tp_module.named_parameters():\n", |
2707 | 2718 | " print(f\"{name=},\\ttype={type(param.data)}\\tglobal shape={param.shape},\\tlocal shape={param._local_tensor.shape if hasattr(param, '_local_tensor') else param.shape}\")\n", |
2708 | 2719 | "\n", |
2709 | | - " dist.barrier() # test 1: forward pass\n", |
2710 | | - " # Convert input to DTensor with replicated placement TODO actually no\n", |
| 2720 | + " dist.barrier() # Test forward and backward pass with Tensor Parallelism\n", |
2711 | 2721 | " tp_input = input.detach().requires_grad_(True)\n", |
2712 | | - "\n", |
2713 | | - " # Test forward and backward pass with Tensor Parallelism\n", |
2714 | 2722 | " tp_output = tp_module(tp_input)\n", |
2715 | 2723 | " tp_output.sum().backward()\n", |
2716 | 2724 | " tp_output = tp_output.trigger_wait() # convert from AsyncCollectiveTensor to regular torch tensor\n", |
|
2734 | 2742 | "id": "CtvFF26mDWH2", |
2735 | 2743 | "outputId": "9ddddac6-65aa-4341-f12f-e884a26f0bea" |
2736 | 2744 | }, |
2737 | | - "execution_count": 176, |
| 2745 | + "execution_count": null, |
2738 | 2746 | "outputs": [ |
2739 | 2747 | { |
2740 | 2748 | "output_type": "stream", |
|
2757 | 2765 | "id": "rbGsyCMHDtXY", |
2758 | 2766 | "outputId": "337e86a1-d9fa-48e2-aff0-2f0089b56ff2" |
2759 | 2767 | }, |
2760 | | - "execution_count": 177, |
| 2768 | + "execution_count": null, |
2761 | 2769 | "outputs": [ |
2762 | 2770 | { |
2763 | 2771 | "output_type": "stream", |
|
2994 | 3002 | "id": "9ofi1_Kgusd8", |
2995 | 3003 | "outputId": "d456bd60-be30-4f81-917d-5ef3d9e04538" |
2996 | 3004 | }, |
2997 | | - "execution_count": 2, |
| 3005 | + "execution_count": null, |
2998 | 3006 | "outputs": [ |
2999 | 3007 | { |
3000 | 3008 | "output_type": "stream", |
|
3183 | 3191 | "metadata": { |
3184 | 3192 | "id": "xBWi66FkC1dz" |
3185 | 3193 | }, |
3186 | | - "execution_count": 2, |
| 3194 | + "execution_count": null, |
3187 | 3195 | "outputs": [] |
3188 | 3196 | }, |
3189 | 3197 | { |
|
0 commit comments