Skip to content

Commit 04a9918

Browse files
author
grigory
committed
Refactor Fast SSC node
1 parent e0f925f commit 04a9918

File tree

25 files changed

+709
-549
lines changed

25 files changed

+709
-549
lines changed

.isort.cfg

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@ line_length=79
33
length_sort_stdlib=1
44
multi_line_output=3
55
include_trailing_comma=True
6+
skip=apak,venv

python_polar_coding/__init__.py

100644100755
File mode changed.

python_polar_coding/polar_codes/__init__.py

100644100755
Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +0,0 @@
1-
from .fast_scan import FastSCANCodec
2-
from .fast_ssc import FastSSCPolarCodec
3-
from .g_fast_scan import GFastSCANCodec
4-
from .g_fast_ssc import GeneralizedFastSSCPolarCodec
5-
from .rc_scan import RCSCANPolarCodec
6-
from .sc import SCPolarCodec
7-
from .sc_list import SCListPolarCodec

python_polar_coding/polar_codes/base/__init__.py

100644100755
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from .decoding_path import DecodingPathMixin
55
from .encoder import *
66
from .functions import *
7+
from .node import *

python_polar_coding/polar_codes/base/functions.py

Lines changed: 0 additions & 52 deletions
This file was deleted.
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from .alpha import (
2+
compute_alpha,
3+
compute_left_alpha,
4+
compute_right_alpha,
5+
function_1,
6+
function_2,
7+
)
8+
from .beta_hard import compute_beta_hard, compute_parent_beta_hard
9+
from .beta_soft import compute_beta_soft
10+
from .encoding import compute_encoding_step
11+
from .node_types import NodeTypes, get_node_type
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import numba
2+
import numpy as np
3+
4+
5+
@numba.njit
6+
def compute_alpha(a: np.array, b: np.array) -> np.array:
7+
"""Basic function to compute intermediate LLR values."""
8+
c = np.zeros(a.shape[0])
9+
for i in range(c.shape[0]):
10+
c[i] = (
11+
np.sign(a[i]) *
12+
np.sign(b[i]) *
13+
np.fabs(np.array([a[i], b[i]])).min()
14+
)
15+
return c
16+
17+
18+
@numba.njit
19+
def compute_left_alpha(llr: np.array) -> np.array:
20+
"""Compute Alpha for left node during SC-based decoding."""
21+
N = llr.size // 2
22+
left = llr[:N]
23+
right = llr[N:]
24+
return compute_alpha(left, right)
25+
26+
27+
@numba.njit
28+
def compute_right_alpha(llr: np.array, left_beta: np.array) -> np.array:
29+
"""Compute Alpha for right node during SC-based decoding."""
30+
N = llr.size // 2
31+
left = llr[:N]
32+
right = llr[N:]
33+
return right - (2 * left_beta - 1) * left
34+
35+
36+
@numba.njit
37+
def function_1(a: np.array, b: np.array, c: np.array) -> np.array:
38+
"""Function 1.
39+
40+
Source: doi:10.1007/s12243-018-0634-7, formula 1.
41+
42+
"""
43+
return compute_alpha(a, b + c)
44+
45+
46+
@numba.njit
47+
def function_2(a: np.array, b: np.array, c: np.array) -> np.array:
48+
"""Function 2.
49+
50+
Source: doi:10.1007/s12243-018-0634-7, formula 2.
51+
52+
"""
53+
return compute_alpha(a, b) + c
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
"""Common functions for polar coding."""
2+
import numba
3+
import numpy as np
4+
5+
from .node_types import NodeTypes
6+
7+
# -----------------------------------------------------------------------------
8+
# Making hard decisions during the decoding
9+
# -----------------------------------------------------------------------------
10+
11+
12+
@numba.njit
13+
def zero(
14+
llr: np.array,
15+
mask_steps: int = 0,
16+
last_chunk_type: int = 0,
17+
) -> np.array:
18+
"""Makes hard decision based on soft input values (LLR)."""
19+
return np.zeros(llr.size, dtype=np.int8)
20+
21+
22+
@numba.njit
23+
def make_hard_decision(
24+
llr: np.array,
25+
mask_steps: int = 0,
26+
last_chunk_type: int = 0,
27+
) -> np.array:
28+
"""Makes hard decision based on soft input values (LLR)."""
29+
return np.array([s < 0 for s in llr], dtype=np.int8)
30+
31+
32+
@numba.njit
33+
def single_parity_check(
34+
llr: np.array,
35+
mask_steps: int = 0,
36+
last_chunk_type: int = 0,
37+
) -> np.array:
38+
"""Compute bits for Single Parity Check node.
39+
40+
Based on: https://arxiv.org/pdf/1307.7154.pdf, Section IV, A.
41+
42+
"""
43+
bits = make_hard_decision(llr)
44+
parity = np.sum(bits) % 2
45+
arg_min = np.abs(llr).argmin()
46+
bits[arg_min] = (bits[arg_min] + parity) % 2
47+
return bits
48+
49+
50+
@numba.njit
51+
def repetition(
52+
llr: np.array,
53+
mask_steps: int = 0,
54+
last_chunk_type: int = 0,
55+
) -> np.array:
56+
"""Compute bits for Repetition node.
57+
58+
Based on: https://arxiv.org/pdf/1307.7154.pdf, Section IV, B.
59+
60+
"""
61+
return (
62+
np.zeros(llr.size, dtype=np.int8) if np.sum(llr) >= 0
63+
else np.ones(llr.size, dtype=np.int8)
64+
)
65+
66+
67+
@numba.njit
68+
def g_repetition(
69+
llr: np.array,
70+
mask_steps: int,
71+
last_chunk_type: int,
72+
) -> np.array:
73+
"""Compute bits for Generalized Repetition node.
74+
75+
Based on: https://arxiv.org/pdf/1804.09508.pdf, Section III, A.
76+
77+
"""
78+
N = llr.size
79+
step = N // mask_steps # step is equal to a chunk size
80+
81+
last_alpha = np.zeros(step)
82+
for i in range(step):
83+
last_alpha[i] = np.sum(np.array([
84+
llr[i + j * step] for j in range(mask_steps)
85+
]))
86+
87+
last_beta = (
88+
make_hard_decision(last_alpha) if last_chunk_type == 1
89+
else single_parity_check(last_alpha)
90+
)
91+
92+
result = np.zeros(N)
93+
for i in range(0, N, step):
94+
result[i: i + step] = last_beta
95+
96+
return result
97+
98+
99+
@numba.njit
100+
def rg_parity(
101+
llr: np.array,
102+
mask_steps: int,
103+
last_chunk_type: int = 0,
104+
) -> np.array:
105+
"""Compute bits for Relaxed Generalized Parity Check node.
106+
107+
Based on: https://arxiv.org/pdf/1804.09508.pdf, Section III, B.
108+
109+
"""
110+
N = llr.size
111+
step = N // mask_steps # step is equal to a chunk size
112+
result = np.zeros(N)
113+
114+
for i in range(step):
115+
alpha = np.zeros(mask_steps)
116+
for j in range(mask_steps):
117+
alpha[j] = llr[i + j * step]
118+
119+
beta = single_parity_check(alpha)
120+
result[i:N:step] = beta
121+
122+
return result
123+
124+
125+
# Mapping between decoding node types and corresponding decoding methods
126+
_methods_map = {
127+
NodeTypes.ZERO: zero,
128+
NodeTypes.ONE: make_hard_decision,
129+
NodeTypes.SINGLE_PARITY_CHECK: single_parity_check,
130+
NodeTypes.REPETITION: repetition,
131+
NodeTypes.RG_PARITY: rg_parity,
132+
NodeTypes.G_REPETITION: g_repetition,
133+
}
134+
135+
136+
def compute_beta_hard(
137+
node_type: str,
138+
llr: np.array,
139+
mask_steps: int = 0,
140+
last_chunk_type: int = 0,
141+
*args, **kwargs,
142+
) -> np.array:
143+
"""Unites functions for making hard decisions during decoding."""
144+
method = _methods_map[node_type]
145+
return method(llr, mask_steps, last_chunk_type, *args, **kwargs)
146+
147+
148+
@numba.njit
149+
def compute_parent_beta_hard(left: np.array, right: np.array) -> np.array:
150+
"""Compute Beta values for parent Node."""
151+
N = left.size
152+
result = np.zeros(N * 2, dtype=np.int8)
153+
result[:N] = (left + right) % 2
154+
result[N:] = right
155+
156+
return result

0 commit comments

Comments
 (0)