Skip to content

Commit 89d26a5

Browse files
author
grigory
committed
Reorganize code of Fast SSC Polar codec
1 parent 0e954b6 commit 89d26a5

File tree

9 files changed

+159
-118
lines changed

9 files changed

+159
-118
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .codec import FastSSCPolarCodec
2+
from .decoder import FastSSCDecoder
3+
from .functions import *
4+
from .node import FastSSCNode

python_polar_coding/polar_codes/fast_ssc.py renamed to python_polar_coding/polar_codes/fast_ssc/codec.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
1-
from .base.polar_code import BasicPolarCode
2-
from .decoders.fast_ssc_decoder import FastSSCDecoder
1+
from python_polar_coding.polar_codes.base import BasePolarCodec
32

3+
from .decoder import FastSSCDecoder
44

5-
class FastSSCPolarCode(BasicPolarCode):
5+
6+
class FastSSCPolarCodec(BasePolarCodec):
67
"""Polar code with SC decoding algorithm.
78
89
Based on: https://arxiv.org/pdf/1307.7154.pdf
910
1011
"""
1112
decoder_class = FastSSCDecoder
1213

13-
def get_decoder(self):
14+
def init_decoder(self):
1415
return self.decoder_class(n=self.n, mask=self.mask,
1516
is_systematic=self.is_systematic)
1617

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import numpy as np
2+
from anytree import PreOrderIter
3+
4+
from python_polar_coding.polar_codes.sc import SCDecoder
5+
6+
from .node import FastSSCNode
7+
8+
9+
class FastSSCDecoder(SCDecoder):
10+
"""Implements Fast SSC decoding algorithm."""
11+
node_class = FastSSCNode
12+
13+
def __init__(self, n: int,
14+
mask: np.array,
15+
is_systematic: bool = True,
16+
code_min_size: int = 0):
17+
super().__init__(n=n, mask=mask, is_systematic=is_systematic)
18+
self._decoding_tree = self.node_class(
19+
mask=self.mask,
20+
N_min=code_min_size,
21+
)
22+
self._position = 0
23+
24+
def _set_initial_state(self, received_llr):
25+
"""Initialize decoder with received message."""
26+
self.current_state = np.zeros(self.n, dtype=np.int8)
27+
self.previous_state = np.ones(self.n, dtype=np.int8)
28+
29+
# LLR values at intermediate steps
30+
self._position = 0
31+
self._decoding_tree.root.alpha = received_llr
32+
33+
def decode_internal(self, received_llr: np.array) -> np.array:
34+
"""Implementation of SC decoding method."""
35+
self._set_initial_state(received_llr)
36+
37+
# Reset the state of the tree before decoding
38+
for node in PreOrderIter(self._decoding_tree):
39+
node.is_computed = False
40+
41+
for leaf in self._decoding_tree.leaves:
42+
self._set_decoder_state(self._position)
43+
self.compute_intermediate_alpha(leaf)
44+
leaf.compute_leaf_beta()
45+
self.compute_intermediate_beta(leaf)
46+
self.set_next_state(leaf.N)
47+
48+
return self.result
49+
50+
@property
51+
def root(self):
52+
"""Returns root node of decoding tree."""
53+
return self._decoding_tree.root
54+
55+
@property
56+
def result(self):
57+
if self.is_systematic:
58+
return self.root.beta
59+
60+
@property
61+
def M(self):
62+
return self._decoding_tree.M
63+
64+
def compute_intermediate_alpha(self, leaf):
65+
"""Compute intermediate Alpha values (LLR)."""
66+
for node in leaf.path[1:]:
67+
if node.is_computed:
68+
continue
69+
70+
parent_alpha = node.parent.alpha
71+
72+
if node.is_left:
73+
node.alpha = self._compute_left_alpha(parent_alpha)
74+
continue
75+
76+
left_node = node.siblings[0]
77+
left_beta = left_node.beta
78+
node.alpha = self._compute_right_alpha(parent_alpha, left_beta)
79+
node.is_computed = True
80+
81+
def compute_intermediate_beta(self, node):
82+
"""Compute intermediate Beta values (BIT)."""
83+
if node.is_left:
84+
return
85+
86+
if node.is_root:
87+
return
88+
89+
parent = node.parent
90+
left = node.siblings[0]
91+
parent.beta = self.compute_parent_beta(left.beta, node.beta)
92+
return self.compute_intermediate_beta(parent)
93+
94+
def set_next_state(self, leaf_size):
95+
self._position += leaf_size
96+
97+
@staticmethod
98+
def compute_parent_beta(left, right):
99+
"""Compute Beta (BITS) of a parent Node."""
100+
N = left.size
101+
# append - njit incompatible
102+
return np.append((left + right) % 2, right)
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import numba
2+
import numpy as np
3+
4+
from ..base import make_hard_decision
5+
6+
7+
@numba.njit
8+
def compute_single_parity_check(llr):
9+
"""Compute bits for Single Parity Check node.
10+
11+
Based on: https://arxiv.org/pdf/1307.7154.pdf, Section IV, A.
12+
13+
"""
14+
bits = make_hard_decision(llr)
15+
parity = np.sum(bits) % 2
16+
arg_min = np.abs(llr).argmin()
17+
bits[arg_min] = (bits[arg_min] + parity) % 2
18+
return bits
19+
20+
21+
@numba.njit
22+
def compute_repetition(llr):
23+
"""Compute bits for Repetition node.
24+
25+
Based on: https://arxiv.org/pdf/1307.7154.pdf, Section IV, B.
26+
27+
"""
28+
return (
29+
np.zeros(llr.size, dtype=np.int8) if np.sum(llr) >= 0
30+
else np.ones(llr.size, dtype=np.int8)
31+
)

