Skip to content

Commit fb863af

Browse files
authored
Merge pull request #32 from YairMZ/numba
Add numba compilation of vnode messages, suppress mypy issues
2 parents c060164 + 2cd1270 commit fb863af

File tree

11 files changed

+57
-35
lines changed

11 files changed

+57
-35
lines changed

.mypy.ini

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,4 +30,7 @@ allow_untyped_calls = True
3030
ignore_missing_imports = True
3131

3232
[mypy-setuptools.*]
33+
ignore_missing_imports = True
34+
35+
[mypy-numba.*]
3336
ignore_missing_imports = True

ldpc/decoder/log_spa_decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def decode(self, channel_word: Sequence[np.float_], max_iter: Optional[int] = No
8383
break
8484

8585
# for each vnode how many equations are failed
86-
vnode_validity: npt.NDArray[np.int_] = np.dot(syndrome, self.h)
86+
vnode_validity: npt.NDArray[np.int_] = np.dot(syndrome, self.h) # type: ignore
8787
# for each vnode how many equations are fulfilled
8888
# vnode_validity: npt.NDArray[np.int_] = np.zeros(self.n, dtype=np.int_)
8989
# syndrome_compliance = {cnode: int(val == 0) for cnode, val in zip(self.ordered_cnodes_uids, syndrome)}

ldpc/decoder/node.py

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,27 @@
11
from __future__ import annotations
22
import numpy as np
33
import itertools
4-
from typing import Any, Optional
4+
from typing import Any, Optional, Tuple
55
from functools import total_ordering
66
from abc import ABC, abstractmethod
77
import numpy.typing as npt
88
from ldpc.decoder.channel_models import ChannelModel
9+
from numba import jit
910

1011

1112
__all__ = ["Node", "CNode", "VNode"]
1213

1314

15+
@jit(nopython=True, cache=True) # type: ignore
16+
def c_message(requester_uid: int, senders: Tuple[int], received_messages: Tuple[float]) -> np.float_:
17+
q: npt.NDArray[np.float_] = np.array([msg for uid, msg in zip(senders, received_messages) if uid != requester_uid])
18+
return -np.prod(np.sign(q)) * np.log( # type: ignore
19+
np.maximum(np.tanh(
20+
sum(
21+
-np.log(np.maximum(np.tanh(np.absolute(q) / 2), 1e3 * np.finfo(np.float_).eps))
22+
) / 2), 1e3 * np.finfo(np.float_).eps))
23+
24+
1425
@total_ordering # type: ignore
1526
class Node(ABC):
1627
"""Base class VNodes anc CNodes.
@@ -28,7 +39,7 @@ def __init__(self, name: str = "", ordering_key: Optional[int] = None) -> None:
2839
self.name = name or str(self.uid)
2940
self.ordering_key = ordering_key if ordering_key is not None else self.uid
3041
self.neighbors: dict[int, Node] = {} # keys as senders uid
31-
self.received_messages: dict[int, Any] = {} # keys as senders uid, values as messages
42+
self.received_messages: dict[int, np.float_] = {} # keys as senders uid, values as messages
3243

3344
def register_neighbor(self, neighbor: Node) -> None:
3445
self.neighbors[neighbor.uid] = neighbor
@@ -48,7 +59,7 @@ def receive_messages(self) -> None:
4859
self.received_messages[node_id] = node.message(self.uid)
4960

5061
@abstractmethod
51-
def message(self, requester_uid: int) -> Any:
62+
def message(self, requester_uid: int) -> np.float_:
5263
"""Used to return a message to the requesting node"""
5364
pass
5465

@@ -87,23 +98,30 @@ def initialize(self) -> None:
8798
"""
8899
clear received messages
89100
"""
90-
self.received_messages = {node_uid: 0 for node_uid in self.neighbors}
101+
self.received_messages = {node_uid: 0 for node_uid in self.neighbors} # type: ignore
91102

92103
def message(self, requester_uid: int) -> np.float_:
93104
"""
94105
pass messages from c-nodes to v-nodes
95106
:param requester_uid: uid of requesting v-node
96107
"""
97-
q: npt.NDArray[np.float_] = np.array([msg for uid, msg in self.received_messages.items() if uid != requester_uid])
98108
if self.decoder_type == "MS":
109+
q: npt.NDArray[np.float_] = np.array([msg for uid, msg in self.received_messages.items() if uid != requester_uid])
99110
return np.prod(np.sign(q)) * np.absolute(q).min() # type: ignore
100111

101-
102112
# full BP
103-
def phi(x: npt.NDArray[np.float_]) -> Any:
104-
"""see sources for definition and reasons for use of this function"""
105-
return -np.log(np.maximum(np.tanh(x/2), 1e3*np.finfo(np.float_).eps))
106-
return np.prod(np.sign(q))*phi(sum(phi(np.absolute(q)))) # type: ignore
113+
return c_message(requester_uid, tuple(self.received_messages.keys()), tuple(self.received_messages.values())) # type: ignore
114+
115+
# def phi(x: npt.NDArray[np.float_]) -> npt.NDArray[np.float_]:
116+
# """see sources for definition and reasons for use of this function"""
117+
# return -np.log(np.maximum(np.tanh(x / 2), 1e3 * np.finfo(np.float_).eps))
118+
# return np.prod(np.sign(q))*phi(sum(phi(np.absolute(q)))) # type: ignore
119+
120+
# return -np.prod(np.sign(q)) * np.log(
121+
# np.maximum(np.tanh(
122+
# sum(
123+
# -np.log(np.maximum(np.tanh(np.absolute(q) / 2), 1e3 * np.finfo(np.float_).eps))
124+
# ) / 2), 1e3 * np.finfo(np.float_).eps))
107125

108126

109127
class VNode(Node):
@@ -129,7 +147,7 @@ def initialize(self, channel_symbol: np.float_) -> None: # type: ignore
129147
self.channel_llr = self.channel_model(channel_symbol)
130148
else:
131149
self.channel_llr = channel_symbol
132-
self.received_messages = {node_uid: 0 for node_uid in self.neighbors}
150+
self.received_messages = {node_uid: 0 for node_uid in self.neighbors} # type: ignore
133151

134152
def message(self, requester_uid: int) -> np.float_:
135153
"""

ldpc/encoder/h_based_encoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,6 @@ def encode(self, information_bits: Bits) -> Bits:
3434
p: npt.NDArray[np.int_] = np.mod(np.matmul(self.h[:, :self.k], info), 2)
3535
if not self.identity_p:
3636
for l in range(1, self.m):
37-
p[l] += np.mod(np.dot(self.h[l, self.k:self.k+l], p[:l]), 2)
37+
p[l] += np.mod(np.dot(self.h[l, self.k:self.k+l], p[:l]), 2) # type: ignore
3838
encoded[self.k:] = p
3939
return Bits(encoded)

requirements.txt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
numpy~=1.22.0
2-
bitstring~=3.1.9
3-
scipy~=1.7.2
4-
networkx~=2.6.3
1+
numpy~=1.21.6
2+
bitstring>=3.1.9
3+
scipy>=1.7.2
4+
networkx>=2.6.3
5+
numba

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
11
[metadata]
2-
description-file = README.md
2+
description_file = README.md

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from setuptools import setup
22
import pathlib
33

4-
VERSION = '0.3.2'
4+
VERSION = '0.3.3'
55

66
# The directory containing this file
77
HERE = pathlib.Path(__file__).resolve().parent

tests/test_a_list.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def test_array_io(self) -> None:
4444
[0, 1, 3],
4545
[0, 2, 3]]
4646
b = a.to_array()
47-
np.testing.assert_array_equal(arr, b)
47+
np.testing.assert_array_equal(arr, b) # type: ignore
4848

