@@ -4607,110 +4607,88 @@ def reduce_C_function(C: array) -> float
46074607 nx = gridsize
46084608 ny = int (nx / math .sqrt (3 ))
46094609 # Count the number of data in each hexagon
4610- x = np .array (x , float )
4611- y = np .array (y , float )
4610+ x = np .asarray (x , float )
4611+ y = np .asarray (y , float )
46124612
4613- if marginals :
4614- xorig = x . copy ()
4615- yorig = y . copy ()
4613+ # Will be log()'d if necessary, and then rescaled.
4614+ tx = x
4615+ ty = y
46164616
46174617 if xscale == 'log' :
46184618 if np .any (x <= 0.0 ):
4619- raise ValueError ("x contains non-positive values, so can not"
4620- " be log-scaled" )
4621- x = np .log10 (x )
4619+ raise ValueError ("x contains non-positive values, so can not "
4620+ "be log-scaled" )
4621+ tx = np .log10 (tx )
46224622 if yscale == 'log' :
46234623 if np .any (y <= 0.0 ):
4624- raise ValueError ("y contains non-positive values, so can not"
4625- " be log-scaled" )
4626- y = np .log10 (y )
4624+ raise ValueError ("y contains non-positive values, so can not "
4625+ "be log-scaled" )
4626+ ty = np .log10 (ty )
46274627 if extent is not None :
46284628 xmin , xmax , ymin , ymax = extent
46294629 else :
4630- xmin , xmax = (np .min (x ), np .max (x )) if len (x ) else (0 , 1 )
4631- ymin , ymax = (np .min (y ), np .max (y )) if len (y ) else (0 , 1 )
4630+ xmin , xmax = (tx .min (), tx .max ()) if len (x ) else (0 , 1 )
4631+ ymin , ymax = (ty .min (), ty .max ()) if len (y ) else (0 , 1 )
46324632
46334633 # to avoid issues with singular data, expand the min/max pairs
46344634 xmin , xmax = mtransforms .nonsingular (xmin , xmax , expander = 0.1 )
46354635 ymin , ymax = mtransforms .nonsingular (ymin , ymax , expander = 0.1 )
46364636
4637+ nx1 = nx + 1
4638+ ny1 = ny + 1
4639+ nx2 = nx
4640+ ny2 = ny
4641+ n = nx1 * ny1 + nx2 * ny2
4642+
46374643 # In the x-direction, the hexagons exactly cover the region from
46384644 # xmin to xmax. Need some padding to avoid roundoff errors.
46394645 padding = 1.e-9 * (xmax - xmin )
46404646 xmin -= padding
46414647 xmax += padding
46424648 sx = (xmax - xmin ) / nx
46434649 sy = (ymax - ymin ) / ny
4644-
4645- x = (x - xmin ) / sx
4646- y = (y - ymin ) / sy
4647- ix1 = np .round (x ).astype (int )
4648- iy1 = np .round (y ).astype (int )
4649- ix2 = np .floor (x ).astype (int )
4650- iy2 = np .floor (y ).astype (int )
4651-
4652- nx1 = nx + 1
4653- ny1 = ny + 1
4654- nx2 = nx
4655- ny2 = ny
4656- n = nx1 * ny1 + nx2 * ny2
4657-
4658- d1 = (x - ix1 ) ** 2 + 3.0 * (y - iy1 ) ** 2
4659- d2 = (x - ix2 - 0.5 ) ** 2 + 3.0 * (y - iy2 - 0.5 ) ** 2
4650+ # Positions in hexagon index coordinates.
4651+ ix = (tx - xmin ) / sx
4652+ iy = (ty - ymin ) / sy
4653+ ix1 = np .round (ix ).astype (int )
4654+ iy1 = np .round (iy ).astype (int )
4655+ ix2 = np .floor (ix ).astype (int )
4656+ iy2 = np .floor (iy ).astype (int )
4657+ # flat indices, plus one so that out-of-range points go to position 0.
4658+ i1 = np .where ((0 <= ix1 ) & (ix1 < nx1 ) & (0 <= iy1 ) & (iy1 < ny1 ),
4659+ ix1 * ny1 + iy1 + 1 , 0 )
4660+ i2 = np .where ((0 <= ix2 ) & (ix2 < nx2 ) & (0 <= iy2 ) & (iy2 < ny2 ),
4661+ ix2 * ny2 + iy2 + 1 , 0 )
4662+
4663+ d1 = (ix - ix1 ) ** 2 + 3.0 * (iy - iy1 ) ** 2
4664+ d2 = (ix - ix2 - 0.5 ) ** 2 + 3.0 * (iy - iy2 - 0.5 ) ** 2
46604665 bdist = (d1 < d2 )
4661- if C is None :
4662- lattice1 = np .zeros ((nx1 , ny1 ))
4663- lattice2 = np .zeros ((nx2 , ny2 ))
4664- c1 = (0 <= ix1 ) & (ix1 < nx1 ) & (0 <= iy1 ) & (iy1 < ny1 ) & bdist
4665- c2 = (0 <= ix2 ) & (ix2 < nx2 ) & (0 <= iy2 ) & (iy2 < ny2 ) & ~ bdist
4666- np .add .at (lattice1 , (ix1 [c1 ], iy1 [c1 ]), 1 )
4667- np .add .at (lattice2 , (ix2 [c2 ], iy2 [c2 ]), 1 )
4668- if mincnt is not None :
4669- lattice1 [lattice1 < mincnt ] = np .nan
4670- lattice2 [lattice2 < mincnt ] = np .nan
4671- accum = np .concatenate ([lattice1 .ravel (), lattice2 .ravel ()])
4672- good_idxs = ~ np .isnan (accum )
46734666
4667+ if C is None : # [1:] drops out-of-range points.
4668+ counts1 = np .bincount (i1 [bdist ], minlength = 1 + nx1 * ny1 )[1 :]
4669+ counts2 = np .bincount (i2 [~ bdist ], minlength = 1 + nx2 * ny2 )[1 :]
4670+ accum = np .concatenate ([counts1 , counts2 ]).astype (float )
4671+ if mincnt is not None :
4672+ accum [accum < mincnt ] = np .nan
4673+ C = np .ones (len (x ))
46744674 else :
4675- if mincnt is None :
4676- mincnt = 0
4677-
4678- # create accumulation arrays
4679- lattice1 = np .empty ((nx1 , ny1 ), dtype = object )
4680- for i in range (nx1 ):
4681- for j in range (ny1 ):
4682- lattice1 [i , j ] = []
4683- lattice2 = np .empty ((nx2 , ny2 ), dtype = object )
4684- for i in range (nx2 ):
4685- for j in range (ny2 ):
4686- lattice2 [i , j ] = []
4687-
4675+ # store the C values in a list per hexagon index
4676+ Cs_at_i1 = [[] for _ in range (1 + nx1 * ny1 )]
4677+ Cs_at_i2 = [[] for _ in range (1 + nx2 * ny2 )]
46884678 for i in range (len (x )):
46894679 if bdist [i ]:
4690- if 0 <= ix1 [i ] < nx1 and 0 <= iy1 [i ] < ny1 :
4691- lattice1 [ix1 [i ], iy1 [i ]].append (C [i ])
4680+ Cs_at_i1 [i1 [i ]].append (C [i ])
46924681 else :
4693- if 0 <= ix2 [i ] < nx2 and 0 <= iy2 [i ] < ny2 :
4694- lattice2 [ix2 [i ], iy2 [i ]].append (C [i ])
4695-
4696- for i in range (nx1 ):
4697- for j in range (ny1 ):
4698- vals = lattice1 [i , j ]
4699- if len (vals ) > mincnt :
4700- lattice1 [i , j ] = reduce_C_function (vals )
4701- else :
4702- lattice1 [i , j ] = np .nan
4703- for i in range (nx2 ):
4704- for j in range (ny2 ):
4705- vals = lattice2 [i , j ]
4706- if len (vals ) > mincnt :
4707- lattice2 [i , j ] = reduce_C_function (vals )
4708- else :
4709- lattice2 [i , j ] = np .nan
4682+ Cs_at_i2 [i2 [i ]].append (C [i ])
4683+ if mincnt is None :
4684+ mincnt = 0
4685+ accum = np .array (
4686+ [reduce_C_function (acc ) if len (acc ) > mincnt else np .nan
4687+ for Cs_at_i in [Cs_at_i1 , Cs_at_i2 ]
4688+ for acc in Cs_at_i [1 :]], # [1:] drops out-of-range points.
4689+ float )
47104690
4711- accum = np .concatenate ([lattice1 .astype (float ).ravel (),
4712- lattice2 .astype (float ).ravel ()])
4713- good_idxs = ~ np .isnan (accum )
4691+ good_idxs = ~ np .isnan (accum )
47144692
47154693 offsets = np .zeros ((n , 2 ), float )
47164694 offsets [:nx1 * ny1 , 0 ] = np .repeat (np .arange (nx1 ), ny1 )
@@ -4767,8 +4745,7 @@ def reduce_C_function(C: array) -> float
47674745 vmin = vmax = None
47684746 bins = None
47694747
4770- # autoscale the norm with current accum values if it hasn't
4771- # been set
4748+ # autoscale the norm with current accum values if it hasn't been set
47724749 if norm is not None :
47734750 if norm .vmin is None and norm .vmax is None :
47744751 norm .autoscale (accum )
@@ -4798,92 +4775,55 @@ def reduce_C_function(C: array) -> float
47984775 return collection
47994776
48004777 # Process marginals
4801- if C is None :
4802- C = np .ones (len (x ))
4778+ bars = []
4779+ for zname , z , zmin , zmax , zscale , nbins in [
4780+ ("x" , x , xmin , xmax , xscale , nx ),
4781+ ("y" , y , ymin , ymax , yscale , 2 * ny ),
4782+ ]:
48034783
4804- def coarse_bin (x , y , bin_edges ):
4805- """
4806- Sort x-values into bins defined by *bin_edges*, then for all the
4807- corresponding y-values in each bin use *reduce_c_function* to
4808- compute the bin value.
4809- """
4810- nbins = len (bin_edges ) - 1
4811- # Sort x-values into bins
4812- bin_idxs = np .searchsorted (bin_edges , x ) - 1
4813- mus = np .zeros (nbins ) * np .nan
4784+ if zscale == "log" :
4785+ bin_edges = np .geomspace (zmin , zmax , nbins + 1 )
4786+ else :
4787+ bin_edges = np .linspace (zmin , zmax , nbins + 1 )
4788+
4789+ verts = np .empty ((nbins , 4 , 2 ))
4790+ verts [:, 0 , 0 ] = verts [:, 1 , 0 ] = bin_edges [:- 1 ]
4791+ verts [:, 2 , 0 ] = verts [:, 3 , 0 ] = bin_edges [1 :]
4792+ verts [:, 0 , 1 ] = verts [:, 3 , 1 ] = .00
4793+ verts [:, 1 , 1 ] = verts [:, 2 , 1 ] = .05
4794+ if zname == "y" :
4795+ verts = verts [:, :, ::- 1 ] # Swap x and y.
4796+
4797+ # Sort z-values into bins defined by bin_edges.
4798+ bin_idxs = np .searchsorted (bin_edges , z ) - 1
4799+ values = np .empty (nbins )
48144800 for i in range (nbins ):
4815- # Get y-values for each bin
4816- yi = y [bin_idxs == i ]
4817- if len (yi ) > 0 :
4818- mus [i ] = reduce_C_function (yi )
4819- return mus
4820-
4821- if xscale == 'log' :
4822- bin_edges = np .geomspace (xmin , xmax , nx + 1 )
4823- else :
4824- bin_edges = np .linspace (xmin , xmax , nx + 1 )
4825- xcoarse = coarse_bin (xorig , C , bin_edges )
4826-
4827- verts , values = [], []
4828- for bin_left , bin_right , val in zip (
4829- bin_edges [:- 1 ], bin_edges [1 :], xcoarse ):
4830- if np .isnan (val ):
4831- continue
4832- verts .append ([(bin_left , 0 ),
4833- (bin_left , 0.05 ),
4834- (bin_right , 0.05 ),
4835- (bin_right , 0 )])
4836- values .append (val )
4837-
4838- values = np .array (values )
4839- trans = self .get_xaxis_transform (which = 'grid' )
4840-
4841- hbar = mcoll .PolyCollection (verts , transform = trans , edgecolors = 'face' )
4842-
4843- hbar .set_array (values )
4844- hbar .set_cmap (cmap )
4845- hbar .set_norm (norm )
4846- hbar .set_alpha (alpha )
4847- hbar .update (kwargs )
4848- self .add_collection (hbar , autolim = False )
4849-
4850- if yscale == 'log' :
4851- bin_edges = np .geomspace (ymin , ymax , 2 * ny + 1 )
4852- else :
4853- bin_edges = np .linspace (ymin , ymax , 2 * ny + 1 )
4854- ycoarse = coarse_bin (yorig , C , bin_edges )
4855-
4856- verts , values = [], []
4857- for bin_bottom , bin_top , val in zip (
4858- bin_edges [:- 1 ], bin_edges [1 :], ycoarse ):
4859- if np .isnan (val ):
4860- continue
4861- verts .append ([(0 , bin_bottom ),
4862- (0 , bin_top ),
4863- (0.05 , bin_top ),
4864- (0.05 , bin_bottom )])
4865- values .append (val )
4866-
4867- values = np .array (values )
4868-
4869- trans = self .get_yaxis_transform (which = 'grid' )
4870-
4871- vbar = mcoll .PolyCollection (verts , transform = trans , edgecolors = 'face' )
4872- vbar .set_array (values )
4873- vbar .set_cmap (cmap )
4874- vbar .set_norm (norm )
4875- vbar .set_alpha (alpha )
4876- vbar .update (kwargs )
4877- self .add_collection (vbar , autolim = False )
4878-
4879- collection .hbar = hbar
4880- collection .vbar = vbar
4801+ # Get C-values for each bin, and compute bin value with
4802+ # reduce_C_function.
4803+ ci = C [bin_idxs == i ]
4804+ values [i ] = reduce_C_function (ci ) if len (ci ) > 0 else np .nan
4805+
4806+ mask = ~ np .isnan (values )
4807+ verts = verts [mask ]
4808+ values = values [mask ]
4809+
4810+ trans = getattr (self , f"get_{ zname } axis_transform" )(which = "grid" )
4811+ bar = mcoll .PolyCollection (
4812+ verts , transform = trans , edgecolors = "face" )
4813+ bar .set_array (values )
4814+ bar .set_cmap (cmap )
4815+ bar .set_norm (norm )
4816+ bar .set_alpha (alpha )
4817+ bar .update (kwargs )
4818+ bars .append (self .add_collection (bar , autolim = False ))
4819+
4820+ collection .hbar , collection .vbar = bars
48814821
48824822 def on_changed (collection ):
4883- hbar .set_cmap (collection .get_cmap ())
4884- hbar .set_clim (collection .get_clim ())
4885- vbar .set_cmap (collection .get_cmap ())
4886- vbar .set_clim (collection .get_clim ())
4823+ collection . hbar .set_cmap (collection .get_cmap ())
4824+ collection . hbar .set_cmap (collection .get_cmap ())
4825+ collection . vbar .set_clim (collection .get_clim ())
4826+ collection . vbar .set_clim (collection .get_clim ())
48874827
48884828 collection .callbacks .connect ('changed' , on_changed )
48894829
0 commit comments