Skip to content

Commit 3331777

Browse files
authored
Merge pull request matplotlib#21352 from anntzer/hexbin
Refactor hexbin().
2 parents cfcf737 + ee206a1 commit 3331777

File tree

1 file changed

+100
-160
lines changed

1 file changed

+100
-160
lines changed

lib/matplotlib/axes/_axes.py

Lines changed: 100 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -4669,110 +4669,88 @@ def reduce_C_function(C: array) -> float
46694669
nx = gridsize
46704670
ny = int(nx / math.sqrt(3))
46714671
# Count the number of data in each hexagon
4672-
x = np.array(x, float)
4673-
y = np.array(y, float)
4672+
x = np.asarray(x, float)
4673+
y = np.asarray(y, float)
46744674

4675-
if marginals:
4676-
xorig = x.copy()
4677-
yorig = y.copy()
4675+
# Will be log()'d if necessary, and then rescaled.
4676+
tx = x
4677+
ty = y
46784678

46794679
if xscale == 'log':
46804680
if np.any(x <= 0.0):
4681-
raise ValueError("x contains non-positive values, so can not"
4682-
" be log-scaled")
4683-
x = np.log10(x)
4681+
raise ValueError("x contains non-positive values, so can not "
4682+
"be log-scaled")
4683+
tx = np.log10(tx)
46844684
if yscale == 'log':
46854685
if np.any(y <= 0.0):
4686-
raise ValueError("y contains non-positive values, so can not"
4687-
" be log-scaled")
4688-
y = np.log10(y)
4686+
raise ValueError("y contains non-positive values, so can not "
4687+
"be log-scaled")
4688+
ty = np.log10(ty)
46894689
if extent is not None:
46904690
xmin, xmax, ymin, ymax = extent
46914691
else:
4692-
xmin, xmax = (np.min(x), np.max(x)) if len(x) else (0, 1)
4693-
ymin, ymax = (np.min(y), np.max(y)) if len(y) else (0, 1)
4692+
xmin, xmax = (tx.min(), tx.max()) if len(x) else (0, 1)
4693+
ymin, ymax = (ty.min(), ty.max()) if len(y) else (0, 1)
46944694

46954695
# to avoid issues with singular data, expand the min/max pairs
46964696
xmin, xmax = mtransforms.nonsingular(xmin, xmax, expander=0.1)
46974697
ymin, ymax = mtransforms.nonsingular(ymin, ymax, expander=0.1)
46984698

