|
13 | 13 | Integer, |
14 | 14 | Integer0, |
15 | 15 | Integer1, |
| 16 | + RationalOneHalf, |
16 | 17 | Number, |
17 | 18 | Symbol, |
18 | 19 | SymbolFalse, |
@@ -213,7 +214,6 @@ def unconvert_subexprs(expr): |
213 | 214 | ) |
214 | 215 |
|
215 | 216 | sympy_expr = convert_sympy(expr) |
216 | | - |
217 | 217 | if deep: |
218 | 218 | # thread over everything |
219 | 219 | for (i, sub_expr,) in enumerate(sub_exprs): |
@@ -270,7 +270,6 @@ def unconvert_subexprs(expr): |
270 | 270 | sympy_expr = sympy_expr.expand(**hints) |
271 | 271 | result = from_sympy(sympy_expr) |
272 | 272 | result = unconvert_subexprs(result) |
273 | | - |
274 | 273 | return result |
275 | 274 |
|
276 | 275 |
|
@@ -1606,3 +1605,359 @@ def apply(self, expr, form, h, evaluation): |
1606 | 1605 | return Expression( |
1607 | 1606 | "List", *[Expression(h, *[i for i in s]) for s in exponents] |
1608 | 1607 | ) |
| 1608 | + |
| 1609 | + |
| 1610 | +class _CoefficientHandler(Builtin): |
| 1611 | + def coeff_power_internal(self, expr, var_exprs, filt, evaluation, form="expr"): |
| 1612 | + from mathics.builtin.patterns import match |
| 1613 | + |
| 1614 | + if len(var_exprs) == 0: |
| 1615 | + if form == "expr": |
| 1616 | + return expr |
| 1617 | + else: |
| 1618 | + return [([], expr)] |
| 1619 | + if len(var_exprs) == 1: |
| 1620 | + target_pat = Pattern.create(var_exprs[0]) |
| 1621 | + var_pats = [target_pat] |
| 1622 | + else: |
| 1623 | + target_pat = Pattern.create(Expression("Alternatives", *var_exprs)) |
| 1624 | + var_pats = [Pattern.create(var) for var in var_exprs] |
| 1625 | + |
| 1626 | + ####### Auxiliary functions ######### |
| 1627 | + def key_powers(lst): |
| 1628 | + key = Expression("Plus", *lst) |
| 1629 | + key = key.evaluate(evaluation) |
| 1630 | + if key.is_numeric(): |
| 1631 | + return key.to_python() |
| 1632 | + return 0 |
| 1633 | + |
| 1634 | + def powers_list(pf): |
| 1635 | + powers = [Integer0 for i, p in enumerate(var_pats)] |
| 1636 | + if pf is None: |
| 1637 | + return powers |
| 1638 | + if pf.is_symbol(): |
| 1639 | + for i, pat in enumerate(var_pats): |
| 1640 | + if match(pf, pat, evaluation): |
| 1641 | + powers[i] = Integer(1) |
| 1642 | + return powers |
| 1643 | + if pf.has_form("Sqrt", 1): |
| 1644 | + for i, pat in enumerate(var_pats): |
| 1645 | + if match(pf._leaves[0], pat, evaluation): |
| 1646 | + powers[i] = RationalOneHalf |
| 1647 | + return powers |
| 1648 | + if pf.has_form("Power", 2): |
| 1649 | + for i, pat in enumerate(var_pats): |
| 1650 | + matchval = match(pf._leaves[0], pat, evaluation) |
| 1651 | + if matchval: |
| 1652 | + powers[i] = pf._leaves[1] |
| 1653 | + return powers |
| 1654 | + if pf.has_form("Times", None): |
| 1655 | + contrib = [powers_list(factor) for factor in pf._leaves] |
| 1656 | + for i in range(len(var_pats)): |
| 1657 | + powers[i] = Expression("Plus", *[c[i] for c in contrib]).evaluate( |
| 1658 | + evaluation |
| 1659 | + ) |
| 1660 | + return powers |
| 1661 | + return powers |
| 1662 | + |
| 1663 | + def split_coeff_pow(term): |
| 1664 | + """ |
| 1665 | + This function factorizes term in a coefficent free |
| 1666 | + of powers of the target variables, and a factor with |
| 1667 | + that powers. |
| 1668 | + """ |
| 1669 | + coeffs = [] |
| 1670 | + powers = [] |
| 1671 | + # First, split factors on those which are powers of the variables |
| 1672 | + # and the rest. |
| 1673 | + if term.is_free(target_pat, evaluation): |
| 1674 | + coeffs.append(term) |
| 1675 | + elif ( |
| 1676 | + term.is_symbol() |
| 1677 | + or term.has_form("Power", 2) |
| 1678 | + or term.has_form("Sqrt", 1) |
| 1679 | + ): |
| 1680 | + powers.append(term) |
| 1681 | + elif term.has_form("Times", None): |
| 1682 | + for factor in term.leaves: |
| 1683 | + if factor.is_free(target_pat, evaluation): |
| 1684 | + coeffs.append(factor) |
| 1685 | + elif match(factor, target_pat, evaluation): |
| 1686 | + powers.append(factor) |
| 1687 | + elif ( |
| 1688 | + factor.has_form("Power", 2) or factor.has_form("Sqrt", 1) |
| 1689 | + ) and match(factor._leaves[0], target_pat, evaluation): |
| 1690 | + powers.append(factor) |
| 1691 | + else: |
| 1692 | + coeffs.append(factor) |
| 1693 | + else: |
| 1694 | + coeffs.append(term) |
| 1695 | + # Now, rebuild both factors |
| 1696 | + if len(coeffs) == 0: |
| 1697 | + coeffs = None |
| 1698 | + elif len(coeffs) == 1: |
| 1699 | + coeffs = coeffs[0] |
| 1700 | + else: |
| 1701 | + coeffs = Expression("Times", *coeffs) |
| 1702 | + if len(powers) == 0: |
| 1703 | + powers = None |
| 1704 | + elif len(powers) == 1: |
| 1705 | + powers = powers[0] |
| 1706 | + else: |
| 1707 | + powers = Expression("Times", *sorted(powers)) |
| 1708 | + return coeffs, powers |
| 1709 | + |
| 1710 | + ################# The actual begin #################### |
| 1711 | + expr = expand( |
| 1712 | + expr, |
| 1713 | + numer=True, |
| 1714 | + denom=False, |
| 1715 | + deep=False, |
| 1716 | + trig=False, |
| 1717 | + modulus=None, |
| 1718 | + target_pat=target_pat, |
| 1719 | + ) |
| 1720 | + |
| 1721 | + if expr.is_free(target_pat, evaluation): |
| 1722 | + if filt: |
| 1723 | + expr = Expression(filt, expr).evaluate(evaluation) |
| 1724 | + if form == "expr": |
| 1725 | + return expr |
| 1726 | + else: |
| 1727 | + return [(powers_list(None), expr)] |
| 1728 | + elif ( |
| 1729 | + expr.is_symbol() |
| 1730 | + or match(expr, target_pat, evaluation) |
| 1731 | + or expr.has_form("Power", 2) |
| 1732 | + or expr.has_form("Sqrt", 1) |
| 1733 | + ): |
| 1734 | + coeff = ( |
| 1735 | + Expression(filt, Integer1).evaluate(evaluation) if filt else Integer1 |
| 1736 | + ) |
| 1737 | + if form == "expr": |
| 1738 | + if coeff is Integer1: |
| 1739 | + return expr |
| 1740 | + else: |
| 1741 | + return Expression("Times", coeff, expr) |
| 1742 | + else: |
| 1743 | + if not coeff.is_free(target_pat, evaluation): |
| 1744 | + return [] |
| 1745 | + return [(powers_list(expr), coeff)] |
| 1746 | + elif expr.has_form("Times", None): |
| 1747 | + coeff, powers = split_coeff_pow(expr) |
| 1748 | + if coeff is None: |
| 1749 | + coeff = Integer1 |
| 1750 | + else: |
| 1751 | + if form != "expr" and not coeff.is_free(target_pat, evaluation): |
| 1752 | + return [] |
| 1753 | + if filt: |
| 1754 | + coeff = Expression(filt, coeff).evaluate(evaluation) |
| 1755 | + |
| 1756 | + if form == "expr": |
| 1757 | + if powers is None: |
| 1758 | + return coeff |
| 1759 | + else: |
| 1760 | + if coeff is Integer1: |
| 1761 | + return powers |
| 1762 | + else: |
| 1763 | + return Expression("Times", coeff, powers) |
| 1764 | + else: |
| 1765 | + pl = powers_list(powers) |
| 1766 | + return [(pl, coeff)] |
| 1767 | + elif expr.has_form("Plus", None): |
| 1768 | + coeff_dict = {} |
| 1769 | + powers_dict = {} |
| 1770 | + powers_order = {} |
| 1771 | + for term in expr._leaves: |
| 1772 | + coeff, powers = split_coeff_pow(term) |
| 1773 | + if ( |
| 1774 | + form != "expr" |
| 1775 | + and coeff is not None |
| 1776 | + and not coeff.is_free(target_pat, evaluation) |
| 1777 | + ): |
| 1778 | + return [] |
| 1779 | + pl = powers_list(powers) |
| 1780 | + key = str(pl) |
| 1781 | + if not key in powers_dict: |
| 1782 | + if form == "expr": |
| 1783 | + powers_dict[key] = powers |
| 1784 | + else: |
| 1785 | + # TODO: check if pl is a monomial... |
| 1786 | + powers_dict[key] = pl |
| 1787 | + coeff_dict[key] = [] |
| 1788 | + powers_order[key] = key_powers(pl) |
| 1789 | + |
| 1790 | + coeff_dict[key].append(Integer1 if coeff is None else coeff) |
| 1791 | + |
| 1792 | + terms = [] |
| 1793 | + for key in sorted( |
| 1794 | + coeff_dict, key=lambda kv: powers_order[kv], reverse=False |
| 1795 | + ): |
| 1796 | + val = coeff_dict[key] |
| 1797 | + if len(val) == 0: |
| 1798 | + continue |
| 1799 | + elif len(val) == 1: |
| 1800 | + coeff = val[0] |
| 1801 | + else: |
| 1802 | + coeff = Expression("Plus", *val) |
| 1803 | + if filt: |
| 1804 | + coeff = Expression(filt, coeff).evaluate(evaluation) |
| 1805 | + |
| 1806 | + powerfactor = powers_dict[key] |
| 1807 | + if form == "expr": |
| 1808 | + if powerfactor: |
| 1809 | + terms.append(Expression("Times", coeff, powerfactor)) |
| 1810 | + else: |
| 1811 | + terms.append(coeff) |
| 1812 | + else: |
| 1813 | + terms.append([powerfactor, coeff]) |
| 1814 | + if form == "expr": |
| 1815 | + return Expression("Plus", *terms) |
| 1816 | + else: |
| 1817 | + return terms |
| 1818 | + else: |
| 1819 | + # expr is not a polynomial. |
| 1820 | + if form == "expr": |
| 1821 | + if filt: |
| 1822 | + expr = Expression(filt, expr).evaluate(evaluation) |
| 1823 | + return expr |
| 1824 | + else: |
| 1825 | + return [] |
| 1826 | + |
| 1827 | + |
| 1828 | +class CoefficientArrays(_CoefficientHandler): |
| 1829 | + """ |
| 1830 | + <dl> |
| 1831 | + <dt>'CoefficientArrays[$polys$, $vars$]' |
| 1832 | + <dd>returns a list of arrays of coefficients of the variables $vars$ in the polynomial $poly$. |
| 1833 | + </dl> |
| 1834 | +
|
| 1835 | + >> CoefficientArrays[1 + x^3, x] |
| 1836 | + = {1, {0}, {{0}}, {{{1}}}} |
| 1837 | + >> CoefficientArrays[1 + x y+ x^3, {x, y}] |
| 1838 | + = {1, {0, 0}, {{0, 1}, {0, 0}}, {{{1, 0}, {0, 0}}, {{0, 0}, {0, 0}}}} |
| 1839 | + >> CoefficientArrays[{1 + x^2, x y}, {x, y}] |
| 1840 | + = {{1, 0}, {{0, 0}, {0, 0}}, {{{1, 0}, {0, 0}}, {{0, 1}, {0, 0}}}} |
| 1841 | + >> CoefficientArrays[(x+y+Sin[z])^3, {x,y}] |
| 1842 | + = {Sin[z] ^ 3, {3 Sin[z] ^ 2, 3 Sin[z] ^ 2}, {{3 Sin[z], 6 Sin[z]}, {0, 3 Sin[z]}}, {{{1, 3}, {0, 3}}, {{0, 0}, {0, 1}}}} |
| 1843 | + >> CoefficientArrays[(x + y + Sin[z])^3, {x, z}] |
| 1844 | + : (x + y + Sin[z]) ^ 3 is not a polynomial in {x, z} |
| 1845 | + = CoefficientArrays[(x + y + Sin[z]) ^ 3, {x, z}] |
| 1846 | + """ |
| 1847 | + |
| 1848 | + options = { |
| 1849 | + "Symmetric": "False", |
| 1850 | + } |
| 1851 | + messages = { |
| 1852 | + "poly": "`1` is not a polynomial in `2`", |
| 1853 | + } |
| 1854 | + |
| 1855 | + def apply_list(self, polys, varlist, evaluation, options): |
| 1856 | + "%(name)s[polys_, varlist_, OptionsPattern[]]" |
| 1857 | + from mathics.builtin.lists import walk_parts |
| 1858 | + |
| 1859 | + if polys.has_form("List", None): |
| 1860 | + list_polys = polys.leaves |
| 1861 | + else: |
| 1862 | + list_polys = [polys] |
| 1863 | + |
| 1864 | + if varlist.is_symbol(): |
| 1865 | + var_exprs = [varlist] |
| 1866 | + elif varlist.has_form("List", None): |
| 1867 | + var_exprs = varlist.get_leaves() |
| 1868 | + else: |
| 1869 | + var_exprs = [varlist] |
| 1870 | + |
| 1871 | + coeffs = [ |
| 1872 | + self.coeff_power_internal(pol, var_exprs, None, evaluation, "coeffs") |
| 1873 | + for pol in list_polys |
| 1874 | + ] |
| 1875 | + |
| 1876 | + dim1 = len(coeffs) |
| 1877 | + dim2 = len(var_exprs) |
| 1878 | + arrays = [] |
| 1879 | + if dim1 == 1: |
| 1880 | + arrays.append(Integer(0)) |
| 1881 | + for i, component in enumerate(coeffs): |
| 1882 | + if len(component) == 0: |
| 1883 | + evaluation.message("CoefficientArrays", "poly", polys, varlist) |
| 1884 | + return |
| 1885 | + for idxcoeff in component: |
| 1886 | + idx, coeff = idxcoeff |
| 1887 | + order = Expression("Plus", *idx).evaluate(evaluation).get_int_value() |
| 1888 | + if order is None: |
| 1889 | + evaluation.message("CoefficientArrays", "poly", polys, varlist) |
| 1890 | + return |
| 1891 | + while len(arrays) <= order: |
| 1892 | + cur_ord = len(arrays) |
| 1893 | + range2 = Expression(SymbolList, Integer(dim2)) |
| 1894 | + its2 = [range2 for k in range(cur_ord)] |
| 1895 | + # TODO: Use SparseArray... |
| 1896 | + # This constructs a tensor or range cur_ord+1 |
| 1897 | + if dim1 > 1: |
| 1898 | + newtable = Expression( |
| 1899 | + "Table", |
| 1900 | + Integer(0), |
| 1901 | + Expression(SymbolList, Integer(dim1)), |
| 1902 | + *its2 |
| 1903 | + ) |
| 1904 | + else: |
| 1905 | + newtable = Expression("Table", Integer(0), *its2) |
| 1906 | + arrays.append(newtable.evaluate(evaluation)) |
| 1907 | + curr_array = arrays[order] |
| 1908 | + arrayidx = [ |
| 1909 | + Integer(n + 1) |
| 1910 | + for n, j in enumerate(idx) |
| 1911 | + for q in range(j.get_int_value()) |
| 1912 | + ] |
| 1913 | + if dim1 > 1: |
| 1914 | + arrayidx = [Integer(i + 1)] + arrayidx |
| 1915 | + if dim1 == 1 and order == 0: |
| 1916 | + arrays[0] = coeff |
| 1917 | + else: |
| 1918 | + arrays[order] = walk_parts( |
| 1919 | + [curr_array], arrayidx, evaluation, coeff |
| 1920 | + ) |
| 1921 | + return Expression("List", *arrays) |
| 1922 | + |
| 1923 | + |
| 1924 | +class Collect(_CoefficientHandler): |
| 1925 | + """ |
| 1926 | + <dl> |
| 1927 | + <dt>'Collect[$expr$, $x$]' |
| 1928 | + <dd> Expands $expr$ and collect together terms having the same power of $x$. |
| 1929 | + <dt>'Collect[$expr$, {$x_1$, $x_2$, ...}]' |
| 1930 | + <dd> Expands $expr$ and collect together terms having the same powers of |
| 1931 | + $x_1$, $x_2$, .... |
| 1932 | + <dt>'Collect[$expr$, {$x_1$, $x_2$, ...}, $filter$]' |
| 1933 | + <dd> After collect the terms, applies $filter$ to each coefficient. |
| 1934 | + </dl> |
| 1935 | +
|
| 1936 | + >> Collect[(x+y)^3, y] |
| 1937 | + = x ^ 3 + 3 x ^ 2 y + 3 x y ^ 2 + y ^ 3 |
| 1938 | + >> Collect[2 Sin[x z] (x+2 y^2 + Sin[y] x), y] |
| 1939 | + = 2 x Sin[x z] + 2 x Sin[x z] Sin[y] + 4 y ^ 2 Sin[x z] |
| 1940 | + >> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, y] |
| 1941 | + = 4 x Sin[x z] + x ^ 3 + y (3 x + 3 x ^ 2) + y ^ 2 (3 x + 4 Sin[x z]) + y ^ 3 |
| 1942 | + >> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, {x,y}] |
| 1943 | + = 4 x Sin[x z] + x ^ 3 + 3 x y + 3 x ^ 2 y + 4 y ^ 2 Sin[x z] + 3 x y ^ 2 + y ^ 3 |
| 1944 | + >> Collect[3 x y+2 Sin[x z] (x+2 y^2 + x) + (x+y)^3, {x,y}, h] |
| 1945 | + = x h[4 Sin[x z]] + x ^ 3 h[1] + x y h[3] + x ^ 2 y h[3] + y ^ 2 h[4 Sin[x z]] + x y ^ 2 h[3] + y ^ 3 h[1] |
| 1946 | + """ |
| 1947 | + |
| 1948 | + rules = { |
| 1949 | + "Collect[expr_, varlst_]": "Collect[expr, varlst, Identity]", |
| 1950 | + } |
| 1951 | + |
| 1952 | + def apply_var_filter(self, expr, varlst, filt, evaluation): |
| 1953 | + """Collect[expr_, varlst_, filt_]""" |
| 1954 | + if filt == Symbol("Identity"): |
| 1955 | + filt = None |
| 1956 | + if varlst.is_symbol(): |
| 1957 | + var_exprs = [varlst] |
| 1958 | + elif varlst.has_form("List", None): |
| 1959 | + var_exprs = varlst.get_leaves() |
| 1960 | + else: |
| 1961 | + var_exprs = [varlst] |
| 1962 | + |
| 1963 | + return self.coeff_power_internal(expr, var_exprs, filt, evaluation, "expr") |
0 commit comments