@@ -59,48 +59,33 @@ inline bool is_trx_field_name(std::string_view input_path, std::string_view outp
5959 output_arg.find (' \\ ' ) == std::string_view::npos;
6060}
6161
62- // Load a TRX file, auto-detecting zip vs directory format, and expose positions
63- // as float32.
64- //
65- // For non-float32 positions, conversion is performed in bounded chunks directly
66- // into the loaded float32 matrix to avoid allocating a second full-size copy .
62+ // Load a TRX file with positions guaranteed as float32.
63+ // For float16/float64 files, delegates to trx::load_float32_positions() which
64+ // performs chunked in-place conversion into owned storage — no segfault, no
65+ // double-size allocation.
66+ // Use this whenever streamline coordinates will be read or written .
6767inline std::unique_ptr<trx::TrxFile<float >> load_trx (std::string_view path) {
6868 const std::string filename (path);
6969 const auto dtype = trx::detect_positions_scalar_type (filename, trx::TrxScalarType::Float32);
70- if (dtype == trx::TrxScalarType::Float32)
71- return trx::load<float >(filename);
72- if (dtype == trx::TrxScalarType::Float16) {
70+ if (dtype == trx::TrxScalarType::Float16)
7371 WARN (" TRX file '" + filename + " ' has float16 positions; converting to float32 on load" );
74- } else {
72+ else if (dtype == trx::TrxScalarType::Float64)
7573 WARN (" TRX file '" + filename + " ' has float64 positions; converting to float32 (precision loss)" );
76- }
77- auto trx_f32 = trx::load<float >(filename);
78- if (!trx_f32 || !trx_f32->streamlines )
79- return trx_f32;
80- const Eigen::Index n_vertices = static_cast <Eigen::Index>(trx_f32->num_vertices ());
81- if (n_vertices <= 0 )
82- return trx_f32;
83- const Eigen::Index chunk_rows = static_cast <Eigen::Index>(1 ) << 20 ;
84- trx::with_trx_reader (filename, [&](auto &reader, trx::TrxScalarType) -> int {
85- const auto &src = reader->streamlines ->_data ;
86- if (src.rows () != n_vertices || src.cols () != 3 )
87- throw Exception (" unexpected TRX positions shape while converting to float32" );
88- for (Eigen::Index row0 = 0 ; row0 < n_vertices; row0 += chunk_rows) {
89- const Eigen::Index row1 = std::min<Eigen::Index>(n_vertices, row0 + chunk_rows);
90- for (Eigen::Index r = row0; r < row1; ++r) {
91- trx_f32->streamlines ->_data (r, 0 ) = static_cast <float >(src (r, 0 ));
92- trx_f32->streamlines ->_data (r, 1 ) = static_cast <float >(src (r, 1 ));
93- trx_f32->streamlines ->_data (r, 2 ) = static_cast <float >(src (r, 2 ));
94- }
95- }
96- return 0 ;
97- });
98- return trx_f32;
74+ return trx::load_float32_positions (filename);
75+ }
76+
77+ // Load a TRX file for metadata inspection only (groups, dps, dpv field
78+ // names/shapes, streamline/vertex counts). Positions are mmapped with their
79+ // native dtype but never read, so this is safe for float16/float64 files and
80+ // avoids any conversion overhead.
81+ // Do NOT read streamline coordinates from the returned TrxFile; use load_trx().
82+ inline std::unique_ptr<trx::TrxFile<float >> load_trx_header_only (std::string_view path) {
83+ return trx::load<float >(std::string (path));
9984}
10085
10186// Count streamlines and total vertices in a TRX file
10287inline std::pair<size_t , size_t > count_trx (std::string_view path) {
103- auto trx = load_trx (path);
88+ auto trx = load_trx_header_only (path);
10489 if (!trx)
10590 throw Exception (" Failed to load TRX file: " + std::string (path));
10691 return {trx->num_streamlines (), trx->num_vertices ()};
@@ -338,7 +323,7 @@ append_group(std::string_view trx_path, const std::string &name, const std::vect
338323// because TrxFile<float> already presents all values as float.
339324// Group indices are copied as uint32_t.
340325inline void copy_trx_sidecar_data (std::string_view src_path, std::string_view dst_path, bool include_dpv = false ) {
341- auto src = load_trx (src_path);
326+ auto src = load_trx_header_only (src_path);
342327 if (!src)
343328 return ;
344329 const std::string dst (dst_path);
@@ -758,7 +743,7 @@ inline std::vector<float> resolve_dps_weights(std::string_view tractogram_path,
758743 }
759744 if (!is_trx (tractogram_path))
760745 throw Exception (" cannot resolve \" " + field_name_or_path + " \" : not an existing file and input is not a TRX file" );
761- auto trx = load_trx (tractogram_path);
746+ auto trx = load_trx_header_only (tractogram_path);
762747 auto it = trx->data_per_streamline .find (field_name_or_path);
763748 if (it == trx->data_per_streamline .end () || !it->second )
764749 throw Exception (" TRX file has no dps field named \" " + field_name_or_path + " \" " );
@@ -818,7 +803,7 @@ inline std::vector<float> resolve_dpv_scalars(std::string_view tractogram_path,
818803 }
819804 if (!is_trx (tractogram_path))
820805 throw Exception (" cannot resolve \" " + field_name_or_path + " \" : not an existing file and input is not a TRX file" );
821- auto trx = load_trx (tractogram_path);
806+ auto trx = load_trx_header_only (tractogram_path);
822807 auto it = trx->data_per_vertex .find (field_name_or_path);
823808 if (it == trx->data_per_vertex .end () || !it->second )
824809 throw Exception (" TRX file has no dpv field named \" " + field_name_or_path + " \" " );
@@ -838,7 +823,7 @@ inline std::vector<float> resolve_dpv_scalars(std::string_view tractogram_path,
838823class TRXScalarReader {
839824public:
840825 TRXScalarReader (std::string_view trx_path, const std::string &field_name) : current_(0 ) {
841- auto trx = load_trx (trx_path);
826+ auto trx = load_trx_header_only (trx_path);
842827 if (!trx || !trx->streamlines )
843828 throw Exception (" Failed to load TRX file: " + std::string (trx_path));
844829
0 commit comments