Skip to content

Commit 554b002

Browse files
authored
Merge pull request #47 from campagnola/qmatrix-auto-black-formatting
style: apply black formatting for qmatrix
2 parents f22bdc2 + 98f52e8 commit 554b002

File tree

1 file changed

+70
-50
lines changed

1 file changed

+70
-50
lines changed

coorx/base_transform.py

Lines changed: 70 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class Transform(object):
3434
where the inverse mapping is ambiguous or otherwise meaningless.
3535
3636
"""
37+
3738
# Flags used to describe the transformation. Subclasses should define each
3839
# as True or False.
3940
# (usually used for making optimization decisions)
@@ -64,7 +65,13 @@ class Transform(object):
6465
# List of keys that will be saved and restored in __getstate__ and __setstate__
6566
state_keys = []
6667

67-
def __init__(self, dims:Dims=None, from_cs:StrOrNone=None, to_cs:StrOrNone=None, cs_graph:StrOrNone=None):
68+
def __init__(
69+
self,
70+
dims: Dims = None,
71+
from_cs: StrOrNone = None,
72+
to_cs: StrOrNone = None,
73+
cs_graph: StrOrNone = None,
74+
):
6875
if dims is None or np.isscalar(dims):
6976
dims = (dims, dims)
7077
if not isinstance(dims, tuple) or len(dims) != 2:
@@ -81,23 +88,22 @@ def __init__(self, dims:Dims=None, from_cs:StrOrNone=None, to_cs:StrOrNone=None,
8188

8289
@property
8390
def dims(self):
84-
"""Tuple holding the (input, output) dimensions for this transform.
85-
"""
91+
"""Tuple holding the (input, output) dimensions for this transform."""
8692
return self._dims
8793

88-
def _dims_from_params(self, params:dict, dims=None):
94+
def _dims_from_params(self, params: dict, dims=None):
8995
"""Determine dimensionality from parameters.
9096
9197
If *dims* is provided, then it must be a tuple (in, out) and it must agree with the length of all provided parameters.
9298
If *dims* is not provided, then it is determined from length of parameters, which must be in agreement.
9399
"""
94100
assert len(params) > 0
95-
inferred_dims = {k:(len(v), len(v)) for k,v in params.items() if v is not None}
101+
inferred_dims = {k: (len(v), len(v)) for k, v in params.items() if v is not None}
96102
if dims is not None:
97-
for k,v in inferred_dims.items():
103+
for k, v in inferred_dims.items():
98104
assert v == dims, f"Length of {k} ({len(dims[k])}) does not match dims {dims}"
99105
return dims
100-
106+
101107
if len(inferred_dims) == 0:
102108
msg = f"Could not determine dimensionality of transform. "
103109
param_names = ' '.join(list(params.keys()))
@@ -106,21 +112,24 @@ def _dims_from_params(self, params:dict, dims=None):
106112
else:
107113
msg += f"Specify dims or at least one of {param_names}."
108114
raise Exception(msg)
109-
115+
110116
keys = list(inferred_dims.keys())
111117
dims = inferred_dims[keys[0]]
112118
for k in keys[1:]:
113-
assert inferred_dims[k] == dims, f"Length of {k} ({len(params[k])}) does not match length of {keys[0]} ({len(params[keys[0]])})"
119+
assert (
120+
inferred_dims[k] == dims
121+
), f"Length of {k} ({len(params[k])}) does not match length of {keys[0]} ({len(params[keys[0]])})"
114122
return dims
115123

116124
@property
117125
def systems(self):
118-
"""The CoordinateSystem instances mapped from and to by this transform.
119-
"""
126+
"""The CoordinateSystem instances mapped from and to by this transform."""
120127
return self._systems
121128

122129
def set_systems(self, from_cs, to_cs, cs_graph=None):
123-
assert (from_cs is None) == (to_cs is None), "from_cs and to_cs must both be None or both be coordinate systems"
130+
assert (from_cs is None) == (
131+
to_cs is None
132+
), "from_cs and to_cs must both be None or both be coordinate systems"
124133
if from_cs is not None:
125134
if cs_graph is None and isinstance(from_cs, CoordinateSystem):
126135
cs_graph = from_cs.graph
@@ -146,7 +155,7 @@ def _map(self, arr):
146155
"""
147156
raise NotImplementedError
148157

149-
def imap(self, obj:Mappable):
158+
def imap(self, obj: Mappable):
150159
"""
151160
Return *obj* mapped through the inverse transformation.
152161
@@ -164,20 +173,20 @@ def _imap(self, arr):
164173
"""
165174
raise NotImplementedError
166175

