Skip to content

Commit e096263

Browse files
committed
region order parity with tck2connectome
1 parent cc80ee2 commit e096263

File tree

2 files changed

+109
-13
lines changed

2 files changed

+109
-13
lines changed

cpp/cmd/trx2connectome.cpp

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "types.h"
2727

2828
#include "connectome/connectome.h"
29+
#include "connectome/lut.h"
2930
#include "connectome/mat2vec.h"
3031
#include "dwi/tractography/connectome/connectome.h"
3132
#include "dwi/tractography/connectome/mapped_track.h"
@@ -87,6 +88,13 @@ void usage() {
8788
"the prefix (and trailing underscore) is stripped from node names in the output")
8889
+ Argument ("prefix").type_text()
8990

91+
+ Option ("lut", "lookup table mapping node names to numeric indices "
92+
"(supports FreeSurfer, AAL, ITK-SNAP, and MRtrix LUT formats); "
93+
"when provided, the output matrix rows and columns are ordered by node index "
94+
"rather than alphabetically, matching tck2connectome output ordering. "
95+
"Requires -group_prefix")
96+
+ Argument ("path").type_file_in()
97+
9098
+ MR::DWI::Tractography::TrackWeightsInOption
9199

92100
+ MR::DWI::Tractography::Connectome::EdgeStatisticOption
@@ -122,23 +130,45 @@ void execute(const node_t max_node_index,
122130
for (size_t i = 0; i < streamline_groups.size(); ++i) {
123131
const auto &grps = streamline_groups[i];
124132
const float w = weights.empty() ? 1.0f : weights[i];
125-
// Each streamline's group list contains one entry per endpoint assignment.
126-
// Deduplicate and take the first two distinct nodes as the edge endpoints;
127-
// using Mapped_track_nodepair ensures exactly one edge per streamline
128-
// (Mapped_track_nodelist would add spurious self-connections).
133+
// Deduplicate node memberships for this streamline.
134+
// For a single-atlas run, a streamline has at most 2 unique nodes (one per
135+
// endpoint), producing one edge — identical to tck2connectome behaviour.
136+
// For a combined-atlas run, a streamline may have 4 unique nodes (2 per atlas),
137+
// so we emit one edge per unique pair to populate all atlas-block sub-matrices.
129138
std::vector<node_t> unique_grps;
130139
for (const auto n : grps) {
131140
if (std::find(unique_grps.begin(), unique_grps.end(), n) == unique_grps.end())
132141
unique_grps.push_back(n);
133142
}
134-
const node_t n1 = unique_grps.size() >= 1 ? unique_grps[0] : node_t(0);
135-
const node_t n2 = unique_grps.size() >= 2 ? unique_grps[1] : node_t(0);
136-
Mapped_track_nodepair mapped;
137-
mapped.set_track_index(i);
138-
mapped.set_factor(1.0f);
139-
mapped.set_weight(w);
140-
mapped.set_nodes(NodePair(n1, n2));
141-
connectome(mapped);
143+
if (unique_grps.empty()) {
144+
// Unassigned streamline — record as (0, 0)
145+
Mapped_track_nodepair mapped;
146+
mapped.set_track_index(i);
147+
mapped.set_factor(1.0f);
148+
mapped.set_weight(w);
149+
mapped.set_nodes(NodePair(node_t(0), node_t(0)));
150+
connectome(mapped);
151+
} else if (unique_grps.size() == 1) {
152+
// One endpoint assigned — record as (node, 0)
153+
Mapped_track_nodepair mapped;
154+
mapped.set_track_index(i);
155+
mapped.set_factor(1.0f);
156+
mapped.set_weight(w);
157+
mapped.set_nodes(NodePair(unique_grps[0], node_t(0)));
158+
connectome(mapped);
159+
} else {
160+
// Two or more unique nodes: emit one edge per unique pair.
161+
for (size_t j = 0; j < unique_grps.size(); ++j) {
162+
for (size_t k = j + 1; k < unique_grps.size(); ++k) {
163+
Mapped_track_nodepair mapped;
164+
mapped.set_track_index(i);
165+
mapped.set_factor(1.0f);
166+
mapped.set_weight(w);
167+
mapped.set_nodes(NodePair(unique_grps[j], unique_grps[k]));
168+
connectome(mapped);
169+
}
170+
}
171+
}
142172
++progress;
143173
}
144174
}
@@ -178,12 +208,24 @@ void run() {
178208
group_prefix = std::string(opt[0][0]) + "_";
179209
}
180210

