Skip to content

Commit f5da64f

Browse files
eric-wieserutensil
authored andcommitted
Replace the broken linear_expand(x, mode=False) (#165)
This introduces a new `linear_expand_terms` function with the same meaning as the above The original was incorrect because the early return paths did not respect the mode argument. By using this function through mv.py, we can remove a line or two from a lot of functions
1 parent f35f01d commit f5da64f

File tree

2 files changed

+25
-33
lines changed

2 files changed

+25
-33
lines changed

galgebra/metric.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def apply_function_list(f,x):
2424
return f(x)
2525

2626

27-
def linear_expand(expr, mode=True):
27+
def linear_expand(expr):
2828

2929
if isinstance(expr, Expr):
3030
expr = expand(expr)
@@ -58,10 +58,13 @@ def linear_expand(expr, mode=True):
5858
else:
5959
bases.append(base)
6060
coefs.append(coef)
61-
if mode:
62-
return (coefs, bases)
63-
else:
64-
return list(zip(coefs, bases))
61+
return (coefs, bases)
62+
63+
64+
def linear_expand_terms(expr):
65+
coefs, bases = linear_expand(expr)
66+
return zip(coefs, bases)
67+
6568

6669
def collect(A, nc_list):
6770
"""

galgebra/mv.py

Lines changed: 17 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -906,9 +906,8 @@ def collect(self,deep=False):
906906
self.obj = self.obj.collect(c)
907907
return self
908908
"""
909-
coefs, bases = metric.linear_expand(self.obj)
910909
obj_dict = {}
911-
for (coef, base) in zip(coefs, bases):
910+
for coef, base in metric.linear_expand_terms(self.obj):
912911
if base in list(obj_dict.keys()):
913912
obj_dict[base] += coef
914913
else:
@@ -997,14 +996,12 @@ def get_grade(self, r):
997996
return Mv(self.Ga.get_grade(self.obj, r), ga=self.Ga)
998997

999998
def components(self):
1000-
(coefs, bases) = metric.linear_expand(self.obj)
1001-
cb = list(zip(coefs, bases))
999+
cb = metric.linear_expand_terms(self.obj)
10021000
cb = sorted(cb, key=lambda x: self.Ga._all_blades_lst.index(x[1]))
10031001
return [self.Ga.mv(coef * base) for (coef, base) in cb]
10041002

10051003
def get_coefs(self, grade):
1006-
(coefs, bases) = metric.linear_expand(self.obj)
1007-
cb = list(zip(coefs, bases))
1004+
cb = metric.linear_expand_terms(self.obj)
10081005
cb = sorted(cb, key=lambda x: self.Ga.blades[grade].index(x[1]))
10091006
(coefs, bases) = list(zip(*cb))
10101007
return coefs
@@ -1040,9 +1037,8 @@ def proj(self, bases_lst):
10401037
part of multivector with the same bases as in the bases_lst.
10411038
"""
10421039
bases_lst = [x.obj for x in bases_lst]
1043-
(coefs, bases) = metric.linear_expand(self.obj)
10441040
obj = 0
1045-
for (coef, base) in zip(coefs, bases):
1041+
for coef, base in metric.linear_expand_terms(self.obj):
10461042
if base in bases_lst:
10471043
obj += coef * base
10481044
return Mv(obj, ga=self.Ga)
@@ -1313,9 +1309,8 @@ def inv(self):
13131309
raise TypeError('In inv() for self =' + str(self) + 'self, or self*self or self*self.rev() is not a scalar')
13141310

13151311
def func(self, fct): # Apply function, fct, to each coefficient of multivector
1316-
(coefs, bases) = metric.linear_expand(self.obj)
13171312
s = S(0)
1318-
for (coef, base) in zip(coefs, bases):
1313+
for coef, base in metric.linear_expand_terms(self.obj):
13191314
s += fct(coef) * base
13201315
fct_self = Mv(s, ga=self.Ga)
13211316
fct_self.characterise_Mv()
@@ -1325,38 +1320,33 @@ def trigsimp(self):
13251320
return self.func(trigsimp)
13261321

13271322
def simplify(self, modes=simplify):
1328-
(coefs, bases) = metric.linear_expand(self.obj)
1323+
if not isinstance(modes, (list, tuple)):
1324+
modes = [modes]
1325+
13291326
obj = S(0)
1330-
if isinstance(modes, list) or isinstance(modes, tuple):
1331-
for (coef, base) in zip(coefs, bases):
1332-
for mode in modes:
1333-
coef = mode(coef)
1334-
obj += coef * base
1335-
else:
1336-
for (coef, base) in zip(coefs, bases):
1337-
obj += modes(coef) * base
1327+
for coef, base in metric.linear_expand_terms(self.obj):
1328+
for mode in modes:
1329+
coef = mode(coef)
1330+
obj += coef * base
13381331
return Mv(obj, ga=self.Ga)
13391332

13401333
def subs(self, d):
13411334
# For each scalar coef of the multivector apply substitution argument d
1342-
(coefs, bases) = metric.linear_expand(self.obj)
13431335
obj = sum((
1344-
coef.subs(d) * base for coef, base in zip(coefs, bases)
1336+
coef.subs(d) * base for coef, base in metric.linear_expand_terms(self.obj)
13451337
), S(0))
13461338
return Mv(obj, ga=self.Ga)
13471339

13481340
def expand(self):
1349-
coefs, bases = metric.linear_expand(self.obj)
13501341
obj = sum((
1351-
expand(coef) * base for coef, base in zip(coefs, bases)
1342+
expand(coef) * base for coef, base in metric.linear_expand_terms(self.obj)
13521343
), S(0))
13531344
return Mv(obj, ga=self.Ga)
13541345

13551346
def list(self):
1356-
(coefs, bases) = metric.linear_expand(self.obj)
13571347
indexes = []
13581348
key_coefs = []
1359-
for (coef, base) in zip(coefs, bases):
1349+
for coef, base in metric.linear_expand_terms(self.obj):
13601350
if base in self.Ga.basis:
13611351
index = self.Ga.basis.index(base)
13621352
key_coefs.append((coef, index))
@@ -2032,7 +2022,7 @@ def blade_rep(self):
20322022
coefs = N * [[]]
20332023
bases = N * [0]
20342024
for term in self.terms:
2035-
for (coef, base) in metric.linear_expand(self.terms[0].obj, mode=False):
2025+
for coef, base in metric.linear_expand_terms(self.terms[0].obj):
20362026
index = self.blades.index(base)
20372027
coefs[index] = coef
20382028
bases[index] = base
@@ -2256,8 +2246,7 @@ def Dop_mv_expand(self, modes=None):
22562246

22572247
for (coef, pdiff) in self.terms:
22582248
if isinstance(coef, Mv) and not coef.is_scalar():
2259-
mv_terms = metric.linear_expand(coef.obj, mode=False)
2260-
for (mv_coef, mv_base) in mv_terms:
2249+
for mv_coef, mv_base in metric.linear_expand_terms(coef.obj):
22612250
if mv_base in bases:
22622251
index = bases.index(mv_base)
22632252
coefs[index] += Sdop([(mv_coef, pdiff)], ga=self.Ga)

0 commit comments

Comments
 (0)