diff --git a/pyrtl/passes.py b/pyrtl/passes.py index 6765ef47..e26d1af3 100644 --- a/pyrtl/passes.py +++ b/pyrtl/passes.py @@ -848,7 +848,7 @@ def arg(num): elif net.op == '^': all_1 = arg(0) & arg(1) all_0 = ~arg(0) & ~arg(1) - dest <<= all_0 & ~all_1 + dest <<= ~all_0 & ~all_1 elif net.op == 'n': dest <<= ~(arg(0) & arg(1)) else: diff --git a/tests/test_passes.py b/tests/test_passes.py index f5dcbcd5..bdd66112 100644 --- a/tests/test_passes.py +++ b/tests/test_passes.py @@ -3,11 +3,11 @@ import os import sys import unittest +from typing import Callable import pyrtl -from pyrtl.wire import Const, Output from pyrtl.rtllib import testingutils as utils - +from pyrtl.wire import Const, Output from .test_transform import NetWireNumTestCases @@ -733,6 +733,87 @@ def test_nested_elimination(self): pyrtl.working_block().sanity_check() +class TestSynthPasses(unittest.TestCase): + in0: pyrtl.Input + in1: pyrtl.Input + out: pyrtl.Output + + def setUp(self): + pyrtl.reset_working_block() + self.in0 = pyrtl.Input(bitwidth=5, name='in0') + self.in1 = pyrtl.Input(bitwidth=5, name='in1') + self.out = pyrtl.Output(bitwidth=5, name='out') + + def check_synth(self, operation: Callable[[int, int], int]): + """ + Simulates the current circuit with some test input pairs (in0, in1) and checks the outputs + against the provided operation. Any synthesis/passes should be run before this gets called. + + :param operation: The operation to test against. This should be a lambda taking an input + pair (in0, in1) and returning the expected output from this circuit. + """ + sim_trace = pyrtl.SimulationTrace() + sim = pyrtl.Simulation(tracer=sim_trace) + + values = [(1, 2), (4, 5), (7, 11)] + """A list of input pairs (in0, in1) to test on.""" + + for in0, in1 in values: + expected_output = operation(in0, in1) + sim.step({'in0': in0, 'in1': in1}) + # compare simulation output to expected output + self.assertEqual(sim.inspect('out'), expected_output, + msg=f"Failed on inputs {in0} and {in1}") + + def test_nand_synth_and(self): + self.out <<= self.in0 & self.in1 + pyrtl.synthesize() + pyrtl.nand_synth() + self.check_synth(lambda a, b: a & b) + + def test_nand_synth_or(self): + self.out <<= self.in0 | self.in1 + pyrtl.synthesize() + pyrtl.nand_synth() + self.check_synth(lambda a, b: a | b) + + def test_nand_synth_xor(self): + self.out <<= self.in0 ^ self.in1 + pyrtl.synthesize() + pyrtl.nand_synth() + self.check_synth(lambda a, b: a ^ b) + + def test_nand_synth_adder(self): + self.out <<= self.in0 + self.in1 + pyrtl.synthesize() + pyrtl.nand_synth() + self.check_synth(lambda a, b: a + b) + + def test_and_inverter_synth_and(self): + self.out <<= self.in0 & self.in1 + pyrtl.synthesize() + pyrtl.and_inverter_synth() + self.check_synth(lambda a, b: a & b) + + def test_and_inverter_synth_or(self): + self.out <<= self.in0 | self.in1 + pyrtl.synthesize() + pyrtl.and_inverter_synth() + self.check_synth(lambda a, b: a | b) + + def test_and_inverter_synth_xor(self): + self.out <<= self.in0 ^ self.in1 + pyrtl.synthesize() + pyrtl.and_inverter_synth() + self.check_synth(lambda a, b: a ^ b) + + def test_and_inverter_synth_adder(self): + self.out <<= self.in0 + self.in1 + pyrtl.synthesize() + pyrtl.and_inverter_synth() + self.check_synth(lambda a, b: a + b) + + class TestSynthOptTiming(NetWireNumTestCases): def setUp(self): pyrtl.reset_working_block()