11#!/usr/bin/env python3
22
33import sys
4+
45sys .path .append ("/home/ubuntu/worldcover/model" )
56
67import os
78import tempfile
89from math import floor
910from pathlib import Path
10- import requests
1111
1212import boto3
1313import einops
1414import geopandas as gpd
15- import pandas as pd
1615import numpy
17- import pyarrow as pa
16+ import pandas as pd
1817import rasterio
18+ import requests
1919import shapely
2020import torch
2121import xarray as xr
2222from rasterio .windows import Window
23- from shapely import box
2423from torchvision .transforms import v2
2524
2625from src .datamodule import ClayDataset
@@ -152,53 +151,54 @@ def download_image(url):
152151 else :
153152 raise Exception ("Failed to download the image" )
154153
154+
155155def patches_and_windows_from_url (url , chunk_size = (PATCH_SIZE , PATCH_SIZE )):
156156 # Download the image from the URL
157157 image_data = download_image (url )
158-
158+
159159 # Open the image using rasterio from memory
160160 with rasterio .io .MemoryFile (image_data ) as memfile :
161161 with memfile .open () as src :
162162 # Read the image data and metadata
163163 img_data = src .read ()
164164 img_meta = src .profile
165165 img_crs = src .crs
166-
166+
167167 # Convert raster data and metadata into an xarray DataArray
168168 img_da = xr .DataArray (img_data , dims = ("band" , "y" , "x" ), attrs = img_meta )
169-
169+
170170 # Tile the data
171171 ds_chunked = img_da .chunk ({"y" : chunk_size [0 ], "x" : chunk_size [1 ]})
172-
172+
173173 # Get the geospatial information from the original dataset
174174 transform = img_meta ["transform" ]
175-
175+
176176 # Iterate over the chunks and compute the geospatial bounds for each chunk
177177 chunk_bounds = {}
178-
178+
179179 for x in range (ds_chunked .sizes ["x" ] // chunk_size [1 ]):
180180 for y in range (ds_chunked .sizes ["y" ] // chunk_size [0 ]):
181181 # Compute chunk coordinates
182182 x_start = x * chunk_size [1 ]
183183 y_start = y * chunk_size [0 ]
184184 x_end = min (x_start + chunk_size [1 ], ds_chunked .sizes ["x" ])
185185 y_end = min (y_start + chunk_size [0 ], ds_chunked .sizes ["y" ])
186-
186+
187187 # Compute chunk geospatial bounds
188188 lon_start , lat_start = transform * (x_start , y_start )
189189 lon_end , lat_end = transform * (x_end , y_end )
190-
190+
191191 # Store chunk bounds
192192 chunk_bounds [(x , y )] = {
193193 "lon_start" : lon_start ,
194194 "lat_start" : lat_start ,
195195 "lon_end" : lon_end ,
196196 "lat_end" : lat_end ,
197197 }
198-
198+
199199 return chunk_bounds , img_crs
200200
201-
201+
202202def make_batch (result ):
203203 pixels = []
204204 for url , win in result :
@@ -233,10 +233,10 @@ def make_batch(result):
233233 "timestep" : torch .as_tensor (data = [ds .normalize_timestamp (f"{ YEAR } -06-01" )]).to (
234234 rgb_model .device
235235 ),
236- "date" : f"{ YEAR } -06-01"
237- ,
236+ "date" : f"{ YEAR } -06-01" ,
238237 }
239238
239+
240240def get_pixels (result ):
241241 pixels = []
242242 for url , win in result :
@@ -323,89 +323,89 @@ def get_pixels(result):
323323 )
324324
325325 yoff += CHIP_SIZE
326-
327-
328326
329327 print (len (embeddings ), len (results ))
330- #embeddings = numpy.vstack(embeddings)
328+ # embeddings = numpy.vstack(embeddings)
331329 embeddings_ = embeddings [0 ]
332330 print ("Embeddings shape: " , embeddings_ .shape )
333-
331+
334332 embeddings_ = embeddings_ [:, :- 2 , :]
335-
336- print (f"Embeddings have shape { embeddings_ .shape } " ) # .mean(axis=1)
337-
333+
334+ print (f"Embeddings have shape { embeddings_ .shape } " ) # .mean(axis=1)
335+
338336 # remove date and lat/lon and reshape to disaggregated patches
339337 embeddings_patch = embeddings_ .reshape ([2 , 16 , 16 , 768 ])
340-
338+
341339 # average over the band groups
342340 embeddings_mean = embeddings_patch .mean (axis = 0 )
343-
344- print (f"Average patch embeddings have shape { embeddings_mean .shape } " )
345341
342+ print (f"Average patch embeddings have shape { embeddings_mean .shape } " )
346343
347344 if result is not None :
348345 print ("result: " , result [0 ][0 ])
349346 pix = get_pixels (result )
350347 chunk_bounds , epsg = patches_and_windows_from_url (result [0 ][0 ])
351- #print("chunk_bounds: ", chunk_bounds)
348+ # print("chunk_bounds: ", chunk_bounds)
352349 print ("chunk bounds length:" , len (chunk_bounds ))
353-
350+
354351 # Iterate through each patch
355352 for i in range (embeddings_mean .shape [0 ]):
356353 for j in range (embeddings_mean .shape [1 ]):
357354 embeddings_output_patch = embeddings_mean [i , j ]
358-
355+
359356 item_ = [
360- element for element in list (chunk_bounds .items ()) if element [0 ] == (i , j )
357+ element
358+ for element in list (chunk_bounds .items ())
359+ if element [0 ] == (i , j )
361360 ]
362361 box_ = [
363362 item_ [0 ][1 ]["lon_start" ],
364363 item_ [0 ][1 ]["lat_start" ],
365364 item_ [0 ][1 ]["lon_end" ],
366365 item_ [0 ][1 ]["lat_end" ],
367366 ]
368- #source_url = batch["source_url"]
367+ # source_url = batch["source_url"]
369368 date = batch ["date" ]
370369 date_as_timestamp = pd .to_datetime (date , format = "%Y-%m-%d" )
371370
372371 # Convert the Pandas Timestamp to the desired data type
373- #date_as_date32 = date_as_timestamp.astype('datetime64[D]')
372+ # date_as_date32 = date_as_timestamp.astype('datetime64[D]')
374373
375- #print(batch["date"])
374+ # print(batch["date"])
376375 data = {
377376 "date" : date_as_timestamp ,
378377 "embeddings" : [numpy .ascontiguousarray (embeddings_output_patch )],
379378 }
380-
379+
381380 # Define the bounding box as a Polygon (xmin, ymin, xmax, ymax)
382381 # The box_ list is encoded as
383382 # [bottom left x, bottom left y, top right x, top right y]
384383 box_emb = shapely .geometry .box (box_ [0 ], box_ [1 ], box_ [2 ], box_ [3 ])
385384
386385 print (str (epsg )[- 4 :])
387-
386+
388387 # Create the GeoDataFrame
389- gdf = gpd .GeoDataFrame (data , geometry = [box_emb ], crs = f"EPSG:{ str (epsg )[- 4 :]} " )
390-
388+ gdf = gpd .GeoDataFrame (
389+ data , geometry = [box_emb ], crs = f"EPSG:{ str (epsg )[- 4 :]} "
390+ )
391+
391392 # Reproject to WGS84 (lon/lat coordinates)
392393 gdf = gdf .to_crs (epsg = 4326 )
393-
394-
394+
395395 with tempfile .TemporaryDirectory () as tmp :
396396 # tmp = "/home/tam/Desktop/wcctmp"
397-
397+
398398 outpath = f"{ tmp } /worldcover_patch_embeddings_{ YEAR } _{ index } _{ i } _{ j } _v{ VERSION } .gpq"
399399 print (f"Uploading embeddings to { outpath } " )
400- #print(gdf)
401-
402- gdf .to_parquet (path = outpath , compression = "ZSTD" , schema_version = "1.0.0" )
403-
400+ # print(gdf)
401+
402+ gdf .to_parquet (
403+ path = outpath , compression = "ZSTD" , schema_version = "1.0.0"
404+ )
405+
404406 s3_client = boto3 .client ("s3" )
405407 s3_client .upload_file (
406408 outpath ,
407409 BUCKET ,
408410 f"v{ VERSION } /{ YEAR } /{ os .path .basename (outpath )} " ,
409411 )
410-
411-
0 commit comments