Skip to content

Commit 25365ef

Browse files
committed
fix _set_vector_unsafe and add tests on _nb_fill
1 parent bd6f01f commit 25365ef

File tree

5 files changed

+62
-2
lines changed

5 files changed

+62
-2
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
/src/lgdo/_version.py
22

3+
#uv
4+
uv.lock
5+
36
# Byte-compiled / optimized / DLL files
47
__pycache__/
58
*.py[cod]

.pre-commit-config.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ repos:
5050
rev: "v2.4.1"
5151
hooks:
5252
- id: codespell
53+
additional_dependencies:
54+
- tomli
5355

5456
- repo: https://github.com/shellcheck-py/shellcheck-py
5557
rev: "v0.10.0.1"

src/lgdo/types/vectorofvectors.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -476,7 +476,10 @@ def _set_vector_unsafe(
476476
else:
477477
nan_val = np.nan
478478
vovutils._nb_fill(
479-
vec, lens, nan_val, self.flattened_data.nda[start : cum_lens[-1]]
479+
vec,
480+
lens,
481+
np.array([nan_val]).astype(self.flattened_data.nda.dtype),
482+
self.flattened_data.nda[start : cum_lens[-1]],
480483
)
481484

482485
# add new vector(s) length to cumulative_length

src/lgdo/types/vovutils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def _nb_fill(
131131
for i, ll in enumerate(len_in):
132132
stop = start + ll
133133
if ll > max_len:
134-
flattened_array_out[start : start + max_len] = aoa_in[i, :]
134+
flattened_array_out[start : start + max_len] = aoa_in[i, :max_len]
135135
flattened_array_out[start + max_len : stop] = nan_val
136136
else:
137137
flattened_array_out[start:stop] = aoa_in[i, :ll]

tests/types/test_vovutils.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,58 @@ def testvov():
2727
return VovColl(v2d, v3d, v4d)
2828

2929

30+
def test_nb_fill():
31+
# test 1d array of int
32+
aoa_in = np.arange(5, dtype="int32").reshape(1, 5)
33+
len_in = np.array([5])
34+
nan_val = np.array([0], dtype=aoa_in.dtype)
35+
flattened_array_out = np.empty(5, dtype=aoa_in.dtype)
36+
37+
vovutils._nb_fill(aoa_in, len_in, nan_val, flattened_array_out)
38+
assert np.array_equal(
39+
flattened_array_out, np.array([0, 1, 2, 3, 4], dtype=aoa_in.dtype)
40+
)
41+
# test 1d array of uint
42+
aoa_in = np.arange(5, dtype="uint16").reshape(1, 5)
43+
len_in = np.array([5])
44+
nan_val = np.array([0], dtype=aoa_in.dtype)
45+
flattened_array_out = np.empty(5, dtype=aoa_in.dtype)
46+
47+
vovutils._nb_fill(aoa_in, len_in, nan_val, flattened_array_out)
48+
assert np.array_equal(
49+
flattened_array_out, np.array([0, 1, 2, 3, 4], dtype=aoa_in.dtype)
50+
)
51+
# test 1d array of float
52+
aoa_in = np.arange(5, dtype="float32").reshape(1, 5)
53+
len_in = np.array([5])
54+
nan_val = np.array([0], dtype=aoa_in.dtype)
55+
flattened_array_out = np.empty(5, dtype=aoa_in.dtype)
56+
57+
vovutils._nb_fill(aoa_in, len_in, nan_val, flattened_array_out)
58+
assert np.array_equal(
59+
flattened_array_out, np.array([0, 1, 2, 3, 4], dtype=aoa_in.dtype)
60+
)
61+
# test 2d array of int
62+
aoa_in = np.array([[1, 2, 3], [4, 5, 6]], dtype="int32")
63+
len_in = np.array([3, 3])
64+
nan_val = np.array([0], dtype=aoa_in.dtype)
65+
flattened_array_out = np.empty(6, dtype=aoa_in.dtype)
66+
vovutils._nb_fill(aoa_in, len_in, nan_val, flattened_array_out)
67+
assert np.array_equal(
68+
flattened_array_out,
69+
np.array([1, 2, 3, 4, 5, 6], dtype=aoa_in.dtype),
70+
)
71+
# test nan value addition
72+
aoa_in = np.array([[1, 2, 3], [4, 5, 6]], dtype="int32")
73+
len_in = np.array([4, 3])
74+
flattened_array_out = np.empty(7, dtype=aoa_in[0].dtype)
75+
vovutils._nb_fill(aoa_in, len_in, nan_val, flattened_array_out)
76+
assert np.array_equal(
77+
flattened_array_out,
78+
np.array([1, 2, 3, 0, 4, 5, 6], dtype=aoa_in[0].dtype),
79+
)
80+
81+
3082
def test_ak_input_validity(testvov):
3183
for v in testvov:
3284
assert vovutils._ak_is_jagged(v) is True

0 commit comments

Comments
 (0)