Skip to content

Commit 6f9c16e

Browse files
committed
sketch algorithm support
1 parent 0b9c8eb commit 6f9c16e

File tree

7 files changed

+300
-24
lines changed

7 files changed

+300
-24
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ dynamic = ["version"]
2020
dependencies = [
2121
"braceexpand",
2222
"rio-cogeo>=3.1",
23+
"rio-tiler>=3.1.5",
2324
"titiler.core>=0.5,<0.8",
2425
"starlette-cramjam>=0.3,<0.4",
2526
"uvicorn",

rio_viz/algorithm/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""rio_viz.algorithm."""
2+
3+
from typing import Dict, Type
4+
5+
from rio_viz.algorithm.base import AlgorithmMetadata, BaseAlgorithm # noqa
6+
from rio_viz.algorithm.dem import Contours, HillShade
7+
from rio_viz.algorithm.index import NormalizedIndex
8+
9+
AVAILABLE_ALGORITHM: Dict[str, Type[BaseAlgorithm]] = {
10+
"hillshade": HillShade,
11+
"contours": Contours,
12+
"normalizedIndex": NormalizedIndex,
13+
}

rio_viz/algorithm/base.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Algorithm base class."""
2+
3+
import abc
4+
from typing import Dict, Optional, Sequence
5+
6+
from pydantic import BaseModel
7+
from rio_tiler.models import ImageData
8+
9+
10+
class BaseAlgorithm(BaseModel, metaclass=abc.ABCMeta):
11+
"""Algorithm baseclass."""
12+
13+
input_nbands: int
14+
15+
output_nbands: int
16+
output_dtype: str
17+
output_min: Optional[Sequence]
18+
output_max: Optional[Sequence]
19+
20+
@abc.abstractmethod
21+
def apply(self, img: ImageData) -> ImageData:
22+
"""Apply"""
23+
...
24+
25+
class Config:
26+
"""Config for model."""
27+
28+
extra = "allow"
29+
30+
31+
class AlgorithmMetadata(BaseModel):
32+
"""Algorithm metadata."""
33+
34+
name: str
35+
inputs: Dict
36+
outputs: Dict
37+
params: Dict

rio_viz/algorithm/dem.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
"""rio_viz.algorithm DEM."""
2+
3+
import numpy
4+
from rio_tiler.colormap import apply_cmap, cmap
5+
from rio_tiler.models import ImageData
6+
from rio_tiler.utils import linear_rescale
7+
8+
from rio_viz.algorithm.base import BaseAlgorithm
9+
10+
11+
class HillShade(BaseAlgorithm):
12+
"""Hillshade."""
13+
14+
azimuth: int = 90
15+
angle_altitude: float = 90
16+
17+
input_nbands: int = 1
18+
19+
output_nbands: int = 1
20+
output_dtype: str = "uint8"
21+
22+
def apply(self, img: ImageData) -> ImageData:
23+
"""Create hillshade from DEM dataset."""
24+
data = img.data[0]
25+
mask = img.mask
26+
27+
x, y = numpy.gradient(data)
28+
29+
slope = numpy.pi / 2.0 - numpy.arctan(numpy.sqrt(x * x + y * y))
30+
aspect = numpy.arctan2(-x, y)
31+
azimuthrad = self.azimuth * numpy.pi / 180.0
32+
altituderad = self.angle_altitude * numpy.pi / 180.0
33+
shaded = numpy.sin(altituderad) * numpy.sin(slope) + numpy.cos(
34+
altituderad
35+
) * numpy.cos(slope) * numpy.cos(azimuthrad - aspect)
36+
hillshade_array = 255 * (shaded + 1) / 2
37+
38+
# ImageData only accept image in form of (count, height, width)
39+
arr = numpy.expand_dims(hillshade_array, axis=0).astype(dtype=numpy.uint8)
40+
41+
return ImageData(
42+
arr,
43+
mask,
44+
assets=img.assets,
45+
crs=img.crs,
46+
bounds=img.bounds,
47+
)
48+
49+
50+
class Contours(BaseAlgorithm):
51+
"""Contours.
52+
53+
Original idea from https://custom-scripts.sentinel-hub.com/dem/contour-lines/
54+
"""
55+
56+
increment: int = 35
57+
thickness: int = 1
58+
minz: int = -12000
59+
maxz: int = 8000
60+
61+
input_nbands: int = 1
62+
63+
output_nbands: int = 3
64+
output_dtype: str = "uint8"
65+
66+
def apply(self, img: ImageData) -> ImageData:
67+
"""Add contours."""
68+
data = img.data
69+
70+
# Apply rescaling for minz,maxz to 1->255 and apply Terrain colormap
71+
arr = linear_rescale(data, (self.minz, self.maxz), (1, 255)).astype("uint8")
72+
arr, _ = apply_cmap(arr, cmap.get("terrain"))
73+
74+
# set black (0) for contour lines
75+
arr = numpy.where(data % self.increment < self.thickness, 0, arr)
76+
77+
return ImageData(
78+
arr,
79+
img.mask,
80+
assets=img.assets,
81+
crs=img.crs,
82+
bounds=img.bounds,
83+
)

rio_viz/algorithm/index.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""rio_viz.algorithm Normalized Index."""
2+
3+
from typing import Sequence
4+
5+
import numpy
6+
from rio_tiler.models import ImageData
7+
8+
from rio_viz.algorithm.base import BaseAlgorithm
9+
10+
11+
class NormalizedIndex(BaseAlgorithm):
12+
"""Normalized Difference Index."""
13+
14+
input_nbands: int = 2
15+
16+
output_nbands: int = 1
17+
output_dtype: str = "float32"
18+
output_min: Sequence[float] = [-1.0]
19+
output_max: Sequence[float] = [1.0]
20+
21+
def apply(self, img: ImageData) -> ImageData:
22+
"""Normalized difference."""
23+
b1 = img.data[0]
24+
b2 = img.data[1]
25+
26+
arr = numpy.where(img.mask, (b2 - b1) / (b2 + b1), 0)
27+
28+
# ImageData only accept image in form of (count, height, width)
29+
arr = numpy.expand_dims(arr, axis=0).astype(self.output_dtype)
30+
31+
return ImageData(
32+
arr,
33+
img.mask,
34+
assets=img.assets,
35+
crs=img.crs,
36+
bounds=img.bounds,
37+
)

rio_viz/app.py

Lines changed: 80 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
from starlette.types import ASGIApp
2323
from starlette_cramjam.middleware import CompressionMiddleware
2424

25+
from rio_viz.algorithm import AVAILABLE_ALGORITHM, AlgorithmMetadata
26+
from rio_viz.dependency import PostProcessParams
2527
from rio_viz.resources.enums import RasterFormat, VectorTileFormat, VectorTileType
2628

2729
from titiler.core.dependencies import (
@@ -36,7 +38,6 @@
3638
HistogramParams,
3739
ImageParams,
3840
ImageRenderingParams,
39-
PostProcessParams,
4041
StatisticsParams,
4142
)
4243
from titiler.core.models.mapbox import TileJSON
@@ -295,19 +296,26 @@ async def preview(
295296
# Adapt options for each reader type
296297
self._update_params(src_dst, layer_params)
297298

298-
data = await src_dst.preview(
299+
img = await src_dst.preview(
299300
**layer_params,
300301
**dataset_params,
301302
**img_params,
302303
)
303304
dst_colormap = getattr(src_dst, "colormap", None)
304305

305-
if not format:
306-
format = RasterFormat.jpeg if data.mask.all() else RasterFormat.png
306+
if postprocess_params.image_process:
307+
img = postprocess_params.image_process.apply(img)
308+
309+
if postprocess_params.rescale:
310+
img.rescale(postprocess_params.rescale)
311+
312+
if postprocess_params.color_formula:
313+
img.apply_color_formula(postprocess_params.color_formula)
307314

308-
image = data.post_process(**postprocess_params)
315+
if not format:
316+
format = RasterFormat.jpeg if img.mask.all() else RasterFormat.png
309317

310-
content = image.render(
318+
content = img.render(
311319
img_format=format.driver,
312320
colormap=colormap or dst_colormap,
313321
**format.profile,
@@ -360,17 +368,24 @@ async def part(
360368
# Adapt options for each reader type
361369
self._update_params(src_dst, layer_params)
362370

363-
data = await src_dst.part(
371+
img = await src_dst.part(
364372
[minx, miny, maxx, maxy],
365373
**layer_params,
366374
**dataset_params,
367375
**img_params,
368376
)
369377
dst_colormap = getattr(src_dst, "colormap", None)
370378

371-
image = data.post_process(**postprocess_params)
379+
if postprocess_params.image_process:
380+
img = postprocess_params.image_process.apply(img)
381+
382+
if postprocess_params.rescale:
383+
img.rescale(postprocess_params.rescale)
384+
385+
if postprocess_params.color_formula:
386+
img.apply_color_formula(postprocess_params.color_formula)
372387

373-
content = image.render(
388+
content = img.render(
374389
img_format=format.driver,
375390
colormap=colormap or dst_colormap,
376391
**format.profile,
@@ -415,17 +430,24 @@ async def geojson_part(
415430
# Adapt options for each reader type
416431
self._update_params(src_dst, layer_params)
417432

418-
data = await src_dst.feature(
433+
img = await src_dst.feature(
419434
geom.dict(exclude_none=True), **layer_params, **dataset_params
420435
)
421436
dst_colormap = getattr(src_dst, "colormap", None)
422437

423-
if not format:
424-
format = RasterFormat.jpeg if data.mask.all() else RasterFormat.png
438+
if postprocess_params.image_process:
439+
img = postprocess_params.image_process.apply(img)
440+
441+
if postprocess_params.rescale:
442+
img.rescale(postprocess_params.rescale)
425443

426-
image = data.post_process(**postprocess_params)
444+
if postprocess_params.color_formula:
445+
img.apply_color_formula(postprocess_params.color_formula)
427446

428-
content = image.render(
447+
if not format:
448+
format = RasterFormat.jpeg if img.mask.all() else RasterFormat.png
449+
450+
content = img.render(
429451
img_format=format.driver,
430452
colormap=colormap or dst_colormap,
431453
**format.profile,
@@ -475,7 +497,7 @@ async def tile(
475497
# Adapt options for each reader type
476498
self._update_params(src_dst, layer_params)
477499

478-
tile_data = await src_dst.tile(
500+
img = await src_dst.tile(
479501
x,
480502
y,
481503
z,
@@ -502,22 +524,27 @@ async def tile(
502524
_mvt_encoder = partial(run_in_threadpool, pixels_encoder)
503525

504526
content = await _mvt_encoder(
505-
tile_data.data,
506-
tile_data.mask,
507-
tile_data.band_names,
527+
img.data,
528+
img.mask,
529+
img.band_names,
508530
feature_type=feature_type.value,
509531
) # type: ignore
510532

511533
# Raster Tile
512534
else:
513-
if not format:
514-
format = (
515-
RasterFormat.jpeg if tile_data.mask.all() else RasterFormat.png
516-
)
535+
if postprocess_params.image_process:
536+
img = postprocess_params.image_process.apply(img)
537+
538+
if postprocess_params.rescale:
539+
img.rescale(postprocess_params.rescale)
540+
541+
if postprocess_params.color_formula:
542+
img.apply_color_formula(postprocess_params.color_formula)
517543

518-
image = tile_data.post_process(**postprocess_params)
544+
if not format:
545+
format = RasterFormat.jpeg if img.mask.all() else RasterFormat.png
519546

520-
content = image.render(
547+
content = img.render(
521548
img_format=format.driver,
522549
colormap=colormap or dst_colormap,
523550
**format.profile,
@@ -646,6 +673,35 @@ async def wmts(
646673
media_type="application/xml",
647674
)
648675

676+
@self.router.get(
677+
"/algorithm",
678+
response_model=List[AlgorithmMetadata],
679+
)
680+
def algo(request: Request):
681+
"""Handle /algorithm."""
682+
algos = []
683+
for k, v in AVAILABLE_ALGORITHM.items():
684+
props = v.schema()["properties"]
685+
ins = {
686+
k.replace("input_", ""): v
687+
for k, v in props.items()
688+
if k.startswith("input_")
689+
}
690+
outs = {
691+
k.replace("output_", ""): v
692+
for k, v in props.items()
693+
if k.startswith("output_")
694+
}
695+
params = {
696+
k: v
697+
for k, v in props.items()
698+
if not k.startswith("input_") and not k.startswith("output_")
699+
}
700+
algos.append(
701+
AlgorithmMetadata(name=k, inputs=ins, outputs=outs, params=params)
702+
)
703+
return algos
704+
649705
@self.router.get(
650706
"/",
651707
responses={200: {"description": "Simple COG viewer."}},

0 commit comments

Comments
 (0)