@@ -434,6 +434,49 @@ def reversed(self, name: str | None = None) -> Colormap:
434434 self .color_stops .reversed (), name = name , category = self .category
435435 )
436436
437+ def shifted (
438+ self ,
439+ shift : float = 0.5 ,
440+ name : str | None = None ,
441+ mode : Literal ["wrap" , "clip" ] = "wrap" ,
442+ ) -> Colormap :
443+ """Return a new Colormap, with colors shifted by a scalar value.
444+
445+ This method shifts the stops in the colormap by a scalar value.
446+ It makes most sense for cyclic colormaps, but can be used with any colormap.
447+
448+ Parameters
449+ ----------
450+ shift : float
451+ The amount to shift the colormap. Positive values shift the colormap
452+ towards the end, negative values shift the colormap towards the beginning.
453+ name : str
454+ A new name for the colormap. If not provided, the name of the new colormap
455+ will be the name of the original colormap with "_shifted{shift}" appended.
456+ mode : {'wrap', 'clip'}, optional
457+ The mode to use when shifting the colormap. Must be one of 'wrap' or
458+ 'clip'. If 'wrap', the colormap will be shifted and wrapped around the ends.
459+ If 'clip', the colormap will be shifted and the colors at the ends will be
460+ clipped and/or repeated as necessary.
461+
462+ Returns
463+ -------
464+ Colormap
465+ A new colormap with the colors shifted.
466+ """
467+ if name is None :
468+ name = f"{ self .name } _shifted{ shift } "
469+
470+ return type (self )(
471+ self .color_stops .shifted (shift = shift , mode = mode ),
472+ name = name ,
473+ category = self .category ,
474+ interpolation = self .interpolation ,
475+ under = self .under_color ,
476+ over = self .over_color ,
477+ bad = self .bad_color ,
478+ )
479+
437480 def to_css (
438481 self ,
439482 max_stops : int | None = None ,
@@ -969,7 +1012,35 @@ def reversed(self) -> ColorStops:
9691012 # invert the positions in the stops
9701013 rev_stops = self ._stops [::- 1 ]
9711014 rev_stops [:, 0 ] = 1 - rev_stops [:, 0 ]
972- return type (self )(rev_stops )
1015+ return type (self )(rev_stops , interpolation = self ._interpolation )
1016+
1017+ def shifted (
1018+ self , shift : float , mode : Literal ["wrap" , "clip" ] = "wrap"
1019+ ) -> ColorStops :
1020+ """Return a new ColorStops object with all positions shifted by `shift`.
1021+
1022+ Parameters
1023+ ----------
1024+ shift : float
1025+ The amount to shift the colormap. Positive values shift the colormap
1026+ towards the end, negative values shift the colormap towards the beginning.
1027+ mode : {'wrap', 'clip'}, optional
1028+ The mode to use when shifting the colormap. Must be one of 'wrap' or
1029+ 'clip'. If 'wrap', the colormap will be shifted and wrapped around the ends.
1030+ If 'clip', the colormap will be shifted and the colors at the ends will be
1031+ clipped and/or repeated as necessary.
1032+ """
1033+ if mode not in {"wrap" , "clip" }: # pragma: no cover
1034+ raise ValueError ("mode must be 'wrap' or 'clip'" )
1035+
1036+ if mode == "wrap" :
1037+ stops = _wrap_shift_color_stops (self ._stops , shift )
1038+ else :
1039+ stops = self ._stops .copy ()
1040+ stops [:, 0 ] += shift
1041+ # throw away stops that are out of bounds
1042+ stops = stops [(stops [:, 0 ] >= 0 ) & (stops [:, 0 ] <= 1 )]
1043+ return type (self )(stops , interpolation = self ._interpolation )
9731044
9741045 @classmethod
9751046 def __get_pydantic_core_schema__ (
@@ -1317,3 +1388,16 @@ def _html_color_patch(color: Color | None) -> str:
13171388 "border: 1px solid #555; "
13181389 f'background-color: { color .hex } ;"></div>'
13191390 )
1391+
1392+
1393+ def _wrap_shift_color_stops (data : np .ndarray , shift_amount : float ) -> np .ndarray :
1394+ """Shift (N, 5) array of color stops by `shift_amount` and wrap around."""
1395+ out = np .array (data , copy = True )
1396+ # ensure that 0 <= data < 1
1397+ # this is important for dealing with wraparound when shifting
1398+ out [:, 0 ][out [:, 0 ] == 1 ] = 1 - np .finfo (float ).eps
1399+ out [:, 0 ] += shift_amount
1400+ out [:, 0 ] %= 1
1401+ # sort the array by the first column
1402+ out = out [out [:, 0 ].argsort ()]
1403+ return out
0 commit comments