diff --git a/notebooks/autoencoder.ipynb b/notebooks/autoencoder.ipynb index 4f8d4881..134c8e5c 100644 --- a/notebooks/autoencoder.ipynb +++ b/notebooks/autoencoder.ipynb @@ -14,6 +14,34 @@ "---" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Overview\n", + "\n", + "In previous notebooks we have demonstrated how `xbatcher` converts both toy `xarray` objects into tensors and back again. In this notebook we incorporate these functions in an end-to-end workflow training an autoencoder on an elevation dataset. Once trained, the model is used to reconstruct two datasets:\n", + "\n", + " - The overall elevation tile\n", + " - A data cube of the autoencoder's latent dimension\n", + "\n", + "## Prerequisites\n", + "\n", + "This notebook assumes familiarity with `xarray`, `xbatcher`, and `torch`. You don't have to know how autoencoders work - we explain that when necessary.\n", + "\n", + "| Concepts | Importance | Notes |\n", + "| --- | --- | --- |\n", + "| [Intro to Xarray](https://tutorial.xarray.dev/overview/xarray-in-45-min.html) | Necessary | Array indexing |\n", + "| [Xbatcher fundamentals](https://projectpythia.org/xbatcher-deep-learning/notebooks/xbatcher-dataloading/) | Necessary | Passing data to models |\n", + "| [PyTorch fundamentals](https://docs.pytorch.org/tutorials/beginner/basics/intro.html) | Helpful | Model training loop |\n", + "| [Autoencoders](https://lightning.ai/docs/pytorch/stable/notebooks/course_UvA-DL/08-deep-autoencoders.html) | Helpful | More information on how autoencoders work.\n", + "\n", + "- **Time to learn**: 30 minutes.\n", + "- **System requirements**:\n", + " - Windows users may hit an import error on `rioxarray` ([link](https://gis.stackexchange.com/questions/417733/unable-to-import-python-rasterio-package-even-though-it-is-installed)). If that happens, add `import osgeo` above `import rioxarray` and that seems to fix the issue.\n", + "\n" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -23,21 +51,19 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 21, "metadata": {}, "outputs": [], "source": [ "import os\n", + "if os.name == 'nt':\n", + " import osgeo\n", + "\n", "from importlib import reload\n", "\n", "# DL stuff\n", - "import matplotlib\n", "import matplotlib.pyplot as plt\n", "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "import torch.optim as optim\n", - "import torch.utils.data as data\n", "\n", "torch.set_default_dtype(torch.float64)\n", "\n", @@ -45,14 +71,12 @@ "import xarray as xr\n", "import xbatcher\n", "import rioxarray\n", - "import xbatcher\n", "from xbatcher.loaders.torch import MapDataset\n", "\n", "# Etc\n", "import numpy as np\n", "from numpy.linalg import norm\n", "from matplotlib import pyplot as plt\n", - "from tqdm.autonotebook import tqdm\n", "\n", "# Locals\n", "import functions\n", @@ -70,14 +94,539 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We will start by pulling a segment of NASADEM for Washington's Olympic peninsula." + "We will start by pulling a segment of NASADEM for Washington's Olympic peninsula. The entire DEM is also available on NASA Earthdata and Planetary Computer." ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": {}, - "outputs": [], + "outputs": [ + { + "data": { + "text/html": [ + "
<xarray.DataArray (band: 1, y: 3600, x: 3600)> Size: 104MB\n", + "array([[[0.23102867, 0.23608769, 0.23988196, ..., 0.10581788,\n", + " 0.10961214, 0.11298482],\n", + " [0.23229342, 0.2386172 , 0.24283305, ..., 0.11045531,\n", + " 0.11340641, 0.1163575 ],\n", + " [0.23102867, 0.23819562, 0.24325464, ..., 0.11467116,\n", + " 0.1163575 , 0.11804384],\n", + " ...,\n", + " [0.00252951, 0.00252951, 0.00252951, ..., 0.03752108,\n", + " 0.03625632, 0.03752108],\n", + " [0.00252951, 0.00252951, 0.00252951, ..., 0.03625632,\n", + " 0.03499157, 0.03878583],\n", + " [0.0029511 , 0.0029511 , 0.0029511 , ..., 0.03667791,\n", + " 0.03625632, 0.04005059]]], shape=(1, 3600, 3600))\n", + "Coordinates:\n", + " * band (band) int64 8B 1\n", + " * x (x) float64 29kB -124.0 -124.0 -124.0 ... -123.0 -123.0 -123.0\n", + " * y (y) float64 29kB 48.0 48.0 48.0 48.0 ... 47.0 47.0 47.0 47.0\n", + " spatial_ref int64 8B 0
<xarray.DataArray (y: 112, x: 112, channel: 64)> Size: 6MB\n", + "array([[[-2.98959292e-01, 1.64430934e-01, 1.15438255e-02, ...,\n", + " -8.08710732e-03, 3.30118317e-02, -5.46855937e-04],\n", + " [-2.24567760e-03, 9.47532182e-02, -3.34303832e-02, ...,\n", + " -6.48077535e-03, 8.64085204e-04, -1.95042637e-02],\n", + " [-2.29584158e-02, -1.93482539e-01, -2.64102204e-02, ...,\n", + " -5.21691757e-02, -2.33358999e-03, -1.43376856e-03],\n", + " ...,\n", + " [-7.14939087e-01, 3.34519677e-02, -1.22026729e-01, ...,\n", + " -4.24666117e-02, 1.09958167e-01, -2.81837341e-02],\n", + " [-1.44703043e+00, -6.82025516e-03, -1.32837495e-01, ...,\n", + " 1.01037129e-01, 1.24011525e-01, 2.93925781e-02],\n", + " [-1.88280569e+00, 1.58894039e-01, 1.94406669e-01, ...,\n", + " 2.02963516e-01, -1.02580176e-01, -2.27122382e-02]],\n", + "\n", + " [[-3.11284028e-01, 4.57395907e-02, -5.93765944e-02, ...,\n", + " 9.89407912e-02, 6.95664992e-03, -2.08061733e-02],\n", + " [-8.16732126e-02, 2.07674481e-01, 5.54563779e-02, ...,\n", + " -3.88096709e-02, 3.14765790e-02, -7.57243798e-04],\n", + " [-1.09194994e-01, -5.78870165e-03, 8.61930280e-02, ...,\n", + " 2.35724309e-02, 4.35084485e-02, 6.65596484e-04],\n", + "...\n", + " -1.84319850e-01, 5.72853962e-02, 1.05324888e-01],\n", + " [-2.57522762e+00, 4.06043765e-01, 2.18213096e-01, ...,\n", + " 2.16262191e-01, 2.44147664e-01, -2.79635558e-02],\n", + " [-4.53882925e+00, -3.61320565e-01, 9.88348282e-02, ...,\n", + " -9.99270161e-02, -1.89315409e-01, 7.27320120e-02]],\n", + "\n", + " [[-5.32383745e+00, -2.59905275e-01, -9.42536151e-02, ...,\n", + " -1.14359197e-01, -7.32243057e-02, -1.17779599e-01],\n", + " [-4.61458548e+00, -3.59844762e-01, 1.96055190e-01, ...,\n", + " -9.51545225e-02, -2.07306574e-01, -1.21820868e-01],\n", + " [-4.43024026e+00, -1.50307160e-01, -2.68508102e-01, ...,\n", + " 2.49489361e-02, -2.64294322e-01, -1.70129536e-02],\n", + " ...,\n", + " [-1.79788969e+00, -4.62374302e-01, 3.78449074e-02, ...,\n", + " -1.45970427e-01, -1.10146726e-01, -3.16941288e-02],\n", + " [-1.89689242e+00, 1.96203432e-01, 1.83354633e-02, ...,\n", + " -1.92172321e-01, 7.78421414e-02, -5.74109845e-02],\n", + " [-4.11928259e+00, -7.79838400e-01, -1.65173175e-01, ...,\n", + " 1.25174990e-01, -2.02087485e-01, -2.00104178e-02]]],\n", + " shape=(112, 112, 64))\n", + "Coordinates:\n", + " * y (y) float64 896B 48.0 47.99 47.98 47.97 ... 47.04 47.03 47.02 47.01\n", + " * x (x) float64 896B -124.0 -124.0 -124.0 ... -123.0 -123.0 -123.0\n", + "Dimensions without coordinates: channel
<xarray.DataArray (channel: 64)> Size: 512B\n", + "array([ 3.05193448, 0.07984048, -0.19406716, 0.23198933, -0.28311083,\n", + " 0.06425693, -0.41277254, 0.10719062, 0.21293048, 0.18220075,\n", + " 0.02312008, 0.26989496, -3.81439401, 0.01594482, 0.21865739,\n", + " 2.11210263, -0.34720822, -0.01163836, 0.0337085 , -0.14257216,\n", + " -2.99752761, -0.09423765, 0.09816827, -1.24537382, -2.30954931,\n", + " -0.71140406, 0.2917199 , -0.0769781 , -0.03020979, -0.22612203,\n", + " -2.74899594, 0.38115255, -0.82641506, -0.48060534, 0.07423552,\n", + " -0.33775832, 0.22184855, 0.06239392, 1.01446021, 0.05642529,\n", + " 0.01650721, 1.78647155, 0.32631236, -0.45510718, 0.24257852,\n", + " 0.43398264, -0.04324041, 0.19464324, 0.64706841, -0.03532742,\n", + " -0.28180135, -4.13436094, 3.57349098, -1.55306568, -4.02541322,\n", + " -0.01129879, -0.14259338, 0.12843601, -0.26321372, -0.1023512 ,\n", + " -3.86097112, 0.10304318, 0.17730129, 0.16089438])\n", + "Coordinates:\n", + " y float64 8B 47.8\n", + " x float64 8B -123.7\n", + "Dimensions without coordinates: channel