Skip to content

Commit 6b8cf22

Browse files
author
grigory
committed
WIP Extended Generalized polar codes
1 parent 3f6ee94 commit 6b8cf22

File tree

9 files changed

+359
-0
lines changed

9 files changed

+359
-0
lines changed

e_g_fast_ssc_analisys.ipynb

Lines changed: 118 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .codec import EGFastSSCPolarCodec
2+
from .decoder import EGFastSSCDecoder
3+
from .functions import *
4+
from .node import EGFastSSCNode
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from python_polar_coding.polar_codes.g_fast_ssc import (
2+
GeneralizedFastSSCPolarCodec,
3+
)
4+
5+
from .decoder import EGFastSSCDecoder
6+
7+
8+
class EGFastSSCPolarCodec(GeneralizedFastSSCPolarCodec):
9+
"""Extended Generalized Fast SSC codec."""
10+
decoder_class = EGFastSSCDecoder
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from python_polar_coding.polar_codes.g_fast_ssc import (
2+
GeneralizedFastSSCDecoder,
3+
)
4+
5+
from .node import EGFastSSCNode
6+
7+
8+
class EGFastSSCDecoder(GeneralizedFastSSCDecoder):
9+
""""""
10+
node_class = EGFastSSCNode
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import numba
2+
import numpy as np
3+
4+
from ..base import compute_left_alpha
5+
6+
7+
@numba.njit
8+
def compute_left_alpha_sign(alpha: np.array) -> np.array:
9+
""""""
10+
left_alpha = compute_left_alpha(alpha)
11+
return np.sign(np.sum(left_alpha))
12+
13+
14+
@numba.njit
15+
def compute_right_alpha(alpha: np.array, left_sign: int = 1) -> np.array:
16+
"""`left_sign` is 1 or -1"""
17+
N = alpha.size // 2
18+
left_alpha = alpha[:N]
19+
right_alpha = alpha[N:]
20+
return right_alpha + left_alpha * left_sign
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
import numpy as np
2+
3+
from python_polar_coding.polar_codes.g_fast_ssc import GeneralizedFastSSCNode
4+
5+
from .functions import compute_left_alpha_sign, compute_right_alpha
6+
7+
8+
class EGFastSSCNode(GeneralizedFastSSCNode):
9+
"""Decoder for Generalized Fast SSC code.
10+
11+
Based on: https://arxiv.org/pdf/1804.09508.pdf
12+
13+
"""
14+
ZERO_ANY = 'ZERO-ANY'
15+
REP_ANY = 'REP-ANY'
16+
17+
def __init__(self, *args, **kwargs):
18+
# Contains `ANY` node for ZERO_ANY or REP_ANY
19+
self.inner_node = None
20+
super().__init__(*args, **kwargs)
21+
22+
@property
23+
def is_any(self):
24+
return (
25+
self.is_zero or
26+
self.is_one or
27+
self.is_repetition or
28+
self.is_parity or
29+
self.is_g_repetition or
30+
self.is_rg_parity
31+
)
32+
33+
@property
34+
def is_zero_any(self):
35+
return self._node_type == self.ZERO_ANY
36+
37+
@property
38+
def is_rep_any(self):
39+
return self._node_type == self.REP_ANY
40+
41+
def get_node_type(self):
42+
ntype = super().get_node_type()
43+
if ntype != self.OTHER:
44+
return ntype
45+
if self._check_is_zero_any(self._mask):
46+
return self.ZERO_ANY
47+
if self._check_is_rep_any(self._mask):
48+
return self.REP_ANY
49+
return self.OTHER
50+
51+
def _check_is_zero_any(self, mask):
52+
""""""
53+
left, right = np.split(mask, 2)
54+
if not self._check_is_zero(left):
55+
return False
56+
inner_node = self.__class__(
57+
mask=right,
58+
name=self.ROOT,
59+
N_min=self.N_min,
60+
AF=self.AF
61+
)
62+
if not inner_node.is_any:
63+
return False
64+
65+
self.inner_node = inner_node
66+
return True
67+
68+
def _check_is_rep_any(self, mask):
69+
""""""
70+
left, right = np.split(mask, 2)
71+
if not self._check_is_rep(left):
72+
return False
73+
right_node = self.__class__(
74+
mask=right,
75+
name=self.ROOT,
76+
N_min=self.N_min,
77+
AF=self.AF
78+
)
79+
if not right_node.is_any:
80+
return False
81+
82+
self.inner_node = right_node
83+
return True
84+
85+
def compute_leaf_beta(self):
86+
super().compute_leaf_beta()
87+
klass = self.__class__
88+
89+
if self._node_type == klass.ZERO_ANY:
90+
self._beta = self.compute_zero_any()
91+
if self._node_type == klass.REP_ANY:
92+
self._beta = self.compute_rep_any()
93+
94+
def compute_zero_any(self):
95+
""""""
96+
right_alpha = compute_right_alpha(self.alpha, left_sign=1)
97+
98+
self.inner_node.alpha = right_alpha
99+
self.inner_node.compute_leaf_beta()
100+
101+
beta = np.zeros(self.N, dtype=np.int8)
102+
beta[:self.inner_node.N] = self.inner_node.beta
103+
beta[self.inner_node.N:] = self.inner_node.beta
104+
return beta
105+
106+
def compute_rep_any(self):
107+
""""""
108+
left_sign = compute_left_alpha_sign(self.alpha)
109+
right_alpha = compute_right_alpha(self.alpha, left_sign)
110+
111+
self.inner_node.alpha = right_alpha
112+
self.inner_node.compute_leaf_beta()
113+
114+
beta = np.zeros(self.N, dtype=np.int8)
115+
beta[:self.inner_node.N] = self.inner_node.beta
116+
beta[self.inner_node.N:] = self.inner_node.beta
117+
return beta