python_polar_coding/polar_codes/decoders/fast_ssc_decoder.py renamed to python_polar_coding/polar_codes/fast_ssc/node.py

Lines changed: 7 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
import numpy as np
2-
from anytree import Node, PreOrderIter
2+
from anytree import Node
33

4-
from ..base import functions
5-
from .sc_decoder import SCDecoder
4+
from python_polar_coding.polar_codes.base import make_hard_decision
5+
6+
from .functions import compute_repetition, compute_single_parity_check
67

78

89
class FastSSCNode(Node):
@@ -109,11 +110,11 @@ def compute_leaf_beta(self):
109110
if self._node_type == FastSSCNode.ZERO_NODE:
110111
self._beta = np.zeros(self.N, dtype=np.int8)
111112
if self._node_type == FastSSCNode.ONE_NODE:
112-
self._beta = functions.make_hard_decision(self.alpha)
113+
self._beta = make_hard_decision(self.alpha)
113114
if self._node_type == FastSSCNode.SINGLE_PARITY_CHECK:
114-
self._beta = functions.compute_single_parity_check(self.alpha)
115+
self._beta = compute_single_parity_check(self.alpha)
115116
if self._node_type == FastSSCNode.REPETITION:
116-
self._beta = functions.compute_repetition(self.alpha)
117+
self._beta = compute_repetition(self.alpha)
117118

118119
def _initialize_beta(self):
119120
"""Initialize BETA values on tree building."""
@@ -164,99 +165,3 @@ def _build_decoding_tree(self):
164165
cls = self.__class__
165166
cls(mask=left_mask, name=self.LEFT, N_min=self.M, parent=self)
166167
cls(mask=right_mask, name=self.RIGHT, N_min=self.M, parent=self)
167-
168-
169-
class FastSSCDecoder(SCDecoder):
170-
"""Implements Fast SSC decoding algorithm."""
171-
node_class = FastSSCNode
172-
173-
def __init__(self, n: int,
174-
mask: np.array,
175-
is_systematic: bool = True,
176-
code_min_size: int = 0):
177-
super().__init__(n=n, mask=mask, is_systematic=is_systematic)
178-
self._decoding_tree = self.node_class(
179-
mask=self.mask,
180-
N_min=code_min_size,
181-
)
182-
self._position = 0
183-
184-
def _set_initial_state(self, received_llr):
185-
"""Initialize decoder with received message."""
186-
self.current_state = np.zeros(self.n, dtype=np.int8)
187-
self.previous_state = np.ones(self.n, dtype=np.int8)
188-
189-
# LLR values at intermediate steps
190-
self._position = 0
191-
self._decoding_tree.root.alpha = received_llr
192-
193-
def decode_internal(self, received_llr: np.array) -> np.array:
194-
"""Implementation of SC decoding method."""
195-
self._set_initial_state(received_llr)
196-
197-
# Reset the state of the tree before decoding
198-
for node in PreOrderIter(self._decoding_tree):
199-
node.is_computed = False
200-
201-
for leaf in self._decoding_tree.leaves:
202-
self._set_decoder_state(self._position)
203-
self.compute_intermediate_alpha(leaf)
204-
leaf.compute_leaf_beta()
205-
self.compute_intermediate_beta(leaf)
206-
self.set_next_state(leaf.N)
207-
208-
return self.result
209-
210-
@property
211-
def root(self):
212-
"""Returns root node of decoding tree."""
213-
return self._decoding_tree.root
214-
215-
@property
216-
def result(self):
217-
if self.is_systematic:
218-
return self.root.beta
219-
220-
@property
221-
def M(self):
222-
return self._decoding_tree.M
223-
224-
def compute_intermediate_alpha(self, leaf):
225-
"""Compute intermediate Alpha values (LLR)."""
226-
for node in leaf.path[1:]:
227-
if node.is_computed:
228-
continue
229-
230-
parent_alpha = node.parent.alpha
231-
232-
if node.is_left:
233-
node.alpha = self._compute_left_alpha(parent_alpha)
234-
continue
235-
236-
left_node = node.siblings[0]
237-
left_beta = left_node.beta
238-
node.alpha = self._compute_right_alpha(parent_alpha, left_beta)
239-
node.is_computed = True
240-
241-
def compute_intermediate_beta(self, node):
242-
"""Compute intermediate Beta values (BIT)."""
243-
if node.is_left:
244-
return
245-
246-
if node.is_root:
247-
return
248-
249-
parent = node.parent
250-
left = node.siblings[0]
251-
parent.beta = self.compute_parent_beta(left.beta, node.beta)
252-
return self.compute_intermediate_beta(parent)
253-
254-
def set_next_state(self, leaf_size):
255-
self._position += leaf_size
256-
257-
@staticmethod
258-
def compute_parent_beta(left, right):
259-
"""Compute Beta (BITS) of a parent Node."""
260-
N = left.size
261-
# append - njit incompatible
262-
return np.append((left + right) % 2, right)

