|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "markdown", |
| 5 | + "id": "501aa46e-5a51-4905-abba-9c9d162887ff", |
| 6 | + "metadata": {}, |
| 7 | + "source": [ |
| 8 | + "# PyTorch Custom Datasets Video Notebook\n", |
| 9 | + "\n", |
| 10 | + "We've used some datasets with PyTorch before.\n", |
| 11 | + "\n", |
| 12 | + "But how do you get your own data into PyTorch?\n", |
| 13 | + "\n", |
| 14 | + "One of the ways to do is via: custom Datasets\n", |
| 15 | + "\n", |
| 16 | + "## Domain libraries\n", |
| 17 | + "\n", |
| 18 | + "Depending on what you're working on, vision, text, audio, recommendation, you'll want to look into each of PyTorch domain libraries for existing data loading functions and customalizable functionsm" |
| 19 | + ] |
| 20 | + }, |
| 21 | + { |
| 22 | + "cell_type": "code", |
| 23 | + "execution_count": 1, |
| 24 | + "id": "27b3dc99-7f43-4eb6-8a74-fc1b3e4fa86b", |
| 25 | + "metadata": {}, |
| 26 | + "outputs": [ |
| 27 | + { |
| 28 | + "data": { |
| 29 | + "text/plain": [ |
| 30 | + "'1.9.1'" |
| 31 | + ] |
| 32 | + }, |
| 33 | + "execution_count": 1, |
| 34 | + "metadata": {}, |
| 35 | + "output_type": "execute_result" |
| 36 | + } |
| 37 | + ], |
| 38 | + "source": [ |
| 39 | + "import torch\n", |
| 40 | + "from torch import nn,optim\n", |
| 41 | + "torch.__version__" |
| 42 | + ] |
| 43 | + }, |
| 44 | + { |
| 45 | + "cell_type": "code", |
| 46 | + "execution_count": 2, |
| 47 | + "id": "f1cc4468-8832-4c43-b801-965f7a1575f2", |
| 48 | + "metadata": {}, |
| 49 | + "outputs": [], |
| 50 | + "source": [ |
| 51 | + "device = 'cuda' if torch.cuda.is_available() else 'cpu'" |
| 52 | + ] |
| 53 | + }, |
| 54 | + { |
| 55 | + "cell_type": "markdown", |
| 56 | + "id": "53fad586-2be6-4902-82ed-59465dbb01fc", |
| 57 | + "metadata": {}, |
| 58 | + "source": [ |
| 59 | + "## Get data\n", |
| 60 | + "\n", |
| 61 | + "Our dataset is a subset of the Food101 dataset.\n", |
| 62 | + "\n", |
| 63 | + "Food101 start 101 different classes of food. (750,250)\n", |
| 64 | + "\n", |
| 65 | + "Our dataset starts with 3 classes of fodd and only 10% of the images(75 and 25 testing)" |
| 66 | + ] |
| 67 | + }, |
| 68 | + { |
| 69 | + "cell_type": "code", |
| 70 | + "execution_count": 3, |
| 71 | + "id": "209c6a08-719f-4af8-a58c-377a1f873757", |
| 72 | + "metadata": {}, |
| 73 | + "outputs": [], |
| 74 | + "source": [ |
| 75 | + "import requests\n", |
| 76 | + "import zipfile\n", |
| 77 | + "from pathlib import Path\n", |
| 78 | + "\n", |
| 79 | + "# Setup path to a data folder\n", |
| 80 | + "\n", |
| 81 | + "data_path = Path(\"data/04/01\")\n", |
| 82 | + "\n", |
| 83 | + "data_path.mkdir(parents=True,exist_ok=True)\n", |
| 84 | + "\n", |
| 85 | + "with open(data_path / \"data.zip\",'wb') as f:\n", |
| 86 | + " request = requests.get(\"https://github.com/mrdbourke/pytorch-deep-learning/raw/main/data/pizza_steak_sushi.zip\")\n", |
| 87 | + " f.write(request.content)\n", |
| 88 | + "with zipfile.ZipFile(data_path / \"data.zip\", 'r') as zip_ref:\n", |
| 89 | + " zip_ref.extractall(data_path)" |
| 90 | + ] |
| 91 | + }, |
| 92 | + { |
| 93 | + "cell_type": "code", |
| 94 | + "execution_count": 4, |
| 95 | + "id": "88a16a48-0f37-4f6e-a179-f19319e83244", |
| 96 | + "metadata": {}, |
| 97 | + "outputs": [], |
| 98 | + "source": [ |
| 99 | + "import os\n", |
| 100 | + "def walk_through_dir(dir_path):\n", |
| 101 | + " \"\"\"Walks through dir_path returning its contents.\"\"\"\n", |
| 102 | + " for dirpath,dirnames,filenames in os.walk(dir_path):\n", |
| 103 | + " print(f\"There are {len(dirnames)} and {len(filenames)} images in {dirpath}\")" |
| 104 | + ] |
| 105 | + }, |
| 106 | + { |
| 107 | + "cell_type": "code", |
| 108 | + "execution_count": 5, |
| 109 | + "id": "cba4eb63-9135-46ff-8b77-eb5fec60cb85", |
| 110 | + "metadata": {}, |
| 111 | + "outputs": [ |
| 112 | + { |
| 113 | + "name": "stdout", |
| 114 | + "output_type": "stream", |
| 115 | + "text": [ |
| 116 | + "There are 2 and 1 images in data/04/01\n", |
| 117 | + "There are 3 and 0 images in data/04/01/train\n", |
| 118 | + "There are 0 and 75 images in data/04/01/train/steak\n", |
| 119 | + "There are 1 and 78 images in data/04/01/train/pizza\n", |
| 120 | + "There are 0 and 3 images in data/04/01/train/pizza/.ipynb_checkpoints\n", |
| 121 | + "There are 0 and 72 images in data/04/01/train/sushi\n", |
| 122 | + "There are 3 and 0 images in data/04/01/test\n", |
| 123 | + "There are 0 and 19 images in data/04/01/test/steak\n", |
| 124 | + "There are 0 and 25 images in data/04/01/test/pizza\n", |
| 125 | + "There are 0 and 31 images in data/04/01/test/sushi\n" |
| 126 | + ] |
| 127 | + } |
| 128 | + ], |
| 129 | + "source": [ |
| 130 | + "walk_through_dir(data_path)" |
| 131 | + ] |
| 132 | + }, |
| 133 | + { |
| 134 | + "cell_type": "markdown", |
| 135 | + "id": "e5eb494b-a379-4123-acef-ae52c1c83bd7", |
| 136 | + "metadata": {}, |
| 137 | + "source": [ |
| 138 | + "## Visualizing an image\n", |
| 139 | + "\n", |
| 140 | + "Let's write some code to:\n", |
| 141 | + "1. Get all of the image paths\n", |
| 142 | + "2. Pick a random image path using `random.choice()`\n", |
| 143 | + "3. Get the image class name using `pathlib.Path.parent.stem`\n", |
| 144 | + "4. Science we're working with images, let open it with PIL\n", |
| 145 | + "5. Show and tell metadata" |
| 146 | + ] |
| 147 | + }, |
| 148 | + { |
| 149 | + "cell_type": "code", |
| 150 | + "execution_count": null, |
| 151 | + "id": "a908ae02-c671-4567-84d7-5d25e4d37c03", |
| 152 | + "metadata": {}, |
| 153 | + "outputs": [], |
| 154 | + "source": [ |
| 155 | + "import random\n", |
| 156 | + "from PIL import Image" |
| 157 | + ] |
| 158 | + } |
| 159 | + ], |
| 160 | + "metadata": { |
| 161 | + "kernelspec": { |
| 162 | + "display_name": "Python 3 (ipykernel)", |
| 163 | + "language": "python", |
| 164 | + "name": "python3" |
| 165 | + }, |
| 166 | + "language_info": { |
| 167 | + "codemirror_mode": { |
| 168 | + "name": "ipython", |
| 169 | + "version": 3 |
| 170 | + }, |
| 171 | + "file_extension": ".py", |
| 172 | + "mimetype": "text/x-python", |
| 173 | + "name": "python", |
| 174 | + "nbconvert_exporter": "python", |
| 175 | + "pygments_lexer": "ipython3", |
| 176 | + "version": "3.8.8" |
| 177 | + } |
| 178 | + }, |
| 179 | + "nbformat": 4, |
| 180 | + "nbformat_minor": 5 |
| 181 | +} |
0 commit comments