Skip to content

Commit 924ee00

Browse files
committed
fix for scalar datasets
1 parent 5561767 commit 924ee00

File tree

2 files changed

+34
-6
lines changed

2 files changed

+34
-6
lines changed

src/h5json/hdf5db.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,11 @@ def getDatasetValues(self, dset_id, sel):
603603
stop = start + sel_inter.count[dim]
604604
slices.append(slice(start, stop, 1))
605605
slices = tuple(slices)
606-
arr[slices] = update_val
606+
# TBD: needs updating to work in the general case!
607+
if slices == ():
608+
arr[slices] = update_val[slices]
609+
else:
610+
arr[slices] = update_val
607611

608612
return arr
609613

@@ -620,6 +624,11 @@ def setDatasetValues(self, dset_id, sel, arr):
620624
raise ValueError("Only hyperslab selections are currently supported")
621625
if not isinstance(arr, np.ndarray):
622626
raise TypeError("Expected ndarray for data value")
627+
tgt_dt = self.getDtype(dset_json)
628+
src_dt = arr.dtype
629+
if src_dt != tgt_dt:
630+
raise TypeError("arr.dtype doesn't match dataset dtype")
631+
623632
if shape_json["class"] == "H5S_NULL":
624633
raise ValueError("writing to null space dataset not supported")
625634
if shape_json["class"] == "H5S_SCALAR":

test/unit/hsds_writer_test.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,23 +244,42 @@ def testReaderWriter(self):
244244
self.assertTrue(db.writer.lastModified is None) # no flush yet
245245

246246
# create a scalar dataset
247-
dset_id = db.createDataset(shape=(), dtype=np.int32)
248-
dset_json = db.getObjectById(dset_id)
247+
dsetA_id = db.createDataset(shape=(), dtype=np.int32)
248+
dset_json = db.getObjectById(dsetA_id)
249249
self.assertTrue("created" in dset_json)
250250
dset_create_time = dset_json["created"]
251251
self.assertTrue(dset_create_time > 0)
252252

253+
db.createHardLink(root_id, "dset_a", dsetA_id)
254+
253255
arr = np.zeros((), dtype=np.int32)
254256
arr[()] = 42
255257
sel_all = selections.select((), ...)
256-
db.setDatasetValues(dset_id, sel_all, arr)
257-
dset_json = db.getObjectById(dset_id)
258+
db.setDatasetValues(dsetA_id, sel_all, arr)
259+
260+
dset_json = db.getObjectById(dsetA_id)
258261
self.assertTrue("lastModified" in dset_json)
259262
self.assertTrue(dset_json["lastModified"] > dset_create_time)
260263

261-
arr = db.getDatasetValues(dset_id, sel_all)
264+
arr = db.getDatasetValues(dsetA_id, sel_all)
262265
self.assertEqual(arr[()], 42)
263266

267+
# create a scalar dataset with string
268+
dt_str = special_dtype(vlen=str)
269+
dsetB_id = db.createDataset(shape=(), dtype=dt_str)
270+
dset_json = db.getObjectById(dsetB_id)
271+
db.createHardLink(root_id, "dset_b", dsetB_id)
272+
273+
arr = np.zeros((), dtype=dt_str)
274+
arr[()] = "hello world"
275+
db.setDatasetValues(dsetB_id, sel_all, arr)
276+
277+
arr = db.getDatasetValues(dsetB_id, sel_all)
278+
279+
e = arr[()]
280+
self.assertEqual(e, "hello world")
281+
self.assertTrue(isinstance(e, str))
282+
264283
db.close()
265284

266285
def testH5PyToHS(self):

0 commit comments

Comments
 (0)