Skip to content

Commit 8f5de9a

Browse files
authored
Merge pull request #2139 from AdeelH/backport-0.30.1
[BACKPORT] Backport changes to 0.30 branch for v0.30.1 release
2 parents db46a7e + 33fd2fb commit 8f5de9a

File tree

18 files changed

+209
-133
lines changed

18 files changed

+209
-133
lines changed

.readthedocs.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@ python:
4040
path: rastervision_pytorch_learner/
4141
- method: pip
4242
path: rastervision_pytorch_backend/
43+
- method: pip
44+
path: rastervision_aws_sagemaker/
4345

4446
# https://docs.readthedocs.io/en/stable/config-file/v2.html#search
4547
search:

docs/framework/examples.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ The ``--tensorboard`` option should be used if running locally and you would lik
8383
export PROCESSED_URI="/opt/data/examples/spacenet/rio/processed-data"
8484
export ROOT_URI="/opt/data/examples/spacenet/rio/local-output"
8585
86-
rastervision run local rastervision.examples.chip_classification.spacenet_rio \
86+
rastervision run local rastervision.pytorch_backend.examples.chip_classification.spacenet_rio \
8787
-a raw_uri $RAW_URI -a processed_uri $PROCESSED_URI -a root_uri $ROOT_URI \
8888
-a test True --splits 2
8989
@@ -104,7 +104,7 @@ To run the full experiment on GPUs using AWS Batch, use something like the follo
104104
export PROCESSED_URI="s3://mybucket/examples/spacenet/rio/processed-data"
105105
export ROOT_URI="s3://mybucket/examples/spacenet/rio/remote-output"
106106
107-
rastervision run batch rastervision.examples.chip_classification.spacenet_rio \
107+
rastervision run batch rastervision.pytorch_backend.examples.chip_classification.spacenet_rio \
108108
-a raw_uri $RAW_URI -a processed_uri $PROCESSED_URI -a root_uri $ROOT_URI \
109109
-a test False --splits 8
110110

rastervision_aws_s3/rastervision/aws_s3/s3_file_system.py

Lines changed: 60 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Tuple
1+
from typing import Any, Iterator, Tuple
22
import io
33
import os
44
import subprocess
@@ -16,41 +16,38 @@
1616

1717

1818
# Code from https://alexwlchan.net/2017/07/listing-s3-keys/
19-
def get_matching_s3_objects(bucket, prefix='', suffix='',
20-
request_payer='None'):
21-
"""
22-
Generate objects in an S3 bucket.
23-
24-
:param bucket: Name of the S3 bucket.
25-
:param prefix: Only fetch objects whose key starts with
26-
this prefix (optional).
27-
:param suffix: Only fetch objects whose keys end with
28-
this suffix (optional).
19+
def get_matching_s3_objects(
20+
bucket: str,
21+
prefix: str = '',
22+
suffix: str = '',
23+
delimiter: str = '/',
24+
request_payer: str = 'None') -> Iterator[tuple[str, Any]]:
25+
"""Generate objects in an S3 bucket.
26+
27+
Args:
28+
bucket: Name of the S3 bucket.
29+
prefix: Only fetch objects whose key starts with this prefix.
30+
suffix: Only fetch objects whose keys end with this suffix.
2931
"""
3032
s3 = S3FileSystem.get_client()
31-
kwargs = {'Bucket': bucket, 'RequestPayer': request_payer}
32-
33-
# If the prefix is a single string (not a tuple of strings), we can
34-
# do the filtering directly in the S3 API.
35-
if isinstance(prefix, str):
36-
kwargs['Prefix'] = prefix
37-
33+
kwargs = dict(
34+
Bucket=bucket,
35+
RequestPayer=request_payer,
36+
Delimiter=delimiter,
37+
Prefix=prefix,
38+
)
3839
while True:
39-
40-
# The S3 API response is a large blob of metadata.
41-
# 'Contents' contains information about the listed objects.
42-
resp = s3.list_objects_v2(**kwargs)
43-
44-
try:
45-
contents = resp['Contents']
46-
except KeyError:
47-
return
48-
49-
for obj in contents:
40+
resp: dict = s3.list_objects_v2(**kwargs)
41+
dirs: list[dict] = resp.get('CommonPrefixes', {})
42+
files: list[dict] = resp.get('Contents', {})
43+
for obj in dirs:
44+
key = obj['Prefix']
45+
if key.startswith(prefix) and key.endswith(suffix):
46+
yield key, obj
47+
for obj in files:
5048
key = obj['Key']
5149
if key.startswith(prefix) and key.endswith(suffix):
52-
yield obj
53-
50+
yield key, obj
5451
# The S3 API is paginated, returning up to 1000 keys at a time.
5552
# Pass the continuation token into the next response, until we
5653
# reach the final page (when this field is missing).
@@ -60,16 +57,26 @@ def get_matching_s3_objects(bucket, prefix='', suffix='',
6057
break
6158

6259

63-
def get_matching_s3_keys(bucket, prefix='', suffix='', request_payer='None'):
64-
"""
65-
Generate the keys in an S3 bucket.
60+
def get_matching_s3_keys(bucket: str,
61+
prefix: str = '',
62+
suffix: str = '',
63+
delimiter: str = '/',
64+
request_payer: str = 'None') -> Iterator[str]:
65+
"""Generate the keys in an S3 bucket.
6666
67-
:param bucket: Name of the S3 bucket.
68-
:param prefix: Only fetch keys that start with this prefix (optional).
69-
:param suffix: Only fetch keys that end with this suffix (optional).
67+
Args:
68+
bucket: Name of the S3 bucket.
69+
prefix: Only fetch keys that start with this prefix.
70+
suffix: Only fetch keys that end with this suffix.
7071
"""
71-
for obj in get_matching_s3_objects(bucket, prefix, suffix, request_payer):
72-
yield obj['Key']
72+
obj_iterator = get_matching_s3_objects(
73+
bucket,
74+
prefix=prefix,
75+
suffix=suffix,
76+
delimiter=delimiter,
77+
request_payer=request_payer)
78+
out = (key for key, _ in obj_iterator)
79+
return out
7380

