Skip to content

Commit 30f2c29

Browse files
authored
Merge pull request #24 from YairMZ/bp_decoder
allow decoder to work by receiving llr's, and without specifying a channel model
2 parents eb11a41 + 399a580 commit 30f2c29

File tree

10 files changed

+90
-35
lines changed

10 files changed

+90
-35
lines changed

README.md

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,11 @@ python -m pytest -n auto --cov-report=html
2323
```
2424
to run tests in parallel (with number of CPU's dictated by machine) to speed up tests.
2525

26+
Verify static typing with
27+
```shell
28+
mypy --strict --config-file .mypy.ini .
29+
```
30+
2631
-----
2732
## Included modules
2833
- [Utilities](ldpc/utils/README.md): implementing various utility operations to assist with encoding, decoding and
@@ -36,7 +41,7 @@ simulations.
3641
```python
3742
import numpy as np
3843
from bitstring import BitArray, Bits
39-
from ldpc.decoder import LogSpaDecoder, bsc_llr
44+
from ldpc.decoder import DecoderWiFi, bsc_llr
4045
from ldpc.encoder import EncoderWiFi
4146
from ldpc.wifi_spec_codes import WiFiSpecCode
4247
from ldpc.utils import QCFile
@@ -56,7 +61,7 @@ np.dot(h, np.array(encoded)) % 2 # creates an all zero vector as required.
5661
# create a decoder which assumes a probability of p=0.05 for bit flips by the channel
5762
# allow up to 20 iterations for the bp decoder.
5863
p = 0.05
59-
decoder = LogSpaDecoder(bsc_llr(p=p), h=h, max_iter=20, info_idx=np.array([True]*324 + [False]*324))
64+
decoder = DecoderWiFi(spec=WiFiSpecCode.N648_R12, max_iter=20, channel_model=bsc_llr(p=p))
6065

