@@ -4669,110 +4669,88 @@ def reduce_C_function(C: array) -> float
46694669 nx = gridsize
46704670 ny = int (nx / math .sqrt (3 ))
46714671 # Count the number of data in each hexagon
4672- x = np .array (x , float )
4673- y = np .array (y , float )
4672+ x = np .asarray (x , float )
4673+ y = np .asarray (y , float )
46744674
4675- if marginals :
4676- xorig = x . copy ()
4677- yorig = y . copy ()
4675+ # Will be log()'d if necessary, and then rescaled.
4676+ tx = x
4677+ ty = y
46784678
46794679 if xscale == 'log' :
46804680 if np .any (x <= 0.0 ):
4681- raise ValueError ("x contains non-positive values, so can not"
4682- " be log-scaled" )
4683- x = np .log10 (x )
4681+ raise ValueError ("x contains non-positive values, so can not "
4682+ "be log-scaled" )
4683+ tx = np .log10 (tx )
46844684 if yscale == 'log' :
46854685 if np .any (y <= 0.0 ):
4686- raise ValueError ("y contains non-positive values, so can not"
4687- " be log-scaled" )
4688- y = np .log10 (y )
4686+ raise ValueError ("y contains non-positive values, so can not "
4687+ "be log-scaled" )
4688+ ty = np .log10 (ty )
46894689 if extent is not None :
46904690 xmin , xmax , ymin , ymax = extent
46914691 else :
4692- xmin , xmax = (np .min (x ), np .max (x )) if len (x ) else (0 , 1 )
4693- ymin , ymax = (np .min (y ), np .max (y )) if len (y ) else (0 , 1 )
4692+ xmin , xmax = (tx .min (), tx .max ()) if len (x ) else (0 , 1 )
4693+ ymin , ymax = (ty .min (), ty .max ()) if len (y ) else (0 , 1 )
46944694
46954695 # to avoid issues with singular data, expand the min/max pairs
46964696 xmin , xmax = mtransforms .nonsingular (xmin , xmax , expander = 0.1 )
46974697 ymin , ymax = mtransforms .nonsingular (ymin , ymax , expander = 0.1 )
46984698
4699+ nx1 = nx + 1
4700+ ny1 = ny + 1
4701+ nx2 = nx
4702+ ny2 = ny
4703+ n = nx1 * ny1 + nx2 * ny2
4704+
46994705 # In the x-direction, the hexagons exactly cover the region from
47004706 # xmin to xmax. Need some padding to avoid roundoff errors.
47014707 padding = 1.e-9 * (xmax - xmin )
47024708 xmin -= padding
47034709 xmax += padding
47044710 sx = (xmax - xmin ) / nx
47054711 sy = (ymax - ymin ) / ny
4706-
4707- x = (x - xmin ) / sx
4708- y = (y - ymin ) / sy
4709- ix1 = np .round (x ).astype (int )
4710- iy1 = np .round (y ).astype (int )
4711- ix2 = np .floor (x ).astype (int )
4712- iy2 = np .floor (y ).astype (int )
4713-
4714- nx1 = nx + 1
4715- ny1 = ny + 1
4716- nx2 = nx
4717- ny2 = ny
4718- n = nx1 * ny1 + nx2 * ny2
4719-
4720- d1 = (x - ix1 ) ** 2 + 3.0 * (y - iy1 ) ** 2
4721- d2 = (x - ix2 - 0.5 ) ** 2 + 3.0 * (y - iy2 - 0.5 ) ** 2
4712+ # Positions in hexagon index coordinates.
4713+ ix = (tx - xmin ) / sx
4714+ iy = (ty - ymin ) / sy
4715+ ix1 = np .round (ix ).astype (int )
4716+ iy1 = np .round (iy ).astype (int )
4717+ ix2 = np .floor (ix ).astype (int )
4718+ iy2 = np .floor (iy ).astype (int )
4719+ # flat indices, plus one so that out-of-range points go to position 0.
4720+ i1 = np .where ((0 <= ix1 ) & (ix1 < nx1 ) & (0 <= iy1 ) & (iy1 < ny1 ),
4721+ ix1 * ny1 + iy1 + 1 , 0 )
4722+ i2 = np .where ((0 <= ix2 ) & (ix2 < nx2 ) & (0 <= iy2 ) & (iy2 < ny2 ),
4723+ ix2 * ny2 + iy2 + 1 , 0 )
4724+
4725+ d1 = (ix - ix1 ) ** 2 + 3.0 * (iy - iy1 ) ** 2
4726+ d2 = (ix - ix2 - 0.5 ) ** 2 + 3.0 * (iy - iy2 - 0.5 ) ** 2
47224727 bdist = (d1 < d2 )
4723- if C is None :
4724- lattice1 = np .zeros ((nx1 , ny1 ))
4725- lattice2 = np .zeros ((nx2 , ny2 ))
4726- c1 = (0 <= ix1 ) & (ix1 < nx1 ) & (0 <= iy1 ) & (iy1 < ny1 ) & bdist
4727- c2 = (0 <= ix2 ) & (ix2 < nx2 ) & (0 <= iy2 ) & (iy2 < ny2 ) & ~ bdist
4728- np .add .at (lattice1 , (ix1 [c1 ], iy1 [c1 ]), 1 )
4729- np .add .at (lattice2 , (ix2 [c2 ], iy2 [c2 ]), 1 )
4730- if mincnt is not None :
4731- lattice1 [lattice1 < mincnt ] = np .nan
4732- lattice2 [lattice2 < mincnt ] = np .nan
4733- accum = np .concatenate ([lattice1 .ravel (), lattice2 .ravel ()])
4734- good_idxs = ~ np .isnan (accum )
47354728
4729+ if C is None : # [1:] drops out-of-range points.
4730+ counts1 = np .bincount (i1 [bdist ], minlength = 1 + nx1 * ny1 )[1 :]
4731+ counts2 = np .bincount (i2 [~ bdist ], minlength = 1 + nx2 * ny2 )[1 :]
4732+ accum = np .concatenate ([counts1 , counts2 ]).astype (float )
4733+ if mincnt is not None :
4734+ accum [accum < mincnt ] = np .nan
4735+ C = np .ones (len (x ))
47364736 else :
4737- if mincnt is None :
4738- mincnt = 0
4739-
4740- # create accumulation arrays
4741- lattice1 = np .empty ((nx1 , ny1 ), dtype = object )
4742- for i in range (nx1 ):
4743- for j in range (ny1 ):
4744- lattice1 [i , j ] = []
4745- lattice2 = np .empty ((nx2 , ny2 ), dtype = object )
4746- for i in range (nx2 ):
4747- for j in range (ny2 ):
4748- lattice2 [i , j ] = []
4749-
4737+ # store the C values in a list per hexagon index
4738+ Cs_at_i1 = [[] for _ in range (1 + nx1 * ny1 )]
4739+ Cs_at_i2 = [[] for _ in range (1 + nx2 * ny2 )]
47504740 for i in range (len (x )):
47514741 if bdist [i ]:
4752- if 0 <= ix1 [i ] < nx1 and 0 <= iy1 [i ] < ny1 :
4753- lattice1 [ix1 [i ], iy1 [i ]].append (C [i ])
4742+ Cs_at_i1 [i1 [i ]].append (C [i ])
47544743 else :
4755- if 0 <= ix2 [i ] < nx2 and 0 <= iy2 [i ] < ny2 :
4756- lattice2 [ix2 [i ], iy2 [i ]].append (C [i ])
4757-
4758- for i in range (nx1 ):
4759- for j in range (ny1 ):
4760- vals = lattice1 [i , j ]
4761- if len (vals ) > mincnt :
4762- lattice1 [i , j ] = reduce_C_function (vals )
4763- else :
4764- lattice1 [i , j ] = np .nan
4765- for i in range (nx2 ):
4766- for j in range (ny2 ):
4767- vals = lattice2 [i , j ]
4768- if len (vals ) > mincnt :
4769- lattice2 [i , j ] = reduce_C_function (vals )
4770- else :
4771- lattice2 [i , j ] = np .nan
4744+ Cs_at_i2 [i2 [i ]].append (C [i ])
4745+ if mincnt is None :
4746+ mincnt = 0
4747+ accum = np .array (
4748+ [reduce_C_function (acc ) if len (acc ) > mincnt else np .nan
4749+ for Cs_at_i in [Cs_at_i1 , Cs_at_i2 ]
4750+ for acc in Cs_at_i [1 :]], # [1:] drops out-of-range points.
4751+ float )
47724752
4773- accum = np .concatenate ([lattice1 .astype (float ).ravel (),
4774- lattice2 .astype (float ).ravel ()])
4775- good_idxs = ~ np .isnan (accum )
4753+ good_idxs = ~ np .isnan (accum )
47764754
47774755 offsets = np .zeros ((n , 2 ), float )
47784756 offsets [:nx1 * ny1 , 0 ] = np .repeat (np .arange (nx1 ), ny1 )
@@ -4830,8 +4808,7 @@ def reduce_C_function(C: array) -> float
48304808 vmin = vmax = None
48314809 bins = None
48324810
4833- # autoscale the norm with current accum values if it hasn't
4834- # been set
4811+ # autoscale the norm with current accum values if it hasn't been set
48354812 if norm is not None :
48364813 if norm .vmin is None and norm .vmax is None :
48374814 norm .autoscale (accum )
@@ -4861,92 +4838,55 @@ def reduce_C_function(C: array) -> float
48614838 return collection
48624839
48634840 # Process marginals
4864- if C is None :
4865- C = np .ones (len (x ))
4841+ bars = []
4842+ for zname , z , zmin , zmax , zscale , nbins in [
4843+ ("x" , x , xmin , xmax , xscale , nx ),
4844+ ("y" , y , ymin , ymax , yscale , 2 * ny ),
4845+ ]:
48664846
4867- def coarse_bin (x , y , bin_edges ):
4868- """
4869- Sort x-values into bins defined by *bin_edges*, then for all the
4870- corresponding y-values in each bin use *reduce_c_function* to
4871- compute the bin value.
4872- """
4873- nbins = len (bin_edges ) - 1
4874- # Sort x-values into bins
4875- bin_idxs = np .searchsorted (bin_edges , x ) - 1
4876- mus = np .zeros (nbins ) * np .nan
4847+ if zscale == "log" :
4848+ bin_edges = np .geomspace (zmin , zmax , nbins + 1 )
4849+ else :
4850+ bin_edges = np .linspace (zmin , zmax , nbins + 1 )
4851+
4852+ verts = np .empty ((nbins , 4 , 2 ))
4853+ verts [:, 0 , 0 ] = verts [:, 1 , 0 ] = bin_edges [:- 1 ]
4854+ verts [:, 2 , 0 ] = verts [:, 3 , 0 ] = bin_edges [1 :]
4855+ verts [:, 0 , 1 ] = verts [:, 3 , 1 ] = .00
4856+ verts [:, 1 , 1 ] = verts [:, 2 , 1 ] = .05
4857+ if zname == "y" :
4858+ verts = verts [:, :, ::- 1 ] # Swap x and y.
4859+
4860+ # Sort z-values into bins defined by bin_edges.
4861+ bin_idxs = np .searchsorted (bin_edges , z ) - 1
4862+ values = np .empty (nbins )
48774863 for i in range (nbins ):
4878- # Get y-values for each bin
4879- yi = y [bin_idxs == i ]
4880- if len (yi ) > 0 :
4881- mus [i ] = reduce_C_function (yi )
4882- return mus
4883-
4884- if xscale == 'log' :
4885- bin_edges = np .geomspace (xmin , xmax , nx + 1 )
4886- else :
4887- bin_edges = np .linspace (xmin , xmax , nx + 1 )
4888- xcoarse = coarse_bin (xorig , C , bin_edges )
4889-
4890- verts , values = [], []
4891- for bin_left , bin_right , val in zip (
4892- bin_edges [:- 1 ], bin_edges [1 :], xcoarse ):
4893- if np .isnan (val ):
4894- continue
4895- verts .append ([(bin_left , 0 ),
4896- (bin_left , 0.05 ),
4897- (bin_right , 0.05 ),
4898- (bin_right , 0 )])
4899- values .append (val )
4900-
4901- values = np .array (values )
4902- trans = self .get_xaxis_transform (which = 'grid' )
4903-
4904- hbar = mcoll .PolyCollection (verts , transform = trans , edgecolors = 'face' )
4905-
4906- hbar .set_array (values )
4907- hbar .set_cmap (cmap )
4908- hbar .set_norm (norm )
4909- hbar .set_alpha (alpha )
4910- hbar .update (kwargs )
4911- self .add_collection (hbar , autolim = False )
4912-
4913- if yscale == 'log' :
4914- bin_edges = np .geomspace (ymin , ymax , 2 * ny + 1 )
4915- else :
4916- bin_edges = np .linspace (ymin , ymax , 2 * ny + 1 )
4917- ycoarse = coarse_bin (yorig , C , bin_edges )
4918-
4919- verts , values = [], []
4920- for bin_bottom , bin_top , val in zip (
4921- bin_edges [:- 1 ], bin_edges [1 :], ycoarse ):
4922- if np .isnan (val ):
4923- continue
4924- verts .append ([(0 , bin_bottom ),
4925- (0 , bin_top ),
4926- (0.05 , bin_top ),
4927- (0.05 , bin_bottom )])
4928- values .append (val )
4929-
4930- values = np .array (values )
4931-
4932- trans = self .get_yaxis_transform (which = 'grid' )
4933-
4934- vbar = mcoll .PolyCollection (verts , transform = trans , edgecolors = 'face' )
4935- vbar .set_array (values )
4936- vbar .set_cmap (cmap )
4937- vbar .set_norm (norm )
4938- vbar .set_alpha (alpha )
4939- vbar .update (kwargs )
4940- self .add_collection (vbar , autolim = False )
4941-
4942- collection .hbar = hbar
4943- collection .vbar = vbar
4864+ # Get C-values for each bin, and compute bin value with
4865+ # reduce_C_function.
4866+ ci = C [bin_idxs == i ]
4867+ values [i ] = reduce_C_function (ci ) if len (ci ) > 0 else np .nan
4868+
4869+ mask = ~ np .isnan (values )
4870+ verts = verts [mask ]
4871+ values = values [mask ]
4872+
4873+ trans = getattr (self , f"get_{ zname } axis_transform" )(which = "grid" )
4874+ bar = mcoll .PolyCollection (
4875+ verts , transform = trans , edgecolors = "face" )
4876+ bar .set_array (values )
4877+ bar .set_cmap (cmap )
4878+ bar .set_norm (norm )
4879+ bar .set_alpha (alpha )
4880+ bar .update (kwargs )
4881+ bars .append (self .add_collection (bar , autolim = False ))
4882+
4883+ collection .hbar , collection .vbar = bars
49444884
49454885 def on_changed (collection ):
4946- hbar .set_cmap (collection .get_cmap ())
4947- hbar .set_clim (collection .get_clim ())
4948- vbar .set_cmap (collection .get_cmap ())
4949- vbar .set_clim (collection .get_clim ())
4886+ collection . hbar .set_cmap (collection .get_cmap ())
4887+ collection . hbar .set_cmap (collection .get_cmap ())
4888+ collection . vbar .set_clim (collection .get_clim ())
4889+ collection . vbar .set_clim (collection .get_clim ())
49504890
49514891 collection .callbacks .connect ('changed' , on_changed )
49524892
0 commit comments