Skip to content

Commit f526f79

Browse files
author
grigory
committed
Refactor RC-SCAN decoder
1 parent 34460ea commit f526f79

File tree

6 files changed

+111
-112
lines changed

6 files changed

+111
-112
lines changed

python_polar_coding/polar_codes/rc_scan/codec.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Union
22

3+
import numpy as np
4+
35
from python_polar_coding.polar_codes.base import BasePolarCodec
46

57
from .decoder import RCSCANDecoder
@@ -38,3 +40,7 @@ def to_dict(self):
3840
d = super().to_dict()
3941
d.update({'I': self.I})
4042
return d
43+
44+
def decode(self, received_message: np.array) -> np.array:
45+
"""Decode received message presented as LLR values."""
46+
return self.decoder(received_message)
Lines changed: 27 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
11
import numpy as np
22
from anytree import PreOrderIter
33

4-
from python_polar_coding.polar_codes.fast_ssc import FastSSCDecoder
5-
6-
from ..base import make_hard_decision
7-
from .functions import function_1, function_2
4+
from python_polar_coding.polar_codes.base import BaseTreeDecoder
5+
from python_polar_coding.polar_codes.base.functions import make_hard_decision
6+
7+
from .functions import (
8+
compute_left_alpha,
9+
compute_parent_beta,
10+
compute_right_alpha,
11+
)
812
from .node import RCSCANNode
913

1014

11-
class RCSCANDecoder(FastSSCDecoder):
15+
class RCSCANDecoder(BaseTreeDecoder):
1216
"""Implements Reduced-complexity SCAN decoding algorithm.
1317
1418
Based on:
@@ -22,26 +26,24 @@ def __init__(
2226
self,
2327
n: int,
2428
mask: np.array,
25-
code_min_size: int = 0,
2629
I: int = 1,
2730
):
28-
super().__init__(n=n, mask=mask, is_systematic=True,
29-
code_min_size=code_min_size)
31+
super().__init__(n=n, mask=mask)
3032
self.I = I
3133

32-
def decode_internal(self, received_llr: np.array) -> np.array:
34+
def decode(self, received_llr: np.array) -> np.array:
3335
"""Implementation of SC decoding method."""
34-
self.clean_before_decoding()
36+
self._clean_before_decoding()
3537

36-
for leaf in self._decoding_tree.leaves:
38+
for leaf in self.leaves:
3739
leaf.initialize_leaf_beta()
3840

3941
for _ in range(self.I):
40-
super().decode_internal(received_llr)
42+
super().decode(received_llr)
4143

4244
return self.result
4345

44-
def clean_before_decoding(self):
46+
def _clean_before_decoding(self):
4547
"""Reset intermediate BETA values.
4648
4749
Run this before calling `__call__` method.
@@ -51,7 +53,7 @@ def clean_before_decoding(self):
5153
if not (node.is_zero or node.is_one):
5254
node.beta *= 0
5355

54-
def compute_intermediate_alpha(self, leaf):
56+
def _compute_intermediate_alpha(self, leaf):
5557
"""Compute intermediate Alpha values (LLR)."""
5658
for node in leaf.path[1:]:
5759
if node.is_computed or node.is_zero or node.is_one:
@@ -61,69 +63,34 @@ def compute_intermediate_alpha(self, leaf):
6163

6264
if node.is_left:
6365
right_beta = node.siblings[0].beta
64-
node.alpha = self.compute_left_alpha(parent_alpha, right_beta)
66+
node.alpha = compute_left_alpha(parent_alpha, right_beta)
6567

6668
if node.is_right:
6769
left_beta = node.siblings[0].beta
68-
node.alpha = self.compute_right_alpha(parent_alpha, left_beta)
70+
node.alpha = compute_right_alpha(parent_alpha, left_beta)
6971

7072
node.is_computed = True
7173

72-
def compute_intermediate_beta(self, node):
74+
def _compute_intermediate_beta(self, node):
7375
"""Compute intermediate BETA values."""
7476
parent = node.parent
7577
if node.is_left or node.is_root or parent.is_root:
7678
return
7779

7880
left = node.siblings[0]
79-
parent.beta = self.compute_parent_beta(left.beta, node.beta, parent.alpha) # noqa
80-
return self.compute_intermediate_beta(parent)
81+
parent.beta = compute_parent_beta(left.beta, node.beta, parent.alpha)
82+
return self._compute_intermediate_beta(parent)
8183

