Skip to content

Commit ba57eaa

Browse files
author
grigory
committed
Refactor G-Fast SSC codec
1 parent ae5251c commit ba57eaa

File tree

10 files changed

+80
-279
lines changed

10 files changed

+80
-279
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,5 @@
88
apak/*
99
firestore/*credentials*
1010
src*
11-
tests/_trial_temp*
11+
tests/_trial_temp*
12+
venv/

python_polar_coding/polar_codes/base/functions/node_types.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ class NodeTypeDetector:
2424
# Minimal number of chunks in generalized nodes
2525
MIN_CHUNKS = 2
2626

27+
def __init__(self, *args, **kwargs):
28+
self.last_chunk_type = None
29+
self.mask_steps = None
30+
2731
def __call__(
2832
self,
2933
supported_nodes: list,

python_polar_coding/polar_codes/base/node.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ def __init__(self, mask: np.array, name: str = ROOT, AF: int = 0, **kwargs): #
4141
self._alpha = np.zeros(self.N, dtype=np.double)
4242
self._beta = np.zeros(self.N, dtype=np.int8)
4343

44+
# For generalized decoders
45+
self.last_chunk_type = get_node_type.last_chunk_type
46+
self.mask_steps = get_node_type.mask_steps
47+
4448
self.build_decoding_tree()
4549

4650
def __str__(self):
Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
from .codec import GeneralizedFastSSCPolarCodec
22
from .decoder import GeneralizedFastSSCDecoder
3-
from .functions import *
4-
from .node import GeneralizedFastSSCNode
3+
from .node import GFastSSCNode

python_polar_coding/polar_codes/g_fast_ssc/codec.py

Lines changed: 9 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,26 +18,22 @@ def __init__(
1818
N: int,
1919
K: int,
2020
design_snr: float = 0.0,
21-
is_systematic: bool = True,
2221
mask: Union[str, None] = None,
2322
pcc_method: str = FastSSCPolarCodec.BHATTACHARYYA,
24-
Ns: int = 1,
25-
AF: int = 1,
23+
AF: int = 0,
2624
):
2725

28-
self.Ns = Ns
2926
self.AF = AF
30-
super().__init__(N=N, K=K,
31-
is_systematic=is_systematic,
32-
design_snr=design_snr,
33-
mask=mask,
34-
pcc_method=pcc_method)
27+
super().__init__(
28+
N=N,
29+
K=K,
30+
design_snr=design_snr,
31+
mask=mask,
32+
pcc_method=pcc_method,
33+
)
3534

3635
def init_decoder(self):
37-
return self.decoder_class(n=self.n, mask=self.mask,
38-
is_systematic=self.is_systematic,
39-
code_min_size=self.Ns,
40-
AF=self.AF)
36+
return self.decoder_class(n=self.n, mask=self.mask, AF=self.AF)
4137

4238
def to_dict(self):
4339
d = super().to_dict()

python_polar_coding/polar_codes/g_fast_ssc/decoder.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,28 +2,16 @@
22

33
from python_polar_coding.polar_codes.fast_ssc import FastSSCDecoder
44

5-
from .node import GeneralizedFastSSCNode
5+
from .node import GFastSSCNode
66

77

88
class GeneralizedFastSSCDecoder(FastSSCDecoder):
9-
node_class = GeneralizedFastSSCNode
9+
node_class = GFastSSCNode
1010

11-
def __init__(
12-
self,
13-
n: int,
14-
mask: np.array,
15-
is_systematic: bool = True,
16-
code_min_size: int = 0,
17-
AF: int = 1,
18-
):
11+
def __init__(self, n: int, mask: np.array, AF: int = 1):
1912
self.AF = AF
20-
super().__init__(
21-
n=n,
22-
mask=mask,
23-
is_systematic=is_systematic,
24-
code_min_size=code_min_size,
25-
)
13+
super().__init__(n=n, mask=mask)
2614

27-
def setup_decoding_tree(self, N_min, **kwargs):
15+
def _setup_decoding_tree(self):
2816
"""Setup decoding tree."""
29-
return self.node_class(mask=self.mask, N_min=N_min, AF=self.AF)
17+
return self.node_class(mask=self.mask, AF=self.AF)

python_polar_coding/polar_codes/g_fast_ssc/functions.py

Lines changed: 0 additions & 53 deletions
This file was deleted.

python_polar_coding/polar_codes/g_fast_ssc/node.py

100644100755
Lines changed: 19 additions & 122 deletions
Original file line numberDiff line numberDiff line change
@@ -1,132 +1,29 @@
1-
import numpy as np
1+
from typing import Dict
22

33
from python_polar_coding.polar_codes.fast_ssc import FastSSCNode
4-
from python_polar_coding.polar_codes.utils import splits
54

6-
from .functions import compute_g_repetition, compute_rg_parity
5+
from ..base import NodeTypes
76

87

9-
class GeneralizedFastSSCNode(FastSSCNode):
8+
class GFastSSCNode(FastSSCNode):
109
"""Decoder for Generalized Fast SSC code.
1110
1211
Based on: https://arxiv.org/pdf/1804.09508.pdf
1312
1413
"""
15-
G_REPETITION = 'G-REPETITION'
16-
RG_PARITY = 'RG-PARITY'
17-
18-
MIN_CHUNKS = 2
19-
20-
def __init__(self, AF=1, *args, **kwargs):
21-
self.AF = AF
22-
self.last_chunk_type = None
23-
self.mask_steps = None
24-
super().__init__(*args, **kwargs)
25-
26-
@property
27-
def is_g_repetition(self):
28-
return self._node_type == self.G_REPETITION
29-
30-
@property
31-
def is_rg_parity(self):
32-
return self._node_type == self.RG_PARITY
33-
34-
def get_node_type(self):
35-
ntype = super().get_node_type()
36-
if ntype != self.OTHER:
37-
return ntype
38-
if self._check_is_g_repetition(self._mask):
39-
return self.G_REPETITION
40-
if self._check_is_rg_parity(self._mask):
41-
return self.RG_PARITY
42-
return self.OTHER
43-
44-
def _build_decoding_tree(self):
45-
"""Build Generalized Fast SSC decoding tree."""
46-
if self.is_simplified_node:
47-
return
48-
49-
if self._mask.size == self.M:
50-
return
51-
52-
left_mask, right_mask = np.split(self._mask, 2)
53-
cls = self.__class__
54-
cls(mask=left_mask, name=self.LEFT, N_min=self.M, parent=self, AF=self.AF)
55-
cls(mask=right_mask, name=self.RIGHT, N_min=self.M, parent=self, AF=self.AF)
56-
57-
def _check_is_g_repetition(self, mask):
58-
"""Check the node is Generalized Repetition node.
59-
60-
Based on: https://arxiv.org/pdf/1804.09508.pdf, Section III, A.
61-
62-
"""
63-
# 1. Split mask into T chunks, T in range [2, 4, ..., N/2]
64-
for t in splits(self.MIN_CHUNKS, self.N // 2):
65-
chunks = np.split(mask, t)
66-
67-
last = chunks[-1]
68-
last_ok = (
69-
(self._check_is_spc(last) and last.size >= self.REPETITION_MIN_SIZE)
70-
or self._check_is_one(last)
71-
)
72-
if not last_ok:
73-
continue
74-
75-
others_ok = all(self._check_is_zero(c) for c in chunks[:-1])
76-
if not others_ok:
77-
continue
78-
79-
self.last_chunk_type = 1 if self._check_is_one(last) else 0
80-
self.mask_steps = t
81-
return True
82-
83-
return False
84-
85-
def _check_is_rg_parity(self, mask):
86-
"""Check the node is Relaxed Generalized Parity Check node.
87-
88-
Based on: https://arxiv.org/pdf/1804.09508.pdf, Section III, B.
89-
90-
"""
91-
# 1. Split mask into T chunks, T in range [2, 4, ..., N/2]
92-
for t in splits(self.MIN_CHUNKS, self.N // 2):
93-
chunks = np.split(mask, t)
94-
95-
first = chunks[0]
96-
if not self._check_is_zero(first):
97-
continue
98-
99-
ones = 0
100-
spcs = 0
101-
for c in chunks[1:]:
102-
if self._check_is_one(c):
103-
ones += 1
104-
elif c.size >= self.SPC_MIN_SIZE and self._check_is_spc(c):
105-
spcs += 1
106-
107-
others_ok = (ones + spcs + 1) == t and spcs <= self.AF
108-
if not others_ok:
109-
continue
110-
111-
self.mask_steps = t
112-
return True
113-
114-
return False
115-
116-
def compute_leaf_beta(self):
117-
super().compute_leaf_beta()
118-
klass = self.__class__
119-
120-
if self._node_type == klass.G_REPETITION:
121-
self._beta = compute_g_repetition(
122-
llr=self.alpha,
123-
mask_steps=self.mask_steps,
124-
last_chunk_type=self.last_chunk_type,
125-
N=self.N,
126-
)
127-
if self._node_type == klass.RG_PARITY:
128-
self._beta = compute_rg_parity(
129-
llr=self.alpha,
130-
mask_steps=self.mask_steps,
131-
N=self.N,
132-
)
14+
supported_nodes = (
15+
NodeTypes.ZERO,
16+
NodeTypes.ONE,
17+
NodeTypes.SINGLE_PARITY_CHECK,
18+
NodeTypes.REPETITION,
19+
NodeTypes.RG_PARITY,
20+
NodeTypes.G_REPETITION,
21+
)
22+
23+
def get_decoding_params(self) -> Dict:
24+
return dict(
25+
node_type=self.node_type,
26+
llr=self.alpha,
27+
mask_steps=self.mask_steps,
28+
last_chunk_type=self.last_chunk_type,
29+
)

python_polar_coding/tests/test_e_g_fast_ssc/__init__.py renamed to python_polar_coding/polar_codes/parallel_rc_scan/__init__.py

File renamed without changes.

0 commit comments

Comments
 (0)