-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathimage_regression.py
More file actions
96 lines (81 loc) · 4.62 KB
/
image_regression.py
File metadata and controls
96 lines (81 loc) · 4.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
"""implementation of the ``image_regression`` fixture."""
import os
from typing import Optional
import ee
import requests
from pytest_regressions.image_regression import ImageRegressionFixture
from .utils import build_fullpath, check_serialized
class ImageFixture(ImageRegressionFixture):
"""Fixture for regression testing of :py:class:`ee.Image`."""
def check(
self,
data_image: ee.Image,
diff_threshold: float = 0.1,
expect_equal: bool = True,
basename: Optional[str] = None,
fullpath: Optional[os.PathLike] = None,
scale: Optional[int] = 30,
viz_params: Optional[dict] = None,
):
"""Check the given image against a previously recorded version, or generate a new file.
This method will create a thumnail version of the requested image. It is made to allow a human user to check the result of the
Computation. The thumbnail will be computed on the fly using earthengine. This mean that the test must be reasonable in size and scale.
We will perform no feasibility checks and your computation might crash if you are too greedy.
The input image will be either a single band image (displayed using black&white colormap) or a 3 band image (displayed using as fake RGB bands).
If the ``viz_params`` parameter is omitted then it will detect the available ands, and use default viz params.
Parameters:
data_image: The image to check. The image needs to be clipped to a geometry or have an existing footprint.
diff_threshold: The threshold for the difference between the expected and obtained images.
expect_equal: If ``True`` the images are expected to be equal, otherwise they are expected to be different.
basename: The basename of the file to test/record. If not given the name of the test is used.
fullpath: complete path to use as a reference file. This option will ignore ``datadir`` fixture when reading *expected* files but will still use it to write *obtained* files. Useful if a reference file is located in the session data dir for example.
scale: The scale to use for the thumbnail.
viz_params: The visualization parameters to use for the thumbnail. If not given, the min and max values of the image will be used.
"""
# rescale the original image
geometry = data_image.geometry()
data_image = data_image.clipToBoundsAndScale(geometry, scale=scale)
# build the different filename to be consistent between our 3 checks
data_name = build_fullpath(
datadir=self.original_datadir,
request=self.request,
extension=".png",
basename=basename,
fullpath=fullpath,
with_test_class_names=self.with_test_class_names,
)
serialized_name = data_name.with_stem(f"serialized_{data_name.stem}").with_suffix(".yml")
is_serialized_equal = check_serialized(
object=data_image,
path=serialized_name,
datadir=self.datadir,
request=self.request,
)
if is_serialized_equal:
# serialized is equal? -> pass test
# TODO: add proper logging
return
else:
# extract min and max for visualization
minMax = data_image.reduceRegion(ee.Reducer.minMax(), geometry, scale)
# create visualization parameters based on the computed minMax values
if viz_params is None:
nbBands = ee.Algorithms.If(data_image.bandNames().size().gte(3), 3, 1)
bands = data_image.bandNames().slice(0, ee.Number(nbBands))
min = bands.map(lambda b: minMax.get(ee.String(b).cat("_min")))
max = bands.map(lambda b: minMax.get(ee.String(b).cat("_max")))
viz_params = ee.Dictionary({"bands": bands, "min": min, "max": max}).getInfo()
# get the thumbnail image
thumb_url = data_image.getThumbURL(params=viz_params)
byte_data = requests.get(thumb_url).content
# if it needs to be checked, we need to round the float values to the same precision as the
# reference file
super().check(byte_data, diff_threshold, expect_equal, fullpath=data_name)
# if we are here it means that the query result is equal but the serialized is not -> regenerate serialized
serialized_name.unlink(missing_ok=True)
check_serialized(
object=data_image,
path=serialized_name,
datadir=self.datadir,
request=self.request,
)