Skip to content

Commit 65b0edf

Browse files
committed
Add add a min-sum optional decoder.
1 parent 86583e1 commit 65b0edf

File tree

6 files changed

+71
-21
lines changed

6 files changed

+71
-21
lines changed

examples/ieee802_11_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@
1919

2020
# create a decoder which assumes a probability of p=0.05 for bit flips by the channel
2121
# allow up to 20 iterations for the bp decoder.
22-
p = 0.05
23-
decoder = DecoderWiFi(spec=WiFiSpecCode.N648_R12, max_iter=20, channel_model=bsc_llr(p=p))
22+
p = 0.03
23+
decoder = DecoderWiFi(spec=WiFiSpecCode.N648_R12, max_iter=20, channel_model=bsc_llr(p=p), decoder_type="MS")
2424

2525
# create a corrupted version of encoded codeword with error rate p
2626
corrupted = BitArray(encoded)

ldpc/decoder/graph.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,13 @@ def add_v_node(self, ordering_key: int, name: str = "", channel_model: Optional[
3030
self.v_nodes[node.uid] = node
3131
return node
3232

33-
def add_c_node(self, name: str = "", ordering_key: Optional[int] = None) -> CNode:
33+
def add_c_node(self, name: str = "", ordering_key: Optional[int] = None, decoder: Optional[str] = "BP") -> CNode:
3434
"""
3535
:param ordering_key: use only for debug purposes
3636
:param name: name of node
37+
:param decoder: must be either "BP" or "MS" for min-sum decoder
3738
"""
38-
node = CNode(name, ordering_key)
39+
node = CNode(name, ordering_key, decoder=decoder)
3940
self.c_nodes[node.uid] = node
4041
return node
4142

@@ -68,10 +69,10 @@ def add_edges_by_name(self, edges_set: set[tuple[str, str]]) -> None:
6869
for v_name, c_name in edges_set:
6970
v_uid = [node.uid for node in self.v_nodes.values() if node.name == v_name]
7071
if not v_uid:
71-
raise ValueError("No v-node with name " + v_name + " in graph")
72+
raise ValueError(f"No v-node with name {v_name} in graph")
7273
c_uid = [node.uid for node in self.c_nodes.values() if node.name == c_name]
7374
if not c_uid:
74-
raise ValueError("No c-node with name " + c_name + " in graph")
75+
raise ValueError(f"No c-node with name {c_name} in graph")
7576
self.add_edge(v_uid[0], c_uid[0])
7677

7778
def get_edges(self, by_name: bool = False) -> Union[set[tuple[str, str]], EdgesSet]:
@@ -94,12 +95,14 @@ def to_nx(self) -> nx.Graph:
9495
return g
9596

9697
@classmethod
97-
def from_biadjacency_matrix(cls, h: npt.ArrayLike, channel_model: Optional[ChannelModel] = None) -> TannerGraph:
98+
def from_biadjacency_matrix(cls, h: npt.ArrayLike, channel_model: Optional[ChannelModel] = None,
99+
decoder: Optional[str] = "BP") -> TannerGraph:
98100
"""
99101
Creates a Tanner Graph from a biadjacency matrix, nodes are ordered according to matrix indices.
100102
101103
:param channel_model: channel model to compute channel symbols llr within v nodes
102104
:param h: parity check matrix, shape MXN with M check nodes and N variable nodes. assumed binary matrix.
105+
:param decoder: must be either "BP" or "MS" for min-sum decoder
103106
"""
104107
g = TannerGraph()
105108
h = np.array(h)
@@ -109,7 +112,7 @@ def from_biadjacency_matrix(cls, h: npt.ArrayLike, channel_model: Optional[Chann
109112
v_uid = g.add_v_node(name=f"v{i}", channel_model=channel_model, ordering_key=i).uid
110113
ordered_vnode_uid[i] = v_uid
111114
for j in range(m):
112-
c_uid = g.add_c_node(name=f"c{j}", ordering_key=j).uid
115+
c_uid = g.add_c_node(name=f"c{j}", ordering_key=j, decoder=decoder).uid
113116
for i in range(n):
114117
if h[j, i] == 1:
115118
g.add_edge(ordered_vnode_uid[i], c_uid)

ldpc/decoder/ieee802_11_decoder.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,19 @@ class DecoderWiFi(LogSpaDecoder):
1111
"""Decode messages according to the codes in the IEEE802.11n standard using Log SPA decoder"""
1212
_spec_base_path: str = os.path.join(os.path.dirname(os.path.dirname(__file__)), 'code_specs', 'ieee802.11')
1313

14-
def __init__(self, spec: WiFiSpecCode, max_iter: int, channel_model: Optional[ChannelModel] = None):
14+
def __init__(self, spec: WiFiSpecCode, max_iter: int, channel_model: Optional[ChannelModel] = None,
15+
decoder_type: Optional[str] = "BP"):
1516
"""
1617
1718
:param channel_model: a callable which receives a channel input, and returns the channel llr
1819
:param spec: specify which code from the spec we use
1920
:param max_iter: The maximal number of iterations for belief propagation algorithm
21+
:param decoder_type: must be either "BP" or "MS" for min-sum decoder
2022
"""
2123
self.spec = spec
2224
qc_file = QCFile.from_file(os.path.join(self._spec_base_path, spec.name + ".qc"))
2325
h = qc_file.to_array()
2426
m, n = h.shape
2527
k = n - m
26-
super().__init__(h, max_iter, info_idx=np.array([True] * k + [False] * m), channel_model=channel_model)
28+
super().__init__(h, max_iter, info_idx=np.array([True] * k + [False] * m), channel_model=channel_model,
29+
decoder_type=decoder_type)

ldpc/decoder/log_spa_decoder.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,18 +20,20 @@ class InfoBitsNotSpecified(Exception):
2020
class LogSpaDecoder:
2121
"""Decode codewords according to Log-SPA version of the belief propagation algorithm"""
2222
def __init__(self, h: ArrayLike, max_iter: int, info_idx: Optional[NDArray[np.bool_]] = None,
23-
channel_model: Optional[ChannelModel] = None):
23+
channel_model: Optional[ChannelModel] = None, decoder_type: Optional[str] = "BP"):
2424
"""
2525
2626
:param h:the parity check code matrix of the code
2727
:param max_iter: The maximal number of iterations for belief propagation algorithm
2828
:param info_idx: a boolean array representing the indices of information bits in the code
2929
:param channel_model: optional, a callable which receives a channel input, and returns the channel llr. If not
3030
specified, llr is expected to be fed into the decoder.
31+
:param decoder_type: must be either "BP" or "MS" for min-sum decoder
3132
"""
33+
self.decoder_type = decoder_type
3234
self.info_idx = info_idx
3335
self.h: npt.NDArray[np.int_] = np.array(h)
34-
self.graph = TannerGraph.from_biadjacency_matrix(h=self.h, channel_model=channel_model)
36+
self.graph = TannerGraph.from_biadjacency_matrix(h=self.h, channel_model=channel_model, decoder=decoder_type)
3537
self.n = len(self.graph.v_nodes)
3638
self.max_iter = max_iter
3739
ordered_cnodes = sorted(self.graph.c_nodes.values())
@@ -80,14 +82,16 @@ def decode(self, channel_word: Sequence[np.float_], max_iter: Optional[int] = No
8082
if not syndrome.any():
8183
break
8284

85+
# for each vnode how many equations are failed
86+
vnode_validity: npt.NDArray[np.int_] = np.dot(syndrome, self.h)
8387
# for each vnode how many equations are fulfilled
84-
vnode_validity: npt.NDArray[np.int_] = np.zeros(self.n, dtype=np.int_)
85-
syndrome_compliance = {cnode: int(val == 0) for cnode, val in zip(self.ordered_cnodes_uids, syndrome)}
86-
87-
for idx, vnode in enumerate(self._ordered_vnodes):
88-
neighbors = vnode.get_neighbors()
89-
for neighbor in neighbors:
90-
vnode_validity[idx] += 2*syndrome_compliance[neighbor] - 1
88+
# vnode_validity: npt.NDArray[np.int_] = np.zeros(self.n, dtype=np.int_)
89+
# syndrome_compliance = {cnode: int(val == 0) for cnode, val in zip(self.ordered_cnodes_uids, syndrome)}
90+
#
91+
# for idx, vnode in enumerate(self._ordered_vnodes):
92+
# neighbors = vnode.get_neighbors()
93+
# for neighbor in neighbors:
94+
# vnode_validity[idx] += 2*syndrome_compliance[neighbor] - 1
9195
return estimate, llr, not syndrome.any(), iteration+1, syndrome, vnode_validity
9296

9397
def info_bits(self, estimate: NDArray[np.int_]) -> Bits:

ldpc/decoder/node.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,16 @@ def __lt__(self, other: Any) -> bool:
7373

7474
class CNode(Node):
7575
"""Check nodes in Tanner graph"""
76+
77+
def __init__(self, name: str = "", ordering_key: Optional[int] = None, decoder: Optional[str] = "BP") -> None:
78+
"""
79+
:param ordering_key: used to order nodes per their order in the parity check matrix
80+
:param name: optional name of node
81+
:param decoder: must be either "BP" or "MS" for min-sum decoder
82+
"""
83+
self.decoder_type = decoder
84+
super().__init__(name, ordering_key)
85+
7686
def initialize(self) -> None:
7787
"""
7888
clear received messages
@@ -84,11 +94,16 @@ def message(self, requester_uid: int) -> np.float_:
8494
pass messages from c-nodes to v-nodes
8595
:param requester_uid: uid of requesting v-node
8696
"""
97+
q: npt.NDArray[np.float_] = np.array([msg for uid, msg in self.received_messages.items() if uid != requester_uid])
98+
if self.decoder_type == "MS":
99+
return np.prod(np.sign(q)) * np.absolute(q).min() # type: ignore
100+
101+
102+
# full BP
87103
def phi(x: npt.NDArray[np.float_]) -> Any:
88104
"""see sources for definition and reasons for use of this function"""
89105
return -np.log(np.maximum(np.tanh(x/2), 1e3*np.finfo(np.float_).eps))
90-
q: npt.NDArray[np.float_] = np.array([msg for uid, msg in self.received_messages.items() if uid != requester_uid])
91-
return np.prod(np.sign(q))*phi(sum(phi(np.absolute(q)+np.finfo(np.float_).eps))) # type: ignore
106+
return np.prod(np.sign(q))*phi(sum(phi(np.absolute(q)))) # type: ignore
92107

93108

94109
class VNode(Node):

tests/test_ieee802_11.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,31 @@ def test_decoder_648_r12(self) -> None:
102102
assert sum(encoded_ref ^ decoded) == 0
103103
assert sum(info_bits ^ decoded_info) == 0
104104

105+
def test_ms_decoder_1296_r23(self) -> None:
106+
info_bits = Bits(auto=np.genfromtxt(
107+
'tests/test_data/ieee_802_11/info_bits_N1296_R23.csv', delimiter=',', dtype=np.int_))
108+
encoded_ref = Bits(auto=np.genfromtxt(
109+
'tests/test_data/ieee_802_11/encoded_N1296_R23.csv', delimiter=',', dtype=np.int_))
110+
p = 0.01
111+
112+
corrupted = BitArray(encoded_ref)
113+
no_errors = int(len(corrupted) * p)
114+
rng = np.random.default_rng()
115+
error_idx = rng.choice(len(corrupted), size=no_errors, replace=False)
116+
for idx in error_idx:
117+
corrupted[idx] = not corrupted[idx]
118+
119+
decoder = DecoderWiFi(spec=WiFiSpecCode.N1296_R23, max_iter=20, channel_model=bsc_llr(p=p),decoder_type="MS")
120+
decoded = Bits()
121+
decoded_info = Bits()
122+
for frame_idx in range(len(corrupted) // decoder.n):
123+
decoder_output = decoder.decode(corrupted[frame_idx * decoder.n: (frame_idx + 1) * decoder.n])
124+
decoded += decoder_output[0]
125+
decoded_info += decoder.info_bits(decoder_output[0])
126+
assert decoder_output[2] is True
127+
assert sum(encoded_ref ^ decoded) == 0
128+
assert sum(info_bits ^ decoded_info) == 0
129+
105130
def test_decoder_no_info(self) -> None:
106131
p = 0.1
107132
h = QCFile.from_file("ldpc/code_specs/ieee802.11/N648_R12.qc").to_array()

0 commit comments

Comments
 (0)