Skip to content

Commit 4c72f2d

Browse files
authored
Adding one-hot to binary helper function (#456)
* Adding one-hot to binary helper function * Addressed comments for one-hot to binary helper function
1 parent e63b71d commit 4c72f2d

File tree

3 files changed

+76
-2
lines changed

3 files changed

+76
-2
lines changed

pyrtl/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
from .helperfuncs import find_and_print_loop
4141
from .helperfuncs import wire_struct
4242
from .helperfuncs import wire_matrix
43+
from .helperfuncs import one_hot_to_binary
4344

4445
from .corecircuits import and_all_bits
4546
from .corecircuits import or_all_bits

pyrtl/helperfuncs.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
from .core import working_block, _NameIndexer, _get_debug_mode, Block
1414
from .pyrtlexceptions import PyrtlError, PyrtlInternalError
1515
from .wire import WireVector, Input, Output, Const, Register, WrappedWireVector
16-
from .corecircuits import as_wires, rtl_all, rtl_any, concat, concat_list
16+
from .corecircuits import as_wires, rtl_all, rtl_any, concat, concat_list, select
1717

1818
# -----------------------------------------------------------------
1919
# ___ __ ___ __ __
@@ -1683,3 +1683,34 @@ def __len__(self):
16831683
return len(self._components)
16841684

16851685
return _WireMatrix
1686+
1687+
1688+
def one_hot_to_binary(w) -> WireVector:
1689+
'''Takes a one-hot input and returns the bit position of the high bit in binary.
1690+
1691+
:param w: WireVector or a WireVector-like object or something that can be converted
1692+
into a Const (in accordance with the :py:func:`as_wires()` required input). Example
1693+
inputs: 0b0010, 64, 0b01.
1694+
:return: The bit position of the high bit in binary as a WireVector.
1695+
1696+
If the input contains multiple 1s, the bit position of the first 1 will
1697+
be returned. If the input contains no 1s, 0 will be returned.
1698+
1699+
Examples::
1700+
1701+
one_hot_to_binary(0b0010) # returns 1
1702+
one_hot_to_binary(64) # returns 6
1703+
one_hot_to_binary(0b1100) # returns 2, the bit position of the first 1
1704+
one_hot_to_binary(0) # returns 0
1705+
'''
1706+
1707+
w = as_wires(w)
1708+
1709+
pos = 0 # Bit position of the first 1
1710+
already_found = as_wires(False) # True if first 1 already found, False otherwise
1711+
1712+
for i in range(len(w)):
1713+
pos = select(w[i] & ~already_found, i, pos)
1714+
already_found = already_found | w[i]
1715+
1716+
return pos

tests/test_helperfuncs.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
import pyrtl.helperfuncs
1111
from pyrtl.rtllib import testingutils as utils
1212

13-
1413
# ---------------------------------------------------------------
1514

15+
1616
class TestWireVectorList(unittest.TestCase):
1717
def setUp(self):
1818
pass
@@ -1772,5 +1772,47 @@ def test_byte_matrix_input_concatenate(self):
17721772
self.assertEqual(sim.inspect('byte_matrix[0].low'), 0xB)
17731773

17741774

1775+
class TestOneHotToBinary(unittest.TestCase):
1776+
def setUp(self):
1777+
pyrtl.reset_working_block()
1778+
1779+
def test_simple_onehot(self):
1780+
i = pyrtl.Input(bitwidth=8, name='i')
1781+
o = pyrtl.Output(bitwidth=3, name='o')
1782+
o <<= pyrtl.one_hot_to_binary(i)
1783+
1784+
sim = pyrtl.Simulation()
1785+
sim.step({i: 0b00000001})
1786+
self.assertEqual(sim.inspect('o'), 0)
1787+
sim.step({i: 0b10000000})
1788+
self.assertEqual(sim.inspect('o'), 7)
1789+
sim.step({i: 32})
1790+
self.assertEqual(sim.inspect('o'), 5)
1791+
sim.step({i: 16})
1792+
self.assertEqual(sim.inspect('o'), 4)
1793+
1794+
def test_multiple_ones(self):
1795+
i = pyrtl.Input(bitwidth=8, name='i')
1796+
o = pyrtl.Output(bitwidth=3, name='o')
1797+
o <<= pyrtl.one_hot_to_binary(i)
1798+
1799+
sim = pyrtl.Simulation()
1800+
sim.step({i: 0b00000101})
1801+
self.assertEqual(sim.inspect('o'), 0)
1802+
sim.step({i: 0b11000000})
1803+
self.assertEqual(sim.inspect('o'), 6)
1804+
sim.step({i: 0b10010010})
1805+
self.assertEqual(sim.inspect('o'), 1)
1806+
1807+
def test_no_ones(self):
1808+
i = pyrtl.Input(bitwidth=8, name='i')
1809+
o = pyrtl.Output(bitwidth=3, name='o')
1810+
o <<= pyrtl.one_hot_to_binary(i)
1811+
1812+
sim = pyrtl.Simulation()
1813+
sim.step({i: 0b00000000})
1814+
self.assertEqual(sim.inspect('o'), 0)
1815+
1816+
17751817
if __name__ == "__main__":
17761818
unittest.main()

0 commit comments

Comments
 (0)