Skip to content

Commit c2fdfaf

Browse files
author
Dominic Price
committed
Minor bug fixes, new test case and more debugging output
1 parent 97f55ec commit c2fdfaf

File tree

3 files changed

+132
-25
lines changed

3 files changed

+132
-25
lines changed

core/algorithms/young_reduce.cc

Lines changed: 99 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,46 @@ std::string ex_to_string(cadabra::Ex::iterator it, const cadabra::Kernel& kernel
2626
return ex_to_string(cadabra::Ex(it), kernel);
2727
}
2828

29-
#define DEBUG_OUTPUT 0
29+
std::string adjform_to_string(const cadabra::yr::adjform_t& adjform, const std::vector<cadabra::nset_t::iterator>& index_map)
30+
{
31+
std::map<cadabra::yr::index_t, int> dummy_map;
32+
int dummy_counter = 0;
33+
std::string res;
34+
for (const auto& elem : adjform) {
35+
if (elem < 0) {
36+
res += *index_map[-(elem + 1)];
37+
}
38+
else if (dummy_map.find(elem) != dummy_map.end()) {
39+
res += "d_" + std::to_string(dummy_map[elem]);
40+
}
41+
else {
42+
dummy_map[adjform[adjform[elem]]] = dummy_counter++;
43+
res += "d_" + std::to_string(dummy_map[adjform[adjform[elem]]]);
44+
}
45+
}
46+
return res;
47+
}
48+
49+
std::string pf_to_string (const cadabra::yr::ProjectedForm& projform, const std::vector<cadabra::nset_t::iterator>& index_map)
50+
{
51+
std::stringstream os;
52+
int i = 0;
53+
int max = 20;
54+
auto it = projform.data.begin();
55+
while (i < max && i < projform.data.size()) {
56+
for (const auto& elem : it->first)
57+
os << elem << ' ';
58+
os << '\t' << it->second << '\n';
59+
++i;
60+
++it;
61+
}
62+
if (i == max) {
63+
os << "(skipped " << (projform.data.size() - max) << " terms)\n";
64+
}
65+
return os.str();
66+
}
67+
68+
#define DEBUG_OUTPUT 1
3069
#define cdebug if (!DEBUG_OUTPUT) {} else std::cerr
3170

3271

