Skip to content

Commit 09ecbd1

Browse files
eric-wieserutensil
authored andcommitted
Implement _make_scalar in terms of _make_grade (#95)
This eliminates some special-casing in _make_mv and _make_spinor Note that to avoid making gh-81 worse, we now need to expicitly only allow strings
1 parent e46f829 commit 09ecbd1

File tree

1 file changed

+31
-33
lines changed

1 file changed

+31
-33
lines changed

galgebra/mv.py

Lines changed: 31 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import copy
66
import numbers
7-
from operator import itemgetter, mul, add
7+
import operator
88
from functools import reduce, cmp_to_key
99

1010
from sympy import (
@@ -173,20 +173,23 @@ def characterise_Mv(self):
173173
@staticmethod
174174
def _make_grade(ga, __name_or_coeffs, __grade, **kwargs):
175175
""" Make a pure grade multivector. """
176+
def add_superscript(root, s):
177+
if not s:
178+
return root
179+
return '{}__{}'.format(root, s)
176180
grade = __grade
177181
if utils.isstr(__name_or_coeffs):
178182
name = __name_or_coeffs
179-
root = name + '__'
180183
if isinstance(kwargs['f'], bool) and not kwargs['f']: #Is a constant mulitvector function
181-
return sum([Symbol(root + super_script, real=True) * base
184+
return sum([Symbol(add_superscript(name, super_script), real=True) * base
182185
for (super_script, base) in zip(ga.blade_super_scripts[grade], ga.blades[grade])])
183186

184187
else:
185188
if isinstance(kwargs['f'], bool): #Is a multivector function of all coordinates
186-
return sum([Function(root + super_script, real=True)(*ga.coords) * base
189+
return sum([Function(add_superscript(name, super_script), real=True)(*ga.coords) * base
187190
for (super_script, base) in zip(ga.blade_super_scripts[grade], ga.blades[grade])])
188191
else: #Is a multivector function of tuple kwargs['f'] variables
189-
return sum([Function(root + super_script, real=True)(*kwargs['f']) * base
192+
return sum([Function(add_superscript(name, super_script), real=True)(*kwargs['f']) * base
190193
for (super_script, base) in zip(ga.blade_super_scripts[grade], ga.blades[grade])])
191194
elif isinstance(__name_or_coeffs, (list, tuple)):
192195
coeffs = __name_or_coeffs
@@ -203,14 +206,7 @@ def _make_scalar(ga, __name_or_value, **kwargs):
203206
""" Make a scalar multivector """
204207
if utils.isstr(__name_or_value):
205208
name = __name_or_value
206-
if 'f' in kwargs and isinstance(kwargs['f'],bool):
207-
if kwargs['f']:
208-
return Function(name)(*ga.coords)
209-
else:
210-
return Symbol(name, real=True)
211-
else:
212-
if 'f' in kwargs and isinstance(kwargs['f'],tuple):
213-
return Function(name)(*kwargs['f'])
209+
return Mv._make_grade(ga, name, 0, **kwargs)
214210
else:
215211
value = __name_or_value
216212
return value
@@ -233,28 +229,30 @@ def _make_pseudo(ga, __name_or_coeffs, **kwargs):
233229
@staticmethod
234230
def _make_mv(ga, __name, **kwargs):
235231
""" Make a general (2**n components) multivector """
236-
tmp = Mv._make_scalar(ga, __name, **kwargs)
237-
for grade in ga.n_range:
238-
tmp += Mv._make_grade(ga, __name, grade + 1, **kwargs)
239-
return tmp
232+
if not isinstance(__name, str):
233+
raise TypeError("Must be a string")
234+
return reduce(operator.add, (
235+
Mv._make_grade(ga, __name, grade, **kwargs)
236+
for grade in range(ga.n + 1)
237+
))
240238

241239
@staticmethod
242240
def _make_spinor(ga, __name, **kwargs):
243241
""" Make a general even (spinor) multivector """
244-
tmp = Mv._make_scalar(ga, __name, **kwargs)
245-
for grade in ga.n_range:
246-
if (grade + 1) % 2 == 0:
247-
tmp += Mv._make_grade(ga, __name, grade + 1, **kwargs)
248-
return tmp
242+
if not isinstance(__name, str):
243+
raise TypeError("Must be a string")
244+
return reduce(operator.add, (
245+
Mv._make_grade(ga, __name, grade, **kwargs)
246+
for grade in range(0, ga.n + 1, 2)
247+
))
249248

250249
@staticmethod
251250
def _make_odd(ga, __name_or_coeffs, **kwargs):
252251
""" Make a general odd multivector """
253-
tmp = S(0)
254-
for grade in ga.n_range:
255-
if (grade + 1) % 2 == 1:
256-
tmp += Mv._make_grade(ga, __name_or_coeffs, grade + 1, **kwargs)
257-
return tmp
252+
return reduce(operator.add, (
253+
Mv._make_grade(ga, __name_or_coeffs, grade, **kwargs)
254+
for grade in range(1, ga.n + 1, 2)
255+
), S(0)) # base case needed in case n == 0
258256

259257
# aliases
260258
_make_grade2 = _make_bivector
@@ -305,7 +303,7 @@ def __init__(self, *args, **kwargs):
305303
self.obj = make_func(self.Ga, *make_args, **kwargs)
306304
elif isinstance(args[1], int): # args[1] = r (integer) Construct grade r multivector
307305
if args[1] == 0:
308-
# make_grade does not work for scalars (gh-82)
306+
# _make_scalar interprets its coefficient argument differently
309307
make_args = list(args)
310308
make_args.pop(1)
311309
self.obj = Mv._make_scalar(self.Ga, *make_args, **kwargs)
@@ -567,7 +565,7 @@ def Mv_str(self):
567565
for arg in args:
568566
c, nc = arg.args_cnc()
569567
if len(c) > 0:
570-
c = reduce(mul, c)
568+
c = reduce(operator.mul, c)
571569
else:
572570
c = S(1)
573571
if len(nc) > 0:
@@ -584,7 +582,7 @@ def Mv_str(self):
584582
if grade0 != S(0):
585583
terms[-1] = (grade0, S(1), -1)
586584
terms = list(terms.items())
587-
sorted_terms = sorted(terms, key=itemgetter(0)) # sort via base indexes
585+
sorted_terms = sorted(terms, key=operator.itemgetter(0)) # sort via base indexes
588586

589587
s = str(sorted_terms[0][1][0] * sorted_terms[0][1][1])
590588
if printer.GaPrinter.fmt == 3:
@@ -654,7 +652,7 @@ def append_plus(c_str):
654652
for arg in args:
655653
c, nc = arg.args_cnc(split_1=False)
656654
if len(c) > 0:
657-
c = reduce(mul, c)
655+
c = reduce(operator.mul, c)
658656
else:
659657
c = S(1)
660658
if len(nc) > 0:
@@ -672,7 +670,7 @@ def append_plus(c_str):
672670
terms[-1] = (grade0, S(1), 0)
673671
terms = list(terms.items())
674672

675-
sorted_terms = sorted(terms, key=itemgetter(0)) # sort via base indexes
673+
sorted_terms = sorted(terms, key=operator.itemgetter(0)) # sort via base indexes
676674

677675
if len(sorted_terms) == 1 and sorted_terms[0][1][2] == 0: # scalar
678676
return printer.latex(printer.coef_simplify(sorted_terms[0][1][0]))
@@ -1327,7 +1325,7 @@ def list(self):
13271325
if index not in indexes:
13281326
key_coefs.append((S(0), index))
13291327

1330-
key_coefs = sorted(key_coefs, key=itemgetter(1))
1328+
key_coefs = sorted(key_coefs, key=operator.itemgetter(1))
13311329
coefs = [x[0] for x in key_coefs]
13321330
return coefs
13331331

0 commit comments

Comments
 (0)