python_polar_coding/tests/test_fast_ssc/__init__.py

Whitespace-only changes.

python_polar_coding/tests/test_as_model/test_fast_ssc.py renamed to python_polar_coding/tests/test_fast_ssc/test_codec.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,51 @@
11
from unittest import TestCase
22

3-
from python_polar_coding.polar_codes.fast_ssc import FastSSCPolarCode
4-
5-
from .base import BasicVerifyPolarCode
3+
from python_polar_coding.polar_codes.fast_ssc import FastSSCPolarCodec
4+
from python_polar_coding.tests.base import BasicVerifyPolarCode
65

76

87
class TestFastSSCCode_1024_512(BasicVerifyPolarCode, TestCase):
9-
polar_code_class = FastSSCPolarCode
8+
polar_code_class = FastSSCPolarCodec
109
code_parameters = {
1110
'N': 1024,
1211
'K': 512,
1312
}
1413

1514

1615
class TestFastSSCCode_1024_256(BasicVerifyPolarCode, TestCase):
17-
polar_code_class = FastSSCPolarCode
16+
polar_code_class = FastSSCPolarCodec
1817
code_parameters = {
1918
'N': 1024,
2019
'K': 256,
2120
}
2221

2322

2423
class TestFastSSCCode_1024_768(BasicVerifyPolarCode, TestCase):
25-
polar_code_class = FastSSCPolarCode
24+
polar_code_class = FastSSCPolarCodec
2625
code_parameters = {
2726
'N': 1024,
2827
'K': 768,
2928
}
3029

3130

3231
class TestFastSSCCode_2048_512(BasicVerifyPolarCode, TestCase):
33-
polar_code_class = FastSSCPolarCode
32+
polar_code_class = FastSSCPolarCodec
3433
code_parameters = {
3534
'N': 2048,
3635
'K': 512,
3736
}
3837

3938

4039
class TestFastSSCCode_2048_1024(BasicVerifyPolarCode, TestCase):
41-
polar_code_class = FastSSCPolarCode
40+
polar_code_class = FastSSCPolarCodec
4241
code_parameters = {
4342
'N': 2048,
4443
'K': 1024,
4544
}
4645

4746

4847
class TestFastSSCCode_2048_1536(BasicVerifyPolarCode, TestCase):
49-
polar_code_class = FastSSCPolarCode
48+
polar_code_class = FastSSCPolarCodec
5049
code_parameters = {
5150
'N': 2048,
5251
'K': 1536,

python_polar_coding/tests/test_fast_ssc_decoder.py renamed to python_polar_coding/tests/test_fast_ssc/test_decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import numpy as np
44

5-
from python_polar_coding.polar_codes.decoders import FastSSCDecoder
5+
from python_polar_coding.polar_codes.fast_ssc import FastSSCDecoder
66

77

88
class TestFastSSCDecoder(TestCase):

python_polar_coding/tests/test_fast_ssc_node.py renamed to python_polar_coding/tests/test_fast_ssc/test_node.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
import numpy as np
44

5-
from python_polar_coding.polar_codes.decoders.fast_ssc_decoder import \
6-
FastSSCNode
5+
from python_polar_coding.polar_codes.fast_ssc.decoder import FastSSCNode
76

87

98
class FastSSCNodeTest(TestCase):

0 commit comments

Comments
 (0)