Skip to content

Commit b4485eb

Browse files
committed
test broadcasting
1 parent 6f94e07 commit b4485eb

File tree

3 files changed

+42
-7
lines changed

3 files changed

+42
-7
lines changed

src/h5json/hdf5db.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,7 @@ def setDatasetValues(self, dset_id, sel, arr):
753753
if sel.select_type != selections.H5S_SELECT_HYPERSLABS:
754754
raise ValueError("tbd")
755755
arr = arr.reshape(sel.mshape)
756-
updates.append((sel, arr.copy()))
756+
updates.append((sel, arr))
757757
self.make_dirty(dset_id)
758758

759759
def resizeDataset(self, dset_id, shape):

test/unit/h5py_writer_test.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,13 @@ def testSimple(self):
9797
g1_1_id = db.createGroup()
9898
db.createHardLink(g1_id, "g1.1", g1_1_id)
9999
dset_111_id = db.createDataset(shape=(10, 10), dtype=np.int32)
100-
arr = np.zeros((10, 10), dtype=np.int32)
101-
for i in range(10):
102-
for j in range(10):
103-
arr[i, j] = i * j
100+
101+
# try setting dset values with broadcasting
102+
arr_one_value = np.zeros((1, 1), dtype=np.int32)
103+
arr_one_value[0, 0] = 42
104104
sel_all = selections.select((10, 10), ...)
105-
db.setDatasetValues(dset_111_id, sel_all, arr)
105+
db.setDatasetValues(dset_111_id, sel_all, arr_one_value)
106+
106107
db.createHardLink(g1_1_id, "dset1.1.1", dset_111_id)
107108
db.createSoftLink(g2_id, "slink", "somewhere")
108109
db.createExternalLink(g2_id, "extlink", "somewhere", "someplace")
@@ -126,12 +127,29 @@ def testSimple(self):
126127
self.assertEqual(dset.shape, (10, 10))
127128
for i in range(10):
128129
for j in range(10):
129-
self.assertEqual(dset[i, j], i * j)
130+
self.assertEqual(dset[i, j], 42)
130131
self.assertTrue("g2" in f)
131132
g2 = f["g2"]
132133
self.assertTrue("extlink" in g2)
133134
self.assertTrue("slink" in g2)
134135

136+
# write dataset values element by element
137+
db.open()
138+
arr = np.zeros((10, 10), dtype=np.int32)
139+
for i in range(10):
140+
for j in range(10):
141+
arr[i, j] = i * j
142+
sel_all = selections.select((10, 10), ...)
143+
db.setDatasetValues(dset_111_id, sel_all, arr)
144+
db.close()
145+
146+
# verify changes in h5py
147+
with h5py.File(filepath) as f:
148+
dset = f["/g1/g1.1/dset1.1.1"]
149+
for i in range(10):
150+
for j in range(10):
151+
self.assertEqual(dset[i, j], i * j)
152+
135153
db.open()
136154
db.createAttribute(g1_id, "a1", "hello")
137155
db.createAttribute(g1_id, "a2", "bye-bye")

test/unit/hdf5db_test.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -478,6 +478,23 @@ def testSimpleDataset(self):
478478
self.assertEqual(val.shape, (1, 1))
479479
self.assertEqual(val[0, 0], i * 10 + j)
480480

481+
# test select all write
482+
sel = selections.select(shape, ...)
483+
print("got sel:", sel)
484+
print(sel.select_type)
485+
arr = np.zeros(shape, dtype=dtype)
486+
arr[...] = 42
487+
db.setDatasetValues(dset_id, sel, arr)
488+
arr = db.getDatasetValues(dset_id, sel)
489+
for i in range(nrows):
490+
for j in range(ncols):
491+
self.assertEqual(arr[i, j], 42)
492+
493+
# try with broadcasting
494+
arr_one_value = np.zeros((1, 1), dtype=dtype)
495+
arr_one_value[0, 0] = 7
496+
db.setDatasetValues(dset_id, sel, arr_one_value)
497+
481498
db.close()
482499

483500
def testStringDataset(self):

0 commit comments

Comments
 (0)