Skip to content

Commit bbe79ef

Browse files
committed
make a header-only reader
1 parent e096263 commit bbe79ef

File tree

7 files changed

+51
-44
lines changed

7 files changed

+51
-44
lines changed

cpp/cmd/connectome2tck.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ void usage() {
220220
// 1-based indices, matching trx2connectome's ordering convention.
221221
static std::vector<std::vector<node_t>>
222222
assignments_from_trx_groups(const std::string &trx_path, const std::string &prefix_filter, node_t &max_node_index) {
223-
auto trx = load_trx(trx_path);
223+
auto trx = load_trx_header_only(trx_path);
224224
if (!trx || !trx->streamlines)
225225
throw Exception("Failed to load TRX file: " + trx_path);
226226
if (trx->groups.empty())

cpp/cmd/tckconvert.cpp

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,13 @@ void usage() {
139139
.allow_multiple()
140140
+ Argument ("name").type_text()
141141
+ Argument ("datatype").type_text()
142-
+ Argument ("path").type_file_in();
142+
+ Argument ("path").type_file_in()
143+
144+
+ Option ("positions_datatype",
145+
"datatype of the positions array in the output TRX file (float16, float32, or float64). "
146+
"Only applies to TRX output. Default: preserves the source datatype for TRX→TRX conversion; "
147+
"float32 for all other input formats.")
148+
+ Argument ("spec").type_text();
143149

144150
}
145151
// clang-format on
@@ -832,7 +838,23 @@ void run() {
832838
trx->add_dpv_from_tsf(spec.name, spec.dtype, spec.path);
833839
}
834840
try {
835-
trx->save(trx_save_path, ZIP_CM_STORE);
841+
trx::TrxSaveOptions save_opts;
842+
save_opts.compression_standard = ZIP_CM_STORE;
843+
{
844+
auto opt = get_options("positions_datatype");
845+
if (!opt.empty()) {
846+
const std::string spec = opt[0][0];
847+
if (spec == "float16")
848+
save_opts.positions_dtype = trx::TrxScalarType::Float16;
849+
else if (spec == "float32")
850+
save_opts.positions_dtype = trx::TrxScalarType::Float32;
851+
else if (spec == "float64")
852+
save_opts.positions_dtype = trx::TrxScalarType::Float64;
853+
else
854+
throw Exception("Unknown -positions_datatype '" + spec + "'; expected float16, float32, or float64");
855+
}
856+
}
857+
trx->save(trx_save_path, save_opts);
836858
if (rename_trx_directory) {
837859
std::error_code ec;
838860
std::filesystem::remove_all(trx_output, ec);

cpp/cmd/tckinfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ void run() {
5656
std::cout << " Tracks file: \"" << argument[i] << "\"\n";
5757

5858
if (Tractography::TRX::is_trx(argument[i])) {
59-
auto trx = Tractography::TRX::load_trx(argument[i]);
59+
auto trx = Tractography::TRX::load_trx_header_only(argument[i]);
6060
if (!trx)
6161
throw Exception("Failed to load TRX file: " + std::string(argument[i]));
6262
Tractography::TRX::print_info(std::cout, *trx, prefix_depth, !prefix_depth_specified);

cpp/cmd/trx2connectome.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,7 +193,7 @@ void run() {
193193
if (!is_trx(tracks_path))
194194
throw Exception("Input must be a TRX file; use trxlabel first to add group assignments");
195195

196-
auto trx = load_trx(tracks_path);
196+
auto trx = load_trx_header_only(tracks_path);
197197
if (!trx)
198198
throw Exception("Failed to load TRX file: " + tracks_path);
199199

cpp/cmd/tsfinfo.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ void run() {
6767

6868
if (DWI::Tractography::TRX::is_trx(path)) {
6969
// TRX mode: list dpv fields
70-
auto trx = DWI::Tractography::TRX::load_trx(path);
70+
auto trx = DWI::Tractography::TRX::load_trx_header_only(path);
7171
if (!trx)
7272
throw Exception("Failed to load TRX file: " + path);
7373

cpp/cmd/tsfvalidate.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ void run() {
6464
throw Exception("Use -field to specify the dpv field to validate when the tracks argument is a TRX file");
6565

6666
const std::string field_name(field_opt[0][0]);
67-
auto trx = TRX::load_trx(tck_path);
67+
auto trx = TRX::load_trx_header_only(tck_path);
6868
if (!trx)
6969
throw Exception("Failed to load TRX file: " + tck_path);
7070

cpp/core/dwi/tractography/trx_utils.h

Lines changed: 22 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -59,48 +59,33 @@ inline bool is_trx_field_name(std::string_view input_path, std::string_view outp
5959
output_arg.find('\\') == std::string_view::npos;
6060
}
6161

62-
// Load a TRX file, auto-detecting zip vs directory format, and expose positions
63-
// as float32.
64-
//
65-
// For non-float32 positions, conversion is performed in bounded chunks directly
66-
// into the loaded float32 matrix to avoid allocating a second full-size copy.
62+
// Load a TRX file with positions guaranteed as float32.
63+
// For float16/float64 files, delegates to trx::load_float32_positions() which
64+
// performs chunked in-place conversion into owned storage — no segfault, no
65+
// double-size allocation.
66+
// Use this whenever streamline coordinates will be read or written.
6767
inline std::unique_ptr<trx::TrxFile<float>> load_trx(std::string_view path) {
6868
const std::string filename(path);
6969
const auto dtype = trx::detect_positions_scalar_type(filename, trx::TrxScalarType::Float32);
70-
if (dtype == trx::TrxScalarType::Float32)
71-
return trx::load<float>(filename);
72-
if (dtype == trx::TrxScalarType::Float16) {
70+
if (dtype == trx::TrxScalarType::Float16)
7371
WARN("TRX file '" + filename + "' has float16 positions; converting to float32 on load");
74-
} else {
72+
else if (dtype == trx::TrxScalarType::Float64)
7573
WARN("TRX file '" + filename + "' has float64 positions; converting to float32 (precision loss)");
76-
}
77-
auto trx_f32 = trx::load<float>(filename);
78-
if (!trx_f32 || !trx_f32->streamlines)
79-
return trx_f32;
80-
const Eigen::Index n_vertices = static_cast<Eigen::Index>(trx_f32->num_vertices());
81-
if (n_vertices <= 0)
82-
return trx_f32;
83-
const Eigen::Index chunk_rows = static_cast<Eigen::Index>(1) << 20;
84-
trx::with_trx_reader(filename, [&](auto &reader, trx::TrxScalarType) -> int {
85-
const auto &src = reader->streamlines->_data;
86-
if (src.rows() != n_vertices || src.cols() != 3)
87-
throw Exception("unexpected TRX positions shape while converting to float32");
88-
for (Eigen::Index row0 = 0; row0 < n_vertices; row0 += chunk_rows) {
89-
const Eigen::Index row1 = std::min<Eigen::Index>(n_vertices, row0 + chunk_rows);
90-
for (Eigen::Index r = row0; r < row1; ++r) {
91-
trx_f32->streamlines->_data(r, 0) = static_cast<float>(src(r, 0));
92-
trx_f32->streamlines->_data(r, 1) = static_cast<float>(src(r, 1));
93-
trx_f32->streamlines->_data(r, 2) = static_cast<float>(src(r, 2));
94-
}
95-
}
96-
return 0;
97-
});
98-
return trx_f32;
74+
return trx::load_float32_positions(filename);
75+
}
76+
77+
// Load a TRX file for metadata inspection only (groups, dps, dpv field
78+
// names/shapes, streamline/vertex counts). Positions are mmapped with their
79+
// native dtype but never read, so this is safe for float16/float64 files and
80+
// avoids any conversion overhead.
81+
// Do NOT read streamline coordinates from the returned TrxFile; use load_trx().
82+
inline std::unique_ptr<trx::TrxFile<float>> load_trx_header_only(std::string_view path) {
83+
return trx::load<float>(std::string(path));
9984
}
10085

10186
// Count streamlines and total vertices in a TRX file
10287
inline std::pair<size_t, size_t> count_trx(std::string_view path) {
103-
auto trx = load_trx(path);
88+
auto trx = load_trx_header_only(path);
10489
if (!trx)
10590
throw Exception("Failed to load TRX file: " + std::string(path));
10691
return {trx->num_streamlines(), trx->num_vertices()};
@@ -338,7 +323,7 @@ append_group(std::string_view trx_path, const std::string &name, const std::vect
338323
// because TrxFile<float> already presents all values as float.
339324
// Group indices are copied as uint32_t.
340325
inline void copy_trx_sidecar_data(std::string_view src_path, std::string_view dst_path, bool include_dpv = false) {
341-
auto src = load_trx(src_path);
326+
auto src = load_trx_header_only(src_path);
342327
if (!src)
343328
return;
344329
const std::string dst(dst_path);
@@ -758,7 +743,7 @@ inline std::vector<float> resolve_dps_weights(std::string_view tractogram_path,
758743
}
759744
if (!is_trx(tractogram_path))
760745
throw Exception("cannot resolve \"" + field_name_or_path + "\": not an existing file and input is not a TRX file");
761-
auto trx = load_trx(tractogram_path);
746+
auto trx = load_trx_header_only(tractogram_path);
762747
auto it = trx->data_per_streamline.find(field_name_or_path);
763748
if (it == trx->data_per_streamline.end() || !it->second)
764749
throw Exception("TRX file has no dps field named \"" + field_name_or_path + "\"");
@@ -818,7 +803,7 @@ inline std::vector<float> resolve_dpv_scalars(std::string_view tractogram_path,
818803
}
819804
if (!is_trx(tractogram_path))
820805
throw Exception("cannot resolve \"" + field_name_or_path + "\": not an existing file and input is not a TRX file");
821-
auto trx = load_trx(tractogram_path);
806+
auto trx = load_trx_header_only(tractogram_path);
822807
auto it = trx->data_per_vertex.find(field_name_or_path);
823808
if (it == trx->data_per_vertex.end() || !it->second)
824809
throw Exception("TRX file has no dpv field named \"" + field_name_or_path + "\"");
@@ -838,7 +823,7 @@ inline std::vector<float> resolve_dpv_scalars(std::string_view tractogram_path,
838823
class TRXScalarReader {
839824
public:
840825
TRXScalarReader(std::string_view trx_path, const std::string &field_name) : current_(0) {
841-
auto trx = load_trx(trx_path);
826+
auto trx = load_trx_header_only(trx_path);
842827
if (!trx || !trx->streamlines)
843828
throw Exception("Failed to load TRX file: " + std::string(trx_path));
844829

0 commit comments

Comments
 (0)