Skip to content

Commit 9e5cdc5

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 081e8de commit 9e5cdc5

File tree

1 file changed

+46
-46
lines changed

1 file changed

+46
-46
lines changed

scripts/worldcover/run.py

Lines changed: 46 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,25 @@
11
#!/usr/bin/env python3
22

33
import sys
4+
45
sys.path.append("/home/ubuntu/worldcover/model")
56

67
import os
78
import tempfile
89
from math import floor
910
from pathlib import Path
10-
import requests
1111

1212
import boto3
1313
import einops
1414
import geopandas as gpd
15-
import pandas as pd
1615
import numpy
17-
import pyarrow as pa
16+
import pandas as pd
1817
import rasterio
18+
import requests
1919
import shapely
2020
import torch
2121
import xarray as xr
2222
from rasterio.windows import Window
23-
from shapely import box
2423
from torchvision.transforms import v2
2524

2625
from src.datamodule import ClayDataset
@@ -152,53 +151,54 @@ def download_image(url):
152151
else:
153152
raise Exception("Failed to download the image")
154153

154+
155155
def patches_and_windows_from_url(url, chunk_size=(PATCH_SIZE, PATCH_SIZE)):
156156
# Download the image from the URL
157157
image_data = download_image(url)
158-
158+
159159
# Open the image using rasterio from memory
160160
with rasterio.io.MemoryFile(image_data) as memfile:
161161
with memfile.open() as src:
162162
# Read the image data and metadata
163163
img_data = src.read()
164164
img_meta = src.profile
165165
img_crs = src.crs
166-
166+
167167
# Convert raster data and metadata into an xarray DataArray
168168
img_da = xr.DataArray(img_data, dims=("band", "y", "x"), attrs=img_meta)
169-
169+
170170
# Tile the data
171171
ds_chunked = img_da.chunk({"y": chunk_size[0], "x": chunk_size[1]})
172-
172+
173173
# Get the geospatial information from the original dataset
174174
transform = img_meta["transform"]
175-
175+
176176
# Iterate over the chunks and compute the geospatial bounds for each chunk
177177
chunk_bounds = {}
178-
178+
179179
for x in range(ds_chunked.sizes["x"] // chunk_size[1]):
180180
for y in range(ds_chunked.sizes["y"] // chunk_size[0]):
181181
# Compute chunk coordinates
182182
x_start = x * chunk_size[1]
183183
y_start = y * chunk_size[0]
184184
x_end = min(x_start + chunk_size[1], ds_chunked.sizes["x"])
185185
y_end = min(y_start + chunk_size[0], ds_chunked.sizes["y"])
186-
186+
187187
# Compute chunk geospatial bounds
188188
lon_start, lat_start = transform * (x_start, y_start)
189189
lon_end, lat_end = transform * (x_end, y_end)
190-
190+
191191
# Store chunk bounds
192192
chunk_bounds[(x, y)] = {
193193
"lon_start": lon_start,
194194
"lat_start": lat_start,
195195
"lon_end": lon_end,
196196
"lat_end": lat_end,
197197
}
198-
198+
199199
return chunk_bounds, img_crs
200200

201-
201+
202202
def make_batch(result):
203203
pixels = []
204204
for url, win in result:
@@ -233,10 +233,10 @@ def make_batch(result):
233233
"timestep": torch.as_tensor(data=[ds.normalize_timestamp(f"{YEAR}-06-01")]).to(
234234
rgb_model.device
235235
),
236-
"date": f"{YEAR}-06-01"
237-
,
236+
"date": f"{YEAR}-06-01",
238237
}
239238

239+
240240
def get_pixels(result):
241241
pixels = []
242242
for url, win in result:
@@ -323,89 +323,89 @@ def get_pixels(result):
323323
)
324324

325325
yoff += CHIP_SIZE
326-
327-
328326