4949
def test_verify_alist(self) -> None:
5050
original_file = "ldpc/code_specs/Mackay_96.3.963.alist"

tests/test_encoders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ def test_params(self) -> None:
1717
enc = EncoderG(g)
1818
assert enc.n == 7
1919
assert enc.k == 4
20-
np.testing.assert_array_equal(g, enc.generator)
20+
np.testing.assert_array_equal(g, enc.generator) # type: ignore
2121

2222
def test_encoding(self) -> None:
2323
g = AList.from_file("tests/test_data/Hamming_7_4_g.alist").to_array()
@@ -47,7 +47,7 @@ def test_params(self) -> None:
4747
assert enc.n == 4098
4848
assert enc.m == 3095
4949
assert enc.k == 4098-3095
50-
np.testing.assert_array_equal(h, enc.h)
50+
np.testing.assert_array_equal(h, enc.h) # type: ignore
5151

5252
def test_encoding(self) -> None:
5353
h = AList.from_file("tests/test_data/systematic_4098_3095.alist").to_array()

tests/test_ieee802_11.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ def test_incorrect_length(self) -> None:
1818
def test_encoding_648_r12(self) -> None:
1919
# comparing encodings with reference implementation by: https://github.com/tavildar/LDPC
2020
info_bits = Bits(auto=np.genfromtxt(
21-
'tests/test_data/ieee_802_11/info_bits_N648_R12.csv', delimiter=',', dtype=np.int_))
21+
'tests/test_data/ieee_802_11/info_bits_N648_R12.csv', delimiter=',', dtype=np.int_)) # type: ignore
2222
encoded_ref = Bits(auto=np.genfromtxt(
23-
'tests/test_data/ieee_802_11/encoded_N648_R12.csv', delimiter=',', dtype=np.int_))
23+
'tests/test_data/ieee_802_11/encoded_N648_R12.csv', delimiter=',', dtype=np.int_)) # type: ignore
2424
enc = EncoderWiFi(WiFiSpecCode.N648_R12)
2525