4699+
nx1 = nx + 1
4700+
ny1 = ny + 1
4701+
nx2 = nx
4702+
ny2 = ny
4703+
n = nx1 * ny1 + nx2 * ny2
4704+
46994705
# In the x-direction, the hexagons exactly cover the region from
47004706
# xmin to xmax. Need some padding to avoid roundoff errors.
47014707
padding = 1.e-9 * (xmax - xmin)
47024708
xmin -= padding
47034709
xmax += padding
47044710
sx = (xmax - xmin) / nx
47054711
sy = (ymax - ymin) / ny
4706-
4707-
x = (x - xmin) / sx
4708-
y = (y - ymin) / sy
4709-
ix1 = np.round(x).astype(int)
4710-
iy1 = np.round(y).astype(int)
4711-
ix2 = np.floor(x).astype(int)
4712-
iy2 = np.floor(y).astype(int)
4713-
4714-
nx1 = nx + 1
4715-
ny1 = ny + 1
4716-
nx2 = nx
4717-
ny2 = ny
4718-
n = nx1 * ny1 + nx2 * ny2
4719-
4720-
d1 = (x - ix1) ** 2 + 3.0 * (y - iy1) ** 2
4721-
d2 = (x - ix2 - 0.5) ** 2 + 3.0 * (y - iy2 - 0.5) ** 2
4712+
# Positions in hexagon index coordinates.
4713+
ix = (tx - xmin) / sx
4714+
iy = (ty - ymin) / sy
4715+
ix1 = np.round(ix).astype(int)
4716+
iy1 = np.round(iy).astype(int)
4717+
ix2 = np.floor(ix).astype(int)
4718+
iy2 = np.floor(iy).astype(int)
4719+
# flat indices, plus one so that out-of-range points go to position 0.
4720+
i1 = np.where((0 <= ix1) & (ix1 < nx1) & (0 <= iy1) & (iy1 < ny1),
4721+
ix1 * ny1 + iy1 + 1, 0)
4722+
i2 = np.where((0 <= ix2) & (ix2 < nx2) & (0 <= iy2) & (iy2 < ny2),
4723+
ix2 * ny2 + iy2 + 1, 0)
4724+
4725+
d1 = (ix - ix1) ** 2 + 3.0 * (iy - iy1) ** 2
4726+
d2 = (ix - ix2 - 0.5) ** 2 + 3.0 * (iy - iy2 - 0.5) ** 2
47224727
bdist = (d1 < d2)
4723-
if C is None:
4724-
lattice1 = np.zeros((nx1, ny1))
4725-
lattice2 = np.zeros((nx2, ny2))
4726-
c1 = (0 <= ix1) & (ix1 < nx1) & (0 <= iy1) & (iy1 < ny1) & bdist
4727-
c2 = (0 <= ix2) & (ix2 < nx2) & (0 <= iy2) & (iy2 < ny2) & ~bdist
4728-
np.add.at(lattice1, (ix1[c1], iy1[c1]), 1)
4729-
np.add.at(lattice2, (ix2[c2], iy2[c2]), 1)
4730-
if mincnt is not None:
4731-
lattice1[lattice1 < mincnt] = np.nan
4732-
lattice2[lattice2 < mincnt] = np.nan
4733-
accum = np.concatenate([lattice1.ravel(), lattice2.ravel()])
4734-
good_idxs = ~np.isnan(accum)
47354728

4729+
if C is None: # [1:] drops out-of-range points.
4730+
counts1 = np.bincount(i1[bdist], minlength=1 + nx1 * ny1)[1:]
4731+
counts2 = np.bincount(i2[~bdist], minlength=1 + nx2 * ny2)[1:]
4732+
accum = np.concatenate([counts1, counts2]).astype(float)
4733+
if mincnt is not None:
4734+
accum[accum < mincnt] = np.nan
4735+
C = np.ones(len(x))
47364736
else:
4737-
if mincnt is None:
4738-
mincnt = 0
4739-
4740-
# create accumulation arrays
4741-
lattice1 = np.empty((nx1, ny1), dtype=object)
4742-
for i in range(nx1):
4743-
for j in range(ny1):
4744-
lattice1[i, j] = []
4745-
lattice2 = np.empty((nx2, ny2), dtype=object)
4746-
for i in range(nx2):
4747-
for j in range(ny2):
4748-
lattice2[i, j] = []
4749-
4737+
# store the C values in a list per hexagon index
4738+
Cs_at_i1 = [[] for _ in range(1 + nx1 * ny1)]
4739+
Cs_at_i2 = [[] for _ in range(1 + nx2 * ny2)]
47504740
for i in range(len(x)):
47514741
if bdist[i]:
4752-
if 0 <= ix1[i] < nx1 and 0 <= iy1[i] < ny1:
4753-
lattice1[ix1[i], iy1[i]].append(C[i])
4742+
Cs_at_i1[i1[i]].append(C[i])
47544743
else:
4755-
if 0 <= ix2[i] < nx2 and 0 <= iy2[i] < ny2:
4756-
lattice2[ix2[i], iy2[i]].append(C[i])
4757-
4758-
for i in range(nx1):
4759-
for j in range(ny1):
4760-
vals = lattice1[i, j]
4761-
if len(vals) > mincnt:
4762-
lattice1[i, j] = reduce_C_function(vals)
4763-
else:
4764-
lattice1[i, j] = np.nan
4765-
for i in range(nx2):
4766-
for j in range(ny2):
4767-
vals = lattice2[i, j]
4768-
if len(vals) > mincnt:
4769-
lattice2[i, j] = reduce_C_function(vals)
4770-
else:
4771-
lattice2[i, j] = np.nan
4744+
Cs_at_i2[i2[i]].append(C[i])
4745+
if mincnt is None:
4746+
mincnt = 0
4747+
accum = np.array(
4748+
[reduce_C_function(acc) if len(acc) > mincnt else np.nan
4749+
for Cs_at_i in [Cs_at_i1, Cs_at_i2]
4750+
for acc in Cs_at_i[1:]], # [1:] drops out-of-range points.
4751+
float)
47724752

