Skip to content

Commit be45df8

Browse files
committed
BF - more robust dtype mapper to fix errors with numpy 1.2.1
1 parent 0420250 commit be45df8

File tree

3 files changed

+144
-28
lines changed

3 files changed

+144
-28
lines changed

nibabel/tests/test_recoder.py

Lines changed: 71 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,11 @@
88
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
99
''' Tests recoder class '''
1010

11-
from nose.tools import assert_equal, assert_raises, assert_true, assert_false
11+
import numpy as np
12+
13+
from ..volumeutils import Recoder, DtypeMapper, native_code, swapped_code
1214

13-
from ..volumeutils import Recoder
15+
from nose.tools import assert_equal, assert_raises, assert_true, assert_false
1416

1517
def test_recoder():
1618
# simplest case, no aliases
@@ -54,6 +56,33 @@ def test_recoder():
5456
# Don't allow funny names
5557
yield assert_raises, KeyError, Recoder, codes, ['field1']
5658

59+
60+
def test_custom_dicter():
61+
# Allow custom dict-like object in constructor
62+
class MyDict(object):
63+
def __init__(self):
64+
self._keys = []
65+
def __setitem__(self, key, value):
66+
self._keys.append(key)
67+
def __getitem__(self, key):
68+
if key in self._keys:
69+
return 'spam'
70+
return 'eggs'
71+
def keys(self):
72+
return ['some', 'keys']
73+
def values(self):
74+
return ['funny', 'list']
75+
# code, label, aliases
76+
codes = ((1,'one','1','first'), (2,'two'))
77+
rc = Recoder(codes, map_maker=MyDict)
78+
yield assert_equal, rc.code[1], 'spam'
79+
yield assert_equal, rc.code['one'], 'spam'
80+
yield assert_equal, rc.code['first'], 'spam'
81+
yield assert_equal, rc.code['bizarre'], 'eggs'
82+
yield assert_equal, rc.value_set(), set(['funny', 'list'])
83+
yield assert_equal, list(rc.keys()), ['some', 'keys']
84+
85+
5786
def test_add_codes():
5887
codes = ((1,'one','1','first'), (2,'two'))
5988
rc = Recoder(codes)
@@ -63,6 +92,7 @@ def test_add_codes():
6392
yield assert_equal, rc.code['three'], 3
6493
yield assert_equal, rc.code['number 1'], 1
6594

95+
6696
def test_sugar():
6797
# Syntactic sugar for recoder class
6898
codes = ((1,'one','1','first'), (2,'two'))
@@ -84,3 +114,42 @@ def test_sugar():
84114
yield assert_true, 'one' in rc
85115
yield assert_false, 'three' in rc
86116

117+
118+
def test_dtmapper():
119+
# dict-like that will lookup on dtypes, even if they don't hash properly
120+
d = DtypeMapper()
121+
assert_raises(KeyError, d.__getitem__, 1)
122+
d[1] = 'something'
123+
assert_equal(d[1], 'something')
124+
assert_equal(list(d.keys()), [1])
125+
assert_equal(list(d.values()), ['something'])
126+
intp_dt = np.dtype('intp')
127+
if intp_dt == np.dtype('int32'):
128+
canonical_dt = np.dtype('int32')
129+
elif intp_dt == np.dtype('int64'):
130+
canonical_dt = np.dtype('int64')
131+
else:
132+
raise RuntimeError('Can I borrow your computer?')
133+
native_dt = canonical_dt.newbyteorder('=')
134+
explicit_dt = canonical_dt.newbyteorder(native_code)
135+
d[canonical_dt] = 'spam'
136+
assert_equal(d[canonical_dt], 'spam')
137+
assert_equal(d[native_dt], 'spam')
138+
assert_equal(d[explicit_dt], 'spam')
139+
# Test keys, values
140+
d = DtypeMapper()
141+
assert_equal(list(d.keys()), [])
142+
assert_equal(list(d.keys()), [])
143+
d[canonical_dt] = 'spam'
144+
assert_equal(list(d.keys()), [canonical_dt])
145+
assert_equal(list(d.values()), ['spam'])
146+
# With other byte order
147+
d = DtypeMapper()
148+
sw_dt = canonical_dt.newbyteorder(swapped_code)
149+
d[sw_dt] = 'spam'
150+
assert_raises(KeyError, d.__getitem__, canonical_dt)
151+
assert_equal(d[sw_dt], 'spam')
152+
sw_intp_dt = intp_dt.newbyteorder(swapped_code)
153+
assert_equal(d[sw_intp_dt], 'spam')
154+
155+