8284
@property
8385
def result(self):
84-
if not self.is_systematic:
85-
raise TypeError('Code must be systematic')
86-
return make_hard_decision(self.root.alpha + self._compute_result_beta())
87-
88-
@staticmethod
89-
def compute_left_alpha(parent_alpha, beta):
90-
"""Compute LLR for left node."""
91-
return RCSCANDecoder.compute_alpha(parent_alpha, beta, is_left=True)
92-
93-
@staticmethod
94-
def compute_right_alpha(parent_alpha, beta):
95-
"""Compute LLR for right node."""
96-
return RCSCANDecoder.compute_alpha(parent_alpha, beta, is_left=False)
97-
98-
@staticmethod
99-
def compute_alpha(parent_alpha, beta, is_left):
100-
"""Compute ALPHA values for left or right node."""
101-
N = parent_alpha.size // 2
102-
left_parent_alpha = parent_alpha[:N]
103-
right_parent_alpha = parent_alpha[N:]
104-
105-
if is_left:
106-
result_alpha = function_1(left_parent_alpha, right_parent_alpha, beta)
107-
else:
108-
result_alpha = function_2(left_parent_alpha, beta, right_parent_alpha)
109-
return result_alpha
110-
111-
@staticmethod
112-
def compute_parent_beta(left_beta, right_beta, parent_alpha):
113-
"""Compute bits of a parent Node."""
114-
N = parent_alpha.size // 2
115-
left_parent_alpha = parent_alpha[:N]
116-
right_parent_alpha = parent_alpha[N:]
117-
118-
parent_beta_left = function_1(left_beta, right_beta, right_parent_alpha)
119-
parent_beta_right = function_2(left_beta, left_parent_alpha, right_beta)
120-
121-
return np.append(parent_beta_left, parent_beta_right)
122-
123-
def _compute_result_beta(self):
86+
return make_hard_decision(self.root.alpha +
87+
self._compute_result_beta())
88+
89+
def _compute_result_beta(self) -> np.array:
12490
"""Compute result BETA values."""
12591
alpha = self.root.alpha
12692
if not self.root.children:
12793
return self.root.beta
94+
12895
left, right = self.root.children
129-
return self.compute_parent_beta(left.beta, right.beta, alpha)
96+
return compute_parent_beta(left.beta, right.beta, alpha)
Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,64 @@
11
import numba
22
import numpy as np
33

4-
from ..base import INFINITY, compute_alpha
4+
from python_polar_coding.polar_codes.base.functions.alpha import (
5+
function_1,
6+
function_2,
7+
)
8+
9+
from ..base import INFINITY
510

611

712
@numba.njit
8-
def function_1(a, b, c):
9-
"""Function 1.
13+
def compute_beta_zero_node(alpha):
14+
"""Compute beta values for ZERO node.
1015
11-
Source: doi:10.1007/s12243-018-0634-7, formula 1.
16+
https://arxiv.org/pdf/1510.06495.pdf Section III.C.
1217
1318
"""
14-
return compute_alpha(a, b + c)
19+
return np.ones(alpha.size, dtype=np.double) * INFINITY
1520

1621

1722
@numba.njit
18-
def function_2(a, b, c):
19-
"""Function 2.
23+
def compute_beta_one_node(alpha):
24+
"""Compute beta values for ONE node.
2025
21-
Source: doi:10.1007/s12243-018-0634-7, formula 2.
26+
https://arxiv.org/pdf/1510.06495.pdf Section III.C.
2227
2328
"""
24-
return compute_alpha(a, b) + c
29+
return np.zeros(alpha.size, dtype=np.double)
2530

2631

2732
@numba.njit
28-
def compute_beta_zero_node(alpha):
29-
"""Compute beta values for ZERO node.
33+
def compute_left_alpha(parent_alpha, beta):
34+
"""Compute LLR for left node."""
35+
N = parent_alpha.size // 2
36+
left_parent_alpha = parent_alpha[:N]
37+
right_parent_alpha = parent_alpha[N:]
3038

31-
https://arxiv.org/pdf/1510.06495.pdf Section III.C.
39+
return function_1(left_parent_alpha, right_parent_alpha, beta)
3240

33-
"""
34-
return np.ones(alpha.size, dtype=np.double) * INFINITY
41+
42+
@numba.njit
43+
def compute_right_alpha(parent_alpha, beta):
44+
"""Compute LLR for right node."""
45+
N = parent_alpha.size // 2
46+
left_parent_alpha = parent_alpha[:N]
47+
right_parent_alpha = parent_alpha[N:]
48+
49+
return function_2(left_parent_alpha, beta, right_parent_alpha)
3550

3651

3752
@numba.njit
38-
def compute_beta_one_node(alpha):
39-
"""Compute beta values for ONE node.
53+
def compute_parent_beta(left_beta, right_beta, parent_alpha):
54+
"""Compute bits of a parent Node."""
55+
N = parent_alpha.size // 2
56+
left_parent_alpha = parent_alpha[:N]
57+
right_parent_alpha = parent_alpha[N:]
4058

41-
https://arxiv.org/pdf/1510.06495.pdf Section III.C.
59+
result = np.zeros(parent_alpha.size, dtype=np.double)
4260

43-
"""
44-
return np.zeros(alpha.size, dtype=np.double)
61+
result[:N] = function_1(left_beta, right_beta, right_parent_alpha)
62+
result[N:] = function_2(left_beta, left_parent_alpha, right_beta)
63+
64+
return result
Lines changed: 28 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,31 @@
1-
import numpy as np
1+
from typing import Dict
22

3-
from python_polar_coding.polar_codes.fast_ssc import FastSSCNode
3+
from python_polar_coding.polar_codes.base.functions.beta_soft import one, zero
44

5-
from .functions import compute_beta_one_node, compute_beta_zero_node
5+
from ..base import NodeTypes, SoftNode
66