python_polar_coding/tests/test_e_g_fast_ssc/__init__.py

Whitespace-only changes.

python_polar_coding/tests/test_e_g_fast_ssc/test_codec.py

Whitespace-only changes.
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
from unittest import TestCase
2+
3+
import numpy as np
4+
5+
from python_polar_coding.polar_codes.e_g_fast_ssc import EGFastSSCNode
6+
7+
8+
class TestEGFastSSC(TestCase):
9+
10+
@classmethod
11+
def setUpClass(cls):
12+
cls.check_node = EGFastSSCNode(mask=np.zeros(4))
13+
cls.alpha8 = np.array([
14+
15+
])
16+
17+
# ZERO-ANY node
18+
19+
def test_check_is_zero_any_right_g_rep(self):
20+
""""""
21+
mask = np.array([
22+
0, 0, 0, 0, 0, 0, 0, 0,
23+
0, 0, 0, 0, 0, 0, 1, 1,
24+
])
25+
self.assertTrue(self.check_node._check_is_zero_any(mask))
26+
27+
def test_check_is_zero_any_right_rg_par(self):
28+
""""""
29+
mask = np.array([
30+
0, 0, 0, 0, 0, 0, 0, 0,
31+
0, 0, 1, 1, 1, 1, 1, 1,
32+
])
33+
self.assertTrue(self.check_node._check_is_zero_any(mask))
34+
35+
def test_compute_zero_any_right_rg_parity(self):
36+
""""""
37+
mask = np.array([0, 0, 0, 0, 0, 0, 0, 0,
38+
0, 0, 1, 1, 1, 1, 1, 1])
39+
alpha = np.array([
40+
-2.7273, -8.7327, 0.1087, -1.6463,
41+
2.7273, -8.7327, -0.1087, 1.6463,
42+
-2.7273, -8.7327, -0.1087, 1.6463,
43+
2.7273, 8.7326, 1.1087, -1.6463,
44+
])
45+
expected = np.array([1, 1, 1, 0, 0, 1, 0, 0,
46+
1, 1, 1, 0, 0, 1, 0, 0, ])
47+
48+
node = EGFastSSCNode(mask=mask)
49+
node.alpha = alpha
50+
node.compute_leaf_beta()
51+
52+
np.testing.assert_equal(node.beta, expected)
53+
54+
# REP-ANY node
55+
56+
def test_check_is_rep_any_right_one(self):
57+
""""""
58+
mask = np.array([0, 0, 0, 1, 1, 1, 1, 1])
59+
self.assertTrue(self.check_node._check_is_rep_any(mask))
60+
61+
def test_check_is_rep_any_right_spc(self):
62+
""""""
63+
mask = np.array([0, 0, 0, 1, 0, 1, 1, 1])
64+
self.assertTrue(self.check_node._check_is_rep_any(mask))
65+
66+
def test_check_is_rep_any_right_g_rep(self):
67+
""""""
68+
mask = np.array([
69+
0, 0, 0, 0, 0, 0, 0, 1,
70+
0, 0, 0, 0, 0, 0, 1, 1,
71+
])
72+
self.assertTrue(self.check_node._check_is_rep_any(mask))
73+
74+
def test_check_is_rep_any_right_rg_par(self):
75+
""""""
76+
mask = np.array([
77+
0, 0, 0, 0, 0, 0, 0, 1,
78+
0, 0, 1, 1, 1, 1, 1, 1,
79+
])
80+
self.assertTrue(self.check_node._check_is_rep_any(mask))

0 commit comments

Comments
 (0)