@@ -34,6 +34,7 @@ def skip_if_different_mount_drives():
3434 from stack import Local , Stack
3535 import tier1_generator
3636 import optimizer_generator
37+ import partial_evaluator_generator
3738
3839
3940def handle_stderr ():
@@ -1431,6 +1432,9 @@ def test_instruction_size_macro(self):
14311432
14321433
14331434class TestGeneratedAbstractCases (unittest .TestCase ):
1435+
1436+ generator = None
1437+
14341438 def setUp (self ) -> None :
14351439 super ().setUp ()
14361440 self .maxDiff = None
@@ -1466,7 +1470,8 @@ def run_cases_test(self, input: str, input2: str, expected: str):
14661470 temp_input .flush ()
14671471
14681472 with handle_stderr ():
1469- optimizer_generator .generate_tier2_abstract_from_files (
1473+ assert self .generator is not None
1474+ self .generator .generate_tier2_abstract_from_files (
14701475 [self .temp_input_filename , self .temp_input2_filename ],
14711476 self .temp_output_filename
14721477 )
@@ -1480,6 +1485,9 @@ def run_cases_test(self, input: str, input2: str, expected: str):
14801485 actual = "" .join (lines )
14811486 self .assertEqual (actual .strip (), expected .strip ())
14821487
1488+
1489+ class TestGeneratedOptimizerCases (TestGeneratedAbstractCases ):
1490+ generator = optimizer_generator
14831491 def test_overridden_abstract (self ):
14841492 input = """
14851493 pure op(OP, (--)) {
@@ -1580,5 +1588,166 @@ def test_missing_override_failure(self):
15801588 self .run_cases_test (input , input2 , output )
15811589
15821590
1591+ class TestGeneratedPECases (TestGeneratedAbstractCases ):
1592+ generator = partial_evaluator_generator
1593+
1594+ def test_overridden_abstract (self ):
1595+ input = """
1596+ pure op(OP, (--)) {
1597+ SPAM();
1598+ }
1599+ """
1600+ input2 = """
1601+ pure op(OP, (--)) {
1602+ eggs();
1603+ }
1604+ """
1605+ output = """
1606+ case OP: {
1607+ eggs();
1608+ break;
1609+ }
1610+ """
1611+ self .run_cases_test (input , input2 , output )
1612+
1613+ def test_overridden_abstract_args (self ):
1614+ input = """
1615+ pure op(OP, (arg1 -- out)) {
1616+ out = SPAM(arg1);
1617+ }
1618+ op(OP2, (arg1 -- out)) {
1619+ out = EGGS(arg1);
1620+ }
1621+ """
1622+ input2 = """
1623+ op(OP, (arg1 -- out)) {
1624+ out = EGGS(arg1);
1625+ }
1626+ """
1627+ output = """
1628+ case OP: {
1629+ _Py_UopsPESlot arg1;
1630+ _Py_UopsPESlot out;
1631+ arg1 = stack_pointer[-1];
1632+ arg1 = stack_pointer[-1];
1633+ out = EGGS(arg1);
1634+ stack_pointer[-1] = out;
1635+ break;
1636+ }
1637+
1638+ case OP2: {
1639+ _Py_UopsPESlot arg1;
1640+ _Py_UopsPESlot out;
1641+ MATERIALIZE_INST();
1642+ arg1 = stack_pointer[-1];
1643+ materialize(&arg1);
1644+ out = sym_new_not_null(ctx);
1645+ stack_pointer[-1] = out;
1646+ break;
1647+ }
1648+ """
1649+ self .run_cases_test (input , input2 , output )
1650+
1651+ def test_no_overridden_case (self ):
1652+ input = """
1653+ pure op(OP, (arg1 -- out)) {
1654+ out = SPAM(arg1);
1655+ }
1656+
1657+ pure op(OP2, (arg1 -- out)) {
1658+ }
1659+
1660+ """
1661+ input2 = """
1662+ pure op(OP2, (arg1 -- out)) {
1663+ out = NULL;
1664+ }
1665+ """
1666+ output = """
1667+ case OP: {
1668+ _Py_UopsPESlot arg1;
1669+ _Py_UopsPESlot out;
1670+ MATERIALIZE_INST();
1671+ arg1 = stack_pointer[-1];
1672+ materialize(&arg1);
1673+ out = sym_new_not_null(ctx);
1674+ stack_pointer[-1] = out;
1675+ break;
1676+ }
1677+
1678+ case OP2: {
1679+ _Py_UopsPESlot arg1;
1680+ _Py_UopsPESlot out;
1681+ arg1 = stack_pointer[-1];
1682+ out = NULL;
1683+ stack_pointer[-1] = out;
1684+ break;
1685+ }
1686+ """
1687+ self .run_cases_test (input , input2 , output )
1688+
1689+ def test_missing_override_failure (self ):
1690+ input = """
1691+ pure op(OP, (arg1 -- out)) {
1692+ SPAM();
1693+ }
1694+ """
1695+ input2 = """
1696+ pure op(OTHER, (arg1 -- out)) {
1697+ }
1698+ """
1699+ output = """
1700+ """
1701+ with self .assertRaisesRegex (AssertionError , "All abstract uops" ):
1702+ self .run_cases_test (input , input2 , output )
1703+
1704+
1705+ def test_validate_inputs (self ):
1706+ input = """
1707+ pure op(OP, (arg1 --)) {
1708+ SPAM();
1709+ }
1710+ """
1711+ input2 = """
1712+ // Non-matching input!
1713+ pure op(OP, (arg1, arg2 --)) {
1714+ }
1715+ """
1716+ output = """
1717+ """
1718+ with self .assertRaisesRegex (AssertionError , "input length don't match" ):
1719+ self .run_cases_test (input , input2 , output )
1720+
1721+ def test_materialize_inputs (self ):
1722+ input = """
1723+ pure op(OP2, (arg1, arg2, arg3[oparg] --)) {
1724+ }
1725+ """
1726+ input2 = """
1727+ pure op(OP2, (arg1, arg2, arg3[oparg] --)) {
1728+ MATERIALIZE_INPUTS();
1729+ }
1730+ """
1731+ output = """
1732+ case OP2: {
1733+ _Py_UopsPESlot *arg3;
1734+ _Py_UopsPESlot arg2;
1735+ _Py_UopsPESlot arg1;
1736+ arg3 = &stack_pointer[-2 - oparg];
1737+ arg2 = stack_pointer[-2];
1738+ arg1 = stack_pointer[-1];
1739+ materialize(&arg1);
1740+ materialize(&arg2);
1741+ for (int _i = oparg; --_i >= 0;) {
1742+ materialize(&arg3[_i]);
1743+ }
1744+ stack_pointer += -2 - oparg;
1745+ assert(WITHIN_STACK_BOUNDS());
1746+ break;
1747+ }
1748+ """
1749+ self .run_cases_test (input , input2 , output )
1750+
1751+
15831752if __name__ == "__main__" :
15841753 unittest .main ()
0 commit comments