Skip to content

Commit 9c79389

Browse files
committed
Fix numpy_to_pds_type library function
1 parent bd996ad commit 9c79389

File tree

3 files changed

+74
-18
lines changed

3 files changed

+74
-18
lines changed

pds4_tools/reader/data_types.py

Lines changed: 26 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -310,53 +310,62 @@ def numpy_to_pds_type(dtype, ascii_numerics=False):
310310
dtype : np.dtype
311311
A NumPy data type.
312312
ascii_numerics
313-
If True, the returned PDS4 data type will be an ASCII numeric type if the input dtype is numeric
314-
or boolean. If False, the returned PDS4 data type will be a binary type. Defaults to False.
313+
If True, the returned PDS4 data type will be an ASCII numeric or boolean type if the input dtype
314+
is numeric or boolean. If False, the returned PDS4 data type will be a binary type. Defaults to
315+
False.
315316
316317
Returns
317318
-------
318319
PDSdtype
319320
A PDS4 data type that could plausibly (see description above) correspond to the input dtype.
320321
"""
321322

323+
# Ensure *dtype* is a NumPy dtype
324+
dtype = np.dtype(dtype)
325+
322326
# For string dtypes
323327
if np.issubdtype(dtype, np_unicode):
324328
data_type = 'UTF8_String'
325329

326-
elif np.issubdtype(dtype, np.string_):
330+
elif np.issubdtype(dtype, np.bytes_):
327331
data_type = 'ASCII_String'
328332

329333
# For datetime dtypes
330334
elif np.issubdtype(dtype, np.datetime64):
331-
data_type = 'ASCII_Date_Time_YMD'
335+
336+
if dtype.name == PDS4_DATE_TYPES['ASCII_Date_YMD'][1]:
337+
data_type = 'ASCII_Date_YMD'
338+
else:
339+
data_type = 'ASCII_Date_Time_YMD'
332340

333341
# For numeric dtypes
334342
else:
335343

336-
# Get numeric ASCII types. We obtain these from builtin portion because if we attempt to match
337-
# e.g. 'int16' to 'int64' it would fail but for ASCII types this should succeed.
338-
ascii_types = dict((value[2], key)
339-
for key, value in six.iteritems(PDS_NUMERIC_TYPES)
340-
if ('ASCII' in key) and ('Numeric_Base' not in key))
344+
# Get numeric ASCII types
345+
# (compare via np.dtype.kind because kind is unique for each ASCII type supported here)
346+
if ascii_numerics:
341347

342-
# Get numeric non-ASCII types, including the correct endianness.
343-
non_ascii_types = dict((np.dtype(value[1]).newbyteorder(value[0]), key)
348+
ascii_types = dict((np.dtype(value[1]).kind, key)
344349
for key, value in six.iteritems(PDS_NUMERIC_TYPES)
345-
if ('ASCII' not in key) and ('Numeric_Base' not in key))
350+
if ('ASCII' in key) and ('Numeric_Base' not in key))
346351

347-
if ascii_numerics:
348-
349-
builtin_type = type(np.asscalar(np.array(0, dtype=dtype))).__name__
350-
data_type = ascii_types.get(builtin_type, None)
352+
data_type = ascii_types.get(dtype.kind, None)
351353

354+
# Get numeric non-ASCII types, including the correct endianness
355+
# (compare via full np.dtype)
352356
else:
357+
358+
non_ascii_types = dict((np.dtype(value[1]).newbyteorder(value[0]), key)
359+
for key, value in six.iteritems(PDS_NUMERIC_TYPES)
360+
if ('ASCII' not in key) and ('Numeric_Base' not in key))
361+
353362
data_type = non_ascii_types.get(dtype, None)
354363

355364
# Raise error if we were unable to find a match
356365
if data_type is None:
357366

358367
raise ValueError("Unable to convert NumPy data type, '{0}', to a PDS4 {1} data type.".
359-
format(dtype, 'ASCII' if ascii_numerics else 'binary'))
368+
format(dtype.name, 'ASCII' if ascii_numerics else 'binary'))
360369

361370
return PDSdtype(data_type)
362371

pds4_tools/tests/compat.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,16 @@
66
import sys
77
import xml.etree.ElementTree as ET
88

9+
import numpy as np
10+
11+
912
PY26 = sys.version_info[0:2] == (2, 6)
1013

14+
# ElementTree compat (Python 2.7+ and 3.3+)
1115
ET_Element = ET._Element if PY26 else ET.Element
16+
17+
# NumPy compat (NumPy 2.0+)
18+
try:
19+
np_unicode = np.unicode_
20+
except AttributeError:
21+
np_unicode = np.str_

pds4_tools/tests/test_core.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from pds4_tools import pds4_read
1515
from pds4_tools.reader.data import PDS_ndarray, PDS_marray
16-
from pds4_tools.reader.data_types import data_type_convert_dates
16+
from pds4_tools.reader.data_types import data_type_convert_dates, numpy_to_pds_type, PDSdtype
1717
from pds4_tools.reader.array_objects import ArrayStructure
1818
from pds4_tools.reader.table_objects import TableStructure, TableManifest
1919
from pds4_tools.reader.label_objects import Label
@@ -1463,6 +1463,43 @@ def test_download_file(self):
14631463
assert xml_equal(structures_web2.label, structures_local.label)
14641464

14651465

1466+
class TestLibraryFunctions(PDS4ToolsTestCase):
1467+
1468+
def test_numpy_to_pds_type(self):
1469+
1470+
# Test strings
1471+
assert PDSdtype('UTF8_String') == numpy_to_pds_type(compat.np_unicode)
1472+
assert PDSdtype('ASCII_String') == numpy_to_pds_type(np.bytes_)
1473+
1474+
# Test bool
1475+
assert PDSdtype('ASCII_Boolean') == numpy_to_pds_type(np.bool_, ascii_numerics=True)
1476+
1477+
# Test ASCII numbers
1478+
assert PDSdtype('ASCII_Real') == numpy_to_pds_type(np.float32, ascii_numerics=True)
1479+
assert PDSdtype('ASCII_Integer') == numpy_to_pds_type(np.int8, ascii_numerics=True)
1480+
assert PDSdtype('ASCII_NonNegative_Integer') == numpy_to_pds_type(np.uint32, ascii_numerics=True)
1481+
1482+
# Test binary numbers
1483+
np_lsb_float64 = np.dtype(np.float64).newbyteorder('<')
1484+
np_int8 = np.dtype(np.int8)
1485+
np_lsb_uint64 = np.dtype(np.uint64).newbyteorder('<')
1486+
np_msb_int32 = np.dtype(np.int32).newbyteorder('>')
1487+
1488+
assert PDSdtype('IEEE754LSBDouble') == numpy_to_pds_type(np_lsb_float64, ascii_numerics=False)
1489+
assert PDSdtype('SignedByte') == numpy_to_pds_type(np_int8, ascii_numerics=False)
1490+
assert PDSdtype('UnsignedLSB8') == numpy_to_pds_type(np_lsb_uint64, ascii_numerics=False)
1491+
assert PDSdtype('SignedMSB4') == numpy_to_pds_type(np_msb_int32, ascii_numerics=False)
1492+
1493+
# Test dates
1494+
np_date = np.datetime64("2000-01-01").dtype
1495+
np_specific_datetime = np.datetime64("2000-01-01 00:00").dtype
1496+
np_generic_datetime = np.datetime64
1497+
1498+
assert PDSdtype('ASCII_Date_YMD') == numpy_to_pds_type(np_date)
1499+
assert PDSdtype('ASCII_Date_Time_YMD') == numpy_to_pds_type(np_specific_datetime)
1500+
assert PDSdtype('ASCII_Date_Time_YMD') == numpy_to_pds_type(np_generic_datetime)
1501+
1502+
14661503
class TestDeprecation(PDS4ToolsTestCase):
14671504

14681505
def test_deprecated(self):

0 commit comments

Comments
 (0)