167-
def _prepare_and_map(self, obj:Mappable):
176+
def _prepare_and_map(self, obj: Mappable):
168177
"""
169-
Convert a mappable object to a 2D numpy array, pass it through this Transform's _map method,
170-
then convert and return the result.
171-
178+
Convert a mappable object to a 2D numpy array, pass it through this Transform's _map method,
179+
then convert and return the result.
180+
172181
The Transform's _map method will be called with a 2D array
173-
of shape (N, M), where N is the number of points and M is the number of dimensions.
182+
of shape (N, M), where N is the number of points and M is the number of dimensions.
174183
Accepts lists, tuples, and arrays of any dimensionality and flattens extra dimensions into N.
175184
After mapping, any flattened axes are re-expanded to match the original input shape.
176185
177-
For list, tuple, and array inputs, the return value is a numpy array of the same shape as
186+
For list, tuple, and array inputs, the return value is a numpy array of the same shape as
178187
the input, with the exception that the last dimension is determined only by the return value.
179188
180-
Alternatively, any class may determine how to map itself by defining a _coorx_transform()
189+
Alternatively, any class may determine how to map itself by defining a _coorx_transform()
181190
method that accepts this transform as an argument.
182191
"""
183192
if hasattr(obj, '_coorx_transform'):
@@ -186,10 +195,15 @@ def _prepare_and_map(self, obj:Mappable):
186195
elif isinstance(obj, (tuple, list, np.ndarray)):
187196
arr_2d, original_shape = self._prepare_arg_for_mapping(obj)
188197
if self.dims[0] not in (None, arr_2d.shape[1]):
189-
raise TypeError(f"Transform maps from {self.dims[0]}D, but data to be mapped is {arr_2d.shape[1]}D")
198+
raise TypeError(
199+
f"Transform maps from {self.dims[0]}D, but data to be mapped is {arr_2d.shape[1]}D"
200+
)
190201
ret = self._map(arr_2d)
191202
assert ret.ndim == 2
192-
assert self.dims[1] in (None, ret.shape[1]), f"Transform maps to {self.dims[1]}D, but mapping generated {ret.shape[1]}D"
203+
assert self.dims[1] in (
204+
None,
205+
ret.shape[1],
206+
), f"Transform maps to {self.dims[1]}D, but mapping generated {ret.shape[1]}D"
193207
return self._restore_shape(ret, original_shape)
194208
else:
195209
raise TypeError(f"Cannot use argument for mapping: {obj}")
@@ -200,7 +214,7 @@ def _prepare_arg_for_mapping(arg):
200214
201215
If the argument ndim is > 2, then all dimensions except the last are flattened.
202216
203-
Return the reshaped array and a tuple containing the original shape.
217+
Return the reshaped array and a tuple containing the original shape.
204218
"""
205219
arg = np.asarray(arg)
206220
original_shape = arg.shape
@@ -209,26 +223,24 @@ def _prepare_arg_for_mapping(arg):
209223

210224
@staticmethod
211225
def _restore_shape(arg, shape):
212-
"""Return an array with shape determined by shape[:-1] + (arg.shape[-1],)
213-
"""
226+
"""Return an array with shape determined by shape[:-1] + (arg.shape[-1],)"""
214227
if arg is None:
215228
return arg
216229
return arg.reshape(shape[:-1] + (arg.shape[-1],))
217230

218231
@property
219232
def inverse(self):
220-
""" The inverse of this transform.
221-
"""
233+
"""The inverse of this transform."""
222234
if self._inverse is None:
223235
self._inverse = InverseTransform(self)
224236
return self._inverse
225237

226238
@property
227239
def dynamic(self):
228-
"""Boolean flag that indicates whether this transform is expected to
240+
"""Boolean flag that indicates whether this transform is expected to
229241
change frequently.
230-
231-
Transforms that are flagged as dynamic will not be collapsed in
242+
243+
Transforms that are flagged as dynamic will not be collapsed in
232244
``ChainTransform.simplified``. This allows changes to the transform
233245
to propagate through the chain without requiring the chain to be
234246
re-simplified.
@@ -241,20 +253,18 @@ def dynamic(self, d):
241253

242254
@property
243255
def params(self):
244-
"""Return a dict of parameters specifying this transform.
245-
"""
256+
"""Return a dict of parameters specifying this transform."""
246257
raise NotImplementedError(f"{self.__class__.__name__}.params")
247258

248259
def set_params(self, **kwds):
249260
"""Set parameters specifying this transform.
250-
261+
251262
Parameter names must be the same as the keys in self.params.
252263
"""
253264
raise NotImplementedError(f"{self.__class__.__name__}.set_params")
254265

255266
def save_state(self):
256-
"""Return serializable parameters that specify this transform.
257-
"""
267+
"""Return serializable parameters that specify this transform."""
258268
return {
259269
'type': type(self).__name__,
260270
'dims': self.dims,
@@ -263,14 +273,13 @@ def save_state(self):
263273
}
264274

265275
def as_affine(self):
266-
"""Return an equivalent affine transform if possible.
267-
"""
276+
"""Return an equivalent affine transform if possible."""
268277
raise NotImplementedError()
269278

270279
@property
271280
def full_matrix(self) -> np.ndarray:
272281
"""
273-
Return the full transformation matrix for this transform, if possible.
282+
Return the full transformation matrix for this transform, if possible.
274283
275284
Modifying the returned array has no effect on the transform instance that generated it.
276285
"""
@@ -279,19 +288,21 @@ def full_matrix(self) -> np.ndarray:
279288
def as_vispy(self):
280289
"""Return a VisPy transform that is equivalent to this transform, if possible."""
281290
from vispy.visuals.transforms import MatrixTransform
291+
282292
# a functional default if nothing else is implemented
283293
return MatrixTransform(self.full_matrix.T)
284294

285295
def as_pyqtgraph(self):
286296
"""Return a PyQtGraph transform that is equivalent to this transform, if possible."""
287297
from pyqtgraph import SRTTransform3D
288298
from pyqtgraph.Qt import QtGui
299+
289300
# a functional default if nothing else is implemented
290301
return SRTTransform3D(QtGui.QMatrix4x4(self.full_matrix.reshape(-1)))
291302

292303
def add_change_callback(self, cb):
293304
self._change_callbacks.append(cb)
294-
305+
295306
def remove_change_callback(self, cb):
296307
self._change_callbacks.remove(cb)
297308

@@ -359,7 +370,11 @@ def __getstate__(self):
359370
if self.systems[0] is None:
360371
state['_systems'] = (None, None, None)
361372
else:
362-
state['_systems'] = (self.systems[0].name, self.systems[1].name, self.systems[0].graph.name)
373+
state['_systems'] = (
374+
self.systems[0].name,
375+
self.systems[1].name,
376+
self.systems[0].graph.name,
377+
)
363378
return state
364379

365380
def __setstate__(self, state):
@@ -369,8 +384,7 @@ def __setstate__(self, state):
369384
self.set_systems(from_cs, to_cs, graph)
370385

371386
def copy(self, from_cs=None, to_cs=None):
372-
"""Return a copy of this transform.
373-
"""
387+
"""Return a copy of this transform."""
374388
tr = self.__class__(dims=self.dims)
375389
state = self.__getstate__()
376390
if from_cs is not None or to_cs is not None:
@@ -403,7 +417,9 @@ def __eq__(self, tr):
403417

404418
def validate_transform_for_mul(self, tr):
405419
if tr.systems[1] != self.systems[0]:
406-
raise TypeError(f"Cannot multiply transforms with different inner coordinate systems: {self.systems[0]} != {tr.systems[1]}")
420+
raise TypeError(
421+
f"Cannot multiply transforms with different inner coordinate systems: {self.systems[0]} != {tr.systems[1]}"
422+
)
407423

408424

409425
class InverseTransform(Transform):
@@ -422,7 +438,12 @@ def set_systems(self, from_cs, to_cs, cs_graph=None):
422438

423439
def as_affine(self):
424440
affine = self._inverse.as_affine()
425-
return type(affine)(matrix=affine.inv_matrix, offset=affine.inv_matrix @ affine.inv_offset, from_cs=self.systems[0], to_cs=self.systems[1])
441+
return type(affine)(
442+
matrix=affine.inv_matrix,
443+
offset=affine.inv_matrix @ affine.inv_offset,
444+
from_cs=self.systems[0],
445+
to_cs=self.systems[1],
446+
)
426447

427448
def copy(self, from_cs=None, to_cs=None):
428449
return self._inverse.copy(from_cs=to_cs, to_cs=from_cs).inverse
@@ -456,10 +477,10 @@ def NonScaling(self):
456477
@property
457478
def Isometric(self):
458479
return self._inverse.Isometric
459-
480+
460481
def __repr__(self):
461-
return ("<Inverse of %r>" % repr(self._inverse))
462-
482+
return "<Inverse of %r>" % repr(self._inverse)
483+
463484

464485
class ChangeEvent:
465486
def __init__(self, transform, source_event=None):
@@ -468,12 +489,11 @@ def __init__(self, transform, source_event=None):
468489

469490
@property
470491
def sources(self):
471-
"""A list of all transforms that changed leading to this event
472-
"""
492+
"""A list of all transforms that changed leading to this event"""
473493
s = [self]
474494
if self.source_event is not None:
475495
s += self.source_event.sources
476-
return s
496+
return s
477497

478498

479499
# import here to avoid import cycle; needed for Transform.__mul__.

0 commit comments

Comments
 (0)