Skip to content

Commit db35dbf

Browse files
author
grigory
committed
Reorganize code of Generalized Fast SSC Polar codec
1 parent 64471ec commit db35dbf

File tree

11 files changed

+126
-171
lines changed

11 files changed

+126
-171
lines changed
Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +0,0 @@
1-
from .fast_ssc import FastSSCPolarCode
2-
from .g_fast_ssc import GeneralizedFastSSCPolarCode
3-
from .rc_scan import RCSCANPolarCode
4-
from .sc import SCPolarCode
5-
from .sc_list import SCListPolarCode

python_polar_coding/polar_codes/base/functions.py

Lines changed: 0 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -28,102 +28,7 @@ def compute_alpha(a, b):
2828
return c
2929

3030

31-
@numba.njit
32-
def function_1(a, b, c):
33-
"""Function 1.
34-
35-
Source: doi:10.1007/s12243-018-0634-7, formula 1.
36-
37-
"""
38-
return compute_alpha(a, b + c)
39-
40-
41-
@numba.njit
42-
def function_2(a, b, c):
43-
"""Function 2.
44-
45-
Source: doi:10.1007/s12243-018-0634-7, formula 2.
46-
47-
"""
48-
return compute_alpha(a, b) + c
49-
50-
5131
@numba.njit
5232
def make_hard_decision(soft_input):
5333
"""Makes hard decision based on soft input values (LLR)."""
5434
return np.array([s < 0 for s in soft_input], dtype=np.int8)
55-
56-
57-
@numba.njit
58-
def compute_single_parity_check(llr):
59-
"""Compute bits for Single Parity Check node.
60-
61-
Based on: https://arxiv.org/pdf/1307.7154.pdf, Section IV, A.
62-
63-
"""
64-
bits = make_hard_decision(llr)
65-
parity = np.sum(bits) % 2
66-
arg_min = np.abs(llr).argmin()
67-
bits[arg_min] = (bits[arg_min] + parity) % 2
68-
return bits
69-
70-
71-
@numba.njit
72-
def compute_repetition(llr):
73-
"""Compute bits for Repetition node.
74-
75-
Based on: https://arxiv.org/pdf/1307.7154.pdf, Section IV, B.
76-
77-
"""
78-
return (
79-
np.zeros(llr.size, dtype=np.int8) if np.sum(llr) >= 0
80-
else np.ones(llr.size, dtype=np.int8)
81-
)
82-
83-
84-
@numba.njit
85-
def compute_g_repetition(llr, mask_steps, last_chunk_type, N):
86-
"""Compute bits for Generalized Repetition node.
87-
88-
Based on: https://arxiv.org/pdf/1804.09508.pdf, Section III, A.
89-
90-
"""
91-
step = N // mask_steps # step is equal to a chunk size
92-
93-
last_alpha = np.zeros(step)
94-
for i in range(step):
95-
last_alpha[i] = np.sum(np.array([
96-
llr[i + j * step] for j in range(mask_steps)
97-
]))
98-
99-
last_beta = (
100-
make_hard_decision(last_alpha) if last_chunk_type == 1
101-
else compute_single_parity_check(last_alpha)
102-
)
103-
104-
result = np.zeros(N)
105-
for i in range(0, N, step):
106-
result[i: i + step] = last_beta
107-
108-
return result
109-
110-
111-
@numba.njit
112-
def compute_rg_parity(llr, mask_steps, N):
113-
"""Compute bits for Relaxed Generalized Parity Check node.
114-
115-
Based on: https://arxiv.org/pdf/1804.09508.pdf, Section III, B.
116-
117-
"""
118-
step = N // mask_steps # step is equal to a chunk size
119-
result = np.zeros(N)
120-
121-
for i in range(step):
122-
alpha = np.zeros(mask_steps)
123-
for j in range(mask_steps):
124-
alpha[j] = llr[i + j * step]
125-
126-
beta = compute_single_parity_check(alpha)
127-
result[i:N:step] = beta
128-
129-
return result

python_polar_coding/polar_codes/decoders/__init__.py

Lines changed: 0 additions & 6 deletions
This file was deleted.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .codec import GeneralizedFastSSCPolarCodec
2+
from .decoder import GeneralizedFastSSCDecoder
3+
from .functions import *
4+
from .node import GeneralizedFastSSCNode
Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,29 @@
11
from typing import Union
22

3-
from .decoders.g_fast_ssc_decoder import GeneralizedFastSSCDecoder
4-
from .fast_ssc import FastSSCPolarCode
3+
from python_polar_coding.polar_codes.fast_ssc import FastSSCPolarCodec
54

5+
from .decoder import GeneralizedFastSSCDecoder
66

7-
class GeneralizedFastSSCPolarCode(FastSSCPolarCode):
7+
8+
class GeneralizedFastSSCPolarCodec(FastSSCPolarCodec):
89
"""Generalized Fast SSC code.
910
1011
Based on: https://arxiv.org/pdf/1804.09508.pdf
1112
1213
"""
1314
decoder_class = GeneralizedFastSSCDecoder
1415

15-
def __init__(self, N: int, K: int,
16-
design_snr: float = 0.0,
17-
is_systematic: bool = True,
18-
mask: Union[str, None] = None,
19-
pcc_method: str = FastSSCPolarCode.BHATTACHARYYA,
20-
Ns: int = 1,
21-
AF: int = 1):
16+
def __init__(
17+
self,
18+
N: int,
19+
K: int,
20+
design_snr: float = 0.0,
21+
is_systematic: bool = True,
22+
mask: Union[str, None] = None,
23+
pcc_method: str = FastSSCPolarCodec.BHATTACHARYYA,
24+
Ns: int = 1,
25+
AF: int = 1,
26+
):
2227

