Skip to content

Commit 28530c4

Browse files
authored
Merge pull request #1324 from mathics/coeffarrays
CoefficientArrays and Collect
2 parents 3de2a23 + 9d587f7 commit 28530c4

File tree

6 files changed

+366
-6
lines changed

6 files changed

+366
-6
lines changed

CHANGES.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ New variables and builtins
1414
++++++++++++++++++++++++++
1515

1616
* ``Arg``
17+
* ``CoefficientArrays`` and ``Collect`` (#1174, #1194)
1718
* ``Dispatch``
1819
* ``FullSimplify``
1920
* ``LetterNumber`` #1298. The ``alphabet`` parameter supports only a minimal number of languages.
@@ -52,6 +53,7 @@ Enhancements
5253
* ``ToString`` accepts an optional *form* parameter.
5354
* The implementation of Streams was redone.
5455
* ``ToExpression`` handles multi-line string input
56+
* ``$VersionNumber`` now set to 10.0 (was 6.0)
5557
* The implementation of Streams was redone.
5658

5759

mathics/builtin/numbers/algebra.py

Lines changed: 357 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
Integer,
1414
Integer0,
1515
Integer1,
16+
RationalOneHalf,
1617
Number,
1718
Symbol,
1819
SymbolFalse,
@@ -213,7 +214,6 @@ def unconvert_subexprs(expr):
213214
)
214215

215216
sympy_expr = convert_sympy(expr)
216-
217217
if deep:
218218
# thread over everything
219219
for (i, sub_expr,) in enumerate(sub_exprs):
@@ -270,7 +270,6 @@ def unconvert_subexprs(expr):
270270
sympy_expr = sympy_expr.expand(**hints)
271271
result = from_sympy(sympy_expr)
272272
result = unconvert_subexprs(result)
273-
274273
return result
275274

276275

