Skip to content

Commit bd6e5a9

Browse files
committed
Merge remote-tracking branch 'dept/young_reduce'
2 parents decd4cf + b02af61 commit bd6e5a9

File tree

4 files changed

+228
-45
lines changed

4 files changed

+228
-45
lines changed

core/algorithms/young_reduce.cc

Lines changed: 162 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,97 @@ std::string ex_to_string(cadabra::Ex::iterator it, const cadabra::Kernel& kernel
2626
return ex_to_string(cadabra::Ex(it), kernel);
2727
}
2828

29-
#define DEBUG_OUTPUT 1
29+
std::string adjform_to_string(const cadabra::yr::adjform_t& adjform, const std::vector<cadabra::nset_t::iterator>* index_map = nullptr)
30+
{
31+
std::map<cadabra::yr::index_t, int> dummy_map;
32+
int dummy_counter = 0;
33+
std::string res;
34+
for (const auto& elem : adjform) {
35+
if (index_map) {
36+
if (elem < 0) {
37+
res += *(*index_map)[-(elem + 1)];
38+
}
39+
else if (dummy_map.find(elem) != dummy_map.end()) {
40+
res += "d_" + std::to_string(dummy_map[elem]);
41+
}
42+
else {
43+
dummy_map[adjform[adjform[elem]]] = dummy_counter++;
44+
res += "d_" + std::to_string(dummy_map[adjform[adjform[elem]]]);
45+
}
46+
}
47+
else {
48+
res += std::to_string(elem);
49+
}
50+
}
51+
return res;
52+
}
53+
54+
std::string pf_to_string (const cadabra::yr::ProjectedForm& projform, const std::vector<cadabra::nset_t::iterator>* index_map = nullptr)
55+
{
56+
std::stringstream os;
57+
int i = 0;
58+
int max = std::min(std::size_t(200), projform.data.size());
59+
auto it = projform.data.begin();
60+
while (i < max) {
61+
os << adjform_to_string(it->first/*, index_map*/);
62+
os << '\t' << it->second << '\n';
63+
++i;
64+
++it;
65+
}
66+
if (i == max) {
67+
os << "(skipped " << (projform.data.size() - max) << " terms)\n";
68+
}
69+
return os.str();
70+
}
71+
72+
#define DEBUG_OUTPUT 0
3073
#define cdebug if (!DEBUG_OUTPUT) {} else std::cerr
3174

3275

3376
////////////////////////////////////////////////////////////////////
3477

