Skip to content

Commit 0e64c8d

Browse files
committed
Allow combine() to accept a trace
If we add an optional argument, we can allow the user's operators with a Trace property to be used whenever traces are generated by combine(). This is not strictly necessary as everyone knows what (A)_{a a} means but it can add some flexibility in notation and also serve as a warm-up for transpose. Also, we should remember that the most sensible use of an indexbracket is to enclose objects that have the ImplicitIndex property. Signed-off-by: Connor Behan <[email protected]>
1 parent d1c24de commit 0e64c8d

File tree

4 files changed

+58
-29
lines changed

4 files changed

+58
-29
lines changed

core/algorithms/combine.cc

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55

66
using namespace cadabra;
77

8-
combine::combine(const Kernel& k, Ex& e)
9-
: Algorithm(k, e)
8+
combine::combine(const Kernel& k, Ex& e, Ex& t)
9+
: Algorithm(k, e), trace_op(t)
1010
{
1111
}
1212

@@ -101,6 +101,7 @@ Algorithm::result_t combine::apply(iterator& it)
101101
}
102102
}
103103

104+
std::string trace_start="";
104105
std::vector<Ex::iterator>::iterator dums1=dummies.begin(), dums2;
105106
dums2=dums1;
106107
++dums2;
@@ -151,19 +152,19 @@ Algorithm::result_t combine::apply(iterator& it)
151152
bool isbrack=*(sib->name)=="\\indexbracket";
152153
if(isbrack && isbrack2) {
153154
auto es=compare.equal_subtree(tr.begin(parn2), tr.begin(sib));
154-
sign*=compare.can_swap(tr.begin(parn2), tr.begin(sib), es, true);
155+
sign*=compare.can_swap_components(tr.begin(parn2), tr.begin(sib), es);
155156
}
156157
else if(isbrack && !isbrack2) {
157158
auto es=compare.equal_subtree(parn2, tr.begin(sib));
158-
sign*=compare.can_swap(parn2, tr.begin(sib), es, true);
159+
sign*=compare.can_swap_components(parn2, tr.begin(sib), es);
159160
}
160161
else if(!isbrack && isbrack2) {
161162
auto es=compare.equal_subtree(tr.begin(parn2), sib);
162-
sign*=compare.can_swap(tr.begin(parn2), sib, es, true);
163+
sign*=compare.can_swap_components(tr.begin(parn2), sib, es);
163164
}
164165
else {
165166
auto es=compare.equal_subtree(parn2, sib);
166-
sign*=compare.can_swap(parn2, sib, es, true);
167+
sign*=compare.can_swap_components(parn2, sib, es);
167168
}
168169
}
169170
if(sib==parn1 || sib==parn2) ++hits;
@@ -235,7 +236,22 @@ Algorithm::result_t combine::apply(iterator& it)
235236
if(consecutive) {
236237
++dums1;
237238
++dums2;
239+
if(dums2!=dummies.end() && trace_op.size()>0) {
240+
if(*(*dums2)->name==trace_start) {
241+
iterator parn=tr.parent(*dums2);
242+
iterator trace=tr.insert(parn, str_node(trace_op.begin()->name));
243+
sibling_iterator nxt=tr.begin(parn);
244+
++nxt;
245+
++dums1;
246+
++dums2;
247+
tr.reparent(trace, tr.begin(parn), nxt);
248+
multiply(trace->multiplier, *parn->multiplier);
249+
tr.erase(parn);
250+
trace_start="";
251+
}
252+
}
238253
}
254+
else trace_start=*(*dums1)->name;
239255
++dums1;
240256
++dums2;
241257
}

