Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 35 additions & 8 deletions skeliner/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,6 @@ def to_npz(
if not path.suffix:
path = path.with_suffix(".npz")

c = {} if not compress else {"compress": True}

if cache_kdtree:
skeleton._ensure_nodes_kdtree()
skeleton._ensure_node_neighbors()
Expand All @@ -367,11 +365,39 @@ def to_npz(
# ragged node2verts → index + offset
if skeleton.node2verts is not None:
n2v_idx = np.concatenate(skeleton.node2verts)
n2v_off = np.cumsum([0, *map(len, skeleton.node2verts)]).astype(np.int64)
n2v_off = np.cumsum([0, *map(len, skeleton.node2verts)])

# Determine smallest dtype for n2v_idx
if len(n2v_idx) > 0:
idx_min, idx_max = n2v_idx.min(), n2v_idx.max()
if idx_min >= 0 and idx_max <= np.iinfo(np.uint16).max:
n2v_idx = n2v_idx.astype(np.uint16)
elif idx_min < 0 and idx_min >= np.iinfo(np.int16).min and idx_max <= np.iinfo(np.int16).max:
n2v_idx = n2v_idx.astype(np.int16)
elif idx_min >= 0 and idx_max <= np.iinfo(np.uint32).max:
n2v_idx = n2v_idx.astype(np.uint32)
elif idx_min < 0 and idx_min >= np.iinfo(np.int32).min and idx_max <= np.iinfo(np.int32).max:
n2v_idx = n2v_idx.astype(np.int32)
else:
n2v_idx = n2v_idx.astype(np.int64)
else:
n2v_idx = n2v_idx.astype(np.uint16)

# Determine smallest dtype for n2v_off
off_min, off_max = n2v_off.min(), n2v_off.max()
if off_min >= 0 and off_max <= np.iinfo(np.uint16).max:
n2v_off = n2v_off.astype(np.uint16)
elif off_min < 0 and off_min >= np.iinfo(np.int16).min and off_max <= np.iinfo(np.int16).max:
n2v_off = n2v_off.astype(np.int16)
elif off_min >= 0 and off_max <= np.iinfo(np.uint32).max:
n2v_off = n2v_off.astype(np.uint32)
elif off_min < 0 and off_min >= np.iinfo(np.int32).min and off_max <= np.iinfo(np.int32).max:
n2v_off = n2v_off.astype(np.int32)
else:
n2v_off = n2v_off.astype(np.int64)
else:
n2v_idx = np.array([], dtype=np.int64)
n2v_off = np.array([0], dtype=np.int64)

n2v_idx = np.array([], dtype=np.uint16)
n2v_off = np.array([0], dtype=np.uint16)
# ----------- NEW: persist the metadata dict -------------------------
# We wrap it in a 0-D object array because np.savez can only store
# ndarrays — this keeps the archive a single *.npz* with no sidecars.
Expand All @@ -397,7 +423,9 @@ def to_npz(
tree_payload["neighbors_idx"] = data.astype(np.int64, copy=False)
tree_payload["neighbors_off"] = offsets.astype(np.int64, copy=False)

np.savez(
save_fun = np.savez_compressed if compress else np.savez

save_fun(
path,
nodes=skeleton.nodes,
edges=skeleton.edges,
Expand All @@ -416,7 +444,6 @@ def to_npz(
**extra,
**meta,
**tree_payload,
**c,
)


Expand Down