Skip to content

Commit cb7d836

Browse files
authored
feat: add shifted method (#52)
* feat: add shifted method * fix shifting * add test
1 parent 5e1a778 commit cb7d836

File tree

2 files changed

+104
-1
lines changed

2 files changed

+104
-1
lines changed

src/cmap/_colormap.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -434,6 +434,49 @@ def reversed(self, name: str | None = None) -> Colormap:
434434
self.color_stops.reversed(), name=name, category=self.category
435435
)
436436

437+
def shifted(
438+
self,
439+
shift: float = 0.5,
440+
name: str | None = None,
441+
mode: Literal["wrap", "clip"] = "wrap",
442+
) -> Colormap:
443+
"""Return a new Colormap, with colors shifted by a scalar value.
444+
445+
This method shifts the stops in the colormap by a scalar value.
446+
It makes most sense for cyclic colormaps, but can be used with any colormap.
447+
448+
Parameters
449+
----------
450+
shift : float
451+
The amount to shift the colormap. Positive values shift the colormap
452+
towards the end, negative values shift the colormap towards the beginning.
453+
name : str
454+
A new name for the colormap. If not provided, the name of the new colormap
455+
will be the name of the original colormap with "_shifted{shift}" appended.
456+
mode : {'wrap', 'clip'}, optional
457+
The mode to use when shifting the colormap. Must be one of 'wrap' or
458+
'clip'. If 'wrap', the colormap will be shifted and wrapped around the ends.
459+
If 'clip', the colormap will be shifted and the colors at the ends will be
460+
clipped and/or repeated as necessary.
461+
462+
Returns
463+
-------
464+
Colormap
465+
A new colormap with the colors shifted.
466+
"""
467+
if name is None:
468+
name = f"{self.name}_shifted{shift}"
469+
470+
return type(self)(
471+
self.color_stops.shifted(shift=shift, mode=mode),
472+
name=name,
473+
category=self.category,
474+
interpolation=self.interpolation,
475+
under=self.under_color,
476+
over=self.over_color,
477+
bad=self.bad_color,
478+
)
479+
437480
def to_css(
438481
self,
439482
max_stops: int | None = None,
@@ -969,7 +1012,35 @@ def reversed(self) -> ColorStops:
9691012
# invert the positions in the stops
9701013
rev_stops = self._stops[::-1]
9711014
rev_stops[:, 0] = 1 - rev_stops[:, 0]
972-
return type(self)(rev_stops)
1015+
return type(self)(rev_stops, interpolation=self._interpolation)
1016+
1017+
def shifted(
1018+
self, shift: float, mode: Literal["wrap", "clip"] = "wrap"
1019+
) -> ColorStops:
1020+
"""Return a new ColorStops object with all positions shifted by `shift`.
1021+
1022+
Parameters
1023+
----------
1024+
shift : float
1025+
The amount to shift the colormap. Positive values shift the colormap
1026+
towards the end, negative values shift the colormap towards the beginning.
1027+
mode : {'wrap', 'clip'}, optional
1028+
The mode to use when shifting the colormap. Must be one of 'wrap' or
1029+
'clip'. If 'wrap', the colormap will be shifted and wrapped around the ends.
1030+
If 'clip', the colormap will be shifted and the colors at the ends will be
1031+
clipped and/or repeated as necessary.
1032+
"""
1033+
if mode not in {"wrap", "clip"}: # pragma: no cover
1034+
raise ValueError("mode must be 'wrap' or 'clip'")
1035+
1036+
if mode == "wrap":
1037+
stops = _wrap_shift_color_stops(self._stops, shift)
1038+
else:
1039+
stops = self._stops.copy()
1040+
stops[:, 0] += shift
1041+
# throw away stops that are out of bounds
1042+
stops = stops[(stops[:, 0] >= 0) & (stops[:, 0] <= 1)]
1043+
return type(self)(stops, interpolation=self._interpolation)
9731044

9741045
@classmethod
9751046
def __get_pydantic_core_schema__(
@@ -1317,3 +1388,16 @@ def _html_color_patch(color: Color | None) -> str:
13171388
"border: 1px solid #555; "
13181389
f'background-color: {color.hex};"></div>'
13191390
)
1391+
1392+
1393+
def _wrap_shift_color_stops(data: np.ndarray, shift_amount: float) -> np.ndarray:
1394+
"""Shift (N, 5) array of color stops by `shift_amount` and wrap around."""
1395+
out = np.array(data, copy=True)
1396+
# ensure that 0 <= data < 1
1397+
# this is important for dealing with wraparound when shifting
1398+
out[:, 0][out[:, 0] == 1] = 1 - np.finfo(float).eps
1399+
out[:, 0] += shift_amount
1400+
out[:, 0] %= 1
1401+
# sort the array by the first column
1402+
out = out[out[:, 0].argsort()]
1403+
return out

tests/test_colormap.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,22 @@ def test_with_extremes() -> None:
215215

216216
assert cm2.under_color == cm.under_color == Color("green")
217217
assert "under" in cm2._repr_html_()
218+
219+
220+
def test_shifted() -> None:
221+
cm = Colormap(["red", "blue", "yellow"])
222+
assert cm.shifted(1) == cm
223+
assert "shifted0.3" in cm.shifted(0.3).name
224+
# two shifts of 0.5 should give the original array
225+
assert cm.shifted().shifted() == cm
226+
227+
wrapped = cm.shifted(0.2, mode="wrap")
228+
assert wrapped == Colormap([(0.2, "yellow"), (0.2, "red"), (0.7, "blue")])
229+
clipped = cm.shifted(0.5, mode="clip")
230+
assert clipped == Colormap([(0.5, "red"), (1, "blue")])
231+
232+
cm = Colormap("viridis")
233+
# forward and backward shifts should cancel out
234+
assert cm.shifted(0.5).shifted(-0.5) == cm
235+
# two shifts of 0.5 should give the original array
236+
assert cm.shifted().shifted() == cm

0 commit comments

Comments
 (0)