6166
# create a corrupted version of encoded codeword with error rate p
6267
corrupted = BitArray(encoded)
@@ -68,6 +73,15 @@ decoded, llr, decode_success, num_of_iterations = decoder.decode(corrupted)
6873
# Verify correct decoding
6974
print(Bits(decoded) == encoded) # true
7075
info = decoder.info_bits(decoded)
76+
77+
# a decoder can also be instantiated without a channel model, in which case llr is expected to be sent for decoding instead of
78+
# hard channel outputs.
79+
channel = bsc_llr(p=p)
80+
channel_llr = channel(np.array(corrupted, dtype=np.int_))
81+
decoder = DecoderWiFi(spec=WiFiSpecCode.N648_R12, max_iter=20)
82+
decoded, llr2, decode_success, num_of_iterations = decoder.decode(channel_llr)
83+
print(Bits(decoded) == encoded) # true
84+
info = decoder.info_bits(decoded)
7185
```
7286
The example is also included as a jupyter notebook. Note however, that you need to launch the notebook from the correct
7387
path for it to be able to access installed packages. To run the notebook:

examples/ieee802_11_example.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
# create a decoder which assumes a probability of p=0.05 for bit flips by the channel
2121
# allow up to 20 iterations for the bp decoder.
2222
p = 0.05
23-
decoder = DecoderWiFi(bsc_llr(p=p), spec=WiFiSpecCode.N648_R12, max_iter=20)
23+
decoder = DecoderWiFi(spec=WiFiSpecCode.N648_R12, max_iter=20, channel_model=bsc_llr(p=p))
2424

2525
# create a corrupted version of encoded codeword with error rate p
2626
corrupted = BitArray(encoded)
@@ -32,3 +32,12 @@
3232
# Verify correct decoding
3333
print(Bits(decoded) == encoded) # true
3434
info = decoder.info_bits(decoded)
35+
36+
# a decoder can also be instantiated without a channel model, in which case llr is expected to be sent for decoding instead of
37+
# hard channel outputs.
38+
channel = bsc_llr(p=p)
39+
channel_llr = channel(np.array(corrupted, dtype=np.int_))
40+
decoder = DecoderWiFi(spec=WiFiSpecCode.N648_R12, max_iter=20)
41+
decoded, llr2, decode_success, num_of_iterations = decoder.decode(channel_llr)
42+
print(Bits(decoded) == encoded) # true
43+
info = decoder.info_bits(decoded)

ldpc/decoder/channel_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
__all__ = ["ChannelModel", "bsc_llr"]
55

6-
ChannelModel = Callable[[int], np.float_]
6+
ChannelModel = Callable[[np.float_], np.float_]
77

88

99
def bsc_llr(p: float) -> ChannelModel:

ldpc/decoder/graph.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,13 @@ def __init__(self) -> None:
2020
self.c_nodes: dict[int, CNode] = {}
2121
self.edges: EdgesSet = set()
2222

23-
def add_v_node(self, channel_model: ChannelModel, ordering_key: int, name: str = "") -> VNode:
23+
def add_v_node(self, ordering_key: int, name: str = "", channel_model: Optional[ChannelModel] = None) -> VNode:
2424
"""
2525
:param ordering_key: should reflect order according to parity check matrix, channel symbols in order
2626
:param name: name of node.
27-
:param channel_model: add an exiting node to graph. If not used a new node is created.
27+
:param channel_model: optional channel model to compute llr out of hard channel outputs. Of not used llr are expected.
2828
"""
29-
node = VNode(channel_model, ordering_key, name)
29+
node = VNode(ordering_key, name=name, channel_model=channel_model)
3030
self.v_nodes[node.uid] = node
3131
return node
3232

@@ -94,7 +94,7 @@ def to_nx(self) -> nx.Graph:
9494
return g
9595

9696
@classmethod
97-
def from_biadjacency_matrix(cls, h: npt.ArrayLike, channel_model: ChannelModel) -> TannerGraph:
97+
def from_biadjacency_matrix(cls, h: npt.ArrayLike, channel_model: Optional[ChannelModel] = None) -> TannerGraph:
9898
"""
9999
Creates a Tanner Graph from a biadjacency matrix, nodes are ordered according to matrix indices.
100100

ldpc/decoder/ieee802_11_decoder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
from ldpc.wifi_spec_codes import WiFiSpecCode
55
import os
66
from ldpc.utils.qc_format import QCFile
7+
from typing import Optional
78

89

910
class DecoderWiFi(LogSpaDecoder):
1011
"""Decode messages according to the codes in the IEEE802.11n standard using Log SPA decoder"""
1112
_spec_base_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'code_specs', 'ieee802.11')
1213

13-
def __init__(self, channel_model: ChannelModel, spec: WiFiSpecCode, max_iter: int):
14+
def __init__(self, spec: WiFiSpecCode, max_iter: int, channel_model: Optional[ChannelModel] = None):
1415
"""
1516
1617
:param channel_model: a callable which receives a channel input, and returns the channel llr
@@ -22,4 +23,4 @@ def __init__(self, channel_model: ChannelModel, spec: WiFiSpecCode, max_iter: in
2223
h = qc_file.to_array()
2324
m, n = h.shape
2425
k = n - m
25-
super().__init__(channel_model, h, max_iter, info_idx=np.array([True] * k + [False] * m))
26+
super().__init__(h, max_iter, info_idx=np.array([True] * k + [False] * m), channel_model=channel_model)

ldpc/decoder/log_spa_decoder.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,22 +19,23 @@ class InfoBitsNotSpecified(Exception):
1919

2020
class LogSpaDecoder:
2121
"""Decode codewords according to Log-SPA version of the belief propagation algorithm"""
22-
def __init__(self, channel_model: ChannelModel, h: ArrayLike, max_iter: int,
23-
info_idx: Optional[NDArray[np.bool_]] = None):
22+
def __init__(self, h: ArrayLike, max_iter: int, info_idx: Optional[NDArray[np.bool_]] = None,
23+
channel_model: Optional[ChannelModel] = None):
2424
"""
2525
26-
:param channel_model: a callable which receives a channel input, and returns the channel llr
2726
:param h:the parity check code matrix of the code
2827
:param max_iter: The maximal number of iterations for belief propagation algorithm
2928
:param info_idx: a boolean array representing the indices of information bits in the code
29+
:param channel_model: optional, a callable which receives a channel input, and returns the channel llr. If not
30+
specified, llr is expected to be fed into the decoder.
3031
"""
3132
self.info_idx = info_idx
3233
self.h: npt.NDArray[np.int_] = np.array(h)
3334
self.graph = TannerGraph.from_biadjacency_matrix(h=self.h, channel_model=channel_model)
3435
self.n = len(self.graph.v_nodes)
3536
self.max_iter = max_iter
3637

37-
def decode(self, channel_word: Sequence[int], max_iter: Optional[int] = None) \
38+
def decode(self, channel_word: Sequence[np.float_], max_iter: Optional[int] = None) \
3839
-> tuple[NDArray[np.int_], NDArray[np.float_], bool, int]:
3940
"""
4041
decode a sequence received from the channel
@@ -69,7 +70,8 @@ def decode(self, channel_word: Sequence[int], max_iter: Optional[int] = None) \
6970

7071
# Check stop condition
7172
llr: npt.NDArray[np.float_] = np.array([node.estimate() for node in self.graph.ordered_v_nodes()])
72-
estimate: npt.NDArray[np.int_] = np.array([1 if node_llr < 0 else 0 for node_llr in llr], dtype=np.int_)
73+
estimate: npt.NDArray[np.int_] = np.array(llr < 0, dtype=np.int_)
74+
# np.array([1 if node_llr < 0 else 0 for node_llr in llr], dtype=np.int_)
7375
syndrome = self.h.dot(estimate) % 2
7476
if not syndrome.any():
7577
break

ldpc/decoder/node.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, name: str = "", ordering_key: Optional[int] = None) -> None:
2525
:param name: name of node
2626
"""
2727
self.uid = next(Node._uid_generator)
28-
self.name = name if name else str(self.uid)
28+
self.name = name or str(self.uid)
2929
self.ordering_key = ordering_key if ordering_key is not None else self.uid
3030
self.neighbors: dict[int, Node] = {} # keys as senders uid
3131
self.received_messages: dict[int, Any] = {} # keys as senders uid, values as messages
@@ -34,10 +34,7 @@ def register_neighbor(self, neighbor: Node) -> None:
3434
self.neighbors[neighbor.uid] = neighbor
3535

3636
def __str__(self) -> str:
37-
if self.name:
38-
return self.name
39-
else:
40-
return str(self.uid)
37+
return self.name or str(self.uid)
4138

4239
def get_neighbors(self) -> list[int]:
4340
"""
@@ -96,24 +93,27 @@ def phi(x: npt.NDArray[np.float_]) -> Any:
9693

9794
class VNode(Node):
9895
"""Variable nodes in Tanner graph"""
99-
def __init__(self, channel_model: ChannelModel, ordering_key: int, name: str = ""):
96+
def __init__(self, ordering_key: int, name: str = "", channel_model: Optional[ChannelModel] = None):
10097
"""
10198
:param channel_model: a function which receives channel outputs anr returns relevant message
10299
:param ordering_key: used to order nodes per their order in the parity check matrix
103100
:param name: optional name of node
104101
"""
105-
self.channel_model: ChannelModel = channel_model
106-
self.channel_symbol: int = None # type: ignore # currently assuming hard channel symbols
102+
self.channel_model: Optional[ChannelModel] = channel_model
103+
self.channel_symbol: np.float_ = None # type: ignore
107104
self.channel_llr: np.float_ = None # type: ignore
108105
super().__init__(name, ordering_key)
109106

110-
def initialize(self, channel_symbol: int) -> None: # type: ignore
107+
def initialize(self, channel_symbol: np.float_) -> None: # type: ignore
111108
"""
112109
clear received messages and initialize channel llr with channel bit
113110
:param channel_symbol: bit received from channel, currently assumes hard inputs.
114111
"""
115112
self.channel_symbol = channel_symbol
116-
self.channel_llr = self.channel_model(channel_symbol)
113+
if self.channel_model is not None:
114+
self.channel_llr = self.channel_model(channel_symbol)
115+
else:
116+
self.channel_llr = channel_symbol
117117
self.received_messages = {node_uid: 0 for node_uid in self.neighbors}
118118

119119
def message(self, requester_uid: int) -> np.float_:
@@ -127,4 +127,4 @@ def message(self, requester_uid: int) -> np.float_:
127127

128128
def estimate(self) -> np.float_:
129129
"""provide a soft bit estimate"""
130-
return self.channel_llr + np.sum(list(self.received_messages.values())) # type: ignore
130+
return self.channel_llr + np.sum(list(self.received_messages.values()), dtype=np.float_) # type: ignore

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from setuptools import setup
22
import pathlib
33

4-
VERSION = '0.1.3'
4+
VERSION = '0.2.0'
55

66
# The directory containing this file
77
HERE = pathlib.Path(__file__).resolve().parent

tests/test_ieee802_11.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ class TestDecoder802_11:
7272
def test_incorrect_length(self) -> None:
7373
p = 0.1
7474
h = QCFile.from_file("ldpc/code_specs/ieee802.11/N648_R12.qc").to_array()
75-
decoder = LogSpaDecoder(bsc_llr(p=p), h=h, max_iter=20)
75+
decoder = LogSpaDecoder(h=h, max_iter=20, channel_model=bsc_llr(p=p))
7676
bits: npt.NDArray[np.int_] = np.array([1, 1, 0], dtype=np.int_)
7777
with pytest.raises(IncorrectLength):
7878
decoder.decode(bits) # type: ignore
@@ -91,7 +91,7 @@ def test_decoder_648_r12(self) -> None:
9191
for idx in error_idx:
9292
corrupted[idx] = not corrupted[idx]
9393

94-
decoder = DecoderWiFi(bsc_llr(p=p), spec=WiFiSpecCode.N648_R12, max_iter=20)
94+
decoder = DecoderWiFi(spec=WiFiSpecCode.N648_R12, max_iter=20, channel_model=bsc_llr(p=p))
9595
decoded = Bits()
9696
decoded_info = Bits()
9797
for frame_idx in range(len(corrupted) // decoder.n):
@@ -105,7 +105,7 @@ def test_decoder_648_r12(self) -> None:
105105
def test_decoder_no_info(self) -> None:
106106
p = 0.1
107107
h = QCFile.from_file("ldpc/code_specs/ieee802.11/N648_R12.qc").to_array()
108-
decoder = LogSpaDecoder(bsc_llr(p=p), h=h, max_iter=20)
108+
decoder = LogSpaDecoder(h=h, max_iter=20, channel_model=bsc_llr(p=p))
109109
estimate: npt.NDArray[np.int_] = np.array(Bits(bytes=bytes(list(range(81)))), np.int_)
110110
with pytest.raises(InfoBitsNotSpecified):
111111
decoder.info_bits(estimate)

wifi_example.ipynb

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,16 @@
1111
"source": [
1212
"import numpy as np\n",
1313
"from bitstring import BitArray, Bits\n",
14-
"from ldpc.decoder import LogSpaDecoder, bsc_llr\n",
14+
"from ldpc.decoder import DecoderWiFi, bsc_llr\n",
1515
"from ldpc.encoder import EncoderWiFi\n",
1616
"from ldpc.wifi_spec_codes import WiFiSpecCode"
1717
],
18-
"execution_count": 4,
18+
"execution_count": 1,
1919
"outputs": []
2020
},
2121
{
2222
"cell_type": "code",
23-
"execution_count": 5,
23+
"execution_count": 2,
2424
"outputs": [],
2525
"source": [
2626
"# create information bearing bits\n",
@@ -38,7 +38,7 @@
3838
"# create a decoder which assumes a probability of p=0.05 for bit flips by the channel\n",
3939
"# allow up to 20 iterations for the bp decoder.\n",
4040
"p = 0.05\n",
41-
"decoder = LogSpaDecoder(bsc_llr(p=p), h=h, max_iter=20, info_idx=np.array([True]*324 + [False]*324))"
41+
"decoder = DecoderWiFi(spec=WiFiSpecCode.N648_R12, max_iter=20, channel_model=bsc_llr(p=p))"
4242
],
4343
"metadata": {
4444
"collapsed": false,
@@ -49,7 +49,7 @@
4949
},
5050
{
5151
"cell_type": "code",
52-
"execution_count": 6,
52+
"execution_count": 3,
5353
"outputs": [
5454
{
5555
"name": "stdout",
@@ -77,6 +77,35 @@
7777
"name": "#%%\n"
7878
}
7979
}
80+
},
81+
{
82+
"cell_type": "code",
83+
"execution_count": 4,
84+
"outputs": [
85+
{
86+
"name": "stdout",
87+
"output_type": "stream",
88+
"text": [
89+
"True\n"
90+
]
91+
}
92+
],
93+
"source": [
94+
"# a decoder can also be instantiated without a channel model, in which case llr is expected to be sent for decoding instead of\n",
95+
"# hard channel outputs.\n",
96+
"channel = bsc_llr(p=p)\n",
97+
"channel_llr = channel(np.array(corrupted, dtype=np.int_))\n",
98+
"decoder = DecoderWiFi(spec=WiFiSpecCode.N648_R12, max_iter=20)\n",
99+
"decoded, llr2, decode_success, num_of_iterations = decoder.decode(channel_llr)\n",
100+
"print(Bits(decoded) == encoded) # true\n",
101+
"info = decoder.info_bits(decoded)"
102+
],
103+
"metadata": {
104+
"collapsed": false,
105+
"pycharm": {
106+
"name": "#%%\n"
107+
}
108+
}
80109
}
81110
],
82111
"metadata": {

0 commit comments

Comments
 (0)