Skip to content

Commit 65eff09

Browse files
committed
Reworked AsinhLocator to allow rounding on arbitrary number base
1 parent 80f2600 commit 65eff09

File tree

2 files changed

+48
-20
lines changed

2 files changed

+48
-20
lines changed

lib/matplotlib/scale.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -464,6 +464,8 @@ class AsinhTransform(Transform):
464464

465465
def __init__(self, linear_width):
466466
super().__init__()
467+
if linear_width <= 0.0:
468+
raise ValueError("Scale parameter 'a0' must be strictly positive")
467469
self.linear_width = linear_width
468470

469471
def transform_non_affine(self, a):
@@ -509,7 +511,8 @@ class AsinhScale(ScaleBase):
509511

510512
name = 'asinh'
511513

512-
def __init__(self, axis, *, linear_width=1.0, **kwargs):
514+
def __init__(self, axis, *, linear_width=1.0,
515+
base=10, subs=(2, 5), **kwargs):
513516
"""
514517
Parameters
515518
----------
@@ -520,18 +523,26 @@ def __init__(self, axis, *, linear_width=1.0, **kwargs):
520523
becomes asympotically logarithmic.
521524
"""
522525
super().__init__(axis)
523-
if linear_width <= 0.0:
524-
raise ValueError("Scale parameter 'a0' must be strictly positive")
525-
self.linear_width = linear_width
526+
self._transform = AsinhTransform(linear_width)
527+
self._base = int(base)
528+
self._subs = subs
529+
530+
linear_width = property(lambda self: self._transform.linear_width)
526531

527532
def get_transform(self):
528-
return AsinhTransform(self.linear_width)
533+
return self._transform
529534

530535
def set_default_locators_and_formatters(self, axis):
531-
axis.set(major_locator=AsinhLocator(self.linear_width),
532-
minor_locator=AutoLocator(),
533-
major_formatter='{x:.3g}',
536+
axis.set(major_locator=AsinhLocator(self.linear_width,
537+
base=self._base),
538+
minor_locator=AsinhLocator(self.linear_width,
539+
base=self._base,
540+
subs=self._subs),
534541
minor_formatter=NullFormatter())
542+
if self._base > 1:
543+
axis.set_major_formatter(LogFormatterSciNotation(self._base))
544+
else:
545+
axis.set_major_formatter('{x:.3g}'),
535546

536547

537548
class LogitTransform(Transform):

lib/matplotlib/ticker.py

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2592,7 +2592,8 @@ class AsinhLocator(Locator):
25922592
This is very unlikely to have any use beyond
25932593
the `~.scale.AsinhScale` class.
25942594
"""
2595-
def __init__(self, linear_width, numticks=11, symthresh=0.2):
2595+
def __init__(self, linear_width, numticks=11, symthresh=0.2,
2596+
base=0, subs=None):
25962597
"""
25972598
Parameters
25982599
----------
@@ -2611,6 +2612,8 @@ def __init__(self, linear_width, numticks=11, symthresh=0.2):
26112612
self.linear_width = linear_width
26122613
self.numticks = numticks
26132614
self.symthresh = symthresh
2615+
self.base = base
2616+
self.subs = subs
26142617

26152618
def set_params(self, numticks=None, symthresh=None):
26162619
"""Set parameters within this locator."""
@@ -2642,19 +2645,33 @@ def tick_values(self, vmin, vmax):
26422645

26432646
# Transform the "on-screen" grid to the data space:
26442647
xs = self.linear_width * np.sinh(ys / self.linear_width)
2645-
zero_xs = (xs == 0)
2646-
2647-
# Round the data-space values to be intuitive decimal numbers:
2648-
decades = (
2649-
np.where(xs >= 0, 1, -1) *
2650-
np.power(10, np.where(zero_xs, 0.0,
2651-
np.floor(np.log10(np.abs(xs)
2652-
+ zero_xs*1e-6))))
2653-
)
2654-
qs = decades * np.round(xs / decades)
2648+
zero_xs = (ys == 0)
2649+
2650+
# Round the data-space values to be intuitive base-n numbers:
2651+
if self.base > 1:
2652+
log_base = math.log(self.base)
2653+
powers = (
2654+
np.where(zero_xs, 0, np.where(xs >=0, 1, -1)) *
2655+
np.power(self.base,
2656+
np.where(zero_xs, 0.0,
2657+
np.floor(np.log(np.abs(xs) + zero_xs*1e-6)
2658+
/ log_base)))
2659+
)
2660+
if self.subs:
2661+
qs = np.outer(powers, self.subs).flatten()
2662+
else:
2663+
qs = powers
2664+
else:
2665+
powers = (
2666+
np.where(xs >= 0, 1, -1) *
2667+
np.power(10, np.where(zero_xs, 0.0,
2668+
np.floor(np.log10(np.abs(xs)
2669+
+ zero_xs*1e-6))))
2670+
)
2671+
qs = powers * np.round(xs / powers)
26552672
ticks = np.array(sorted(set(qs)))
26562673

2657-
if len(ticks) > self.numticks // 2:
2674+
if len(ticks) >= 2:
26582675
return ticks
26592676
else:
26602677
return np.linspace(vmin, vmax, self.numticks)

0 commit comments

Comments
 (0)