2626
encoded = Bits()
@@ -31,9 +31,9 @@ def test_encoding_648_r12(self) -> None:
3131
def test_encoding_648_r56(self) -> None:
3232
# comparing encodings with reference implementation by: https://github.com/tavildar/LDPC
3333
info_bits = Bits(auto=np.genfromtxt(
34-
'tests/test_data/ieee_802_11/info_bits_N648_R56.csv', delimiter=',', dtype=np.int_))
34+
'tests/test_data/ieee_802_11/info_bits_N648_R56.csv', delimiter=',', dtype=np.int_)) # type: ignore
3535
encoded_ref = Bits(auto=np.genfromtxt(
36-
'tests/test_data/ieee_802_11/encoded_N648_R56.csv', delimiter=',', dtype=np.int_))
36+
'tests/test_data/ieee_802_11/encoded_N648_R56.csv', delimiter=',', dtype=np.int_)) # type: ignore
3737
enc = EncoderWiFi(WiFiSpecCode.N648_R56)
3838

3939
encoded = Bits()
@@ -44,9 +44,9 @@ def test_encoding_648_r56(self) -> None:
4444
def test_encoding_1296_r23(self) -> None:
4545
# comparing encodings with reference implementation by: https://github.com/tavildar/LDPC
4646
info_bits = Bits(auto=np.genfromtxt(
47-
'tests/test_data/ieee_802_11/info_bits_N1296_R23.csv', delimiter=',', dtype=np.int_))
47+
'tests/test_data/ieee_802_11/info_bits_N1296_R23.csv', delimiter=',', dtype=np.int_)) # type: ignore
4848
encoded_ref = Bits(auto=np.genfromtxt(
49-
'tests/test_data/ieee_802_11/encoded_N1296_R23.csv', delimiter=',', dtype=np.int_))
49+
'tests/test_data/ieee_802_11/encoded_N1296_R23.csv', delimiter=',', dtype=np.int_)) # type: ignore
5050
enc = EncoderWiFi(WiFiSpecCode.N1296_R23)
5151

5252
encoded = Bits()
@@ -57,9 +57,9 @@ def test_encoding_1296_r23(self) -> None:
5757
def test_encoding_1944_r34(self) -> None:
5858
# comparing encodings with reference implementation by: https://github.com/tavildar/LDPC
5959
info_bits = Bits(auto=np.genfromtxt(
60-
'tests/test_data/ieee_802_11/info_bits_N1944_R34.csv', delimiter=',', dtype=np.int_))
60+
'tests/test_data/ieee_802_11/info_bits_N1944_R34.csv', delimiter=',', dtype=np.int_)) # type: ignore
6161
encoded_ref = Bits(auto=np.genfromtxt(
62-
'tests/test_data/ieee_802_11/encoded_N1944_R34.csv', delimiter=',', dtype=np.int_))
62+
'tests/test_data/ieee_802_11/encoded_N1944_R34.csv', delimiter=',', dtype=np.int_)) # type: ignore
6363
enc = EncoderWiFi(WiFiSpecCode.N1944_R34)
6464

6565
encoded = Bits()
@@ -79,9 +79,9 @@ def test_incorrect_length(self) -> None:
7979

8080
def test_decoder_648_r12(self) -> None:
8181
info_bits = Bits(auto=np.genfromtxt(
82-
'tests/test_data/ieee_802_11/info_bits_N648_R12.csv', delimiter=',', dtype=np.int_))
82+
'tests/test_data/ieee_802_11/info_bits_N648_R12.csv', delimiter=',', dtype=np.int_)) # type: ignore
8383
encoded_ref = Bits(auto=np.genfromtxt(
84-
'tests/test_data/ieee_802_11/encoded_N648_R12.csv', delimiter=',', dtype=np.int_))
84+
'tests/test_data/ieee_802_11/encoded_N648_R12.csv', delimiter=',', dtype=np.int_)) # type: ignore
8585
p = 0.01
8686

8787
corrupted = BitArray(encoded_ref)
@@ -104,9 +104,9 @@ def test_decoder_648_r12(self) -> None:
104104

105105
def test_ms_decoder_1296_r23(self) -> None:
106106
info_bits = Bits(auto=np.genfromtxt(
107-
'tests/test_data/ieee_802_11/info_bits_N1296_R23.csv', delimiter=',', dtype=np.int_))
107+
'tests/test_data/ieee_802_11/info_bits_N1296_R23.csv', delimiter=',', dtype=np.int_)) # type: ignore
108108
encoded_ref = Bits(auto=np.genfromtxt(
109-
'tests/test_data/ieee_802_11/encoded_N1296_R23.csv', delimiter=',', dtype=np.int_))
109+
'tests/test_data/ieee_802_11/encoded_N1296_R23.csv', delimiter=',', dtype=np.int_)) # type: ignore
110110
p = 0.01
111111

112112
corrupted = BitArray(encoded_ref)

0 commit comments

Comments
 (0)