diff --git a/docs/helpers.rst b/docs/helpers.rst index 1dba2319..894a2dcf 100644 --- a/docs/helpers.rst +++ b/docs/helpers.rst @@ -111,4 +111,5 @@ Encoders and Decoders --------------------- .. autofunction:: pyrtl.helperfuncs.one_hot_to_binary +.. autofunction:: pyrtl.helperfuncs.binary_to_one_hot diff --git a/pyrtl/__init__.py b/pyrtl/__init__.py index 821b9941..e30e6b41 100644 --- a/pyrtl/__init__.py +++ b/pyrtl/__init__.py @@ -41,6 +41,7 @@ from .helperfuncs import wire_struct from .helperfuncs import wire_matrix from .helperfuncs import one_hot_to_binary +from .helperfuncs import binary_to_one_hot from .corecircuits import and_all_bits from .corecircuits import or_all_bits diff --git a/pyrtl/helperfuncs.py b/pyrtl/helperfuncs.py index 2c8fec8b..9731b239 100644 --- a/pyrtl/helperfuncs.py +++ b/pyrtl/helperfuncs.py @@ -13,7 +13,15 @@ from .core import working_block, _NameIndexer, _get_debug_mode, Block from .pyrtlexceptions import PyrtlError, PyrtlInternalError from .wire import WireVector, Input, Output, Const, Register, WrappedWireVector -from .corecircuits import as_wires, rtl_all, rtl_any, concat, concat_list, select +from .corecircuits import ( + as_wires, + rtl_all, + rtl_any, + concat, + concat_list, + select, + shift_left_logical +) # ----------------------------------------------------------------- # ___ __ ___ __ __ @@ -1715,3 +1723,35 @@ def one_hot_to_binary(w) -> WireVector: already_found = already_found | w[i] return pos + + +def binary_to_one_hot(bit_position, max_bitwidth: int = None) -> WireVector: + '''Takes an input representing a bit position and returns a WireVector + with that bit position set to 1 and the others to 0. + + :param bit_position: WireVector, WireVector-like object, or something that can be converted + into a :py:class:`.Const` (in accordance with the :py:func:`.as_wires()` + required input). Example inputs: ``0b10``, ``0b1000``, ``4``. + :param max_bitwidth: Optional integer maximum bitwidth for the resulting one-hot WireVector. + :return: WireVector with the bit position given by the input set to 1 and all other bits + set to 0 (bit position 0 being the least significant bit). + + If the max_bitwidth provided is not sufficient for the given bit_position to be set to 1, + a ``0`` WireVector of size max_bitwidth will be returned. + + Examples:: + + binary_to_onehot(0) # returns 0b01 + binary_to_onehot(3) # returns 0b1000 + binary_to_onehot(0b100) # returns 0b10000 + ''' + + bit_position = as_wires(bit_position) + + if max_bitwidth is not None: + bitwidth = max_bitwidth + else: + bitwidth = 2 ** len(bit_position) + + # Need to dynamically set the appropriate bit position since bit_position may not be a Const + return shift_left_logical(Const(1, bitwidth=bitwidth), bit_position) diff --git a/tests/test_helperfuncs.py b/tests/test_helperfuncs.py index 972fb502..a9798fee 100644 --- a/tests/test_helperfuncs.py +++ b/tests/test_helperfuncs.py @@ -1814,5 +1814,47 @@ def test_no_ones(self): self.assertEqual(sim.inspect('o'), 0) +class TestBinaryToOneHot(unittest.TestCase): + def setUp(self): + pyrtl.reset_working_block() + + def test_simple_binary_to_one_hot(self): + bit_position = pyrtl.Input(bitwidth=8, name='bit_position') + one_hot = pyrtl.Output(name='one_hot') + one_hot <<= pyrtl.binary_to_one_hot(bit_position) + + self.assertEqual(one_hot.bitwidth, 256) + + sim = pyrtl.Simulation() + sim.step({bit_position: 0}) + self.assertEqual(sim.inspect('one_hot'), 0b01) + sim.step({bit_position: 2}) + self.assertEqual(sim.inspect('one_hot'), 0b0100) + sim.step({bit_position: 5}) + self.assertEqual(sim.inspect('one_hot'), 0b00100000) + sim.step({bit_position: 12}) + self.assertEqual(sim.inspect('one_hot'), 0b0001000000000000) + sim.step({bit_position: 15}) + self.assertEqual(sim.inspect('one_hot'), 0b1000000000000000) + + # Tests with the max_bitwidth set + def test_with_max_bitwidth(self): + bit_position = pyrtl.Input(bitwidth=8, name='bit_position') + one_hot = pyrtl.Output(name='one_hot') + one_hot <<= pyrtl.binary_to_one_hot(bit_position, max_bitwidth=4) + + self.assertEqual(one_hot.bitwidth, 4) + + sim = pyrtl.Simulation() + sim.step({bit_position: 0}) + self.assertEqual(sim.inspect('one_hot'), 0b0001) + sim.step({bit_position: 3}) + self.assertEqual(sim.inspect('one_hot'), 0b1000) + + # The max_bitwidth set is not enough for a bit position of 4 + sim.step({bit_position: 4}) + self.assertEqual(sim.inspect('one_hot'), 0b0000) + + if __name__ == "__main__": unittest.main()