4773-
accum = np.concatenate([lattice1.astype(float).ravel(),
4774-
lattice2.astype(float).ravel()])
4775-
good_idxs = ~np.isnan(accum)
4753+
good_idxs = ~np.isnan(accum)
47764754

47774755
offsets = np.zeros((n, 2), float)
47784756
offsets[:nx1 * ny1, 0] = np.repeat(np.arange(nx1), ny1)
@@ -4830,8 +4808,7 @@ def reduce_C_function(C: array) -> float
48304808
vmin = vmax = None
48314809
bins = None
48324810

4833-
# autoscale the norm with current accum values if it hasn't
4834-
# been set
4811+
# autoscale the norm with current accum values if it hasn't been set
48354812
if norm is not None:
48364813
if norm.vmin is None and norm.vmax is None:
48374814
norm.autoscale(accum)
@@ -4861,92 +4838,55 @@ def reduce_C_function(C: array) -> float
48614838
return collection
48624839

48634840
# Process marginals
4864-
if C is None:
4865-
C = np.ones(len(x))
4841+
bars = []
4842+
for zname, z, zmin, zmax, zscale, nbins in [
4843+
("x", x, xmin, xmax, xscale, nx),
4844+
("y", y, ymin, ymax, yscale, 2 * ny),
4845+
]:
48664846