2328
self.Ns = Ns
2429
self.AF = AF
@@ -28,7 +33,7 @@ def __init__(self, N: int, K: int,
2833
mask=mask,
2934
pcc_method=pcc_method)
3035

31-
def get_decoder(self):
36+
def init_decoder(self):
3237
return self.decoder_class(n=self.n, mask=self.mask,
3338
is_systematic=self.is_systematic,
3439
code_min_size=self.Ns,
@@ -38,11 +43,3 @@ def to_dict(self):
3843
d = super().to_dict()
3944
d.update({'AF': self.AF})
4045
return d
41-
42-
def show_tree(self):
43-
nodes = [
44-
(f'{leaf.name}, {leaf._node_type}, {leaf.N}: '
45-
f'{"".join([str(m) for m in leaf._mask])}')
46-
for leaf in self.tree.leaves
47-
]
48-
return '\n'.join(nodes)
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import numpy as np
2+
3+
from python_polar_coding.polar_codes.fast_ssc import FastSSCDecoder
4+
5+
from .node import GeneralizedFastSSCNode
6+
7+
8+
class GeneralizedFastSSCDecoder(FastSSCDecoder):
9+
node_class = GeneralizedFastSSCNode
10+
11+
def __init__(
12+
self,
13+
n: int,
14+
mask: np.array,
15+
is_systematic: bool = True,
16+
code_min_size: int = 0,
17+
AF: int = 1,
18+
):
19+
super().__init__(n=n, mask=mask, is_systematic=is_systematic)
20+
self._decoding_tree = self.node_class(
21+
mask=self.mask,
22+
N_min=code_min_size,
23+
AF=AF,
24+
)
25+
self._position = 0
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import numba
2+
import numpy as np
3+
4+
from ..base import make_hard_decision
5+
from ..fast_ssc import compute_single_parity_check
6+
7+
8+
@numba.njit
9+
def compute_g_repetition(llr, mask_steps, last_chunk_type, N):
10+
"""Compute bits for Generalized Repetition node.
11+
12+
Based on: https://arxiv.org/pdf/1804.09508.pdf, Section III, A.
13+
14+
"""
15+
step = N // mask_steps # step is equal to a chunk size
16+
17+
last_alpha = np.zeros(step)
18+
for i in range(step):
19+
last_alpha[i] = np.sum(np.array([
20+
llr[i + j * step] for j in range(mask_steps)
21+
]))
22+
23+
last_beta = (
24+
make_hard_decision(last_alpha) if last_chunk_type == 1
25+
else compute_single_parity_check(last_alpha)
26+
)
27+
28+
result = np.zeros(N)
29+
for i in range(0, N, step):
30+
result[i: i + step] = last_beta
31+
32+
return result
33+
34+
35+
@numba.njit
36+
def compute_rg_parity(llr, mask_steps, N):
37+
"""Compute bits for Relaxed Generalized Parity Check node.
38+
39+
Based on: https://arxiv.org/pdf/1804.09508.pdf, Section III, B.
40+
41+
"""
42+
step = N // mask_steps # step is equal to a chunk size
43+
result = np.zeros(N)
44+
45+
for i in range(step):
46+
alpha = np.zeros(mask_steps)
47+
for j in range(mask_steps):
48+
alpha[j] = llr[i + j * step]
49+
50+
beta = compute_single_parity_check(alpha)
51+
result[i:N:step] = beta
52+
53+
return result

python_polar_coding/polar_codes/decoders/g_fast_ssc_decoder.py renamed to python_polar_coding/polar_codes/g_fast_ssc/node.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
import numpy as np
22

3-
from ..base import functions
4-
from .fast_ssc_decoder import FastSSCDecoder, FastSSCNode
3+
from python_polar_coding.polar_codes.fast_ssc import FastSSCNode
4+
from python_polar_coding.polar_codes.utils import splits
55

6-
7-
def splits(start, end):
8-
while start <= end:
9-
yield start
10-
start *= 2
6+
from .functions import compute_g_repetition, compute_rg_parity
117

128

139
class GeneralizedFastSSCNode(FastSSCNode):
@@ -114,30 +110,15 @@ def compute_leaf_beta(self):
114110
klass = self.__class__
115111

116112
if self._node_type == klass.G_REPETITION:
117-
self._beta = functions.compute_g_repetition(
113+
self._beta = compute_g_repetition(
118114
llr=self.alpha,
119115
mask_steps=self.mask_steps,
120116
last_chunk_type=self.last_chunk_type,
121117
N=self.N,
122118
)
123119
if self._node_type == klass.RG_PARITY:
124-
self._beta = functions.compute_rg_parity(
120+
self._beta = compute_rg_parity(
125121
llr=self.alpha,
126122
mask_steps=self.mask_steps,
127123
N=self.N,
128124
)
129-
130-
131-
class GeneralizedFastSSCDecoder(FastSSCDecoder):
132-
node_class = GeneralizedFastSSCNode
133-
134-
def __init__(self, n: int,
135-
mask: np.array,
136-
is_systematic: bool = True,
137-
code_min_size: int = 0,
138-
AF: int = 1):
139-
super().__init__(n=n, mask=mask, is_systematic=is_systematic)
140-
self._decoding_tree = self.node_class(mask=self.mask,
141-
N_min=code_min_size,
142-
AF=AF)
143-
self._position = 0

python_polar_coding/tests/test_g_fast_ssc/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)