|
| 1 | +"""Copyright (c) Microsoft Corporation. Licensed under the MIT license.""" |
| 2 | + |
| 3 | +import torch |
| 4 | + |
| 5 | +__all__ = ["area", "compute_patch_areas", "radius_earth"] |
| 6 | + |
| 7 | + |
| 8 | +radius_earth = 6378137 / 1000 |
| 9 | +"""float: Radius of the earth in kilometers.""" |
| 10 | + |
| 11 | + |
| 12 | +def area(polygon: torch.Tensor) -> torch.Tensor: |
| 13 | + """Compute the area of a polygon specified by latitudes and longitudes in degrees. |
| 14 | +
|
| 15 | + This function is a PyTorch port of the PyPI package `area`. In particular, it is heavily |
| 16 | + inspired by the following file: |
| 17 | +
|
| 18 | + https://github.com/scisco/area/blob/9d9549d6ebffcbe4bffe11b71efa2d406d1c9fe9/area/__init__.py |
| 19 | +
|
| 20 | + Args: |
| 21 | + polygon (:class:`torch.Tensor`): Polygon of the shape `(*b, n, 2)` where `b` is an optional |
| 22 | + multidimensional batch size, `n` is the number of points of the polygon, and 2 |
| 23 | + concatenates first latitudes and then longitudes. The polygon does not have be closed. |
| 24 | +
|
| 25 | + Returns: |
| 26 | + :class:`torch.Tensor`: Area in square kilometers. |
| 27 | + """ |
| 28 | + # Be sure to close the loop. |
| 29 | + polygon = torch.cat((polygon, polygon[..., -1:, :]), axis=-2) |
| 30 | + |
| 31 | + area = torch.zeros(polygon.shape[:-2], dtype=polygon.dtype, device=polygon.device) |
| 32 | + n = polygon.shape[-2] # Number of points of the polygon |
| 33 | + |
| 34 | + rad = torch.deg2rad # Convert degrees to radians. |
| 35 | + |
| 36 | + if n > 2: |
| 37 | + for i in range(n): |
| 38 | + i_lower = i |
| 39 | + i_middle = (i + 1) % n |
| 40 | + i_upper = (i + 2) % n |
| 41 | + |
| 42 | + lon_lower = polygon[..., i_lower, 1] |
| 43 | + lat_middle = polygon[..., i_middle, 0] |
| 44 | + lon_upper = polygon[..., i_upper, 1] |
| 45 | + |
| 46 | + area = area + (rad(lon_upper) - rad(lon_lower)) * torch.sin(rad(lat_middle)) |
| 47 | + |
| 48 | + area = area * radius_earth * radius_earth / 2 |
| 49 | + |
| 50 | + return torch.abs(area) |
| 51 | + |
| 52 | + |
| 53 | +def expand_matrix(matrix: torch.Tensor) -> torch.Tensor: |
| 54 | + """Expand matrix by adding one row and one column to each side, using |
| 55 | + linear interpolation. |
| 56 | +
|
| 57 | + Args: |
| 58 | + matrix (:class:`torch.Tensor`): Matrix to expand. |
| 59 | +
|
| 60 | + Returns: |
| 61 | + :class:`torch.Tensor`: `matrix`, but with two extra rows and two extra columns. |
| 62 | + """ |
| 63 | + # Add top and bottom rows. |
| 64 | + matrix = torch.cat( |
| 65 | + ( |
| 66 | + 2 * matrix[0:1] - matrix[1:2], |
| 67 | + matrix, |
| 68 | + 2 * matrix[-1:] - matrix[-2:-1], |
| 69 | + ), |
| 70 | + dim=0, |
| 71 | + ) |
| 72 | + |
| 73 | + # Add left and right columns. |
| 74 | + matrix = torch.cat( |
| 75 | + ( |
| 76 | + 2 * matrix[:, 0:1] - matrix[:, 1:2], |
| 77 | + matrix, |
| 78 | + 2 * matrix[:, -1:] - matrix[:, -2:-1], |
| 79 | + ), |
| 80 | + dim=1, |
| 81 | + ) |
| 82 | + |
| 83 | + return matrix |
| 84 | + |
| 85 | + |
| 86 | +def compute_patch_areas(lat: torch.Tensor, lon: torch.Tensor) -> torch.Tensor: |
| 87 | + """A pair of latitude and longitude matrices defines a number non-intersecting patches on the |
| 88 | + Earth. For a global grid, these patches span the entire surface of the Earth. For a local grid, |
| 89 | + the patches might span only a country or a continent. This function computes the area of every |
| 90 | + specified patch. |
| 91 | +
|
| 92 | + To divide the Earth into patches, the idea is to let a grid point be the _center_ of the |
| 93 | + corresponding patch. The vertices of this patch will then sit exactly inbetween the grid |
| 94 | + point and the grid points immediately diagonally and non-diagonally above, below, left, and |
| 95 | + right. For a grid point at the very top of the grid, for example, there is no immediately above |
| 96 | + grid point. In that case, we enlarge the grid by a row at the top by linearly interpolating the |
| 97 | + latitudinal progression. |
| 98 | +
|
| 99 | + Summary of algorithm: |
| 100 | + 1. Enlarge the latitude and longitude matrices by adding one row and one column to each side. |
| 101 | + 2. Calculate the patch vertices by averaging every 2x2 square in the enlarged grid. We also |
| 102 | + call these points the midpoints. |
| 103 | + 3. By using the vertices of the patches, i.e. the midpoints, compute the areas of the patches. |
| 104 | +
|
| 105 | + Args: |
| 106 | + lat (:class:`torch.Tensor`): Latitude matrix. Must be decreasing along rows. |
| 107 | + lon (:class:`torch.Tensor`): Longitude matrix. Must be increasing along columns. |
| 108 | +
|
| 109 | + Returns: |
| 110 | + :class:`torch.Tensor`: Areas in square kilometer. |
| 111 | + """ |
| 112 | + if not (lat.dim() == lon.dim() == 2): |
| 113 | + raise ValueError("`lat` and `lon` must both be matrices.") |
| 114 | + if lat.shape != lat.shape: |
| 115 | + raise ValueError("`lat` and `lon` must have the same shape.") |
| 116 | + |
| 117 | + # Check that the latitude matrix is decreasing in the appropriate way. |
| 118 | + if not torch.all(lat[1:] - lat[:-1] <= 0): |
| 119 | + raise ValueError("`lat` must be decreasing along rows.") |
| 120 | + |
| 121 | + # Check that the longitude matrix is increasing in the appropriate way. |
| 122 | + if not torch.all(lon[:, 1:] - lon[:, :-1] >= 0): |
| 123 | + raise ValueError("`lon` must be increasing along columns.") |
| 124 | + |
| 125 | + # Enlarge the latitude and longitude matrices for the midpoint computation. |
| 126 | + lat = expand_matrix(lat) |
| 127 | + lon = expand_matrix(lon) |
| 128 | + |
| 129 | + # Latitudes cannot expand beyond the poles. |
| 130 | + lat = torch.clamp(lat, -90, 90) |
| 131 | + |
| 132 | + # Calculate midpoints between entries in lat/lon. This is very important for symmetry of the |
| 133 | + # resulting areas. |
| 134 | + lat_midpoints = (lat[:-1, :-1] + lat[:-1, 1:] + lat[1:, :-1] + lat[1:, 1:]) / 4 |
| 135 | + lon_midpoints = (lon[:-1, :-1] + lon[:-1, 1:] + lon[1:, :-1] + lon[1:, 1:]) / 4 |
| 136 | + |
| 137 | + # Determine squares and return the area of those squares. |
| 138 | + top_left = torch.stack((lat_midpoints[1:, :-1], lon_midpoints[1:, :-1]), dim=-1) |
| 139 | + top_right = torch.stack((lat_midpoints[1:, 1:], lon_midpoints[1:, 1:]), dim=-1) |
| 140 | + bottom_left = torch.stack((lat_midpoints[:-1, :-1], lon_midpoints[:-1, :-1]), dim=-1) |
| 141 | + bottom_right = torch.stack((lat_midpoints[:-1, 1:], lon_midpoints[:-1, 1:]), dim=-1) |
| 142 | + polygon = torch.stack((top_left, top_right, bottom_right, bottom_left), dim=-2) |
| 143 | + |
| 144 | + return area(polygon) |
0 commit comments