Skip to content

Commit 40f334a

Browse files
author
Kasper Peeters
committed
Added 'join' function to join two objects into a list.
1 parent 16bdd6a commit 40f334a

File tree

2 files changed

+60
-1
lines changed

2 files changed

+60
-1
lines changed

core/pythoncdb/py_ex.cc

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,42 @@ namespace cadabra {
9595
// }
9696
}
9797

98+
Ex_ptr Ex_join(const Ex_ptr ex1, const Ex_ptr ex2)
99+
{
100+
if (ex1->size() == 0) return ex2;
101+
if (ex2->size() == 0) return ex1;
102+
103+
bool comma1 = (*ex1->begin()->name == "\\comma");
104+
bool comma2 = (*ex2->begin()->name == "\\comma");
105+
106+
if(comma1 || comma2) {
107+
if (comma1) {
108+
auto ret = std::make_shared<Ex>(*ex1);
109+
auto loc = ret->append_child(ret->begin(), ex2->begin());
110+
if (comma2)
111+
ret->flatten_and_erase(loc);
112+
return ret;
113+
}
114+
else {
115+
auto ret = std::make_shared<Ex>(ex2->begin());
116+
auto loc = ret->prepend_child(ret->begin(), ex1->begin());
117+
if (comma1)
118+
ret->flatten_and_erase(loc);
119+
return ret;
120+
}
121+
}
122+
else {
123+
auto ret = std::make_shared<Ex>(*ex1);
124+
if (*ret->begin()->name != "\\comma")
125+
ret->wrap(ret->begin(), str_node("\\comma"));
126+
ret->append_child(ret->begin(), ex2->begin());
127+
128+
auto it = ret->begin();
129+
cleanup_dispatch(*get_kernel_from_scope(), *ret, it);
130+
return ret;
131+
}
132+
}
133+
98134
Ex_ptr Ex_mul(const Ex_ptr ex1, const Ex_ptr ex2)
99135
{
100136
return Ex_mul(ex1, ex2, ex2->begin());
@@ -652,6 +688,7 @@ namespace cadabra {
652688
.def("from_sympy", &sympy::SympyBridge::import_ex)
653689
;
654690

691+
m.def("join", &Ex_join);
655692
m.def("tree", &print_tree);
656693

657694
m.def("map_sympy", &map_sympy_wrapper,

tests/programming.cdb

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,28 @@ def test14():
288288
__cdbkernel__ = create_scope()
289289
ex:={A,B} ~ {C,D};
290290
assert(ex==${A,B,C,D}$)
291-
print("Test 14 passed")
291+
print("Test 14a passed")
292+
ex1:= {A,B};
293+
ex2:= {C,D};
294+
ex3 = join(ex1, ex2)
295+
assert(ex3==${A,B,C,D}$)
296+
print("Test 14b passed")
297+
ex1:= A;
298+
ex2:= {C,D};
299+
ex3 = join(ex1, ex2)
300+
assert(ex3==${A,C,D}$)
301+
print("Test 14c passed")
302+
ex1:= {A,B};
303+
ex2:= C;
304+
ex3 = join(ex1, ex2)
305+
assert(ex3==${A,B,C}$)
306+
print("Test 14d passed")
307+
ex1:= A;
308+
ex2:= C;
309+
ex3 = join(ex1, ex2)
310+
assert(ex3==${A,C}$)
311+
print("Test 14e passed")
292312

293313
test14()
314+
315+

0 commit comments

Comments
 (0)