diff --git a/docs/map_widgets.md b/docs/map_widgets.md new file mode 100644 index 0000000000..33169afd06 --- /dev/null +++ b/docs/map_widgets.md @@ -0,0 +1,3 @@ +# map_widgets module + +::: geemap.map_widgets diff --git a/geemap/ee_tile_layers.py b/geemap/ee_tile_layers.py index 86a2c9b5ef..11f22939c8 100644 --- a/geemap/ee_tile_layers.py +++ b/geemap/ee_tile_layers.py @@ -60,6 +60,8 @@ def _ee_object_to_image(ee_object, vis_params): def _validate_palette(palette): + if isinstance(palette, tuple): + palette = list(palette) if isinstance(palette, box.Box): if "default" not in palette: raise ValueError("The provided palette Box object is invalid.") @@ -92,7 +94,9 @@ def __init__( shown (bool, optional): A flag indicating whether the layer should be on by default. Defaults to True. opacity (float, optional): The layer's opacity represented as a number between 0 and 1. Defaults to 1. """ - self.url_format = _get_tile_url_format(ee_object, _validate_vis_params(vis_params)) + self.url_format = _get_tile_url_format( + ee_object, _validate_vis_params(vis_params) + ) super().__init__( tiles=self.url_format, attr="Google Earth Engine", @@ -127,7 +131,9 @@ def __init__( shown (bool, optional): A flag indicating whether the layer should be on by default. Defaults to True. opacity (float, optional): The layer's opacity represented as a number between 0 and 1. Defaults to 1. """ - self.url_format = _get_tile_url_format(ee_object, _validate_vis_params(vis_params)) + self.url_format = _get_tile_url_format( + ee_object, _validate_vis_params(vis_params) + ) super().__init__( url=self.url_format, attribution="Google Earth Engine", diff --git a/geemap/geemap.py b/geemap/geemap.py index f330df851b..439442421f 100644 --- a/geemap/geemap.py +++ b/geemap/geemap.py @@ -26,14 +26,79 @@ from .common import * from .conversion import * from .ee_tile_layers import * -from .timelapse import * +from . import map_widgets from .plot import * +from .timelapse import * from . import examples basemaps = Box(xyz_to_leaflet(), frozen_box=True) +class MapDrawControl(ipyleaflet.DrawControl, map_widgets.AbstractDrawControl): + """"Implements the AbstractDrawControl for the map.""" + _roi_start = False + _roi_end = False + + def __init__(self, host_map, **kwargs): + super(MapDrawControl,self).__init__(host_map=host_map, **kwargs) + + @property + def user_roi(self): + return self.last_geometry + + @property + def user_rois(self): + return self.collection + + # NOTE: Overridden for backwards compatibility, where edited geometries are + # added to the layer instead of modified in place. Remove when + # https://github.com/jupyter-widgets/ipyleaflet/issues/1119 is fixed to + # allow geometry edits to be reflected on the tile layer. + def _handle_geometry_edited(self, geo_json): + return self._handle_geometry_created(geo_json) + + def _get_synced_geojson_from_draw_control(self): + return [data.copy() for data in self.data] + + def _bind_to_draw_control(self): + # Handles draw events + def handle_draw(_, action, geo_json): + try: + self._roi_start = True + if action == "created": + self._handle_geometry_created(geo_json) + elif action == "edited": + self._handle_geometry_edited(geo_json) + elif action == "deleted": + self._handle_geometry_deleted(geo_json) + self._roi_end = True + self._roi_start = False + except Exception as e: + self.reset(clear_draw_control=False) + self._roi_start = False + self._roi_end = False + print("There was an error creating Earth Engine Feature.") + raise Exception(e) + self.on_draw(handle_draw) + # NOTE: Uncomment the following code once + # https://github.com/jupyter-widgets/ipyleaflet/issues/1119 is fixed + # to allow edited geometries to be reflected instead of added. + # def handle_data_update(_): + # self._sync_geometries() + # self.observe(handle_data_update, 'data') + + def _remove_geometry_at_index_on_draw_control(self, index): + # NOTE: Uncomment the following code once + # https://github.com/jupyter-widgets/ipyleaflet/issues/1119 is fixed to + # remove drawn geometries with `remove_last_drawn()`. + # del self.data[index] + # self.send_state(key='data') + pass + + def _clear_draw_control(self): + return self.clear() + class Map(ipyleaflet.Map): """The Map class inherits the ipyleaflet Map class. The arguments you can pass to the Map initialization @@ -45,6 +110,23 @@ class Map(ipyleaflet.Map): object: ipyleaflet map object. """ + # Map attributes for drawing features + @property + def draw_features(self): + return self.draw_control.features if self.draw_control else [] + @property + def draw_last_feature(self): + return self.draw_control.last_feature if self.draw_control else None + @property + def draw_layer(self): + return self.draw_control.layer if self.draw_control else None + @property + def user_roi(self): + return self.draw_control.user_roi if self.draw_control else None + @property + def user_rois(self): + return self.draw_control.user_rois if self.draw_control else None + def __init__(self, **kwargs): """Initialize a map object. The following additional parameters can be passed in addition to the ipyleaflet.Map parameters: @@ -83,6 +165,8 @@ def __init__(self, **kwargs): center = [20, 0] zoom = 2 + self.inspector_control = None + # Set map width and height if "height" not in kwargs: kwargs["height"] = "600px" @@ -167,13 +251,6 @@ def __init__(self, **kwargs): if kwargs.get(control, True): self.add_controls(control, position="bottomright") - # Map attributes for drawing features - self.draw_features = [] - self.draw_last_feature = None - self.draw_layer = None - self.user_roi = None - self.user_rois = None - # Map attributes for layers self.geojson_layers = [] self.ee_layers = [] @@ -1014,7 +1091,7 @@ def add_colorbar( layer_name=None, font_size=9, axis_off=False, - max_width="270px", + max_width=None, **kwargs, ): """Add a matplotlib colorbar to the map @@ -1030,142 +1107,32 @@ def add_colorbar( layer_name (str, optional): The layer name associated with the colorbar. Defaults to None. font_size (int, optional): Font size for the colorbar. Defaults to 9. axis_off (bool, optional): Whether to turn off the axis. Defaults to False. - max_width (str, optional): Maximum width of the colorbar in pixels. Defaults to "300px". + max_width (str, optional): Maximum width of the colorbar in pixels. Defaults to None. Raises: TypeError: If the vis_params is not a dictionary. ValueError: If the orientation is not either horizontal or vertical. - ValueError: If the provided min value is not scalar type. - ValueError: If the provided max value is not scalar type. - ValueError: If the provided opacity value is not scalar type. - ValueError: If cmap or palette is not provided. + TypeError: If the provided min value is not scalar type. + TypeError: If the provided max value is not scalar type. + TypeError: If the provided opacity value is not scalar type. + TypeError: If cmap or palette is not provided. """ - import matplotlib as mpl - import matplotlib.pyplot as plt - import numpy as np - - if isinstance(vis_params, list): - vis_params = {"palette": vis_params} - elif isinstance(vis_params, tuple): - vis_params = {"palette": list(vis_params)} - elif vis_params is None: - vis_params = {} - - if "colors" in kwargs and isinstance(kwargs["colors"], list): - vis_params["palette"] = kwargs["colors"] - - if "colors" in kwargs and isinstance(kwargs["colors"], tuple): - vis_params["palette"] = list(kwargs["colors"]) - - if "vmin" in kwargs: - vis_params["min"] = kwargs["vmin"] - del kwargs["vmin"] - - if "vmax" in kwargs: - vis_params["max"] = kwargs["vmax"] - del kwargs["vmax"] - - if "caption" in kwargs: - label = kwargs["caption"] - del kwargs["caption"] - - if not isinstance(vis_params, dict): - raise TypeError("The vis_params must be a dictionary.") - - if orientation not in ["horizontal", "vertical"]: - raise ValueError("The orientation must be either horizontal or vertical.") - - if orientation == "horizontal": - width, height = 3.0, 0.3 - else: - width, height = 0.3, 3.0 - - if "width" in kwargs: - width = kwargs["width"] - kwargs.pop("width") - - if "height" in kwargs: - height = kwargs["height"] - kwargs.pop("height") - - vis_keys = list(vis_params.keys()) - - if "min" in vis_params: - vmin = vis_params["min"] - if type(vmin) not in (int, float): - raise ValueError("The provided min value must be scalar type.") - else: - vmin = 0 - if "max" in vis_params: - vmax = vis_params["max"] - if type(vmax) not in (int, float): - raise ValueError("The provided max value must be scalar type.") - else: - vmax = 1 - - if "opacity" in vis_params: - alpha = vis_params["opacity"] - if type(alpha) not in (int, float): - raise ValueError("The provided opacity value must be type scalar.") - elif "alpha" in kwargs: - alpha = kwargs["alpha"] - else: - alpha = 1 - - if cmap is not None: - cmap = mpl.pyplot.get_cmap(cmap) - norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) - - if "palette" in vis_keys: - hexcodes = to_hex_colors(check_cmap(vis_params["palette"])) - if discrete: - cmap = mpl.colors.ListedColormap(hexcodes) - vals = np.linspace(vmin, vmax, cmap.N + 1) - norm = mpl.colors.BoundaryNorm(vals, cmap.N) - - else: - cmap = mpl.colors.LinearSegmentedColormap.from_list( - "custom", hexcodes, N=256 - ) - norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) - - elif cmap is not None: - cmap = mpl.pyplot.get_cmap(cmap) - norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax) - - else: - raise ValueError( - 'cmap keyword or "palette" key in vis_params must be provided.' - ) - - fig, ax = plt.subplots(figsize=(width, height)) - cb = mpl.colorbar.ColorbarBase( - ax, norm=norm, alpha=alpha, cmap=cmap, orientation=orientation, **kwargs + colorbar = map_widgets.Colorbar( + vis_params, + cmap, + discrete, + label, + orientation, + transparent_bg, + font_size, + axis_off, + max_width, + **kwargs, ) - - if label is not None: - cb.set_label(label, fontsize=font_size) - elif "bands" in vis_keys: - cb.set_label(vis_params["bands"], fontsize=font_size) - - if axis_off: - ax.set_axis_off() - ax.tick_params(labelsize=font_size) - - # set the background color to transparent - if transparent_bg: - fig.patch.set_alpha(0.0) - - output = widgets.Output(layout=widgets.Layout(width=max_width)) colormap_ctrl = ipyleaflet.WidgetControl( - widget=output, - position=position, - transparent_bg=transparent_bg, + widget=colorbar, position=position, transparent_bg=transparent_bg ) - with output: - output.outputs = () - plt.show() self._colorbar = colormap_ctrl if layer_name in self.ee_layer_names: @@ -2545,200 +2512,6 @@ def close_btn_clicked(b): return vis_widget - def _point_info(self, latlon, decimals=3, return_node=False): - """Create the ipytree widget for displaying the mouse clicking info. - - Args: - latlon (list | tuple): The coordinates (lat, lon) of the point. - decimals (int, optional): Number of decimals to round the coordinates to. Defaults to 3. - return_node (bool, optional): If True, return the ipytree node. - Otherwise, return the ipytree tree widget. Defaults to False. - - Returns: - ipytree.Node | ipytree.Tree: The ipytree node or tree widget. - """ - from ipytree import Node, Tree - - point_nodes = [ - Node(f"Longitude: {latlon[1]}"), - Node(f"Latitude: {latlon[0]}"), - Node(f"Zoom Level: {self.zoom}"), - Node(f"Scale (approx. m/px): {self.get_scale()}"), - ] - label = f"Point ({latlon[1]:.{decimals}f}, {latlon[0]:.{decimals}f}) at {int(self.get_scale())}m/px" - root_node = Node( - label, nodes=point_nodes, icon="map", opened=self._expand_point - ) - - root_node.open_icon = "plus-square" - root_node.open_icon_style = "success" - root_node.close_icon = "minus-square" - root_node.close_icon_style = "info" - - if return_node: - return root_node - else: - return Tree(nodes=[root_node]) - - def _pixels_info( - self, latlon, names=None, visible=True, decimals=2, return_node=False - ): - """Create the ipytree widget for displaying the pixel values at the mouse clicking point. - - Args: - latlon (list | tuple): The coordinates (lat, lon) of the point. - names (str | list, optional): The names of the layers to be included. Defaults to None. - visible (bool, optional): Whether to inspect visible layers only. Defaults to True. - decimals (int, optional): Number of decimals to round the pixel values. Defaults to 2. - return_node (bool, optional): If True, return the ipytree node. - Otherwise, return the ipytree tree widget. Defaults to False. - - Returns: - ipytree.Node | ipytree.Tree: The ipytree node or tree widget. - """ - from ipytree import Node, Tree - - if names is not None: - if isinstance(names, str): - names = [names] - layers = {} - for name in names: - if name in self.ee_layer_names: - layers[name] = self.ee_layer_dict[name] - else: - layers = self.ee_layer_dict - xy = ee.Geometry.Point(latlon[::-1]) - sample_scale = self.getScale() - - root_node = Node("Pixels", icon="archive") - - nodes = [] - - for layer in layers: - layer_name = layer - ee_object = layers[layer]["ee_object"] - object_type = ee_object.__class__.__name__ - - if visible: - if not self.ee_layer_dict[layer_name]["ee_layer"].visible: - continue - - try: - if isinstance(ee_object, ee.ImageCollection): - ee_object = ee_object.mosaic() - - if isinstance(ee_object, ee.Image): - item = ee_object.reduceRegion( - ee.Reducer.first(), xy, sample_scale - ).getInfo() - b_name = "band" - if len(item) > 1: - b_name = "bands" - - label = f"{layer_name}: {object_type} ({len(item)} {b_name})" - layer_node = Node(label, opened=self._expand_pixels) - - keys = sorted(item.keys()) - for key in keys: - value = item[key] - if isinstance(value, float): - value = round(value, decimals) - layer_node.add_node(Node(f"{key}: {value}", icon="file")) - - nodes.append(layer_node) - except: - pass - - root_node.nodes = nodes - - root_node.open_icon = "plus-square" - root_node.open_icon_style = "success" - root_node.close_icon = "minus-square" - root_node.close_icon_style = "info" - - if return_node: - return root_node - else: - return Tree(nodes=[root_node]) - - def _objects_info(self, latlon, names=None, visible=True, return_node=False): - """Create the ipytree widget for displaying the Earth Engine objects at the mouse clicking point. - - Args: - latlon (list | tuple): The coordinates (lat, lon) of the point. - names (str | list, optional): The names of the layers to be included. Defaults to None. - visible (bool, optional): Whether to inspect visible layers only. Defaults to True. - return_node (bool, optional): If True, return the ipytree node. - Otherwise, return the ipytree tree widget. Defaults to False. - - Returns: - ipytree.Node | ipytree.Tree: The ipytree node or tree widget. - """ - from ipytree import Node, Tree - - if names is not None: - if isinstance(names, str): - names = [names] - layers = {} - for name in names: - if name in self.ee_layer_names: - layers[name] = self.ee_layer_dict[name] - else: - layers = self.ee_layer_dict - - xy = ee.Geometry.Point(latlon[::-1]) - root_node = Node("Objects", icon="archive") - - nodes = [] - - for layer in layers: - layer_name = layer - ee_object = layers[layer]["ee_object"] - - if visible: - if not self.ee_layer_dict[layer_name]["ee_layer"].visible: - continue - - if isinstance(ee_object, ee.FeatureCollection): - # Check geometry type - geom_type = ee.Feature(ee_object.first()).geometry().type() - lat, lon = latlon - delta = 0.005 - bbox = ee.Geometry.BBox( - lon - delta, - lat - delta, - lon + delta, - lat + delta, - ) - # Create a bounding box to filter points - xy = ee.Algorithms.If( - geom_type.compareTo(ee.String("Point")), - xy, - bbox, - ) - - ee_object = ee_object.filterBounds(xy).first() - - try: - node = get_info( - ee_object, layer_name, opened=self._expand_objects, return_node=True - ) - nodes.append(node) - except: - pass - - root_node.nodes = nodes - - root_node.open_icon = "plus-square" - root_node.open_icon_style = "success" - root_node.close_icon = "minus-square" - root_node.close_icon_style = "info" - - if return_node: - return root_node - else: - return Tree(nodes=[root_node]) - def inspect(self, latlon): """Create the Inspector GUI. @@ -2763,13 +2536,7 @@ def inspect(self, latlon): return tree def add_inspector( - self, - names=None, - visible=True, - decimals=2, - position="topright", - opened=True, - show_close_button=True, + self, names=None, visible=True, decimals=2, position="topright", opened=True ): """Add the Inspector GUI to the map. @@ -2779,13 +2546,24 @@ def add_inspector( decimals (int, optional): The number of decimal places to round the coordinates. Defaults to 2. position (str, optional): The position of the Inspector GUI. Defaults to "topright". opened (bool, optional): Whether the control is opened. Defaults to True. - """ - from .toolbar import ee_inspector_gui + if self.inspector_control: + return - ee_inspector_gui( - self, names, visible, decimals, position, opened, show_close_button + def _on_close(): + self.toolbar_reset() + if self.inspector_control: + if self.inspector_control in self.controls: + self.remove_control(self.inspector_control) + self.inspector_control.close() + self.inspector_control = None + + inspector = map_widgets.Inspector(self, names, visible, decimals, opened) + inspector.on_close = _on_close + self.inspector_control = ipyleaflet.WidgetControl( + widget=inspector, position=position ) + self.add(self.inspector_control) def add_layer_manager( self, position="topright", opened=True, show_close_button=True @@ -2813,8 +2591,8 @@ def add_draw_control(self, position="topleft"): Args: position (str, optional): The position of the draw control. Defaults to "topleft". """ - - draw_control = ipyleaflet.DrawControl( + draw_control = MapDrawControl( + host_map=self, marker={"shapeOptions": {"color": "#3388ff"}}, rectangle={"shapeOptions": {"color": "#3388ff"}}, # circle={"shapeOptions": {"color": "#3388ff"}}, @@ -2823,50 +2601,6 @@ def add_draw_control(self, position="topleft"): remove=True, position=position, ) - - # Handles draw events - def handle_draw(target, action, geo_json): - try: - self._roi_start = True - geom = geojson_to_ee(geo_json, False) - self.user_roi = geom - feature = ee.Feature(geom) - self.draw_last_feature = feature - if not hasattr(self, "_draw_count"): - self._draw_count = 0 - if action == "deleted" and len(self.draw_features) > 0: - self.draw_features.remove(feature) - self._draw_count -= 1 - else: - self.draw_features.append(feature) - self._draw_count += 1 - collection = ee.FeatureCollection(self.draw_features) - self.user_rois = collection - ee_draw_layer = EELeafletTileLayer( - collection, {"color": "blue"}, "Drawn Features", False, 0.5 - ) - draw_layer_index = self.find_layer_index("Drawn Features") - - if draw_layer_index == -1: - self.add(ee_draw_layer) - self.draw_layer = ee_draw_layer - else: - self.substitute_layer(self.draw_layer, ee_draw_layer) - self.draw_layer = ee_draw_layer - self._roi_end = True - self._roi_start = False - except Exception as e: - self._draw_count = 0 - self.draw_features = [] - self.draw_last_feature = None - self.draw_layer = None - self.user_roi = None - self._roi_start = False - self._roi_end = False - print("There was an error creating Earth Engine Feature.") - raise Exception(e) - - draw_control.on_draw(handle_draw) self.add(draw_control) self.draw_control = draw_control @@ -4375,46 +4109,36 @@ def add_remote_tile( else: raise Exception("The source must be a URL.") + def remove_draw_control(self): + controls = [] + old_draw_control = None + for control in self.controls: + if isinstance(control, MapDrawControl): + old_draw_control = control + + else: + controls.append(control) + + self.controls = tuple(controls) + if old_draw_control: + old_draw_control.close() + def remove_drawn_features(self): """Removes user-drawn geometries from the map""" - if self.draw_layer is not None: - self.remove_layer(self.draw_layer) - self._draw_count = 0 - self.draw_features = [] - self.draw_last_feature = None - self.draw_layer = None - self.user_roi = None - self.user_rois = None - self._chart_values = [] - self._chart_points = [] - self._chart_labels = None if self.draw_control is not None: - self.draw_control.clear() + self.draw_control.reset() def remove_last_drawn(self): - """Removes user-drawn geometries from the map""" - if self.draw_layer is not None: - collection = ee.FeatureCollection(self.draw_features[:-1]) - ee_draw_layer = EELeafletTileLayer( - collection, {"color": "blue"}, "Drawn Features", True, 0.5 - ) - if self._draw_count == 1: + """Removes last user-drawn geometry from the map""" + if self.draw_control is not None: + if self.draw_control.count == 1: self.remove_drawn_features() - else: - self.substitute_layer(self.draw_layer, ee_draw_layer) - self.draw_layer = ee_draw_layer - self._draw_count -= 1 - self.draw_features = self.draw_features[:-1] - self.draw_last_feature = self.draw_features[-1] - self.draw_layer = ee_draw_layer - self.user_roi = ee.Feature( - collection.toList(collection.size()).get( - collection.size().subtract(1) - ) - ).geometry() - self.user_rois = collection - self._chart_values = self._chart_values[:-1] - self._chart_points = self._chart_points[:-1] + elif self.draw_control.count: + self.draw_control.remove_geometry(self.draw_control.geometries[-1]) + if hasattr(self, '_chart_values'): + self._chart_values = self._chart_values[:-1] + if hasattr(self, '_chart_points'): + self._chart_points = self._chart_points[:-1] # self._chart_labels = None def extract_values_to_points(self, filename): diff --git a/geemap/map_widgets.py b/geemap/map_widgets.py new file mode 100644 index 0000000000..b989b782ff --- /dev/null +++ b/geemap/map_widgets.py @@ -0,0 +1,535 @@ +"""Various ipywidgets that can be added to a map.""" + +import enum + +import ipywidgets + +from IPython.core.display import display +import ee +import ipytree + +from . import common +from .ee_tile_layers import EELeafletTileLayer + + +class Colorbar(ipywidgets.Output): + """A matplotlib colorbar widget that can be added to the map.""" + + def __init__( + self, + vis_params=None, + cmap="gray", + discrete=False, + label=None, + orientation="horizontal", + transparent_bg=False, + font_size=9, + axis_off=False, + max_width=None, + **kwargs, + ): + """Add a matplotlib colorbar to the map. + + Args: + vis_params (dict): Visualization parameters as a dictionary. See https://developers.google.com/earth-engine/guides/image_visualization for options. + cmap (str, optional): Matplotlib colormap. Defaults to "gray". See https://matplotlib.org/3.3.4/tutorials/colors/colormaps.html#sphx-glr-tutorials-colors-colormaps-py for options. + discrete (bool, optional): Whether to create a discrete colorbar. Defaults to False. + label (str, optional): Label for the colorbar. Defaults to None. + orientation (str, optional): Orientation of the colorbar, such as "vertical" and "horizontal". Defaults to "horizontal". + transparent_bg (bool, optional): Whether to use transparent background. Defaults to False. + font_size (int, optional): Font size for the colorbar. Defaults to 9. + axis_off (bool, optional): Whether to turn off the axis. Defaults to False. + max_width (str, optional): Maximum width of the colorbar in pixels. Defaults to None. + + Raises: + TypeError: If the vis_params is not a dictionary. + ValueError: If the orientation is not either horizontal or vertical. + ValueError: If the provided min value is not scalar type. + ValueError: If the provided max value is not scalar type. + ValueError: If the provided opacity value is not scalar type. + ValueError: If cmap or palette is not provided. + """ + + import matplotlib # pylint: disable=import-outside-toplevel + import numpy # pylint: disable=import-outside-toplevel + + if max_width is None: + if orientation == "horizontal": + max_width = "270px" + else: + max_width = "100px" + + if isinstance(vis_params, (list, tuple)): + vis_params = {"palette": list(vis_params)} + elif not vis_params: + vis_params = {} + + if not isinstance(vis_params, dict): + raise TypeError("The vis_params must be a dictionary.") + + if isinstance(kwargs.get("colors"), (list, tuple)): + vis_params["palette"] = list(kwargs["colors"]) + + width, height = self._get_dimensions(orientation, kwargs) + + vmin = vis_params.get("min", kwargs.pop("vmin", 0)) + if type(vmin) not in (int, float): + raise TypeError("The provided min value must be scalar type.") + + vmax = vis_params.get("max", kwargs.pop("mvax", 1)) + if type(vmax) not in (int, float): + raise TypeError("The provided max value must be scalar type.") + + alpha = vis_params.get("opacity", kwargs.pop("alpha", 1)) + if type(alpha) not in (int, float): + raise TypeError("The provided opacity or alpha value must be type scalar.") + + if "palette" in vis_params.keys(): + hexcodes = common.to_hex_colors(common.check_cmap(vis_params["palette"])) + if discrete: + cmap = matplotlib.colors.ListedColormap(hexcodes) + linspace = numpy.linspace(vmin, vmax, cmap.N + 1) + norm = matplotlib.colors.BoundaryNorm(linspace, cmap.N) + else: + cmap = matplotlib.colors.LinearSegmentedColormap.from_list( + "custom", hexcodes, N=256 + ) + norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) + elif cmap: + cmap = matplotlib.pyplot.get_cmap(cmap) + norm = matplotlib.colors.Normalize(vmin=vmin, vmax=vmax) + else: + raise ValueError( + 'cmap keyword or "palette" key in vis_params must be provided.' + ) + + fig, ax = matplotlib.pyplot.subplots(figsize=(width, height)) + cb = matplotlib.colorbar.ColorbarBase( + ax, + norm=norm, + alpha=alpha, + cmap=cmap, + orientation=orientation, + **kwargs, + ) + + label = label or vis_params.get("bands") or kwargs.pop("caption", None) + if label: + cb.set_label(label, fontsize=font_size) + + if axis_off: + ax.set_axis_off() + ax.tick_params(labelsize=font_size) + + # Set the background color to transparent. + if transparent_bg: + fig.patch.set_alpha(0.0) + + super().__init__(layout=ipywidgets.Layout(width=max_width)) + with self: + self.outputs = () + matplotlib.pyplot.show() + + def _get_dimensions(self, orientation, kwargs): + default_dims = {"horizontal": (3.0, 0.3), "vertical": (0.3, 3.0)} + if orientation in default_dims: + default = default_dims[orientation] + return ( + kwargs.get("width", default[0]), + kwargs.get("height", default[1]), + ) + raise ValueError( + f"orientation must be one of [{', '.join(default_dims.keys())}]." + ) + + +class Inspector(ipywidgets.VBox): + host_map = None + expand_point = False + expand_pixels = True + expand_objects = False + on_close = None + + def __init__(self, host_map, names=None, visible=True, decimals=2, opened=True): + self.host_map = host_map + self.names = names + self.visible = visible + self.decimals = decimals + self.opened = opened + + host_map.default_style = {"cursor": "crosshair"} + + left_padded_square = ipywidgets.Layout( + width="28px", height="28px", padding="0px 0px 0px 4px" + ) + + self.toolbar_button = ipywidgets.ToggleButton( + value=opened, tooltip="Inspector", icon="info", layout=left_padded_square + ) + self.toolbar_button.observe(self._on_toolbar_btn_click, "value") + + close_button = ipywidgets.ToggleButton( + value=False, + tooltip="Close the tool", + icon="times", + button_style="primary", + layout=left_padded_square, + ) + close_button.observe(self._on_close_btn_click, "value") + + self.inspector_output = ipywidgets.Output( + layout=ipywidgets.Layout( + max_width="600px", max_height="500px", overflow="auto" + ) + ) + expand_point = self._create_expand_checkbox("Point", self.expand_point) + expand_pixels = self._create_expand_checkbox("Pixels", self.expand_pixels) + expand_objects = self._create_expand_checkbox("Objects", self.expand_objects) + expand_point.observe(self._on_expand_point_changed, "value") + expand_pixels.observe(self._on_expand_pixels_changed, "value") + expand_objects.observe(self._on_expand_objects_changed, "value") + self.inspector_checks = ipywidgets.HBox( + children=[ + ipywidgets.Label( + "Expand", layout=ipywidgets.Layout(padding="0px 8px 0px 4px") + ), + expand_point, + expand_pixels, + expand_objects, + ] + ) + self._clear_inspector_output() + + self.toolbar_header = ipywidgets.HBox( + children=[close_button, self.toolbar_button] + ) + self.toolbar_footer = ipywidgets.VBox(children=[self.inspector_output]) + + host_map.on_interaction(self._on_map_interaction) + self.toolbar_button.value = opened + + super().__init__(children=[self.toolbar_header, self.toolbar_footer]) + + def _create_expand_checkbox(self, title, checked): + layout = ipywidgets.Layout(width="auto", padding="0px 6px 0px 0px") + return ipywidgets.Checkbox( + description=title, indent=False, value=checked, layout=layout + ) + + def _on_map_interaction(self, **kwargs): + latlon = kwargs.get("coordinates") + if kwargs.get("type") == "click" and self.toolbar_button.value: + self.host_map.default_style = {"cursor": "wait"} + self._clear_inspector_output() + tree = ipytree.Tree() + nodes = [] + nodes.append(self._point_info(latlon, return_node=True)) + pixels_node = self._pixels_info(latlon, return_node=True) + if pixels_node.nodes: + nodes.append(pixels_node) + objects_node = self._objects_info(latlon, return_node=True) + if objects_node.nodes: + nodes.append(objects_node) + tree.nodes = nodes + with self.inspector_output: + display(tree) + self.host_map.default_style = {"cursor": "crosshair"} + + def _clear_inspector_output(self): + with self.inspector_output: + self.inspector_output.clear_output(wait=True) + display(self.inspector_checks) + + def _on_expand_point_changed(self, change): + self.expand_point = change["new"] + + def _on_expand_pixels_changed(self, change): + self.expand_pixels = change["new"] + + def _on_expand_objects_changed(self, change): + self.expand_objects = change["new"] + + def _on_toolbar_btn_click(self, change): + if change["new"]: + self.host_map.default_style = {"cursor": "crosshair"} + self.children = [self.toolbar_header, self.toolbar_footer] + self._clear_inspector_output() + else: + self.children = [self.toolbar_button] + self.host_map.default_style = {"cursor": "default"} + + def _on_close_btn_click(self, change): + if change["new"]: + if self.host_map: + self.host_map.default_style = {"cursor": "default"} + self.host_map.on_interaction(self._on_map_interaction, remove=True) + if self.on_close is not None: + self.on_close() + + def _get_visible_map_layers(self): + layers = {} + if self.names: + names = [names] if isinstance(names, str) else self.names + for name in names: + if name in self.host_map.ee_layer_names: + layers[name] = self.host_map.ee_layer_dict[name] + else: + layers = self.host_map.ee_layer_dict + return {k: v for k, v in layers.items() if v["ee_layer"].visible} + + def _root_node(self, title, nodes, return_node, **kwargs): + root_node = ipytree.Node( + title, + icon="archive", + nodes=nodes, + open_icon="plus-square", + open_icon_style="success", + close_icon="minus-square", + close_icon_style="info", + **kwargs, + ) + return root_node if return_node else ipytree.Tree(nodes=[root_node]) + + def _point_info(self, latlon, return_node=False): + scale = self.host_map.get_scale() + label = f"Point ({latlon[1]:.{self.decimals}f}, {latlon[0]:.{self.decimals}f}) at {int(scale)}m/px" + nodes = [ + ipytree.Node(f"Longitude: {latlon[1]}"), + ipytree.Node(f"Latitude: {latlon[0]}"), + ipytree.Node(f"Zoom Level: {self.host_map.zoom}"), + ipytree.Node(f"Scale (approx. m/px): {scale}"), + ] + return self._root_node(label, nodes, return_node, opened=self.expand_point) + + def _pixels_info(self, latlon, return_node=False): + if not self.visible: + return self._root_node("Pixels", [], return_node) + + layers = self._get_visible_map_layers() + xy = ee.Geometry.Point(latlon[::-1]) + scale = self.host_map.getScale() + nodes = [] + for layer_name, layer in layers.items(): + obj = layer["ee_object"] + if isinstance(obj, ee.ImageCollection): + obj = obj.mosaic() + if isinstance(obj, ee.Image): + try: + item = obj.reduceRegion(ee.Reducer.first(), xy, scale).getInfo() + except: + continue + b_name = "band" if len(item) == 1 else "bands" + obj_type = obj.__class__.__name__ + label = f"{layer_name}: {obj_type} ({len(item)} {b_name})" + layer_node = ipytree.Node(label, opened=self.expand_pixels) + for key, value in sorted(item.items()): + if isinstance(value, float): + value = round(value, self.decimals) + layer_node.add_node(ipytree.Node(f"{key}: {value}", icon="file")) + nodes.append(layer_node) + + return self._root_node("Pixels", nodes, return_node) + + def _objects_info(self, latlon, return_node=False): + if not self.visible: + return self._root_node("Objects", [], return_node) + + layers = self._get_visible_map_layers() + xy = ee.Geometry.Point(latlon[::-1]) + nodes = [] + for layer_name, layer in layers.items(): + obj = layer["ee_object"] + if isinstance(obj, ee.FeatureCollection): + geom_type = ee.Feature(obj.first()).geometry().type() + lat, lon = latlon + delta = 0.005 + bbox = ee.Geometry.BBox( + lon - delta, lat - delta, lon + delta, lat + delta + ) + xy = ee.Algorithms.If(geom_type.compareTo(ee.String("Point")), xy, bbox) + obj = obj.filterBounds(xy).first() + try: + nodes.append( + common.get_info( + obj, layer_name, opened=self.expand_objects, return_node=True + ) + ) + except: + pass + + return self._root_node("Objects", nodes, return_node) + + +class DrawActions(enum.StrEnum): + CREATED='created' + EDITED='edited' + DELETED='deleted' + REMOVED_LAST='removed-last' + + +class AbstractDrawControl(object): + host_map = None + layer = None + geometries = [] + properties = [] + last_geometry = None + last_draw_action = None + _geometry_create_dispatcher = ipywidgets.CallbackDispatcher() + _geometry_edit_dispatcher = ipywidgets.CallbackDispatcher() + _geometry_delete_dispatcher = ipywidgets.CallbackDispatcher() + + def __init__(self, host_map): + self.host_map = host_map + self.layer = None + self.geometries = [] + self.properties = [] + self.last_geometry = None + self.last_draw_action = None + self._geometry_create_dispatcher = ipywidgets.CallbackDispatcher() + self._geometry_edit_dispatcher = ipywidgets.CallbackDispatcher() + self._geometry_delete_dispatcher = ipywidgets.CallbackDispatcher() + self._bind_to_draw_control() + + @property + def features(self): + if self.count: + return [ + ee.Feature(geometry, self.properties[i]) for i, geometry in enumerate(self.geometries) + ] + else: + return [] + + @property + def collection(self): + return ee.FeatureCollection(self.features) if self.count else None + + @property + def last_feature(self): + property = self.get_geometry_properties(self.last_geometry) + return ee.Feature(self.last_geometry, property) if self.last_geometry else None + + @property + def count(self): + return len(self.geometries) + + def reset(self, clear_draw_control=True): + """Resets the draw controls.""" + if self.layer is not None: + self.host_map.remove_layer(self.layer) + self.geometries = [] + self.properties = [] + self.last_geometry = None + self.layer = None + if clear_draw_control: + self._clear_draw_control() + + def remove_geometry(self, geometry): + index = self.geometries.index(geometry) + if index >= 0: + del self.geometries[index] + del self.properties[index] + self._remove_geometry_at_index_on_draw_control(index) + if index == self.count and geometry == self.last_geometry: + # Treat this like an "undo" of the last drawn geometry. + self.last_geometry = self.geometries[-1] + self.last_draw_action = DrawActions.REMOVED_LAST + if self.layer is not None: + self._redraw_layer() + + def get_geometry_properties(self, geometry): + index = self.geometries.index(geometry) + if index >= 0: + return self.properties[index] + else: + return None + + def set_geometry_properties(self, geometry, property): + index = self.geometries.index(geometry) + if index >= 0: + self.properties[index] = property + + def on_geometry_create(self, callback, remove=False): + self._geometry_create_dispatcher.register_callback(callback, remove=remove) + + def on_geometry_edit(self, callback, remove=False): + self._geometry_edit_dispatcher.register_callback(callback, remove=remove) + + def on_geometry_delete(self, callback, remove=False): + self._geometry_delete_dispatcher.register_callback(callback, remove=remove) + + def _bind_to_draw_control(self): + """Set up draw control event handling like create, edit, and delete.""" + raise NotImplementedError() + + def _remove_geometry_at_index_on_draw_control(self): + """Remove the geometry at the given index on the draw control.""" + raise NotImplementedError() + + def _clear_draw_control(self): + """Clears the geometries from the draw control.""" + raise NotImplementedError() + + def _get_synced_geojson_from_draw_control(self): + """Returns an up-to-date of GeoJSON from the draw control.""" + raise NotImplementedError() + + def _sync_geometries(self): + """Sync the local geometries with those from the draw control.""" + if not self.count: + return + # The current geometries from the draw_control. + test_geojsons = self._get_synced_geojson_from_draw_control() + i = 0 + while i < self.count and i < len(test_geojsons): + local_geometry = None + test_geometry = None + while i < self.count and i < len(test_geojsons): + local_geometry = self.geometries[i] + test_geometry = common.geojson_to_ee(test_geojsons[i], False) + if test_geometry == local_geometry: + i += 1 + else: + break + if i < self.count and test_geometry is not None: + self.geometries[i] = test_geometry + if self.layer is not None: + self._redraw_layer() + + def _redraw_layer(self): + layer = EELeafletTileLayer( + self.collection, {"color": "blue"}, "Drawn Features", False, 0.5 + ) + if self.host_map: + layer_index = self.host_map.find_layer_index("Drawn Features") + if layer_index == -1: + self.host_map.add_layer(layer) + else: + self.host_map.substitute(self.host_map.layers[layer_index], layer) + self.layer = layer + + def _handle_geometry_created(self, geo_json): + geometry = common.geojson_to_ee(geo_json, False) + self.last_geometry = geometry + self.last_draw_action = DrawActions.CREATED + self.geometries.append(geometry) + self.properties.append(None) + self._redraw_layer() + self._geometry_create_dispatcher(self, geometry=geometry) + + def _handle_geometry_edited(self, geo_json): + geometry = common.geojson_to_ee(geo_json, False) + self.last_geometry = geometry + self.last_draw_action = DrawActions.EDITED + self._sync_geometries() + self._redraw_layer() + self._geometry_edit_dispatcher(self, geometry=geometry) + + def _handle_geometry_deleted(self, geo_json): + geometry = common.geojson_to_ee(geo_json, False) + self.last_geometry = geometry + self.last_draw_action = DrawActions.DELETED + i = self.geometries.index(geometry) + del self.geometries[i] + del self.properties[i] + self._redraw_layer() + self._geometry_delete_dispatcher(self, geometry=geometry) \ No newline at end of file diff --git a/geemap/toolbar.py b/geemap/toolbar.py index 73b9f67317..6b51f3c088 100644 --- a/geemap/toolbar.py +++ b/geemap/toolbar.py @@ -19,6 +19,7 @@ from .common import * from .timelapse import * +from .geemap import MapDrawControl def main_toolbar(m, position="topright", **kwargs): @@ -158,8 +159,7 @@ def tool_callback(change): m.remove_drawn_features() tool.value = False elif tool_name == "inspector": - if not hasattr(m, "inspector_control"): - m.add_inspector() + m.add_inspector() tool.value = False elif tool_name == "plotting": ee_plot_gui(m) @@ -717,201 +717,6 @@ def handle_interaction(**kwargs): return toolbar_widget -def ee_inspector_gui( - m, - names=None, - visible=True, - decimals=2, - position="topright", - opened=True, - show_close_button=True, - max_width="300px", -): - """Earth Engine Inspector GUI. - - Args: - m (geemap.Map): The geemap.Map object. - names (str | list, optional): The names of the layers to be included. Defaults to None. - visible (bool, optional): Whether to inspect visible layers only. Defaults to True. - decimals (int, optional): The number of decimal places to round the values. Defaults to 2. - position (str, optional): The position of the control. Defaults to "topright". - opened (bool, optional): Whether the control is opened. Defaults to True. - show_close_button (bool, optional): Whether to show the close button. Defaults to True. - max_width - - """ - from ipytree import Tree - - m._expand_point = False - m._expand_pixels = True - m._expand_objects = False - m.default_style = {"cursor": "crosshair"} - - toolbar_button = widgets.ToggleButton( - value=True, - tooltip="Inspector", - icon="info", - layout=widgets.Layout(width="28px", height="28px", padding="0px 0px 0px 4px"), - ) - - close_button = widgets.ToggleButton( - value=False, - tooltip="Close the tool", - icon="times", - button_style="primary", - layout=widgets.Layout(height="28px", width="28px", padding="0px 0px 0px 4px"), - ) - - layout = { - "border": "1px solid black", - "max_width": max_width, - "max_height": "500px", - "overflow": "auto", - } - inspector_output = widgets.Output(layout=layout) - - expand_label = widgets.Label( - "Expand ", - layout=widgets.Layout(padding="0px 0px 0px 4px"), - ) - - expand_point = widgets.Checkbox( - description="Point", - indent=False, - value=m._expand_point, - layout=widgets.Layout(width="65px"), - ) - - expand_pixels = widgets.Checkbox( - description="Pixels", - indent=False, - value=m._expand_pixels, - layout=widgets.Layout(width="65px"), - ) - - expand_objects = widgets.Checkbox( - description="Objects", - indent=False, - value=m._expand_objects, - layout=widgets.Layout(width="70px"), - ) - - def expand_point_changed(change): - m._expand_point = change["new"] - - def expand_pixels_changed(change): - m._expand_pixels = change["new"] - - def expand_objects_changed(change): - m._expand_objects = change["new"] - - expand_point.observe(expand_point_changed, "value") - expand_pixels.observe(expand_pixels_changed, "value") - expand_objects.observe(expand_objects_changed, "value") - - inspector_checks = widgets.HBox() - inspector_checks.children = [ - expand_label, - widgets.Label(""), - expand_point, - expand_pixels, - expand_objects, - ] - - with inspector_output: - inspector_output.outputs = () - display(inspector_checks) - - toolbar_header = widgets.HBox() - if show_close_button: - toolbar_header.children = [close_button, toolbar_button] - else: - toolbar_header.children = [toolbar_button] - toolbar_footer = widgets.VBox() - toolbar_footer.children = [inspector_output] - toolbar_widget = widgets.VBox() - toolbar_widget.children = [toolbar_header, toolbar_footer] - - def handle_interaction(**kwargs): - latlon = kwargs.get("coordinates") - if kwargs.get("type") == "click" and toolbar_button.value: - m.default_style = {"cursor": "wait"} - ###################################### Temporary fix for Solara - inspector_output = widgets.Output(layout=layout) - toolbar_footer.children = [inspector_output] - ###################################### - with inspector_output: - inspector_output.outputs = () - display(inspector_checks) - - tree = Tree() - nodes = [] - point_node = m._point_info(latlon, return_node=True) - nodes.append(point_node) - pixels_node = m._pixels_info( - latlon, names, visible, decimals, return_node=True - ) - if pixels_node.nodes: - nodes.append(pixels_node) - objects_node = m._objects_info(latlon, names, visible, return_node=True) - if objects_node.nodes: - nodes.append(objects_node) - tree.nodes = nodes - - display(tree) - m.default_style = {"cursor": "crosshair"} - - m.on_interaction(handle_interaction) - - def toolbar_btn_click(change): - if change["new"]: - m.default_style = {"cursor": "crosshair"} - # close_button.value = False - toolbar_widget.children = [toolbar_header, toolbar_footer] - ###################################### Temporary fix for Solara - inspector_output = widgets.Output(layout=layout) - toolbar_footer.children = [inspector_output] - ###################################### - with inspector_output: - inspector_output.outputs = () - display(inspector_checks) - else: - toolbar_widget.children = [toolbar_button] - m.default_style = {"cursor": "default"} - - toolbar_button.observe(toolbar_btn_click, "value") - - def close_btn_click(change): - if change["new"]: - m.default_style = {"cursor": "default"} - toolbar_button.value = False - if m is not None: - m.toolbar_reset() - m.on_interaction(handle_interaction, remove=True) - if ( - m.inspector_control is not None - and m.inspector_control in m.controls - ): - m.remove_control(m.inspector_control) - m.inspector_control = None - delattr(m, "inspector_control") - toolbar_widget.close() - - close_button.observe(close_btn_click, "value") - - toolbar_button.value = opened - if m is not None: - inspector_control = ipyleaflet.WidgetControl( - widget=toolbar_widget, position=position - ) - - if inspector_control not in m.controls: - m.add(inspector_control) - m.inspector_control = inspector_control - else: - return toolbar_widget - - def layer_manager_gui( m, position="topright", opened=True, return_widget=False, show_close_button=True ): @@ -2423,7 +2228,8 @@ def button_clicked(change): if change["new"] == "Apply": if len(color.value) != 7: color.value = "#3388ff" - draw_control = ipyleaflet.DrawControl( + draw_control = MapDrawControl( + host_map = m, marker={"shapeOptions": {"color": color.value}, "repeatMode": False}, rectangle={"shapeOptions": {"color": color.value}, "repeatMode": False}, polygon={"shapeOptions": {"color": color.value}, "repeatMode": False}, @@ -2432,19 +2238,8 @@ def button_clicked(change): edit=False, remove=False, ) - - controls = [] - old_draw_control = None - for control in m.controls: - if isinstance(control, ipyleaflet.DrawControl): - controls.append(draw_control) - old_draw_control = control - - else: - controls.append(control) - - m.controls = tuple(controls) - old_draw_control.close() + m.remove_draw_control() + m.add(draw_control) m.draw_control = draw_control train_props = {} @@ -2463,52 +2258,10 @@ def button_clicked(change): train_props["color"] = color.value # Handles draw events - def handle_draw(target, action, geo_json): - from .ee_tile_layers import EELeafletTileLayer - - try: - geom = geojson_to_ee(geo_json, False) - m.user_roi = geom - - if len(train_props) > 0: - feature = ee.Feature(geom, train_props) - else: - feature = ee.Feature(geom) - m.draw_last_feature = feature - if not hasattr(m, "_draw_count"): - m._draw_count = 0 - if action == "deleted" and len(m.draw_features) > 0: - m.draw_features.remove(feature) - m._draw_count -= 1 - else: - m.draw_features.append(feature) - m._draw_count += 1 - collection = ee.FeatureCollection(m.draw_features) - m.user_rois = collection - ee_draw_layer = EELeafletTileLayer( - collection, {"color": "blue"}, "Drawn Features", False, 0.5 - ) - draw_layer_index = m.find_layer_index("Drawn Features") - - if draw_layer_index == -1: - m.add_layer(ee_draw_layer) - m.draw_layer = ee_draw_layer - else: - m.substitute_layer(m.draw_layer, ee_draw_layer) - m.draw_layer = ee_draw_layer - - except Exception as e: - m._draw_count = 0 - m.draw_features = [] - m.draw_last_feature = None - m.draw_layer = None - m.user_roi = None - m._roi_start = False - m._roi_end = False - print("There was an error creating Earth Engine Feature.") - raise Exception(e) - - draw_control.on_draw(handle_draw) + def set_properties(_, geometry): + if len(train_props) > 0: + draw_control.set_geometry_properties(geometry, train_props) + draw_control.on_geometry_create(set_properties) elif change["new"] == "Clear": prop_text1.value = "" @@ -2521,6 +2274,9 @@ def handle_draw(target, action, geo_json): if m.training_ctrl is not None and m.training_ctrl in m.controls: m.remove_control(m.training_ctrl) full_widget.close() + # Restore default draw control. + m.remove_draw_control() + m.add_draw_control() buttons.value = None buttons.observe(button_clicked, "value") diff --git a/mkdocs.yml b/mkdocs.yml index 8520fc3ddb..b72e1d4e98 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -116,6 +116,7 @@ nav: - geemap module: geemap.md - kepler module: kepler.md - legends module: legends.md + - map_widgets module: map_widgets.md - ml module: ml.md - osm module: osm.md - plot module: plot.md diff --git a/tests/test_map_widgets.py b/tests/test_map_widgets.py new file mode 100644 index 0000000000..caabe3e2e5 --- /dev/null +++ b/tests/test_map_widgets.py @@ -0,0 +1,222 @@ +#!/usr/bin/env python + +"""Tests for `map_widgets` module.""" + + +import unittest +from unittest.mock import patch, MagicMock, ANY +from geemap import map_widgets + + +class TestColorbar(unittest.TestCase): + """Tests for the Colorbar class in the `map_widgets` module.""" + + TEST_COLORS = ["blue", "red", "green"] + TEST_COLORS_HEX = ["#0000ff", "#ff0000", "#008000"] + + def setUp(self): + self.fig_mock = MagicMock() + self.ax_mock = MagicMock() + self.subplots_mock = patch("matplotlib.pyplot.subplots").start() + self.subplots_mock.return_value = (self.fig_mock, self.ax_mock) + + self.colorbar_base_mock = MagicMock() + self.colorbar_base_class_mock = patch( + "matplotlib.colorbar.ColorbarBase" + ).start() + self.colorbar_base_class_mock.return_value = self.colorbar_base_mock + + self.normalize_mock = MagicMock() + self.normalize_class_mock = patch("matplotlib.colors.Normalize").start() + self.normalize_class_mock.return_value = self.normalize_mock + + self.boundary_norm_mock = MagicMock() + self.boundary_norm_class_mock = patch("matplotlib.colors.BoundaryNorm").start() + self.boundary_norm_class_mock.return_value = self.boundary_norm_mock + + self.listed_colormap = MagicMock() + self.listed_colormap_class_mock = patch( + "matplotlib.colors.ListedColormap" + ).start() + self.listed_colormap_class_mock.return_value = self.listed_colormap + + self.linear_segmented_colormap_mock = MagicMock() + self.colormap_from_list_mock = patch( + "matplotlib.colors.LinearSegmentedColormap.from_list" + ).start() + self.colormap_from_list_mock.return_value = self.linear_segmented_colormap_mock + + check_cmap_mock = patch("geemap.common.check_cmap").start() + check_cmap_mock.side_effect = lambda x: x + + self.cmap_mock = MagicMock() + self.get_cmap_mock = patch("matplotlib.pyplot.get_cmap").start() + self.get_cmap_mock.return_value = self.cmap_mock + + def tearDown(self): + patch.stopall() + + def test_colorbar_no_args(self): + map_widgets.Colorbar() + self.normalize_class_mock.assert_called_with(vmin=0, vmax=1) + self.get_cmap_mock.assert_called_with("gray") + self.subplots_mock.assert_called_with(figsize=(3.0, 0.3)) + self.ax_mock.set_axis_off.assert_not_called() + self.ax_mock.tick_params.assert_called_with(labelsize=9) + self.fig_mock.patch.set_alpha.assert_not_called() + self.colorbar_base_mock.set_label.assert_not_called() + self.colorbar_base_class_mock.assert_called_with( + self.ax_mock, + norm=self.normalize_mock, + alpha=1, + cmap=self.cmap_mock, + orientation="horizontal", + ) + + def test_colorbar_orientation_horizontal(self): + map_widgets.Colorbar(orientation="horizontal") + self.subplots_mock.assert_called_with(figsize=(3.0, 0.3)) + + def test_colorbar_orientation_vertical(self): + map_widgets.Colorbar(orientation="vertical") + self.subplots_mock.assert_called_with(figsize=(0.3, 3.0)) + + def test_colorbar_orientation_override(self): + map_widgets.Colorbar(orientation="horizontal", width=2.0) + self.subplots_mock.assert_called_with(figsize=(2.0, 0.3)) + + def test_colorbar_invalid_orientation(self): + with self.assertRaisesRegex(ValueError, "orientation must be one of"): + map_widgets.Colorbar(orientation="not an orientation") + + def test_colorbar_label(self): + map_widgets.Colorbar(label="Colorbar lbl", font_size=42) + self.colorbar_base_mock.set_label.assert_called_with( + "Colorbar lbl", fontsize=42 + ) + + def test_colorbar_label_as_bands(self): + map_widgets.Colorbar(vis_params={"bands": "b1"}) + self.colorbar_base_mock.set_label.assert_called_with("b1", fontsize=9) + + def test_colorbar_label_with_caption(self): + map_widgets.Colorbar(caption="Colorbar caption") + self.colorbar_base_mock.set_label.assert_called_with( + "Colorbar caption", fontsize=9 + ) + + def test_colorbar_label_precedence(self): + map_widgets.Colorbar( + label="Colorbar lbl", + vis_params={"bands": "b1"}, + caption="Colorbar caption", + font_size=21, + ) + self.colorbar_base_mock.set_label.assert_called_with( + "Colorbar lbl", fontsize=21 + ) + + def test_colorbar_axis(self): + map_widgets.Colorbar(axis_off=True, font_size=24) + self.ax_mock.set_axis_off.assert_called() + self.ax_mock.tick_params.assert_called_with(labelsize=24) + + def test_colorbar_transparent_bg(self): + map_widgets.Colorbar(transparent_bg=True) + self.fig_mock.patch.set_alpha.assert_called_with(0.0) + + def test_colorbar_vis_params_palette(self): + map_widgets.Colorbar( + vis_params={ + "palette": self.TEST_COLORS, + "min": 11, + "max": 21, + "opacity": 0.2, + } + ) + self.normalize_class_mock.assert_called_with(vmin=11, vmax=21) + self.colormap_from_list_mock.assert_called_with( + "custom", self.TEST_COLORS_HEX, N=256 + ) + self.colorbar_base_class_mock.assert_called_with( + self.ax_mock, + norm=self.normalize_mock, + alpha=0.2, + cmap=self.linear_segmented_colormap_mock, + orientation="horizontal", + ) + + def test_colorbar_vis_params_discrete_palette(self): + map_widgets.Colorbar( + vis_params={"palette": self.TEST_COLORS, "min": -1}, discrete=True + ) + self.boundary_norm_class_mock.assert_called_with([-1], ANY) + self.listed_colormap_class_mock.assert_called_with(self.TEST_COLORS_HEX) + self.colorbar_base_class_mock.assert_called_with( + self.ax_mock, + norm=self.boundary_norm_mock, + alpha=1, + cmap=self.listed_colormap, + orientation="horizontal", + ) + + def test_colorbar_vis_params_palette_as_list(self): + map_widgets.Colorbar(vis_params=self.TEST_COLORS, discrete=True) + self.boundary_norm_class_mock.assert_called_with([0], ANY) + self.listed_colormap_class_mock.assert_called_with(self.TEST_COLORS_HEX) + self.colorbar_base_class_mock.assert_called_with( + self.ax_mock, + norm=self.boundary_norm_mock, + alpha=1, + cmap=self.listed_colormap, + orientation="horizontal", + ) + + def test_colorbar_kwargs_colors(self): + map_widgets.Colorbar(colors=self.TEST_COLORS, discrete=True) + self.boundary_norm_class_mock.assert_called_with([0], ANY) + self.listed_colormap_class_mock.assert_called_with(self.TEST_COLORS_HEX) + self.colorbar_base_class_mock.assert_called_with( + self.ax_mock, + norm=self.boundary_norm_mock, + alpha=1, + cmap=self.listed_colormap, + orientation="horizontal", + colors=self.TEST_COLORS, + ) + + def test_colorbar_min_max(self): + map_widgets.Colorbar( + vis_params={"palette": self.TEST_COLORS, "min": -1.5}, vmin=-1, vmax=2 + ) + self.normalize_class_mock.assert_called_with(vmin=-1.5, vmax=1) + + def test_colorbar_invalid_min(self): + with self.assertRaisesRegex(TypeError, "min value must be scalar type"): + map_widgets.Colorbar(vis_params={"min": "invalid_min"}) + + def test_colorbar_invalid_max(self): + with self.assertRaisesRegex(TypeError, "max value must be scalar type"): + map_widgets.Colorbar(vis_params={"max": "invalid_max"}) + + def test_colorbar_opacity(self): + map_widgets.Colorbar(vis_params={"opacity": 0.5}, colors=self.TEST_COLORS) + self.colorbar_base_class_mock.assert_called_with( + ANY, norm=ANY, alpha=0.5, cmap=ANY, orientation=ANY, colors=ANY + ) + + def test_colorbar_alpha(self): + map_widgets.Colorbar(alpha=0.5, colors=self.TEST_COLORS) + self.colorbar_base_class_mock.assert_called_with( + ANY, norm=ANY, alpha=0.5, cmap=ANY, orientation=ANY, colors=ANY + ) + + def test_colorbar_invalid_alpha(self): + with self.assertRaisesRegex( + TypeError, "opacity or alpha value must be type scalar" + ): + map_widgets.Colorbar(alpha="invalid_alpha", colors=self.TEST_COLORS) + + def test_colorbar_vis_params_throws_for_not_dict(self): + with self.assertRaisesRegex(TypeError, "vis_params must be a dictionary"): + map_widgets.Colorbar(vis_params="NOT a dict")