Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 66 additions & 9 deletions src/computation_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,17 +85,21 @@ struct MutualInfoKey {
}
};

// value: Info3PointBlock
// 3-point information is permutation invariant with respect to {X, Y, Z}, but
// shifted 3-point information is permutation invariant with respect to only {X,
// Y} because of the complexity term.
// Forward declaration for lookup helper
struct Info3PointLookup;

// Persistent key
struct Info3PointKey {
set<int> XY;
int Z;
set<int> ui;
vector<int> ui;

Info3PointKey(int X, int Y, int Z, const vector<int>& ui)
: XY({X, Y}), Z(Z), ui(begin(ui), end(ui)) {}
Info3PointKey(int X, int Y, int Z, const vector<int>& ui_input)
: XY({X, Y}), Z(Z){
ui = ui_input;
std::sort(ui.begin(), ui.end());
ui.erase(std::unique(ui.begin(), ui.end()), ui.end()); // Ensure uniqueness
}

bool operator<(const Info3PointKey& other) const {
if (XY == other.XY) {
Expand All @@ -108,8 +112,60 @@ struct Info3PointKey {
}
};

// Lightweight lookup helper - avoids constructing sets
struct Info3PointLookup {
int X, Y; // Store as ints, not set
int Z;
const vector<int>* ui_ptr; // No heap allocation, just pointer to ui vector

Info3PointLookup(int X, int Y, int Z, const vector<int>& ui)
: X(X), Y(Y), Z(Z), ui_ptr(&ui) {}
};


// Transparent comparator for std::map
struct Info3PointCompare {
using is_transparent = void; // Enable transparent lookup (C++14)

// Key vs Key
bool operator()(const Info3PointKey& lhs, const Info3PointKey& rhs) const {
if (lhs.XY != rhs.XY) return lhs.XY < rhs.XY;
if (lhs.Z != rhs.Z) return lhs.Z < rhs.Z;
return lhs.ui < rhs.ui; // Both vectors already canonicalized
}

// Key vs Lookup
bool operator()(const Info3PointKey& lhs, const Info3PointLookup& rhs) const {
set<int> rhs_xy{rhs.X, rhs.Y};
if (lhs.XY != rhs_xy) return lhs.XY < rhs_xy;
if (lhs.Z != rhs.Z) return lhs.Z < rhs.Z;

// Compare ui: lhs.ui is canonicalized, need to canonicalize rhs
vector<int> rhs_ui = *rhs.ui_ptr; // Copy
std::sort(rhs_ui.begin(), rhs_ui.end());
rhs_ui.erase(std::unique(rhs_ui.begin(), rhs_ui.end()), rhs_ui.end());

return lhs.ui < rhs_ui;
}

// Lookup vs Key
bool operator()(const Info3PointLookup& lhs, const Info3PointKey& rhs) const {
// Compare XY
set<int> lhs_xy{lhs.X, lhs.Y};
if (lhs_xy != rhs.XY) return lhs_xy < rhs.XY;
if (lhs.Z != rhs.Z) return lhs.Z < rhs.Z;

// Compare ui: rhs.ui is canonicalized, need to canonicalize lhs
vector<int> lhs_ui = *lhs.ui_ptr; // Copy
std::sort(lhs_ui.begin(), lhs_ui.end());
lhs_ui.erase(std::unique(lhs_ui.begin(), lhs_ui.end()), lhs_ui.end());

return lhs_ui < rhs.ui;
}
};

using MutualInfoMap = std::map<MutualInfoKey, InfoBlock>;
using Info3PointMap = std::map<Info3PointKey, Info3PointBlock>;
using Info3PointMap = std::map<Info3PointKey, Info3PointBlock, Info3PointCompare>; // Use transparent comparator
using EntropyMap = std::map<Info3PointKey, double>;

class InfoScoreCache {
Expand All @@ -131,7 +187,8 @@ class InfoScoreCache {

pair<Info3PointBlock, bool> getInfo3Point(
int X, int Y, int Z, const vector<int>& ui) {
auto it = i3_map_.find(Info3PointKey(X, Y, Z, ui));
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

avoid constructing these tmp objects each time we look up in the cache and use a transparent lookup

// Use lightweight lookup - no set construction!
auto it = i3_map_.find(Info3PointLookup(X, Y, Z, ui));
bool found = it != i3_map_.end();
return std::make_pair(found ? it->second : Info3PointBlock(), found);
}
Expand Down
Loading