@@ -1606,3 +1605,359 @@ def apply(self, expr, form, h, evaluation):
16061605
return Expression(
16071606
"List", *[Expression(h, *[i for i in s]) for s in exponents]
16081607
)
1608+
1609+
1610+
class _CoefficientHandler(Builtin):
1611+
def coeff_power_internal(self, expr, var_exprs, filt, evaluation, form="expr"):
1612+
from mathics.builtin.patterns import match
1613+
1614+
if len(var_exprs) == 0:
1615+
if form == "expr":
1616+
return expr
1617+
else:
1618+
return [([], expr)]
1619+
if len(var_exprs) == 1:
1620+
target_pat = Pattern.create(var_exprs[0])
1621+
var_pats = [target_pat]
1622+
else:
1623+
target_pat = Pattern.create(Expression("Alternatives", *var_exprs))
1624+
var_pats = [Pattern.create(var) for var in var_exprs]
1625+
1626+
####### Auxiliary functions #########
1627+
def key_powers(lst):
1628+
key = Expression("Plus", *lst)
1629+
key = key.evaluate(evaluation)
1630+
if key.is_numeric():
1631+
return key.to_python()
1632+
return 0
1633+
1634+
def powers_list(pf):
1635+
powers = [Integer0 for i, p in enumerate(var_pats)]
1636+
if pf is None:
1637+
return powers
1638+
if pf.is_symbol():
1639+
for i, pat in enumerate(var_pats):
1640+
if match(pf, pat, evaluation):
1641+
powers[i] = Integer(1)
1642+
return powers
1643+
if pf.has_form("Sqrt", 1):
1644+
for i, pat in enumerate(var_pats):
1645+
if match(pf._leaves[0], pat, evaluation):
1646+
powers[i] = RationalOneHalf
1647+
return powers
1648+
if pf.has_form("Power", 2):
1649+
for i, pat in enumerate(var_pats):
1650+
matchval = match(pf._leaves[0], pat, evaluation)
1651+
if matchval:
1652+
powers[i] = pf._leaves[1]
1653+
return powers
1654+
if pf.has_form("Times", None):
1655+
contrib = [powers_list(factor) for factor in pf._leaves]
1656+
for i in range(len(var_pats)):
1657+
powers[i] = Expression("Plus", *[c[i] for c in contrib]).evaluate(
1658+
evaluation
1659+
)
1660+
return powers
1661+
return powers
1662+
1663+
def split_coeff_pow(term):
1664+
"""
1665+
This function factorizes term in a coefficent free
1666+
of powers of the target variables, and a factor with
1667+
that powers.
1668+
"""
1669+
coeffs = []
1670+
powers = []
1671+
# First, split factors on those which are powers of the variables
1672+
# and the rest.
1673+
if term.is_free(target_pat, evaluation):
1674+
coeffs.append(term)
1675+
elif (
1676+
term.is_symbol()
1677+
or term.has_form("Power", 2)
1678+
or term.has_form("Sqrt", 1)
1679+
):
1680+
powers.append(term)
1681+
elif term.has_form("Times", None):
1682+
for factor in term.leaves:
1683+
if factor.is_free(target_pat, evaluation):
1684+
coeffs.append(factor)
1685+
elif match(factor, target_pat, evaluation):
1686+
powers.append(factor)
1687+
elif (
1688+
factor.has_form("Power", 2) or factor.has_form("Sqrt", 1)
1689+
) and match(factor._leaves[0], target_pat, evaluation):
1690+
powers.append(factor)
1691+
else:
1692+
coeffs.append(factor)
1693+
else:
1694+
coeffs.append(term)
1695+
# Now, rebuild both factors
1696+
if len(coeffs) == 0:
1697+
coeffs = None
1698+
elif len(coeffs) == 1:
1699+
coeffs = coeffs[0]
1700+
else:
1701+
coeffs = Expression("Times", *coeffs)
1702+
if len(powers) == 0:
1703+
powers = None
1704+
elif len(powers) == 1:
1705+
powers = powers[0]
1706+
else:
1707+
powers = Expression("Times", *sorted(powers))
1708+
return coeffs, powers
1709+
1710+
################# The actual begin ####################
1711+
expr = expand(
1712+
expr,
1713+
numer=True,
1714+
denom=False,
1715+
deep=False,
1716+
trig=False,
1717+
modulus=None,
1718+
target_pat=target_pat,
1719+
)
1720+
1721+
if expr.is_free(target_pat, evaluation):
1722+
if filt:
1723+
expr = Expression(filt, expr).evaluate(evaluation)
1724+
if form == "expr":
1725+
return expr
1726+
else:
1727+
return [(powers_list(None), expr)]
1728+
elif (
1729+
expr.is_symbol()
1730+
or match(expr, target_pat, evaluation)
1731+
or expr.has_form("Power", 2)
1732+
or expr.has_form("Sqrt", 1)
1733+
):
1734+
coeff = (
1735+
Expression(filt, Integer1).evaluate(evaluation) if filt else Integer1
1736+
)
1737+
if form == "expr":
1738+
if coeff is Integer1:
1739+
return expr
1740+
else:
1741+
return Expression("Times", coeff, expr)
1742+
else:
1743+
if not coeff.is_free(target_pat, evaluation):
1744+
return []
1745+
return [(powers_list(expr), coeff)]
1746+
elif expr.has_form("Times", None):
1747+
coeff, powers = split_coeff_pow(expr)
1748+
if coeff is None:
1749+
coeff = Integer1
1750+
else:
1751+
if form != "expr" and not coeff.is_free(target_pat, evaluation):
1752+
return []
1753+
if filt:
1754+
coeff = Expression(filt, coeff).evaluate(evaluation)
1755+
1756+
if form == "expr":
1757+
if powers is None:
1758+
return coeff
1759+
else:
1760+
if coeff is Integer1:
1761+
return powers
1762+
else:
1763+
return Expression("Times", coeff, powers)
1764+
else:
1765+
pl = powers_list(powers)
1766+
return [(pl, coeff)]
1767+
elif expr.has_form("Plus", None):
1768+
coeff_dict = {}
1769+
powers_dict = {}
1770+
powers_order = {}
1771+
for term in expr._leaves:
1772+
coeff, powers = split_coeff_pow(term)
1773+
if (
1774+
form != "expr"
1775+
and coeff is not None
1776+
and not coeff.is_free(target_pat, evaluation)
1777+
):
1778+
return []
1779+
pl = powers_list(powers)
1780+
key = str(pl)
1781+
if not key in powers_dict:
1782+
if form == "expr":
1783+
powers_dict[key] = powers
1784+
else:
1785+
# TODO: check if pl is a monomial...
1786+
powers_dict[key] = pl
1787+
coeff_dict[key] = []
1788+
powers_order[key] = key_powers(pl)
1789+
1790+
coeff_dict[key].append(Integer1 if coeff is None else coeff)
1791+
1792+
terms = []
1793+
for key in sorted(
1794+
coeff_dict, key=lambda kv: powers_order[kv], reverse=False
1795+
):
1796+
val = coeff_dict[key]
1797+
if len(val) == 0:
1798+
continue
1799+
elif len(val) == 1:
1800+
coeff = val[0]
1801+
else:
1802+
coeff = Expression("Plus", *val)
1803+
if filt:
1804+
coeff = Expression(filt, coeff).evaluate(evaluation)
1805+
1806+
powerfactor = powers_dict[key]
1807+
if form == "expr":
1808+
if powerfactor:
1809+
terms.append(Expression("Times", coeff, powerfactor))
1810+
else:
1811+
terms.append(coeff)
1812+
else:
1813+
terms.append([powerfactor, coeff])
1814+
if form == "expr":
1815+
return Expression("Plus", *terms)
1816+
else:
1817+
return terms
1818+
else:
1819+
# expr is not a polynomial.
1820+
if form == "expr":
1821+
if filt:
1822+
expr = Expression(filt, expr).evaluate(evaluation)
1823+
return expr
1824+
else:
1825+
return []
1826+
1827+
1828+
class CoefficientArrays(_CoefficientHandler):
1829+
"""
1830+
<dl>
1831+
<dt>'CoefficientArrays[$polys$, $vars$]'
1832+
<dd>returns a list of arrays of coefficients of the variables $vars$ in the polynomial $poly$.
1833+
</dl>
1834+
1835+
>> CoefficientArrays[1 + x^3, x]
1836+
= {1, {0}, {{0}}, {{{1}}}}
1837+
>> CoefficientArrays[1 + x y+ x^3, {x, y}]
1838+
= {1, {0, 0}, {{0, 1}, {0, 0}}, {{{1, 0}, {0, 0}}, {{0, 0}, {0, 0}}}}
1839+
>> CoefficientArrays[{1 + x^2, x y}, {x, y}]
1840+
= {{1, 0}, {{0, 0}, {0, 0}}, {{{1, 0}, {0, 0}}, {{0, 1}, {0, 0}}}}
1841+
>> CoefficientArrays[(x+y+Sin[z])^3, {x,y}]
1842+
= {Sin[z] ^ 3, {3 Sin[z] ^ 2, 3 Sin[z] ^ 2}, {{3 Sin[z], 6 Sin[z]}, {0, 3 Sin[z]}}, {{{1, 3}, {0, 3}}, {{0, 0}, {0, 1}}}}
1843+
>> CoefficientArrays[(x + y + Sin[z])^3, {x, z}]
1844+
: (x + y + Sin[z]) ^ 3 is not a polynomial in {x, z}
1845+
= CoefficientArrays[(x + y + Sin[z]) ^ 3, {x, z}]
1846+
"""
1847+
1848+
options = {
1849+
"Symmetric": "False",
1850+
}
1851+
messages = {
1852+
"poly": "`1` is not a polynomial in `2`",
1853+
}
1854+
1855+
def apply_list(self, polys, varlist, evaluation, options):
1856+
"%(name)s[polys_, varlist_, OptionsPattern[]]"
1857+
from mathics.builtin.lists import walk_parts
1858+
1859+
if polys.has_form("List", None):
1860+
list_polys = polys.leaves
1861+
else:
1862+
list_polys = [polys]
1863+
1864+
if varlist.is_symbol():
1865+
var_exprs = [varlist]
1866+
elif varlist.has_form("List", None):
1867+
var_exprs = varlist.get_leaves()
1868+
else:
1869+
var_exprs = [varlist]
1870+
1871+
coeffs = [
1872+
self.coeff_power_internal(pol, var_exprs, None, evaluation, "coeffs")
1873+
for pol in list_polys
1874+
]
1875+
1876+
dim1 = len(coeffs)
1877+
dim2 = len(var_exprs)
1878+
arrays = []
1879+
if dim1 == 1:
1880+
arrays.append(Integer(0))
1881+
for i, component in enumerate(coeffs):
1882+
if len(component) == 0:
1883+
evaluation.message("CoefficientArrays", "poly", polys, varlist)
1884+
return
1885+
for idxcoeff in component:
1886+
idx, coeff = idxcoeff
1887+
order = Expression("Plus", *idx).evaluate(evaluation).get_int_value()
1888+
if order is None:
1889+
evaluation.message("CoefficientArrays", "poly", polys, varlist)
1890+
return
1891+
while len(arrays) <= order:
1892+
cur_ord = len(arrays)
1893+
range2 = Expression(SymbolList, Integer(dim2))
1894+
its2 = [range2 for k in range(cur_ord)]
1895+
# TODO: Use SparseArray...
1896+
# This constructs a tensor or range cur_ord+1
1897+
if dim1 > 1:
1898+
newtable = Expression(
1899+
"Table",
1900+
Integer(0),
1901+
Expression(SymbolList, Integer(dim1)),
1902+
*its2
1903+
)
1904+
else:
1905+
newtable = Expression("Table", Integer(0), *its2)
1906+
arrays.append(newtable.evaluate(evaluation))
1907+
curr_array = arrays[order]
1908+
arrayidx = [
1909+
Integer(n + 1)
1910+
for n, j in enumerate(idx)
1911+
for q in range(j.get_int_value())
1912+
]
1913+
if dim1 > 1:
1914+
arrayidx = [Integer(i + 1)] + arrayidx
1915+
if dim1 == 1 and order == 0:
1916+
arrays[0] = coeff
1917+
else:
1918+
arrays[order] = walk_parts(
1919+
[curr_array], arrayidx, evaluation, coeff
1920+
)
1921+
return Expression("List", *arrays)
1922+
1923+
1924+
class Collect(_CoefficientHandler):
1925+
"""
1926+
<dl>
1927+
<dt>'Collect[$expr$, $x$]'
1928+
<dd> Expands $expr$ and collect together terms having the same power of $x$.
1929+
<dt>'Collect[$expr$, {$x_1$, $x_2$, ...}]'
1930+
<dd> Expands $expr$ and collect together terms having the same powers of
1931+
$x_1$, $x_2$, ....
1932+
<dt>'Collect[$expr$, {$x_1$, $x_2$, ...}, $filter$]'
1933+
<dd> After collect the terms, applies $filter$ to each coefficient.
1934+
</dl>
1935+
1936+
>> Collect[(x+y)^3, y]
1937+
= x ^ 3 + 3 x ^ 2 y + 3 x y ^ 2 + y ^ 3
1938+
>> Collect[2 Sin[x z] (x+2 y^2 + Sin[y] x), y]
1939+
= 2 x Sin[x z] + 2 x Sin[x z] Sin[y] + 4 y ^ 2 Sin[x z]
1940+
>> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, y]
1941+
= 4 x Sin[x z] + x ^ 3 + y (3 x + 3 x ^ 2) + y ^ 2 (3 x + 4 Sin[x z]) + y ^ 3
1942+
>> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, {x,y}]
1943+
= 4 x Sin[x z] + x ^ 3 + 3 x y + 3 x ^ 2 y + 4 y ^ 2 Sin[x z] + 3 x y ^ 2 + y ^ 3
1944+
>> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, {x,y}, h]
1945+
= x h[4 Sin[x z]] + x ^ 3 h[1] + x y h[3] + x ^ 2 y h[3] + y ^ 2 h[4 Sin[x z]] + x y ^ 2 h[3] + y ^ 3 h[1]
1946+
"""
1947+
1948+
rules = {
1949+
"Collect[expr_, varlst_]": "Collect[expr, varlst, Identity]",
1950+
}
1951+
1952+
def apply_var_filter(self, expr, varlst, filt, evaluation):
1953+
"""Collect[expr_, varlst_, filt_]"""
1954+
if filt == Symbol("Identity"):
1955+
filt = None
1956+
if varlst.is_symbol():
1957+
var_exprs = [varlst]
1958+
elif varlst.has_form("List", None):
1959+
var_exprs = varlst.get_leaves()
1960+
else:
1961+
var_exprs = [varlst]
1962+
1963+
return self.coeff_power_internal(expr, var_exprs, filt, evaluation, "expr")

0 commit comments

Comments
 (0)