Skip to content

Commit 34460ea

Browse files
author
grigory
committed
Refactor Fast-SSC codec
1 parent 04a9918 commit 34460ea

File tree

7 files changed

+125
-82
lines changed

7 files changed

+125
-82
lines changed

python_polar_coding/polar_codes/base/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from .codec import BaseCRCPolarCodec, BasePolarCodec
22
from .constants import *
3-
from .decoder import BaseDecoder
3+
from .decoder import BaseDecoder, BaseTreeDecoder
44
from .decoding_path import DecodingPathMixin
55
from .encoder import *
66
from .functions import *

python_polar_coding/polar_codes/base/decoder.py

Lines changed: 97 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import abc
22

33
import numpy as np
4+
from anytree import PreOrderIter
45

56

67
class BaseDecoder(metaclass=abc.ABCMeta):
@@ -14,13 +15,13 @@ def __init__(self, n, mask: np.array, is_systematic: bool = True):
1415

1516
def decode(self, received_llr: np.array) -> np.array:
1617
decoded = self.decode_internal(received_llr)
17-
return self.get_result(decoded)
18+
return self.extract_result(decoded)
1819

1920
@abc.abstractmethod
2021
def decode_internal(self, received_llr: np.array) -> np.array:
2122
"""Implementation of particular decoding method."""
2223