329327
print(len(embeddings), len(results))
330-
#embeddings = numpy.vstack(embeddings)
328+
# embeddings = numpy.vstack(embeddings)
331329
embeddings_ = embeddings[0]
332330
print("Embeddings shape: ", embeddings_.shape)
333-
331+
334332
embeddings_ = embeddings_[:, :-2, :]
335-
336-
print(f"Embeddings have shape {embeddings_.shape}") #.mean(axis=1)
337-
333+
334+
print(f"Embeddings have shape {embeddings_.shape}") # .mean(axis=1)
335+
338336
# remove date and lat/lon and reshape to disaggregated patches
339337
embeddings_patch = embeddings_.reshape([2, 16, 16, 768])
340-
338+
341339
# average over the band groups
342340
embeddings_mean = embeddings_patch.mean(axis=0)
343-
344-
print(f"Average patch embeddings have shape {embeddings_mean.shape}")
345341

342+
print(f"Average patch embeddings have shape {embeddings_mean.shape}")
346343

347344
if result is not None:
348345
print("result: ", result[0][0])
349346
pix = get_pixels(result)
350347
chunk_bounds, epsg = patches_and_windows_from_url(result[0][0])
351-
#print("chunk_bounds: ", chunk_bounds)
348+
# print("chunk_bounds: ", chunk_bounds)
352349
print("chunk bounds length:", len(chunk_bounds))
353-
350+
354351
# Iterate through each patch
355352
for i in range(embeddings_mean.shape[0]):
356353
for j in range(embeddings_mean.shape[1]):
357354
embeddings_output_patch = embeddings_mean[i, j]
358-
355+
359356
item_ = [
360-
element for element in list(chunk_bounds.items()) if element[0] == (i, j)
357+
element
358+
for element in list(chunk_bounds.items())
359+
if element[0] == (i, j)
361360
]
362361
box_ = [
363362
item_[0][1]["lon_start"],
364363
item_[0][1]["lat_start"],
365364
item_[0][1]["lon_end"],
366365
item_[0][1]["lat_end"],
367366
]
368-
#source_url = batch["source_url"]
367+
# source_url = batch["source_url"]
369368
date = batch["date"]
370369
date_as_timestamp = pd.to_datetime(date, format="%Y-%m-%d")
371370

372371
# Convert the Pandas Timestamp to the desired data type
373-
#date_as_date32 = date_as_timestamp.astype('datetime64[D]')
372+
# date_as_date32 = date_as_timestamp.astype('datetime64[D]')
374373

375-
#print(batch["date"])
374+
# print(batch["date"])
376375
data = {
377376
"date": date_as_timestamp,
378377
"embeddings": [numpy.ascontiguousarray(embeddings_output_patch)],
379378
}
380-
379+
381380
# Define the bounding box as a Polygon (xmin, ymin, xmax, ymax)
382381
# The box_ list is encoded as
383382
# [bottom left x, bottom left y, top right x, top right y]
384383
box_emb = shapely.geometry.box(box_[0], box_[1], box_[2], box_[3])
385384

386385
print(str(epsg)[-4:])
387-
386+
388387
# Create the GeoDataFrame
389-
gdf = gpd.GeoDataFrame(data, geometry=[box_emb], crs=f"EPSG:{str(epsg)[-4:]}")
390-
388+
gdf = gpd.GeoDataFrame(
389+
data, geometry=[box_emb], crs=f"EPSG:{str(epsg)[-4:]}"
390+
)
391+
391392
# Reproject to WGS84 (lon/lat coordinates)
392393
gdf = gdf.to_crs(epsg=4326)
393-
394-
394+
395395
with tempfile.TemporaryDirectory() as tmp:
396396
# tmp = "/home/tam/Desktop/wcctmp"
397-
397+
398398
outpath = f"{tmp}/worldcover_patch_embeddings_{YEAR}_{index}_{i}_{j}_v{VERSION}.gpq"
399399
print(f"Uploading embeddings to {outpath}")
400-
#print(gdf)
401-
402-
gdf.to_parquet(path=outpath, compression="ZSTD", schema_version="1.0.0")
403-
400+
# print(gdf)
401+
402+
gdf.to_parquet(
403+
path=outpath, compression="ZSTD", schema_version="1.0.0"
404+
)
405+
404406
s3_client = boto3.client("s3")
405407
s3_client.upload_file(
406408
outpath,
407409
BUCKET,
408410
f"v{VERSION}/{YEAR}/{os.path.basename(outpath)}",
409411
)
410-
411-

0 commit comments

Comments
 (0)