7481

7582
def progressbar(total_size: int, desc: str):
@@ -180,8 +187,9 @@ def read_bytes(uri: str) -> bytes:
180187
bucket, key = S3FileSystem.parse_uri(uri)
181188
with io.BytesIO() as file_buffer:
182189
try:
183-
file_size = s3.head_object(
184-
Bucket=bucket, Key=key)['ContentLength']
190+
obj = s3.head_object(
191+
Bucket=bucket, Key=key, RequestPayer=request_payer)
192+
file_size = obj['ContentLength']
185193
with progressbar(file_size, desc='Downloading') as bar:
186194
s3.download_fileobj(
187195
Bucket=bucket,
@@ -256,7 +264,9 @@ def copy_from(src_uri: str, dst_path: str) -> None:
256264
request_payer = S3FileSystem.get_request_payer()
257265
bucket, key = S3FileSystem.parse_uri(src_uri)
258266
try:
259-
file_size = s3.head_object(Bucket=bucket, Key=key)['ContentLength']
267+
obj = s3.head_object(
268+
Bucket=bucket, Key=key, RequestPayer=request_payer)
269+
file_size = obj['ContentLength']
260270
with progressbar(file_size, desc=f'Downloading') as bar:
261271
s3.download_file(
262272
Bucket=bucket,
@@ -284,11 +294,16 @@ def last_modified(uri: str) -> datetime:
284294
return head_data['LastModified']
285295

286296
@staticmethod
287-
def list_paths(uri, ext=''):
297+
def list_paths(uri: str, ext: str = '', delimiter: str = '/') -> list[str]:
288298
request_payer = S3FileSystem.get_request_payer()
289299
parsed_uri = urlparse(uri)
290300
bucket = parsed_uri.netloc
291301
prefix = os.path.join(parsed_uri.path[1:])
292302
keys = get_matching_s3_keys(
293-
bucket, prefix, suffix=ext, request_payer=request_payer)
294-
return [os.path.join('s3://', bucket, key) for key in keys]
303+
bucket,
304+
prefix,
305+
suffix=ext,
306+
delimiter=delimiter,
307+
request_payer=request_payer)
308+
paths = [os.path.join('s3://', bucket, key) for key in keys]
309+
return paths

rastervision_core/rastervision/core/data/dataset_config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,12 @@ def get_split_config(self, split_ind, num_splits):
9090
@property
9191
def all_scenes(self) -> List[SceneConfig]:
9292
return self.train_scenes + self.validation_scenes + self.test_scenes
93+
94+
def __repr__(self):
95+
num_train = len(self.train_scenes)
96+
num_val = len(self.validation_scenes)
97+
num_test = len(self.test_scenes)
98+
out = (f'DatasetConfig(train_scenes=<{num_train} scenes>, '
99+
f'validation_scenes=<{num_val} scenes>, '
100+
f'test_scenes=<{num_test} scenes>)')
101+
return out

rastervision_core/rastervision/core/data/raster_source/raster_source.py

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -147,24 +147,20 @@ def get_chip(self,
147147

148148
return chip
149149

150-
def get_chip_by_map_window(
151-
self,
152-
window_map_coords: 'Box',
153-
out_shape: Optional[Tuple[int, int]] = None) -> 'np.ndarray':
154-
"""Same as get_chip(), but input is a window in map coords. """
150+
def get_chip_by_map_window(self, window_map_coords: 'Box', *args,
151+
**kwargs) -> 'np.ndarray':
152+
"""Same as get_chip(), but input is a window in map coords."""
155153
window_pixel_coords = self.crs_transformer.map_to_pixel(
156154
window_map_coords, bbox=self.bbox).normalize()
157-
chip = self.get_chip(window_pixel_coords, out_shape=out_shape)
155+
chip = self.get_chip(window_pixel_coords, *args, **kwargs)
158156
return chip
159157

160-
def _get_chip_by_map_window(
161-
self,
162-
window_map_coords: 'Box',
163-
out_shape: Optional[Tuple[int, int]] = None) -> 'np.ndarray':
164-
"""Same as _get_chip(), but input is a window in map coords. """
158+
def _get_chip_by_map_window(self, window_map_coords: 'Box', *args,
159+
**kwargs) -> 'np.ndarray':
160+
"""Same as _get_chip(), but input is a window in map coords."""
165161
window_pixel_coords = self.crs_transformer.map_to_pixel(
166162
window_map_coords, bbox=self.bbox)
167-
chip = self._get_chip(window_pixel_coords, out_shape=out_shape)
163+
chip = self._get_chip(window_pixel_coords, *args, **kwargs)
168164
return chip
169165

170166
def get_raw_chip(self,

rastervision_core/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ rastervision_pipeline==0.30.0
22
shapely==2.0.2
33
geopandas==0.14.3
44
numpy==1.26.3
5-
pillow==10.2.0
5+
pillow==10.3.0
66
pyproj==3.6.1
77
rasterio==1.3.9
88
pystac==1.9.0

rastervision_pipeline/rastervision/pipeline/cli.py

Lines changed: 50 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
1+
from typing import TYPE_CHECKING
12
import sys
23
import os
34
import logging
4-
import importlib
5-
import importlib.util
6-
from typing import List, Dict, Optional, Tuple
75

86
import click
97

108
from rastervision.pipeline import (registry_ as registry, rv_config_ as
119
rv_config)
1210
from rastervision.pipeline.file_system import (file_to_json, get_tmp_dir)
13-
from rastervision.pipeline.config import build_config, save_pipeline_config
11+
from rastervision.pipeline.config import (build_config, Config,
12+
save_pipeline_config)
1413
from rastervision.pipeline.pipeline_config import PipelineConfig
1514

15+
if TYPE_CHECKING:
16+
from rastervision.pipeline.runner import Runner
17+
1618
log = logging.getLogger(__name__)
1719

1820

@@ -40,8 +42,9 @@ def convert_bool_args(args: dict) -> dict:
4042
return new_args
4143

4244

43-
def get_configs(cfg_module_path: str, runner: str,
44-
args: Dict[str, any]) -> List[PipelineConfig]:
45+
def get_configs(cfg_module_path: str,
46+
runner: str | None = None,
47+
args: dict[str, any] | None = None) -> list[PipelineConfig]:
4548
"""Get PipelineConfigs from a module.
4649
4750
Calls a get_config(s) function with some arguments from the CLI
@@ -55,6 +58,26 @@ def get_configs(cfg_module_path: str, runner: str,
5558
args: CLI args to pass to the get_config(s) function that comes from
5659
the --args option
5760
"""
61+
if cfg_module_path.endswith('.json'):
62+
cfgs_json = file_to_json(cfg_module_path)
63+
if not isinstance(cfgs_json, list):
64+
cfgs_json = [cfgs_json]
65+
cfgs = [Config.deserialize(json) for json in cfgs_json]
66+
else:
67+
cfgs = get_configs_from_module(cfg_module_path, runner, args)
68+
69+
for cfg in cfgs:
70+
if not issubclass(type(cfg), PipelineConfig):
71+
raise TypeError('All objects returned by get_configs in '
72+
f'{cfg_module_path} must be PipelineConfigs.')
73+
return cfgs
74+
75+
76+
def get_configs_from_module(cfg_module_path: str, runner: str,
77+
args: dict[str, any]) -> list[PipelineConfig]:
78+
import importlib
79+
import importlib.util
80+
5881
if cfg_module_path.endswith('.py'):
5982
# From https://stackoverflow.com/questions/67631/how-to-import-a-module-given-the-full-path # noqa
6083
spec = importlib.util.spec_from_file_location('rastervision.pipeline',
@@ -65,20 +88,14 @@ def get_configs(cfg_module_path: str, runner: str,
6588
cfg_module = importlib.import_module(cfg_module_path)
6689

6790
_get_config = getattr(cfg_module, 'get_config', None)
68-
_get_configs = _get_config
69-
if _get_config is None:
70-
_get_configs = getattr(cfg_module, 'get_configs', None)
91+
_get_configs = getattr(cfg_module, 'get_configs', _get_config)
7192
if _get_configs is None:
72-
raise Exception('There must be a get_config or get_configs function '
73-
f'in {cfg_module_path}.')
93+
raise ImportError('There must be a get_config() or get_configs() '
94+
f'function in {cfg_module_path}.')
95+
7496
cfgs = _get_configs(runner, **args)
7597
if not isinstance(cfgs, list):
7698
cfgs = [cfgs]
77-
78-
for cfg in cfgs:
79-
if not issubclass(type(cfg), PipelineConfig):
80-
raise Exception('All objects returned by get_configs in '
81-
f'{cfg_module_path} must be PipelineConfigs.')
8299
return cfgs
83100

84101

@@ -89,8 +106,7 @@ def get_configs(cfg_module_path: str, runner: str,
89106
@click.option(
90107
'-v', '--verbose', help='Increment the verbosity level.', count=True)
91108
@click.option('--tmpdir', help='Root of temporary directories to use.')
92-
def main(ctx: click.Context, profile: Optional[str], verbose: int,
93-
tmpdir: str):
109+
def main(ctx: click.Context, profile: str | None, verbose: int, tmpdir: str):
94110
"""The main click command.
95111
96112
Sets the profile, verbosity, and tmp_dir in RVConfig.
@@ -103,20 +119,22 @@ def main(ctx: click.Context, profile: Optional[str], verbose: int,
103119
rv_config.set_everett_config(profile=profile)
104120

105121

106-
def _run_pipeline(cfg,
107-
runner,
108-
tmp_dir,
109-
splits=1,
110-
commands=None,
122+
def _run_pipeline(cfg: PipelineConfig,
123+
runner: 'Runner',
124+
tmp_dir: str,
125+
splits: int = 1,
126+
commands: list[str] | None = None,
111127
pipeline_run_name: str = 'raster-vision'):
112128
cfg.update()
113129
cfg.recursive_validate_config()
114-
# This is to run the validation again to check any fields that may have changed
115-
# after the Config was constructed, possibly by the update method.
130+
131+
# This is to run the validation again to check any fields that may have
132+
# changed after the Config was constructed, possibly by the update method.
116133
build_config(cfg.dict())
117134
cfg_json_uri = cfg.get_config_uri()
118135
save_pipeline_config(cfg, cfg_json_uri)
119136
pipeline = cfg.build(tmp_dir)
137+
120138
if not commands:
121139
commands = pipeline.commands
122140

@@ -150,8 +168,8 @@ def _run_pipeline(cfg,
150168
'--pipeline-run-name',
151169
default='raster-vision',
152170
help='The name for this run of the pipeline.')
153-
def run(runner: str, cfg_module: str, commands: List[str],
154-
arg: List[Tuple[str, str]], splits: int, pipeline_run_name: str):
171+
def run(runner: str, cfg_module: str, commands: list[str],
172+
arg: list[tuple[str, str]], splits: int, pipeline_run_name: str):
155173
"""Run COMMANDS within pipelines in CFG_MODULE using RUNNER.
156174
157175
RUNNER: name of the Runner to use
@@ -178,9 +196,9 @@ def run(runner: str, cfg_module: str, commands: List[str],
178196

179197
def _run_command(cfg_json_uri: str,
180198
command: str,
181-
split_ind: Optional[int] = None,
182-
num_splits: Optional[int] = None,
183-
runner: Optional[str] = None):
199+
split_ind: int | None = None,
200+
num_splits: int | None = None,
201+
runner: str | None = None):
184202
"""Run a single command using a serialized PipelineConfig.
185203
186204
Args:
@@ -229,8 +247,8 @@ def _run_command(cfg_json_uri: str,
229247
help='The number of processes to use for running splittable commands')
230248
@click.option(
231249
'--runner', type=str, help='Name of runner to use', default='inprocess')
232-
def run_command(cfg_json_uri: str, command: str, split_ind: Optional[int],
233-
num_splits: Optional[int], runner: str):
250+
def run_command(cfg_json_uri: str, command: str, split_ind: int | None,
251+
num_splits: int | None, runner: str):
234252
"""Run a single COMMAND using a serialized PipelineConfig in CFG_JSON_URI."""
235253
_run_command(
236254
cfg_json_uri,

0 commit comments

Comments
 (0)