211+
// Load LUT if provided (requires -group_prefix)
212+
std::unique_ptr<MR::Connectome::LUT> lut;
213+
{
214+
auto opt = get_options("lut");
215+
if (!opt.empty()) {
216+
if (group_prefix.empty())
217+
throw Exception("-lut requires -group_prefix to identify which groups to match against the lookup table");
218+
lut = std::make_unique<MR::Connectome::LUT>(std::string(opt[0][0]));
219+
}
220+
}
221+
181222
std::vector<std::string> group_names = collect_group_names(*trx, group_prefix);
182223

183224
if (group_names.empty())
184225
throw Exception("No groups match the specified prefix '" + group_prefix + "'; check the group names with tckinfo");
185226

186-
GroupNodeMapping mapping = build_group_node_mapping(group_names, group_prefix);
227+
GroupNodeMapping mapping = lut ? build_group_node_mapping(group_names, group_prefix, *lut)
228+
: build_group_node_mapping(group_names, group_prefix);
187229
const node_t max_node_index = static_cast<node_t>(mapping.max_node_index);
188230

189231
const auto streamline_groups_u32 = invert_group_memberships(*trx, mapping.group_to_node);

cpp/core/dwi/tractography/trx_utils.h

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,6 +173,60 @@ inline GroupNodeMapping build_group_node_mapping(const std::vector<std::string>
173173
return mapping;
174174
}
175175

176+
// Overload that uses a LUT to recover numeric node-ID ordering.
177+
// The LUT maps node_id → LUT_node; we invert to name → node_id and match
178+
// stripped group names against the LUT names. This produces a matrix with
179+
// rows/columns ordered by node index, matching tck2connectome output.
180+
template <typename LUT_type>
181+
inline GroupNodeMapping
182+
build_group_node_mapping(const std::vector<std::string> &group_names, const std::string &prefix, const LUT_type &lut) {
183+
GroupNodeMapping mapping;
184+
if (group_names.empty())
185+
return mapping;
186+
187+
// Build reverse map: sanitized LUT name → node_id
188+
// (sanitize the same way trxlabel does — spaces and special chars become '_')
189+
auto sanitize = [](std::string name) {
190+
for (char &c : name) {
191+
if (c == ' ' || c == '/' || c == '\\' || c == ':' || c == '*' || c == '?' || c == '"' || c == '<' || c == '>')
192+
c = '_';
193+
}
194+
return name;
195+
};
196+
197+
std::map<std::string, uint32_t> name_to_node;
198+
uint32_t max_id = 0;
199+
for (const auto &[node_id, entry] : lut) {
200+
const std::string sanitized = sanitize(entry.get_name());
201+
name_to_node[sanitized] = static_cast<uint32_t>(node_id);
202+
if (static_cast<uint32_t>(node_id) > max_id)
203+
max_id = static_cast<uint32_t>(node_id);
204+
}
205+
206+
mapping.integer_names = true;
207+
mapping.max_node_index = max_id;
208+
mapping.ordered_display_names.assign(static_cast<size_t>(max_id) + 1, "");
209+
210+
// Populate display names from LUT (all entries, even those with no groups)
211+
for (const auto &[node_id, entry] : lut) {
212+
const std::string sanitized = sanitize(entry.get_name());
213+
mapping.ordered_display_names[static_cast<size_t>(node_id)] = sanitized;
214+
}
215+
216+
// Map each group to its LUT node ID
217+
for (const auto &name : group_names) {
218+
const std::string stripped = prefix.empty() ? name : name.substr(prefix.size());
219+
auto it = name_to_node.find(stripped);
220+
if (it != name_to_node.end()) {
221+
mapping.group_to_node[name] = it->second;
222+
} else {
223+
WARN("Group '" + name + "' (stripped: '" + stripped + "') not found in LUT; skipping");
224+
}
225+
}
226+
227+
return mapping;
228+
}
229+
176230
inline std::vector<std::string> collect_group_names(const trx::TrxFile<float> &trx, const std::string &prefix) {
177231
std::vector<std::string> names;
178232
for (const auto &[name, _] : trx.groups) {

0 commit comments

Comments
 (0)