Skip to content

Commit d85b459

Browse files
author
grigory
committed
Refactor alpha computation
1 parent 1cc0125 commit d85b459

File tree

6 files changed

+44
-17
lines changed

6 files changed

+44
-17
lines changed

python_polar_coding/polar_codes/base/functions.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,21 @@ def compute_alpha(a, b):
3232
def make_hard_decision(soft_input):
3333
"""Makes hard decision based on soft input values (LLR)."""
3434
return np.array([s < 0 for s in soft_input], dtype=np.int8)
35+
36+
37+
@numba.njit
38+
def compute_left_alpha(llr):
39+
"""Compute Alpha for left node during SC-based decoding."""
40+
N = llr.size // 2
41+
left = llr[:N]
42+
right = llr[N:]
43+
return compute_alpha(left, right)
44+
45+
46+
@numba.njit
47+
def compute_right_alpha(llr, left_beta):
48+
"""Compute Alpha for right node during SC-based decoding."""
49+
N = llr.size // 2
50+
left = llr[:N]
51+
right = llr[N:]
52+
return right - (2 * left_beta - 1) * left

python_polar_coding/polar_codes/fast_ssc/decoder.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22
from anytree import PreOrderIter
33

44
from python_polar_coding.polar_codes.sc import SCDecoder
5+
from python_polar_coding.polar_codes.base.functions import (
6+
compute_left_alpha,
7+
compute_right_alpha,
8+
)
59

610
from .node import FastSSCNode
711

@@ -18,12 +22,13 @@ def __init__(
1822
code_min_size: int = 0,
1923
):
2024
super().__init__(n=n, mask=mask, is_systematic=is_systematic)
21-
self._decoding_tree = self.node_class(
22-
mask=self.mask,
23-
N_min=code_min_size,
24-
)
25+
self._decoding_tree = self.setup_decoding_tree(code_min_size)
2526
self._position = 0
2627

28+
def setup_decoding_tree(self, N_min, **kwargs):
29+
"""Setup decoding tree."""
30+
return self.node_class(mask=self.mask, N_min=N_min)
31+
2732
def _set_initial_state(self, received_llr):
2833
"""Initialize decoder with received message."""
2934
self.current_state = np.zeros(self.n, dtype=np.int8)
@@ -73,12 +78,12 @@ def compute_intermediate_alpha(self, leaf):
7378
parent_alpha = node.parent.alpha
7479

7580
if node.is_left:
76-
node.alpha = self._compute_left_alpha(parent_alpha)
81+
node.alpha = compute_left_alpha(parent_alpha)
7782
continue
7883

7984
left_node = node.siblings[0]
8085
left_beta = left_node.beta
81-
node.alpha = self._compute_right_alpha(parent_alpha, left_beta)
86+
node.alpha = compute_right_alpha(parent_alpha, left_beta)
8287
node.is_computed = True
8388

8489
def compute_intermediate_beta(self, node):

python_polar_coding/polar_codes/fast_ssc/node.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def get_node_type(self):
138138
return FastSSCNode.ZERO_NODE
139139
if self._check_is_one(self._mask) and self.N >= self.one_min_size:
140140
return FastSSCNode.ONE_NODE
141-
if self.N >= self.repetition_min_size and self._check_is_parity(self._mask): # noqa
141+
if self.N >= self.repetition_min_size and self._check_is_rep(self._mask): # noqa
142142
return FastSSCNode.REPETITION
143143
if self.N >= self.spc_min_size and self._check_is_spc(self._mask):
144144
return FastSSCNode.SINGLE_PARITY_CHECK
@@ -153,7 +153,7 @@ def _check_is_zero(self, mask):
153153
def _check_is_spc(self, mask):
154154
return mask[0] == 0 and np.sum(mask) == mask.size - 1
155155

156-
def _check_is_parity(self, mask):
156+
def _check_is_rep(self, mask):
157157
return mask[-1] == 1 and np.sum(mask) == 1
158158

159159
def _build_decoding_tree(self):

python_polar_coding/polar_codes/g_fast_ssc/decoder.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,14 @@ def __init__(
1616
code_min_size: int = 0,
1717
AF: int = 1,
1818
):
19-
super().__init__(n=n, mask=mask, is_systematic=is_systematic)
20-
self._decoding_tree = self.node_class(
21-
mask=self.mask,
22-
N_min=code_min_size,
23-
AF=AF,
19+
self.AF = AF
20+
super().__init__(
21+
n=n,
22+
mask=mask,
23+
is_systematic=is_systematic,
24+
code_min_size=code_min_size,
2425
)
25-
self._position = 0
26+
27+
def setup_decoding_tree(self, N_min, **kwargs):
28+
"""Setup decoding tree."""
29+
return self.node_class(mask=self.mask, N_min=N_min, AF=self.AF)

python_polar_coding/polar_codes/g_fast_ssc/functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from ..fast_ssc import compute_single_parity_check
66

77

8-
@numba.njit
8+
# @numba.njit
99
def compute_g_repetition(llr, mask_steps, last_chunk_type, N):
1010
"""Compute bits for Generalized Repetition node.
1111

python_polar_coding/polar_codes/sc/decoder.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,13 @@ def _compute_intermediate_alpha(self, position):
8888
continue
8989

9090
if self.current_state[i - 1] == 0:
91-
self.intermediate_llr[i] = self._compute_left_alpha(llr)
91+
self.intermediate_llr[i] = functions.compute_left_alpha(llr)
9292
continue
9393

9494
end = position
9595
start = end - np.power(2, self.n - i)
9696
left_bits = self.intermediate_bits[i][start: end]
97-
self.intermediate_llr[i] = self._compute_right_alpha(llr, left_bits)
97+
self.intermediate_llr[i] = functions.compute_right_alpha(llr, left_bits)
9898

9999
def _compute_beta(self, position):
100100
"""Make decision about current decoding value."""

0 commit comments

Comments
 (0)