78+
// Get the next permutation of adjform and return the number of swaps
79+
// required for the transformation
80+
int next_perm(cadabra::yr::adjform_t& adjform)
81+
{
82+
int n = adjform.size();
83+
84+
// Find longest non-increasing suffix to get pivot
85+
int pivot = n - 2;
86+
while (pivot > -1) {
87+
if (adjform[pivot + 1] > adjform[pivot])
88+
break;
89+
--pivot;
90+
}
91+
92+
// Entire sequence is already sorted, return
93+
if (pivot == -1)
94+
return 0;
95+
96+
// Find rightmost element greater than pivot
97+
int idx = n - 1;
98+
while (idx > pivot) {
99+
if (adjform[idx] > adjform[pivot])
100+
break;
101+
--idx;
102+
}
103+
104+
// Swap with pivot
105+
std::swap(adjform[pivot], adjform[idx]);
106+
107+
// Reverse the suffix
108+
int swaps = 1;
109+
int maxswaps = (n - pivot - 1) / 2;
110+
for (int i = 0; i < maxswaps; ++i) {
111+
if (adjform[pivot + i + 1] != adjform[n - i - 1]) {
112+
std::swap(adjform[pivot + i + 1], adjform[n - i - 1]);
113+
++swaps;
114+
}
115+
}
116+
117+
return swaps;
118+
}
119+
35120
// Returns the position of 'val' between 'begin' and 'end', starting
36121
// the search at 'offset'
37122
template <typename It, typename T>
@@ -46,6 +131,7 @@ namespace cadabra {
46131
{
47132
mpq_class ProjectedForm::compare(const ProjectedForm& other) const
48133
{
134+
cdebug << "entered compare\n";
49135
// Early failure checks
50136
if (data.empty() || data.size() != other.data.size())
51137
return 0;
@@ -54,20 +140,40 @@ namespace cadabra {
54140
// other terms checking that the factor is the same. If not, return 0
55141
auto a_it = data.begin(), b_it = other.data.begin(), a_end = data.end();
56142
mpq_class factor = a_it->second / b_it->second;
143+
cdebug << "factor is " << factor << '\n';
57144
while (a_it != a_end) {
58-
if (a_it->second / b_it->second != factor)
145+
cdebug << "comparing " << a_it->second << " * ";
146+
for (const auto& elem : a_it->first)
147+
cdebug << elem << ' ';
148+
cdebug << " to " << b_it->second << " * ";
149+
for (const auto& elem : b_it->first)
150+
cdebug << elem << ' ';
151+
cdebug << '\n';
152+
if (a_it->second / b_it->second != factor) {
153+
cdebug << "factor was " << (a_it->second / b_it->second) << "!\n";
59154
return 0;
155+
}
60156
++a_it, ++b_it;
61157
}
158+
cdebug << "matched all terms!\n";
62159
return factor;
63160
}
64161

65-
void ProjectedForm::combine(const ProjectedForm& other)
162+
void ProjectedForm::combine(const ProjectedForm& other, mpq_class factor)
66163
{
67-
for (const auto& kv : other.data) {
68-
data[kv.first] += kv.second;
69-
if (data[kv.first] == 0)
70-
data.erase(kv.first);
164+
if (factor == 1) {
165+
for (const auto& kv : other.data) {
166+
data[kv.first] += kv.second;
167+
if (data[kv.first] == 0)
168+
data.erase(kv.first);
169+
}
170+
}
171+
else {
172+
for (const auto& kv : other.data) {
173+
data[kv.first] += kv.second * factor;
174+
if (data[kv.first] == 0)
175+
data.erase(kv.first);
176+
}
71177
}
72178
}
73179

@@ -87,19 +193,23 @@ namespace cadabra {
87193
data[adjform] = value;
88194
}
89195

90-
void ProjectedForm::apply_young_symmetry(const std::vector<index_t>& indices, bool antisymmetric)
196+
void ProjectedForm::apply_young_symmetry(const adjform_t& indices, bool antisymmetric)
91197
{
92-
map_t old_data = data;
198+
map_t old_data;
199+
std::swap(old_data, data);
93200

94201
// Loop over all entries, for each one looping over all permutations
95202
// of the indices to be symmetrized and creating a new term for that
96203
// permutation; then add the new term to the list of entries
97204
for (const auto& kv : old_data) {
205+
206+
cdebug << "Applying young_symmetry " << (antisymmetric ? -kv.second : kv.second) << " * " << adjform_to_string(indices) << " to term " << adjform_to_string(kv.first) << '\n';
207+
98208
auto perm = indices;
99-
bool flip = false;
100209
int parity = 1;
101-
while (std::next_permutation(perm.begin(), perm.end())) {
102-
if (antisymmetric && (flip = !flip))
210+
int swaps = 2;
211+
do {
212+
if (antisymmetric && swaps % 2 != 0)
103213
parity *= -1;
104214
auto ret = kv.first;
105215
for (size_t i = 0; i < indices.size(); ++i) {
@@ -109,8 +219,9 @@ namespace cadabra {
109219
if (ret[index] >= 0)
110220
ret[ret[index]] = index;
111221
}
222+
cdebug << "\tMade term " << adjform_to_string(ret) << " * " << (parity * kv.second) << '\n';
112223
data[ret] += parity * kv.second;
113-
}
224+
} while (swaps = next_perm(perm));
114225
}
115226
}
116227

@@ -147,14 +258,20 @@ namespace cadabra {
147258
Ex::iterator l1 = lhs.begin(), l2 = lhs.end();
148259
Ex::iterator r1 = rhs.begin(), r2 = rhs.end();
149260

261+
std::vector<Ex::iterator> l_indices, r_indices;
262+
150263
// Loop over all tree nodes using a depth first iterator. If the
151264
// entry is an index ensure that it has the same parent_rel, if it
152265
// is any other type of node check that the names match.
153266
while (l1 != l2 && r1 != r2) {
154267
if (l1->is_index()) {
268+
l1.skip_children();
269+
r1.skip_children();
155270
if (l1->fl.parent_rel != r1->fl.parent_rel) {
156271
return false;
157272
}
273+
l_indices.push_back(l1);
274+
r_indices.push_back(r1);
158275
}
159276
else {
160277
if (l1->name != r1->name || l1->multiplier != r1->multiplier) {
@@ -185,20 +302,20 @@ namespace cadabra {
185302
}
186303
}
187304

188-
std::vector<Ex::iterator> split_ex(Ex::iterator it, const std::string& delim, Ex pat)
305+
std::vector<Ex::iterator> split_ex(Ex::iterator it, const std::string& delim, Ex::iterator pat)
189306
{
190307
if (*it->name == delim) {
191308
std::vector<Ex::iterator> res;
192309
Ex::sibling_iterator beg = it.begin(), end = it.end();
193310
while (beg != end) {
194-
if (check_structure(beg, pat.begin()))
311+
if (check_structure(beg, pat))
195312
res.push_back(beg);
196313
++beg;
197314
}
198315
return res;
199316
}
200317
else {
201-
if (check_structure(it, pat.begin()))
318+
if (check_structure(it, pat))
202319
return std::vector<Ex::iterator>(1, it);
203320
else
204321
return std::vector<Ex::iterator>();
@@ -251,13 +368,7 @@ young_reduce::~young_reduce()
251368

252369
bool young_reduce::can_apply(iterator it)
253370
{
254-
if (pat == Ex::iterator()) {
255-
// No pattern set, can only apply to a sum node
256-
return *it->name == "\\sum";
257-
}
258-
else {
259-
return true;
260-
}
371+
return true;
261372
}
262373

263374
young_reduce::result_t young_reduce::apply(iterator& it)
@@ -271,8 +382,11 @@ young_reduce::result_t young_reduce::apply(iterator& it)
271382
res = apply_known(it);
272383
}
273384

274-
if (res != result_t::l_no_action)
385+
if (res != result_t::l_no_action) {
386+
cdebug << "Action taken; cleaning...";
275387
cleanup_dispatch(kernel, tr, it);
388+
cdebug << "done!\n";
389+
}
276390
return res;
277391
}
278392

@@ -282,9 +396,18 @@ young_reduce::result_t young_reduce::apply_known(iterator& it)
282396
ProjectedForm it_sym;
283397
auto nodes = split_ex(it, "\\sum", pat);
284398
cdebug << "Found " << nodes.size() << " terms which match pat:\n";
399+
if (nodes.size() == 0)
400+
return result_t::l_no_action;
401+
285402
for (auto& node : nodes) {
286403
cdebug << "\t" << ex_to_string(node, kernel) << "\n";
287-
it_sym.combine(symmetrize(node));
404+
if (subtree_equal(&kernel.properties, pat, node, -2, true, -1)) {
405+
cdebug << "Matched pat; combining\n";
406+
it_sym.combine(pat_sym, *node->multiplier / *pat->multiplier);
407+
}
408+
else {
409+
it_sym.combine(symmetrize(node));
410+
}
288411
}
289412

290413
// Check if projection yielded zero
@@ -297,12 +420,12 @@ young_reduce::result_t young_reduce::apply_known(iterator& it)
297420
// Check if projection is a multiple of 'pat'
298421
auto factor = it_sym.compare(pat_sym);
299422
if (factor != 0) {
300-
cdebug << "Projection was a multiple of pat; reducing...\n";
301-
auto newit = tr.replace(nodes.back(), pat);
423+
cdebug << "Projection was a multiple (" << factor << ") of pat; reducing...\n";
424+
it = tr.replace(nodes.back(), pat);
302425
nodes.pop_back();
303426
for (auto node : nodes)
304427
node_zero(node);
305-
multiply(newit->multiplier, factor);
428+
multiply(it->multiplier, factor);
306429
cdebug << "Produced " << ex_to_string(it, kernel) << '\n';
307430
return result_t::l_applied;
308431
}
@@ -325,6 +448,12 @@ young_reduce::result_t young_reduce::apply_unknown(iterator& it)
325448
while (!terms.empty()) {
326449
cdebug << "Examining " << ex_to_string(terms.back(), kernel) << "...";
327450
bool can_reduce = set_pattern(terms.back());
451+
if (can_reduce && pat_sym.data.empty()) {
452+
cdebug << "pat_sym is zero; zeroing node\n";
453+
node_zero(pat);
454+
res = result_t::l_applied;
455+
continue;
456+
}
328457
std::vector<Ex::iterator> cur_terms;
329458
for (index_t i = terms.size() - 2; i != -1; --i) {
330459
if (check_structure(terms.back(), terms[i])) {
@@ -368,6 +497,8 @@ young_reduce::result_t young_reduce::apply_unknown(iterator& it)
368497
}
369498
}
370499

500+
pat = Ex::iterator();
501+
pat_sym.clear();
371502
return res;
372503
}
373504

@@ -402,6 +533,7 @@ bool young_reduce::set_pattern(Ex::iterator new_pat)
402533

403534
ProjectedForm young_reduce::symmetrize(Ex::iterator it)
404535
{
536+
cdebug << "symmetrizing " << ex_to_string(it, kernel) << "produces:\n";
405537
ProjectedForm sym;
406538
sym.insert(to_adjform(it), 1);
407539

@@ -458,6 +590,8 @@ ProjectedForm young_reduce::symmetrize(Ex::iterator it)
458590

459591
sym.multiply(*it->multiplier);
460592

593+
cdebug << pf_to_string(sym, &index_map) << '\n';
594+
461595
return sym;
462596
}
463597

core/algorithms/young_reduce.hh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace cadabra {
2222
mpq_class compare(const ProjectedForm& other) const;
2323

2424
// Add all contributions from 'other' into 'this'
25-
void combine(const ProjectedForm& other);
25+
void combine(const ProjectedForm& other, mpq_class factor = 1);
2626

2727
// Multiply all terms by a constant factor
2828
void multiply(mpq_class k);
@@ -52,7 +52,7 @@ namespace cadabra {
5252
// entry. If 'pat' is specified, any terms not matching 'pat' are
5353
// ignored
5454
std::vector<Ex::iterator> split_ex(Ex::iterator it, const std::string& delim);
55-
std::vector<Ex::iterator> split_ex(Ex::iterator it, const std::string& delim, Ex pat);
55+
std::vector<Ex::iterator> split_ex(Ex::iterator it, const std::string& delim, Ex::iterator pat);
5656

5757
// Rewrite adjform type indices as dummy indices
5858
adjform_t collapse_dummy_indices(adjform_t adjform);

frontend/gtkmm/NotebookWindow.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1928,6 +1928,11 @@ void NotebookWindow::compare_git_choose()
19281928
error_dialog.set_title("Git error");
19291929
error_dialog.run();
19301930
}
1931+
#else
1932+
Gtk::MessageDialog not_supported_dialog("Due to a bug in the Windows version of Gtkmm this feature isn't currently supported. Sorry for the inconvenience!");
1933+
not_supported_dialog.set_transient_for(*this);
1934+
not_supported_dialog.set_title("Feature not supported");
1935+
not_supported_dialog.run();
19311936
#endif
19321937
}
19331938

0 commit comments

Comments
 (0)