Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/helpers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,5 @@ Encoders and Decoders
---------------------

.. autofunction:: pyrtl.helperfuncs.one_hot_to_binary
.. autofunction:: pyrtl.helperfuncs.binary_to_one_hot

1 change: 1 addition & 0 deletions pyrtl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
from .helperfuncs import wire_struct
from .helperfuncs import wire_matrix
from .helperfuncs import one_hot_to_binary
from .helperfuncs import binary_to_one_hot

from .corecircuits import and_all_bits
from .corecircuits import or_all_bits
Expand Down
44 changes: 43 additions & 1 deletion pyrtl/helperfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,15 @@
from .core import working_block, _NameIndexer, _get_debug_mode, Block
from .pyrtlexceptions import PyrtlError, PyrtlInternalError
from .wire import WireVector, Input, Output, Const, Register, WrappedWireVector
from .corecircuits import as_wires, rtl_all, rtl_any, concat, concat_list, select
from .corecircuits import (
as_wires,
rtl_all,
rtl_any,
concat,
concat_list,
select,
shift_left_logical
)

# -----------------------------------------------------------------
# ___ __ ___ __ __
Expand Down Expand Up @@ -1715,3 +1723,37 @@ def one_hot_to_binary(w) -> WireVector:
already_found = already_found | w[i]

return pos


def binary_to_one_hot(bit_position, max_bitwidth: int = None) -> WireVector:
'''Takes an input representing a bit position and returns a WireVector
with that bit position set to 1 and the others to 0.

:param bit_position: WireVector, WireVector-like object, or something that can be converted
into a :py:class:`.Const` (in accordance with the :py:func:`.as_wires()`
required input). Example inputs: ``0b10``, ``0b1000, ``4``.
:param max_bitwidth: Optional integer maximum bitwidth for the resulting one-hot WireVector.
:return: WireVector with the bit position, given by the input, set to 1 and all others set to 0.

If the max_bitwidth provided is not sufficient for the given bit_position to be set to 1,
a ``0`` WireVector of size max_bitwidth will be returned.

Examples::

binary_to_onehot(2) # returns 0b0100
binary_to_onehot(8) # returns 0b00100000
binary_to_onehot(12) # returns 0b0001000000000000
binary_to_onehot(15) # returns 0b1000000000000000

'''

bit_position = as_wires(bit_position)

if max_bitwidth is not None:
bitwidth = max_bitwidth
else:
bitwidth = 2 ** len(bit_position)

onehot = Const(1, bitwidth=bitwidth)

return shift_left_logical(onehot, bit_position)
42 changes: 42 additions & 0 deletions tests/test_helperfuncs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,5 +1814,47 @@ def test_no_ones(self):
self.assertEqual(sim.inspect('o'), 0)


class TestBinaryToOneHot(unittest.TestCase):
def setUp(self):
pyrtl.reset_working_block()

def test_simple_binary_to_one_hot(self):
i = pyrtl.Input(bitwidth=8, name='i')
o = pyrtl.Output(bitwidth=16, name='o')
o <<= pyrtl.binary_to_one_hot(i)

sim = pyrtl.Simulation()
sim.step({i: 0})
self.assertEqual(sim.inspect('o'), 0b01)
sim.step({i: 2})
self.assertEqual(sim.inspect('o'), 0b0100)
sim.step({i: 5})
self.assertEqual(sim.inspect('o'), 0b00100000)
sim.step({i: 12})
self.assertEqual(sim.inspect('o'), 0b0001000000000000)
sim.step({i: 15})
self.assertEqual(sim.inspect('o'), 0b1000000000000000)

def test_sufficient_max_bitwidth(self):
i = pyrtl.Input(bitwidth=8, name='i')
o = pyrtl.Output(bitwidth=16, name='o')
o <<= pyrtl.binary_to_one_hot(i, max_bitwidth=8)

sim = pyrtl.Simulation()
sim.step({i: 0})
self.assertEqual(sim.inspect('o'), 0b0001)
sim.step({i: 6})
self.assertEqual(sim.inspect('o'), 0b01000000)

def test_insufficient_max_bitwidth(self):
i = pyrtl.Input(bitwidth=8, name='i')
o = pyrtl.Output(bitwidth=16, name='o')
o <<= pyrtl.binary_to_one_hot(i, max_bitwidth=4)

sim = pyrtl.Simulation()
sim.step({i: 5})
self.assertEqual(sim.inspect('o'), 0b0000)


if __name__ == "__main__":
unittest.main()