Skip to content

Commit 135885d

Browse files
author
Kasper Peeters
committed
Fix a bug in accessing the Weight value in Python. Fix an error with setting the multiplier of a node from Python.
1 parent 2ac0254 commit 135885d

File tree

6 files changed

+38
-5
lines changed

6 files changed

+38
-5
lines changed

core/ExNode.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,8 @@ void ExNode::set_multiplier(pybind11::object mult)
325325
if(!ex->is_valid(it))
326326
throw ConsistencyException("Cannot set the multiplier of an iterator before the first 'next'.");
327327

328-
pybind11::object mpq = pybind11::module::import("gmpy2").attr("mpq");
329-
multiply(it->multiplier, pybind11::cast<int>(mult));
328+
set(it->multiplier, multiplier_t(mult.attr("numerator").cast<long>(),
329+
mult.attr("denominator").cast<long>()) );
330330
}
331331

332332

core/Storage.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1014,6 +1014,12 @@ namespace cadabra {
10141014
num=rat_set.insert(fac).first;
10151015
}
10161016

1017+
void set(rset_t::iterator& num, multiplier_t fac)
1018+
{
1019+
fac.canonicalize();
1020+
num=rat_set.insert(fac).first;
1021+
}
1022+
10171023
void add(rset_t::iterator& num, multiplier_t fac)
10181024
{
10191025
fac+=*num;

core/Storage.hh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,8 @@ namespace cadabra {
129129
void one(rset_t::iterator&);
130130
void flip_sign(rset_t::iterator&);
131131
void half(rset_t::iterator&);
132-
132+
void set(rset_t::iterator&, multiplier_t);
133+
133134
/// \ingroup core
134135
///
135136
/// Basic storage class for symbolic mathemematical expressions. The

core/algorithms/substitute.cc

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ substitute::substitute(const Kernel& k, Ex& tr, Ex& args_, bool partial)
1818
{
1919
if(args.is_empty())
2020
throw ArgumentException("substitute: Replacement rule is an empty expression.");
21-
21+
22+
Stopwatch sw;
23+
sw.start();
2224
cadabra::do_list(args, args.begin(), [&](Ex::iterator arrow) {
2325
//args.print_recursive_treeform(std::cerr, arrow);
2426
if(*arrow->name!="\\arrow" && *arrow->name!="\\equals")
@@ -81,6 +83,8 @@ substitute::substitute(const Kernel& k, Ex& tr, Ex& args_, bool partial)
8183
}
8284
return true;
8385
});
86+
sw.stop();
87+
std::cerr << "preparation took " << sw << std::endl;
8488
}
8589

8690
bool substitute::can_apply(iterator st)

core/pythoncdb/py_properties.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,14 @@ namespace cadabra {
360360
def_abstract_prop<Py_DependsBase>(m, "DependsBase")
361361
.def("dependencies", [](const Py_DependsBase & p) { return p.get_prop()->dependencies(p.get_kernel(), p.get_it()); });
362362
def_abstract_prop<Py_WeightBase>(m, "WeightBase")
363-
.def("value", [](const Py_WeightBase & p, const std::string& forcedLabel) { return p.get_prop()->value(p.get_kernel(), p.get_it(), forcedLabel); });
363+
.def("value", [](const Py_WeightBase & p, const std::string& forcedLabel) {
364+
// This is mpq_class, convert to the Python equivalent.
365+
pybind11::object mpq = pybind11::module::import("gmpy2").attr("mpq");
366+
auto m = p.get_prop()->value(p.get_kernel(), p.get_it(), forcedLabel);
367+
pybind11::object mult = mpq(m.get_num().get_si(), m.get_den().get_si());
368+
return mult;
369+
});
370+
364371
def_abstract_prop<Py_DifferentialFormBase>(m, "DifferentialFormBase")
365372
.def("degree", [](const Py_DifferentialFormBase & p) { return p.get_prop()->degree(p.get_props(), p.get_it()); });
366373

tests/programming.cdb

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,5 +341,20 @@ def test16():
341341
{i,j,k}::Indices(isospin, position=independent).
342342
assert($\Lambda_{a}$.matches($\Lambda_{i}$)==False)
343343
assert($\Lambda_{a}$.matches($\Lambda^{i}$)==False)
344+
print("Test 16 passed")
344345

345346
test16()
347+
348+
def test17():
349+
x::Weight(value=42, label=field);
350+
tst1 = Weight.get($x$, label="field").value("field")
351+
assert(tst1==42)
352+
print("Test 17a passed")
353+
ex:= 3 a;
354+
ex.top().multiplier = tst1
355+
tst2:= 42 a - @(ex);
356+
assert(tst2==0)
357+
print("Test 17b passed")
358+
359+
test17()
360+

0 commit comments

Comments
 (0)