Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyrtl/passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -848,7 +848,7 @@ def arg(num):
elif net.op == '^':
all_1 = arg(0) & arg(1)
all_0 = ~arg(0) & ~arg(1)
dest <<= all_0 & ~all_1
dest <<= ~all_0 & ~all_1
elif net.op == 'n':
dest <<= ~(arg(0) & arg(1))
else:
Expand Down
85 changes: 83 additions & 2 deletions tests/test_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@
import os
import sys
import unittest
from typing import Callable

import pyrtl
from pyrtl.wire import Const, Output
from pyrtl.rtllib import testingutils as utils

from pyrtl.wire import Const, Output
from .test_transform import NetWireNumTestCases


Expand Down Expand Up @@ -733,6 +733,87 @@ def test_nested_elimination(self):
pyrtl.working_block().sanity_check()


class TestSynthPasses(unittest.TestCase):
in0: pyrtl.Input
in1: pyrtl.Input
out: pyrtl.Output

def setUp(self):
pyrtl.reset_working_block()
self.in0 = pyrtl.Input(bitwidth=5, name='in0')
self.in1 = pyrtl.Input(bitwidth=5, name='in1')
self.out = pyrtl.Output(bitwidth=5, name='out')

def check_synth(self, operation: Callable[[int, int], int]):
"""
Simulates the current circuit with some test input pairs (in0, in1) and checks the outputs
against the provided operation. Any synthesis/passes should be run before this gets called.

:param operation: The operation to test against. This should be a lambda taking an input
pair (in0, in1) and returning the expected output from this circuit.
"""
sim_trace = pyrtl.SimulationTrace()
sim = pyrtl.Simulation(tracer=sim_trace)

values = [(1, 2), (4, 5), (7, 11)]
"""A list of input pairs (in0, in1) to test on."""

for in0, in1 in values:
expected_output = operation(in0, in1)
sim.step({'in0': in0, 'in1': in1})
# compare simulation output to expected output
self.assertEqual(sim.inspect('out'), expected_output,
msg=f"Failed on inputs {in0} and {in1}")

def test_nand_synth_and(self):
self.out <<= self.in0 & self.in1
pyrtl.synthesize()
pyrtl.nand_synth()
self.check_synth(lambda a, b: a & b)

def test_nand_synth_or(self):
self.out <<= self.in0 | self.in1
pyrtl.synthesize()
pyrtl.nand_synth()
self.check_synth(lambda a, b: a | b)

def test_nand_synth_xor(self):
self.out <<= self.in0 ^ self.in1
pyrtl.synthesize()
pyrtl.nand_synth()
self.check_synth(lambda a, b: a ^ b)

def test_nand_synth_adder(self):
self.out <<= self.in0 + self.in1
pyrtl.synthesize()
pyrtl.nand_synth()
self.check_synth(lambda a, b: a + b)

def test_and_inverter_synth_and(self):
self.out <<= self.in0 & self.in1
pyrtl.synthesize()
pyrtl.and_inverter_synth()
self.check_synth(lambda a, b: a & b)

def test_and_inverter_synth_or(self):
self.out <<= self.in0 | self.in1
pyrtl.synthesize()
pyrtl.and_inverter_synth()
self.check_synth(lambda a, b: a | b)

def test_and_inverter_synth_xor(self):
self.out <<= self.in0 ^ self.in1
pyrtl.synthesize()
pyrtl.and_inverter_synth()
self.check_synth(lambda a, b: a ^ b)

def test_and_inverter_synth_adder(self):
self.out <<= self.in0 + self.in1
pyrtl.synthesize()
pyrtl.and_inverter_synth()
self.check_synth(lambda a, b: a + b)


class TestSynthOptTiming(NetWireNumTestCases):
def setUp(self):
pyrtl.reset_working_block()
Expand Down