nibabel/tests/test_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def test_dtypes():
298298
# check we have the fields we were expecting
299299
assert_equal(dtr.value_set(), set((16,)))
300300
assert_equal(dtr.fields, ('code', 'label', 'type',
301-
'dtype', 'native_dtype', 'sw_dtype'))
301+
'dtype', 'sw_dtype'))
302302
# These of course should pass regardless of dtype
303303
assert_equal(dtr[np.float32], 16)
304304
assert_equal(dtr['float32'], 16)
@@ -314,13 +314,13 @@ def test_dtypes():
314314
assert_equal(dtr[np.dtype('f4').newbyteorder('S')], 16)
315315
assert_equal(dtr.value_set(), set((16,)))
316316
assert_equal(dtr.fields, ('code', 'label', 'type', 'niistring',
317-
'dtype', 'native_dtype', 'sw_dtype'))
317+
'dtype', 'sw_dtype'))
318318
assert_equal(dtr.niistring[16], 'ASTRING')
319319
# And that unequal elements raises error
320320
dt_defs = ((16, 'float32', np.float32, 'ASTRING'),
321321
(16, 'float32', np.float32))
322322
assert_raises(ValueError, make_dt_codes, dt_defs)
323-
# And that 2 or 5 elements raises error
323+
# And that 2 or 5 elements raises error
324324
dt_defs = ((16, 'float32'),)
325325
assert_raises(ValueError, make_dt_codes, dt_defs)
326326
dt_defs = ((16, 'float32', np.float32, 'ASTRING', 'ANOTHERSTRING'),)

nibabel/volumeutils.py