23-
def get_result(self, decoded: np.array) -> np.array:
24+
def extract_result(self, decoded: np.array) -> np.array:
2425
"""Get decoding result.
2526
2627
Extract info bits from decoded message due to polar code mask.
@@ -32,3 +33,97 @@ def get_result(self, decoded: np.array) -> np.array:
3233
if self.mask[i] == 1:
3334
decoded_info = np.append(decoded_info, decoded[i])
3435
return np.array(decoded_info, dtype=np.int)
36+
37+
38+
class BaseTreeDecoder(metaclass=abc.ABCMeta):
39+
"""Basic class for polar decoder that use tree for decoding."""
40+
41+
node_class: 'BaseDecodingNode'
42+
43+
def __init__(self, n, mask: np.array):
44+
self.N = mask.shape[0]
45+
self.n = n
46+
self.mask = mask
47+
48+
self._decoding_tree = self._setup_decoding_tree()
49+
self._position = 0
50+
51+
def __call__(self, received_llr: np.array) -> np.array:
52+
decoded = self.decode(received_llr)
53+
return self.extract_result(decoded)
54+
55+
@property
56+
def leaves(self):
57+
return self._decoding_tree.leaves
58+
59+
@property
60+
def root(self):
61+
"""Returns root node of decoding tree."""
62+
return self._decoding_tree.root
63+
64+
@property
65+
def result(self):
66+
return self.root.beta
67+
68+
def decode(self, received_llr: np.array) -> np.array:
69+
"""Implementation of decoding using tree."""
70+
self._set_initial_state(received_llr)
71+
self._reset_tree_computed_state()
72+
73+
for leaf in self.leaves:
74+
self._set_decoder_state(self._position)
75+
self._compute_intermediate_alpha(leaf)
76+
leaf()
77+
self._compute_intermediate_beta(leaf)
78+
self._set_next_state(leaf.N)
79+
80+
return self.result
81+
82+
def extract_result(self, decoded: np.array) -> np.array:
83+
"""Get decoding result.
84+
85+
Extract info bits from decoded message due to polar code mask.
86+
87+
"""
88+
decoded_info = list()
89+
90+
for i in range(self.N):
91+
if self.mask[i] == 1:
92+
decoded_info = np.append(decoded_info, decoded[i])
93+
return np.array(decoded_info, dtype=np.int)
94+
95+
def _setup_decoding_tree(self, ):
96+
"""Setup decoding tree."""
97+
return self.node_class(mask=self.mask)
98+
99+
def _set_initial_state(self, received_llr):
100+
"""Initialize decoder with received message."""
101+
self.current_state = np.zeros(self.n, dtype=np.int8)
102+
self.previous_state = np.ones(self.n, dtype=np.int8)
103+
104+
# LLR values at intermediate steps
105+
self._position = 0
106+
self._decoding_tree.root.alpha = received_llr
107+
108+
def _reset_tree_computed_state(self):
109+
"""Reset the state of the tree before decoding"""
110+
for node in PreOrderIter(self._decoding_tree):
111+
node.is_computed = False
112+
113+
def _set_decoder_state(self, position):
114+
"""Set current state of the decoder."""
115+
bits = np.unpackbits(
116+
np.array([position], dtype=np.uint32).byteswap().view(np.uint8)
117+
)
118+
self.current_state = bits[-self.n:]
119+
120+
@abc.abstractmethod
121+
def _compute_intermediate_alpha(self, leaf):
122+
"""Compute intermediate Alpha values (LLR)."""
123+
124+
@abc.abstractmethod
125+
def _compute_intermediate_beta(self, node):
126+
"""Compute intermediate Beta values (Bits or LLR)."""
127+
128+
def _set_next_state(self, leaf_size):
129+
self._position += leaf_size

python_polar_coding/polar_codes/base/functions/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,11 @@
55
function_1,
66
function_2,
77
)
8-
from .beta_hard import compute_beta_hard, compute_parent_beta_hard
8+
from .beta_hard import (
9+
compute_beta_hard,
10+
compute_parent_beta_hard,
11+
make_hard_decision,
12+
)
913
from .beta_soft import compute_beta_soft
1014
from .encoding import compute_encoding_step
1115
from .node_types import NodeTypes, get_node_type

python_polar_coding/polar_codes/fast_ssc/codec.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import numpy as np
2+
13
from python_polar_coding.polar_codes.base import BasePolarCodec
24

35
from .decoder import FastSSCDecoder
@@ -12,9 +14,8 @@ class FastSSCPolarCodec(BasePolarCodec):
1214
decoder_class = FastSSCDecoder
1315

1416
def init_decoder(self):
15-
return self.decoder_class(n=self.n, mask=self.mask,
16-
is_systematic=self.is_systematic)
17+
return self.decoder_class(n=self.n, mask=self.mask)
1718

18-
@property
19-
def tree(self):
20-
return self.decoder._decoding_tree
19+
def decode(self, received_message: np.array) -> np.array:
20+
"""Decode received message presented as LLR values."""
21+
return self.decoder(received_message)
Lines changed: 6 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,18 @@
1-
import numpy as np
2-
from anytree import PreOrderIter
3-
1+
from python_polar_coding.polar_codes.base import BaseTreeDecoder
42
from python_polar_coding.polar_codes.base.functions import (
53
compute_left_alpha,
6-
compute_right_alpha,
74
compute_parent_beta_hard,
5+
compute_right_alpha,
86
)
9-
from python_polar_coding.polar_codes.sc import SCDecoder
107

118
from .node import FastSSCNode
129

1310

14-
class FastSSCDecoder(SCDecoder):
11+
class FastSSCDecoder(BaseTreeDecoder):
1512
"""Implements Fast SSC decoding algorithm."""
1613
node_class = FastSSCNode
1714

18-
def __init__(
19-
self,
20-
n: int,
21-
mask: np.array,
22-
is_systematic: bool = True,
23-
code_min_size: int = 0,
24-
):
25-
super().__init__(n=n, mask=mask, is_systematic=is_systematic)
26-
self._decoding_tree = self.setup_decoding_tree(code_min_size)
27-
self._position = 0
28-
29-
def setup_decoding_tree(self, N_min, **kwargs):
30-
"""Setup decoding tree."""
31-
return self.node_class(mask=self.mask, N_min=N_min)
32-
33-
def _set_initial_state(self, received_llr):
34-
"""Initialize decoder with received message."""
35-
self.current_state = np.zeros(self.n, dtype=np.int8)
36-
self.previous_state = np.ones(self.n, dtype=np.int8)
37-
38-
# LLR values at intermediate steps
39-
self._position = 0
40-
self._decoding_tree.root.alpha = received_llr
41-
42-
def decode_internal(self, received_llr: np.array) -> np.array:
43-
"""Implementation of SC decoding method."""
44-
self._set_initial_state(received_llr)
45-
46-
# Reset the state of the tree before decoding
47-
for node in PreOrderIter(self._decoding_tree):
48-
node.is_computed = False
49-
50-
for leaf in self._decoding_tree.leaves:
51-
self._set_decoder_state(self._position)
52-
self.compute_intermediate_alpha(leaf)
53-
leaf()
54-
self.compute_intermediate_beta(leaf)
55-
self.set_next_state(leaf.N)
56-
57-
return self.result
58-
59-
@property
60-
def root(self):
61-
"""Returns root node of decoding tree."""
62-
return self._decoding_tree.root
63-
64-
@property
65-
def result(self):
66-
if self.is_systematic:
67-
return self.root.beta
68-
69-
def compute_intermediate_alpha(self, leaf):
15+
def _compute_intermediate_alpha(self, leaf):
7016
"""Compute intermediate Alpha values (LLR)."""
7117
for node in leaf.path[1:]:
7218
if node.is_computed:
@@ -87,7 +33,7 @@ def compute_intermediate_alpha(self, leaf):
8733
node.alpha = compute_right_alpha(parent_alpha, left_beta)
8834
node.is_computed = True
8935

90-
def compute_intermediate_beta(self, node):
36+
def _compute_intermediate_beta(self, node):
9137
"""Compute intermediate Beta values (BIT)."""
9238
if node.is_left:
9339
return
@@ -98,7 +44,4 @@ def compute_intermediate_beta(self, node):
9844
parent = node.parent
9945
left = node.siblings[0]
10046
parent.beta = compute_parent_beta_hard(left.beta, node.beta)
101-
return self.compute_intermediate_beta(parent)
102-
103-
def set_next_state(self, leaf_size):
104-
self._position += leaf_size
47+
return self._compute_intermediate_beta(parent)

python_polar_coding/tests/test_base/test_codec.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def test_precode_and_extract(self):
128128
precoded = self.systematic_code.encoder._precode(self.message)
129129
self.assertEqual(precoded.size, self.codeword_length)
130130

131-
extracted = self.systematic_code.decoder.get_result(precoded)
131+
extracted = self.systematic_code.decoder.extract_result(precoded)
132132
self.assertEqual(extracted.size, self.info_length)
133133

134134
self.assertTrue(all(extracted == self.message))
@@ -143,13 +143,13 @@ def test_systematic_encode(self):
143143
encoded = self.systematic_code.encode(self.message)
144144
self.assertTrue(all(encoded == self.sys_enc_msg))
145145

146-
extracted = self.systematic_code.decoder.get_result(encoded)
146+
extracted = self.systematic_code.decoder.extract_result(encoded)
147147
self.assertTrue(all(extracted == self.message))
148148

149149
def test_systematic_encode_with_crc(self):
150150
"""Test for systematic encoding with CRC support"""
151151
encoded = self.systematic_crc_code.encode(self.message)
152152
self.assertTrue(all(encoded == self.sys_crc_enc_msg))
153153

154-
extracted = self.systematic_crc_code.decoder.get_result(encoded)
154+
extracted = self.systematic_crc_code.decoder.extract_result(encoded)
155155
self.assertTrue(all(extracted[:self.info_length] == self.message))

python_polar_coding/tests/test_fast_ssc/test_decoder.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def setUpClass(cls):
2323

2424
def test_zero_node_decoder(self):
2525
mask = np.zeros(self.length, dtype=np.int8)
26-
decoder = FastSSCDecoder(mask=mask, is_systematic=True, n=self.n)
26+
decoder = FastSSCDecoder(mask=mask, n=self.n)
2727
self.assertEqual(len(decoder._decoding_tree.leaves), 1)
2828

2929
decoder.decode(self.received_llr)
@@ -34,7 +34,7 @@ def test_zero_node_decoder(self):
3434

3535
def test_one_node_decoder(self):
3636
mask = np.ones(self.length, dtype=np.int8)
37-
decoder = FastSSCDecoder(mask=mask, is_systematic=True, n=self.n)
37+
decoder = FastSSCDecoder(mask=mask, n=self.n)
3838
self.assertEqual(len(decoder._decoding_tree.leaves), 1)
3939

4040
decoder.decode(self.received_llr)
@@ -45,7 +45,7 @@ def test_one_node_decoder(self):
4545

4646
def test_spc_node_decoder(self):
4747
mask = np.array([0, 1, 1, 1, 1, 1, 1, 1], dtype=np.int8)
48-
decoder = FastSSCDecoder(mask=mask, is_systematic=True, n=self.n)
48+
decoder = FastSSCDecoder(mask=mask, n=self.n)
4949
self.assertEqual(len(decoder._decoding_tree.leaves), 1)
5050

5151
decoder.decode(self.received_llr)
@@ -56,7 +56,7 @@ def test_spc_node_decoder(self):
5656

5757
def test_repetition_node_decoder(self):
5858
mask = np.array([0, 0, 0, 0, 0, 0, 0, 1], dtype=np.int8)
59-
decoder = FastSSCDecoder(mask=mask, is_systematic=True, n=self.n)
59+
decoder = FastSSCDecoder(mask=mask, n=self.n)
6060
self.assertEqual(len(decoder._decoding_tree.leaves), 1)
6161

6262
decoder.decode(self.received_llr)
@@ -67,7 +67,7 @@ def test_repetition_node_decoder(self):
6767

6868
def test_repetition_spc_node_decoder(self):
6969
mask = np.array([0, 0, 0, 1, 0, 1, 1, 1], dtype=np.int8)
70-
decoder = FastSSCDecoder(mask=mask, is_systematic=True, n=self.n)
70+
decoder = FastSSCDecoder(mask=mask, n=self.n)
7171
self.assertEqual(len(decoder._decoding_tree.leaves), 2)
7272

7373
decoder.decode(self.received_llr)
@@ -103,7 +103,7 @@ def test_repetition_spc_node_decoder(self):
103103

104104
def test_spc_repetition_node_decoder(self):
105105
mask = np.array([0, 1, 1, 1, 0, 0, 0, 1], dtype=np.int8)
106-
decoder = FastSSCDecoder(mask=mask, is_systematic=True, n=self.n)
106+
decoder = FastSSCDecoder(mask=mask, n=self.n)
107107
self.assertEqual(len(decoder._decoding_tree.leaves), 2)
108108

109109
decoder.decode(self.received_llr)
@@ -154,7 +154,7 @@ def test_complex(self):
154154
np.array([0, ], dtype=np.int8),
155155
np.array([0, 1, 1, 1, ], dtype=np.int8),
156156
]
157-
decoder = FastSSCDecoder(mask=mask, is_systematic=True, n=4)
157+
decoder = FastSSCDecoder(mask=mask, n=4)
158158

159159
# Check tree structure
160160
self.assertEqual(len(decoder._decoding_tree.leaves), len(sub_codes))

0 commit comments

Comments
 (0)