@@ -92,6 +92,34 @@ def _add_pad(self, x_min, x_max, y_min, y_max):
9292 return x_min - dx , x_max + dx , y_min - dy , y_max + dy
9393
9494
95+ class _User2DTransform (Transform ):
96+ """A transform defined by two user-set functions."""
97+
98+ input_dims = output_dims = 2
99+
100+ def __init__ (self , forward , backward ):
101+ """
102+ Parameters
103+ ----------
104+ forward, backward : callable
105+ The forward and backward transforms, taking ``x`` and ``y`` as
106+ separate arguments and returning ``(tr_x, tr_y)``.
107+ """
108+ # The normal Matplotlib convention would be to take and return an
109+ # (N, 2) array but axisartist uses the transposed version.
110+ super ().__init__ ()
111+ self ._forward = forward
112+ self ._backward = backward
113+
114+ def transform_non_affine (self , values ):
115+ # docstring inherited
116+ return np .transpose (self ._forward (* np .transpose (values )))
117+
118+ def inverted (self ):
119+ # docstring inherited
120+ return type (self )(self ._backward , self ._forward )
121+
122+
95123class GridFinder :
96124 def __init__ (self ,
97125 transform ,
@@ -123,7 +151,7 @@ def __init__(self,
123151 self .grid_locator2 = grid_locator2
124152 self .tick_formatter1 = tick_formatter1
125153 self .tick_formatter2 = tick_formatter2
126- self .update_transform (transform )
154+ self .set_transform (transform )
127155
128156 def get_grid_info (self , x1 , y1 , x2 , y2 ):
129157 """
@@ -214,27 +242,26 @@ def _clip_grid_lines_and_find_ticks(self, lines, values, levs, bb):
214242
215243 return gi
216244
217- def update_transform (self , aux_trans ):
218- if not isinstance (aux_trans , Transform ) and len (aux_trans ) != 2 :
219- raise TypeError ("'aux_trans' must be either a Transform instance "
220- "or a pair of callables" )
221- self ._aux_transform = aux_trans
245+ def set_transform (self , aux_trans ):
246+ if isinstance (aux_trans , Transform ):
247+ self ._aux_transform = aux_trans
248+ elif len (aux_trans ) == 2 and all (map (callable , aux_trans )):
249+ self ._aux_transform = _User2DTransform (* aux_trans )
250+ else :
251+ raise TypeError ("'aux_trans' must be either a Transform "
252+ "instance or a pair of callables" )
253+
254+ def get_transform (self ):
255+ return self ._aux_transform
256+
257+ update_transform = set_transform # backcompat alias.
222258
223259 def transform_xy (self , x , y ):
224- aux_trf = self ._aux_transform
225- if isinstance (aux_trf , Transform ):
226- return aux_trf .transform (np .column_stack ([x , y ])).T
227- else :
228- transform_xy , inv_transform_xy = aux_trf
229- return transform_xy (x , y )
260+ return self ._aux_transform .transform (np .column_stack ([x , y ])).T
230261
231262 def inv_transform_xy (self , x , y ):
232- aux_trf = self ._aux_transform
233- if isinstance (aux_trf , Transform ):
234- return aux_trf .inverted ().transform (np .column_stack ([x , y ])).T
235- else :
236- transform_xy , inv_transform_xy = aux_trf
237- return inv_transform_xy (x , y )
263+ return self ._aux_transform .inverted ().transform (
264+ np .column_stack ([x , y ])).T
238265
239266 def update (self , ** kw ):
240267 for k in kw :
0 commit comments