@@ -46,6 +85,7 @@ namespace cadabra {
4685
{
4786
mpq_class ProjectedForm::compare(const ProjectedForm& other) const
4887
{
88+
cdebug << "entered compare\n";
4989
// Early failure checks
5090
if (data.empty() || data.size() != other.data.size())
5191
return 0;
@@ -54,20 +94,40 @@ namespace cadabra {
5494
// other terms checking that the factor is the same. If not, return 0
5595
auto a_it = data.begin(), b_it = other.data.begin(), a_end = data.end();
5696
mpq_class factor = a_it->second / b_it->second;
97+
cdebug << "factor is " << factor << '\n';
5798
while (a_it != a_end) {
58-
if (a_it->second / b_it->second != factor)
99+
cdebug << "comparing " << a_it->second << " * ";
100+
for (const auto& elem : a_it->first)
101+
cdebug << elem << ' ';
102+
cdebug << " to " << b_it->second << " * ";
103+
for (const auto& elem : b_it->first)
104+
cdebug << elem << ' ';
105+
cdebug << '\n';
106+
if (a_it->second / b_it->second != factor) {
107+
cdebug << "factor was " << (a_it->second / b_it->second) << "!\n";
59108
return 0;
109+
}
60110
++a_it, ++b_it;
61111
}
112+
cdebug << "matched all terms!\n";
62113
return factor;
63114
}
64115

65-
void ProjectedForm::combine(const ProjectedForm& other)
116+
void ProjectedForm::combine(const ProjectedForm& other, mpq_class factor)
66117
{
67-
for (const auto& kv : other.data) {
68-
data[kv.first] += kv.second;
69-
if (data[kv.first] == 0)
70-
data.erase(kv.first);
118+
if (factor == 1) {
119+
for (const auto& kv : other.data) {
120+
data[kv.first] += kv.second;
121+
if (data[kv.first] == 0)
122+
data.erase(kv.first);
123+
}
124+
}
125+
else {
126+
for (const auto& kv : other.data) {
127+
data[kv.first] += kv.second * factor;
128+
if (data[kv.first] == 0)
129+
data.erase(kv.first);
130+
}
71131
}
72132
}
73133

@@ -147,14 +207,20 @@ namespace cadabra {
147207
Ex::iterator l1 = lhs.begin(), l2 = lhs.end();
148208
Ex::iterator r1 = rhs.begin(), r2 = rhs.end();
149209

210+
std::vector<Ex::iterator> l_indices, r_indices;
211+
150212
// Loop over all tree nodes using a depth first iterator. If the
151213
// entry is an index ensure that it has the same parent_rel, if it
152214
// is any other type of node check that the names match.
153215
while (l1 != l2 && r1 != r2) {
154216
if (l1->is_index()) {
217+
l1.skip_children();
218+
r1.skip_children();
155219
if (l1->fl.parent_rel != r1->fl.parent_rel) {
156220
return false;
157221
}
222+
l_indices.push_back(l1);
223+
r_indices.push_back(r1);
158224
}
159225
else {
160226
if (l1->name != r1->name || l1->multiplier != r1->multiplier) {
@@ -185,20 +251,20 @@ namespace cadabra {
185251
}
186252
}
187253

188-
std::vector<Ex::iterator> split_ex(Ex::iterator it, const std::string& delim, Ex pat)
254+
std::vector<Ex::iterator> split_ex(Ex::iterator it, const std::string& delim, Ex::iterator pat)
189255
{
190256
if (*it->name == delim) {
191257
std::vector<Ex::iterator> res;
192258
Ex::sibling_iterator beg = it.begin(), end = it.end();
193259
while (beg != end) {
194-
if (check_structure(beg, pat.begin()))
260+
if (check_structure(beg, pat))
195261
res.push_back(beg);
196262
++beg;
197263
}
198264
return res;
199265
}
200266
else {
201-
if (check_structure(it, pat.begin()))
267+
if (check_structure(it, pat))
202268
return std::vector<Ex::iterator>(1, it);
203269
else
204270
return std::vector<Ex::iterator>();
@@ -251,13 +317,7 @@ young_reduce::~young_reduce()
251317

252318
bool young_reduce::can_apply(iterator it)
253319
{
254-
if (pat == Ex::iterator()) {
255-
// No pattern set, can only apply to a sum node
256-
return *it->name == "\\sum";
257-
}
258-
else {
259-
return true;
260-
}
320+
return true;
261321
}
262322

263323
young_reduce::result_t young_reduce::apply(iterator& it)
@@ -285,9 +345,18 @@ young_reduce::result_t young_reduce::apply_known(iterator& it)
285345
ProjectedForm it_sym;
286346
auto nodes = split_ex(it, "\\sum", pat);
287347
cdebug << "Found " << nodes.size() << " terms which match pat:\n";
348+
if (nodes.size() == 0)
349+
return result_t::l_no_action;
350+
288351
for (auto& node : nodes) {
289352
cdebug << "\t" << ex_to_string(node, kernel) << "\n";
290-
it_sym.combine(symmetrize(node));
353+
if (subtree_equal(&kernel.properties, pat, node, -2, true, -1)) {
354+
cdebug << "Matched pat; combining\n";
355+
it_sym.combine(pat_sym, *node->multiplier / *pat->multiplier);
356+
}
357+
else {
358+
it_sym.combine(symmetrize(node));
359+
}
291360
}
292361

293362
// Check if projection yielded zero
@@ -300,7 +369,7 @@ young_reduce::result_t young_reduce::apply_known(iterator& it)
300369
// Check if projection is a multiple of 'pat'
301370
auto factor = it_sym.compare(pat_sym);
302371
if (factor != 0) {
303-
cdebug << "Projection was a multiple of pat; reducing...\n";
372+
cdebug << "Projection was a multiple (" << factor << ") of pat; reducing...\n";
304373
it = tr.replace(nodes.back(), pat);
305374
nodes.pop_back();
306375
for (auto node : nodes)
@@ -328,6 +397,12 @@ young_reduce::result_t young_reduce::apply_unknown(iterator& it)
328397
while (!terms.empty()) {
329398
cdebug << "Examining " << ex_to_string(terms.back(), kernel) << "...";
330399
bool can_reduce = set_pattern(terms.back());
400+
if (can_reduce && pat_sym.data.empty()) {
401+
cdebug << "pat_sym is zero; zeroing node\n";
402+
node_zero(pat);
403+
res = result_t::l_applied;
404+
continue;
405+
}
331406
std::vector<Ex::iterator> cur_terms;
332407
for (index_t i = terms.size() - 2; i != -1; --i) {
333408
if (check_structure(terms.back(), terms[i])) {
@@ -371,6 +446,8 @@ young_reduce::result_t young_reduce::apply_unknown(iterator& it)
371446
}
372447
}
373448

449+
pat = Ex::iterator();
450+
pat_sym.clear();
374451
return res;
375452
}
376453

@@ -405,6 +482,7 @@ bool young_reduce::set_pattern(Ex::iterator new_pat)
405482

406483
ProjectedForm young_reduce::symmetrize(Ex::iterator it)
407484
{
485+
cdebug << "symmetrizing " << ex_to_string(it, kernel) << "produces:\n";
408486
ProjectedForm sym;
409487
sym.insert(to_adjform(it), 1);
410488

@@ -461,6 +539,8 @@ ProjectedForm young_reduce::symmetrize(Ex::iterator it)
461539

462540
sym.multiply(*it->multiplier);
463541

542+
cdebug << sym << '\n';
543+
464544
return sym;
465545
}
466546

core/algorithms/young_reduce.hh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ namespace cadabra {
2222
mpq_class compare(const ProjectedForm& other) const;
2323

2424
// Add all contributions from 'other' into 'this'
25-
void combine(const ProjectedForm& other);
25+
void combine(const ProjectedForm& other, mpq_class factor = 1);
2626

2727
// Multiply all terms by a constant factor
2828
void multiply(mpq_class k);
@@ -52,7 +52,7 @@ namespace cadabra {
5252
// entry. If 'pat' is specified, any terms not matching 'pat' are
5353
// ignored
5454
std::vector<Ex::iterator> split_ex(Ex::iterator it, const std::string& delim);
55-
std::vector<Ex::iterator> split_ex(Ex::iterator it, const std::string& delim, Ex pat);
55+
std::vector<Ex::iterator> split_ex(Ex::iterator it, const std::string& delim, Ex::iterator pat);
5656

5757
// Rewrite adjform type indices as dummy indices
5858
adjform_t collapse_dummy_indices(adjform_t adjform);

tests/youngreduce.cdb

Lines changed: 31 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,10 +81,13 @@ def test09():
8181
__cdbkernel__ = create_scope()
8282
R_{a b c d}::RiemannTensor.
8383
ex := R_{a b c d} + R_{a d b c};
84-
# This one still segfaults:
85-
# young_reduce(ex);
84+
85+
young_reduce(ex)
86+
assert ex == $R_{a b c d} + R_{a d b c}$
87+
8688
young_reduce(ex, $R_{a c d b}$)
8789
assert ex == $-R_{a c d b}$
90+
8891
print("Test 09 passed")
8992

9093
test09()
@@ -130,9 +133,33 @@ def test12():
130133
test12()
131134

132135
def test13():
136+
__cdbkernel__ = create_scope()
133137
R_{a b c d}::RiemannTensor.
134138
A^{a b c}::AntiSymmetric.
135-
ex:= R_{a b c d} A^{a b c};
136-
young_reduce(ex, $R_{a b c d} A^{a b c}$);
139+
ex:= R_{a b c d} A^{a b c}:
140+
young_reduce(ex, $R_{a b c d} A^{a b c}$)
137141
assert ex==0
138142
print("Test 13 passed")
143+
144+
test13()
145+
146+
def test14():
147+
__cdbkernel__ = create_scope()
148+
A_{a b c d e}::AntiSymmetric.
149+
ex = young_reduce($9A_{b a c e d} - 3A_{e d c b a}$, $A_{b a c e d}$)
150+
display(ex)
151+
assert ex == $6A_{b a c e d}$
152+
print("Test 14 passed")
153+
154+
test14()
155+
156+
def test15():
157+
__cdbkernel__ = create_scope()
158+
R_{a b c d}::RiemannTensor.
159+
A_{a b c d}::AntiSymmetric.
160+
ex := R_{b a c d}:
161+
young_reduce(ex, $A_{b a c d})
162+
assert ex == $R_{b a c d}
163+
print("Test 15 passed")
164+
165+
test15()

0 commit comments

Comments
 (0)