4867-
def coarse_bin(x, y, bin_edges):
4868-
"""
4869-
Sort x-values into bins defined by *bin_edges*, then for all the
4870-
corresponding y-values in each bin use *reduce_c_function* to
4871-
compute the bin value.
4872-
"""
4873-
nbins = len(bin_edges) - 1
4874-
# Sort x-values into bins
4875-
bin_idxs = np.searchsorted(bin_edges, x) - 1
4876-
mus = np.zeros(nbins) * np.nan
4847+
if zscale == "log":
4848+
bin_edges = np.geomspace(zmin, zmax, nbins + 1)
4849+
else:
4850+
bin_edges = np.linspace(zmin, zmax, nbins + 1)
4851+
4852+
verts = np.empty((nbins, 4, 2))
4853+
verts[:, 0, 0] = verts[:, 1, 0] = bin_edges[:-1]
4854+
verts[:, 2, 0] = verts[:, 3, 0] = bin_edges[1:]
4855+
verts[:, 0, 1] = verts[:, 3, 1] = .00
4856+
verts[:, 1, 1] = verts[:, 2, 1] = .05
4857+
if zname == "y":
4858+
verts = verts[:, :, ::-1] # Swap x and y.
4859+
4860+
# Sort z-values into bins defined by bin_edges.
4861+
bin_idxs = np.searchsorted(bin_edges, z) - 1
4862+
values = np.empty(nbins)
48774863
for i in range(nbins):
4878-
# Get y-values for each bin
4879-
yi = y[bin_idxs == i]
4880-
if len(yi) > 0:
4881-
mus[i] = reduce_C_function(yi)
4882-
return mus
4883-
4884-
if xscale == 'log':
4885-
bin_edges = np.geomspace(xmin, xmax, nx + 1)
4886-
else:
4887-
bin_edges = np.linspace(xmin, xmax, nx + 1)
4888-
xcoarse = coarse_bin(xorig, C, bin_edges)
4889-
4890-
verts, values = [], []
4891-
for bin_left, bin_right, val in zip(
4892-
bin_edges[:-1], bin_edges[1:], xcoarse):
4893-
if np.isnan(val):
4894-
continue
4895-
verts.append([(bin_left, 0),
4896-
(bin_left, 0.05),
4897-
(bin_right, 0.05),
4898-
(bin_right, 0)])
4899-
values.append(val)
4900-
4901-
values = np.array(values)
4902-
trans = self.get_xaxis_transform(which='grid')
4903-
4904-
hbar = mcoll.PolyCollection(verts, transform=trans, edgecolors='face')
4905-
4906-
hbar.set_array(values)
4907-
hbar.set_cmap(cmap)
4908-
hbar.set_norm(norm)
4909-
hbar.set_alpha(alpha)
4910-
hbar.update(kwargs)
4911-
self.add_collection(hbar, autolim=False)
4912-
4913-
if yscale == 'log':
4914-
bin_edges = np.geomspace(ymin, ymax, 2 * ny + 1)
4915-
else:
4916-
bin_edges = np.linspace(ymin, ymax, 2 * ny + 1)
4917-
ycoarse = coarse_bin(yorig, C, bin_edges)
4918-
4919-
verts, values = [], []
4920-
for bin_bottom, bin_top, val in zip(
4921-
bin_edges[:-1], bin_edges[1:], ycoarse):
4922-
if np.isnan(val):
4923-
continue
4924-
verts.append([(0, bin_bottom),
4925-
(0, bin_top),
4926-
(0.05, bin_top),
4927-
(0.05, bin_bottom)])
4928-
values.append(val)
4929-
4930-
values = np.array(values)
4931-
4932-
trans = self.get_yaxis_transform(which='grid')
4933-
4934-
vbar = mcoll.PolyCollection(verts, transform=trans, edgecolors='face')
4935-
vbar.set_array(values)
4936-
vbar.set_cmap(cmap)
4937-
vbar.set_norm(norm)
4938-
vbar.set_alpha(alpha)
4939-
vbar.update(kwargs)
4940-
self.add_collection(vbar, autolim=False)
4941-
4942-
collection.hbar = hbar
4943-
collection.vbar = vbar
4864+
# Get C-values for each bin, and compute bin value with
4865+
# reduce_C_function.
4866+
ci = C[bin_idxs == i]
4867+
values[i] = reduce_C_function(ci) if len(ci) > 0 else np.nan
4868+
4869+
mask = ~np.isnan(values)
4870+
verts = verts[mask]
4871+
values = values[mask]
4872+
4873+
trans = getattr(self, f"get_{zname}axis_transform")(which="grid")
4874+
bar = mcoll.PolyCollection(
4875+
verts, transform=trans, edgecolors="face")
4876+
bar.set_array(values)
4877+
bar.set_cmap(cmap)
4878+
bar.set_norm(norm)
4879+
bar.set_alpha(alpha)
4880+
bar.update(kwargs)
4881+
bars.append(self.add_collection(bar, autolim=False))
4882+
4883+
collection.hbar, collection.vbar = bars
49444884

49454885
def on_changed(collection):
4946-
hbar.set_cmap(collection.get_cmap())
4947-
hbar.set_clim(collection.get_clim())
4948-
vbar.set_cmap(collection.get_cmap())
4949-
vbar.set_clim(collection.get_clim())
4886+
collection.hbar.set_cmap(collection.get_cmap())
4887+
collection.hbar.set_cmap(collection.get_cmap())
4888+
collection.vbar.set_clim(collection.get_clim())
4889+
collection.vbar.set_clim(collection.get_clim())
49504890

49514891
collection.callbacks.connect('changed', on_changed)
49524892

0 commit comments

Comments
 (0)