core/algorithms/combine.hh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,15 @@ namespace cadabra {
77

88
class combine : public Algorithm {
99
public:
10-
combine(const Kernel&, Ex&);
10+
combine(const Kernel&, Ex&, Ex&);
1111

1212
virtual bool can_apply(iterator) override;
1313
virtual result_t apply(iterator&) override;
1414

1515
private:
1616
typedef std::map<nset_t::iterator, iterator> indexlocmap_t;
1717

18+
Ex trace_op;
1819
indexlocmap_t iloc;
1920
};
2021

core/pythoncdb/py_algorithms.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ namespace cadabra {
8181
def_algo<collect_components>(m, "collect_components", true, false, 0);
8282
def_algo<collect_factors>(m, "collect_factors", true, false, 0);
8383
def_algo<collect_terms>(m, "collect_terms", true, false, 0);
84-
def_algo<combine>(m, "combine", true, false, 0);
8584
def_algo<decompose_product>(m, "decompose_product", true, false, 0);
8685
def_algo<distribute>(m, "distribute", true, false, 0);
8786
def_algo<eliminate_kronecker>(m, "eliminate_kronecker", true, false, 0);
@@ -100,9 +99,10 @@ namespace cadabra {
10099
def_algo<sort_sum>(m, "sort_sum", true, false, 0);
101100
def_algo<tabdimension>(m, "tab_dimension", true, false, 0);
102101
def_algo<young_project_product>(m, "young_project_product", true, false, 0);
103-
def_algo<drop_weight, Ex>(m, "drop_weight", false, false, 0, py::arg("condition") = Ex{});
102+
def_algo<combine, Ex>(m, "combine", true, false, 0, py::arg("trace_op") = Ex{});
104103
def_algo<complete, Ex>(m, "complete", false, false, 0, py::arg("add"));
105104
def_algo<decompose, Ex>(m, "decompose", false, false, 0, py::arg("basis"));
105+
def_algo<drop_weight, Ex>(m, "drop_weight", false, false, 0, py::arg("condition") = Ex{});
106106
def_algo<eliminate_metric, Ex>(m, "eliminate_metric", true, false, 0, py::arg("preferred") = Ex{});
107107
def_algo<eliminate_vielbein, Ex>(m, "eliminate_vielbein", true, false, 0, py::arg("preferred") = Ex{});
108108
def_algo<keep_weight, Ex>(m, "keep_weight", false, false, 0, py::arg("condition"));

tests/implicit.cdb

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -83,16 +83,28 @@ test06()
8383
def test07():
8484
__cdbkernel__=create_scope()
8585
{a,b}::Indices(vector);
86-
{A,B}::AntiCommuting;
86+
Tr{#}::Trace(indices=vector);
8787
ex:=(A)_{a b} (B)_{b a};
88-
combine(_)
89-
tst:= (-B A)_{b b} - @(ex);
88+
combine(_, trace_op=$Tr$)
89+
tst:= Tr( B A ) - @(ex);
9090
assert(tst==0)
9191
print("Test 07 passed")
9292

9393
test07()
9494

9595
def test08():
96+
__cdbkernel__=create_scope()
97+
{a,b}::Indices(vector);
98+
{A,B}::AntiCommuting;
99+
ex:=(A)_{a b} (B)_{b a};
100+
combine(_)
101+
tst:= (-B A)_{b b} - @(ex);
102+
assert(tst==0)
103+
print("Test 08 passed")
104+
105+
test08()
106+
107+
def test09():
96108
__cdbkernel__=create_scope()
97109
{m,n}::Indices(vector);
98110
{a,b,c}::Indices(spinor, position=fixed);
@@ -102,33 +114,33 @@ def test08():
102114
expand(_)
103115
tst:= (\Gamma^{m})^{a}_{b} (\Gamma^{n})^{b}_{c} - @(ex);
104116
assert(tst==0)
105-
print("Test 08 passed")
117+
print("Test 09 passed")
106118

107-
test08()
119+
test09()
108120

109-
def test09():
121+
def test10():
110122
__cdbkernel__=create_scope()
111123
{a,b}::Indices(vector);
112124
ex:=(u)_{a} (M)^{a b} (v)_{b};
113125
combine(_)
114126
tst:= \indexbracket(u M v) - @(ex);
115127
assert(tst==0)
116-
print("Test 09 passed")
128+
print("Test 10 passed")
117129

118-
test09()
130+
test10()
119131

120-
def test10():
132+
def test11():
121133
__cdbkernel__=create_scope()
122134
{a,b,c}::Indices(vector);
123135
ex:=(B)_{b c} (A)_{a b};
124136
combine(_);
125137
tst:= (A B)_{a c} - @(ex);
126138
assert(tst==0)
127-
print("Test 10 passed")
139+
print("Test 11 passed")
128140

129-
test10()
141+
test11()
130142

131-
def test11():
143+
def test12():
132144
__cdbkernel__=create_scope()
133145
{a,b}::Indices(vector);
134146
Tr{#}::Trace(indices=vector);
@@ -140,11 +152,11 @@ def test11():
140152
sort_product(_);
141153
tst:= -C D Tr( A B ) - @(ex);
142154
assert(tst==0)
143-
print("Test 11 passed")
155+
print("Test 12 passed")
144156

145-
test11()
157+
test12()
146158

147-
def test12():
159+
def test13():
148160
__cdbkernel__=create_scope()
149161
{a,b}::Indices(vector);
150162
Tr{#}::Trace(indices=vector);
@@ -155,18 +167,18 @@ def test12():
155167
sort_product(_);
156168
tst:= Tr( -A B ) - @(ex);
157169
assert(tst==0)
158-
print("Test 12 passed")
170+
print("Test 13 passed")
159171

160-
test12()
172+
test13()
161173

162-
def test13():
174+
def test14():
163175
__cdbkernel__=create_scope()
164176
{A,B,C}::ImplicitIndex;
165177
ex:=(B A C)_{a a};
166178
sort_product(_);
167179
tst:= (A C B)_{a a} - @(ex);
168180
assert(tst==0)
169-
print("Test 13 passed")
181+
print("Test 14 passed")
170182

171-
test13()
183+
test14()
172184

0 commit comments

Comments
 (0)