Lines changed: 70 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ class Recoder(object):
7070
>>> recodes[2]
7171
2
7272
'''
73-
def __init__(self, codes, fields=('code',)):
73+
def __init__(self, codes, fields=('code',), map_maker=dict):
7474
''' Create recoder object
7575
7676
``codes`` give a sequence of code, alias sequences
@@ -92,16 +92,20 @@ def __init__(self, codes, fields=('code',)):
9292
codes : seqence of sequences
9393
Each sequence defines values (codes) that are equivalent
9494
fields : {('code',) string sequence}, optional
95-
names by which elements in sequences can be accesssed
96-
95+
names by which elements in sequences can be accessed
96+
map_maker: callable, optional
97+
constructor for dict-like objects used to store key value pairs.
98+
Default is ``dict``. ``map_maker()`` generates an empty mapping.
99+
The mapping need only implement ``__getitem__, __setitem__, keys,
100+
values``.
97101
'''
98102
self.fields = tuple(fields)
99103
self.field1 = {} # a placeholder for the check below
100104
for name in fields:
101105
if name in self.__dict__:
102106
raise KeyError('Input name %s already in object dict'
103107
% name)
104-
self.__dict__[name] = {}
108+
self.__dict__[name] = map_maker()
105109
self.field1 = self.__dict__[fields[0]]
106110
self.add_codes(codes)
107111

@@ -153,7 +157,11 @@ def __getitem__(self, key):
153157
def __contains__(self, key):
154158
""" True if field1 in recoder contains `key`
155159
"""
156-
return key in self.field1
160+
try:
161+
self.field1[key]
162+
except KeyError:
163+
return False
164+
return True
157165

158166
def keys(self):
159167
''' Return all available code and alias values
@@ -191,7 +199,6 @@ def value_set(self, name=None):
191199
>>> rc = Recoder(codes, fields=('code', 'label'))
192200
>>> rc.value_set('label') == set(('one', 'two', 'repeat value'))
193201
True
194-
195202
'''
196203
if name is None:
197204
d = self.field1
@@ -204,6 +211,59 @@ def value_set(self, name=None):
204211
endian_codes = Recoder(endian_codes)
205212

206213

214+
class DtypeMapper(object):
215+
""" Specialized mapper for numpy dtypes
216+
217+
We pass this mapper into the Recoder class to deal with numpy dtype hashing.
218+
219+
The hashing problem is that dtypes that compare equal may not have the same
220+
hash. This is true for numpys up to the current at time of writing (1.6.0).
221+
For numpy 1.2.1 at least, even dtypes that look exactly the same in terms of
222+
fields don't always have the same hash. This makes dtypes difficult to use
223+
as keys in a dictionary.
224+
225+
This class wraps a dictionary in order to implement a __getitem__ to deal
226+
with dtype hashing. If the key doesn't appear to be in the mapping, and it
227+
is a dtype, we compare (using ==) all known dtype keys to the input key, and
228+
return any matching values for the matching key.
229+
"""
230+
def __init__(self):
231+
self._dict = {}
232+
self._dtype_keys = []
233+
234+
def keys(self):
235+
return self._dict.keys()
236+
237+
def values(self):
238+
return self._dict.values()
239+
240+
def __setitem__(self, key, value):
241+
""" Set item into mapping, checking for dtype keys
242+
243+
Cache dtype keys for comparison test in __getitem__
244+
"""
245+
self._dict[key] = value
246+
if hasattr(key, 'subdtype'):
247+
self._dtype_keys.append(key)
248+
249+
def __getitem__(self, key):
250+
""" Get item from mapping, checking for dtype keys
251+
252+
First do simple hash lookup, then check for a dtype key that has failed
253+
the hash lookup. Look then for any known dtype keys that compare equal
254+
to `key`.
255+
"""
256+
try:
257+
return self._dict[key]
258+
except KeyError:
259+
pass
260+
if hasattr(key, 'subdtype'):
261+
for dt in self._dtype_keys:
262+
if key == dt:
263+
return self._dict[dt]
264+
raise KeyError(key)
265+
266+
207267
def pretty_mapping(mapping, getterfunc=None):
208268
''' Make pretty string from mapping
209269
@@ -265,7 +325,7 @@ def pretty_mapping(mapping, getterfunc=None):
265325

266326

267327
def make_dt_codes(codes_seqs):
268-
''' Create full dt codes object from datatype codes
328+
''' Create full dt codes Recoder instance from datatype codes
269329
270330
Include created numpy dtype (from numpy type) and opposite endian
271331
numpy dtype
@@ -299,23 +359,10 @@ def make_dt_codes(codes_seqs):
299359
raise ValueError('Sequences must all have the same length')
300360
np_type = seq[2]
301361
this_dt = np.dtype(np_type)
302-
code_syns = list(seq)
303-
dtypes = [this_dt]
304-
# intp type is effectively same as int32 on 32 bit and int64 on 64 bit.
305-
# They compare equal, but in some (all?) numpy versions, they may hash
306-
# differently. If so we need to add them
307-
if this_dt == intp_dt and hash(this_dt) != hash(intp_dt):
308-
dtypes.append(intp_dt)
309-
# To satisfy an oddness in numpy dtype hashing, we need to add the dtype
310-
# with explicit native order as well as the default dtype (=) order
311-
for dt in dtypes:
312-
code_syns +=[dt,
313-
dt.newbyteorder(native_code),
314-
dt.newbyteorder(swapped_code)]
362+
# Add swapped dtype to synonyms
363+
code_syns = list(seq) + [this_dt, this_dt.newbyteorder(swapped_code)]
315364
dt_codes.append(code_syns)
316-
return Recoder(dt_codes, fields + ['dtype',
317-
'native_dtype',
318-
'sw_dtype'])
365+
return Recoder(dt_codes, fields + ['dtype', 'sw_dtype'], DtypeMapper)
319366

320367

321368
def can_cast(in_type, out_type, has_intercept=False, has_slope=False):

0 commit comments

Comments
 (0)