77

8-
class RCSCANNode(FastSSCNode):
8+
class RCSCANNode(SoftNode):
9+
supported_nodes = (
10+
NodeTypes.ZERO,
11+
NodeTypes.ONE,
12+
)
13+
14+
@property
15+
def is_zero(self) -> bool:
16+
"""Check is the node is Zero node."""
17+
return self.node_type == NodeTypes.ZERO
18+
19+
@property
20+
def is_one(self) -> bool:
21+
"""Check is the node is One node."""
22+
return self.node_type == NodeTypes.ONE
23+
24+
def get_decoding_params(self) -> Dict:
25+
return dict(
26+
node_type=self.node_type,
27+
llr=self.alpha,
28+
)
929

1030
def compute_leaf_beta(self):
1131
"""Do nothing for ZERO and ONE nodes.
@@ -24,22 +44,7 @@ def initialize_leaf_beta(self):
2444
if not self.is_leaf:
2545
return
2646

27-
if self._node_type == RCSCANNode.ZERO_NODE:
28-
self._beta = compute_beta_zero_node(self.alpha)
29-
if self._node_type == RCSCANNode.ONE_NODE:
30-
self._beta = compute_beta_one_node(self.alpha)
31-
32-
def get_node_type(self):
33-
"""Get the type of RC SCAN Node.
34-
35-
* Zero node - [0, 0, 0, 0, 0, 0, 0, 0];
36-
* One node - [1, 1, 1, 1, 1, 1, 1, 1];
37-
38-
Or other type.
39-
40-
"""
41-
if np.all(self._mask == 0):
42-
return RCSCANNode.ZERO_NODE
43-
if np.all(self._mask == 1):
44-
return RCSCANNode.ONE_NODE
45-
return RCSCANNode.OTHER
47+
if self.is_zero:
48+
self._beta = zero(self.alpha)
49+
if self.is_one:
50+
self._beta = one(self.alpha)

python_polar_coding/tests/test_rc_scan/test_decoder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def test_sub_codes(self):
7070
)
7171

7272
for i, leaf in enumerate(decoder._decoding_tree.leaves):
73-
np.testing.assert_equal(leaf._mask, self.sub_codes[i])
73+
np.testing.assert_equal(leaf.mask, self.sub_codes[i])
7474

7575
def test_no_noise(self):
7676
decoder = self._get_decoder()

python_polar_coding/tests/test_rc_scan/test_node.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
import numpy as np
44

5-
from python_polar_coding.polar_codes.rc_scan import INFINITY, RCSCANNode
5+
from python_polar_coding.polar_codes.base import INFINITY
6+
from python_polar_coding.polar_codes.base.functions import NodeTypes
7+
from python_polar_coding.polar_codes.rc_scan import RCSCANNode
68

79

810
class TestRCSCANNode(TestCase):
@@ -13,7 +15,7 @@ def setUpClass(cls):
1315

1416
def test_zero_node(self):
1517
node = RCSCANNode(mask=np.zeros(4))
16-
self.assertEqual(node._node_type, RCSCANNode.ZERO_NODE)
18+
self.assertTrue(node.is_zero)
1719

1820
node.llr = self.llr
1921
node.initialize_leaf_beta()
@@ -24,7 +26,7 @@ def test_zero_node(self):
2426

2527
def test_one_node(self):
2628
node = RCSCANNode(mask=np.ones(4))
27-
self.assertEqual(node._node_type, RCSCANNode.ONE_NODE)
29+
self.assertTrue(node.is_one)
2830

2931
node.llr = self.llr
3032
node.initialize_leaf_beta()
@@ -43,7 +45,6 @@ def test_with_multiple_nodes(self):
4345
0, 0, 0, 0,
4446
1, 1, 1, 1,
4547
]))
46-
self.assertEqual(node._node_type, RCSCANNode.OTHER)
4748

4849
leaf_path_lengths = [5, 5, 4, 4, 4, 3, 3]
4950
leaf_masks = [
@@ -52,12 +53,12 @@ def test_with_multiple_nodes(self):
5253
np.array([0, 0, 0, 0, ]), np.array([1, 1, 1, 1, ]),
5354
]
5455
leaf_types = [
55-
node.ZERO_NODE, node.ONE_NODE, node.ZERO_NODE,
56-
node.ZERO_NODE, node.ONE_NODE,
57-
node.ZERO_NODE, node.ONE_NODE,
56+
NodeTypes.ZERO, NodeTypes.ONE, NodeTypes.ZERO,
57+
NodeTypes.ZERO, NodeTypes.ONE,
58+
NodeTypes.ZERO, NodeTypes.ONE,
5859
]
5960

6061
for i, leaf in enumerate(node.leaves):
6162
self.assertEqual(len(leaf.path), leaf_path_lengths[i])
62-
np.testing.assert_equal(leaf._mask, leaf_masks[i])
63-
self.assertEqual(leaf._node_type, leaf_types[i])
63+
np.testing.assert_equal(leaf.mask, leaf_masks[i])
64+
self.assertTrue(leaf.node_type, leaf_types[i])

0 commit comments

Comments
 (0)