Skip to content

Commit 9841d00

Browse files
committed
add python interface
1 parent 4376cea commit 9841d00

File tree

6 files changed

+865
-0
lines changed

6 files changed

+865
-0
lines changed

python/__init__.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
from warnings import warn
2+
from range_coder._range_coder import RangeEncoder, RangeDecoder # noqa: F401
3+
4+
try:
5+
import numpy as np
6+
except ImportError:
7+
pass
8+
9+
10+
def prob_to_cum_freq(prob, resolution=1024):
11+
"""
12+
Converts probability distribution into a cumulative frequency table.
13+
14+
Makes sure that non-zero probabilities are represented by non-zero frequencies,
15+
provided that :samp:`len({prob}) <= {resolution}`.
16+
17+
Parameters
18+
----------
19+
prob : ndarray or list
20+
A one-dimensional array representing a probability distribution
21+
22+
resolution : int
23+
Number of hypothetical samples used to generate integer frequencies
24+
25+
Returns
26+
-------
27+
list
28+
Cumulative frequency table
29+
"""
30+
31+
if len(prob) > resolution:
32+
warn('Resolution smaller than number of symbols.')
33+
34+
prob = np.asarray(prob, dtype=np.float64)
35+
freq = np.zeros(prob.size, dtype=int)
36+
37+
# this is similar to gradient descent in KL divergence (convex)
38+
with np.errstate(divide='ignore', invalid='ignore'):
39+
for _ in range(resolution):
40+
freq[np.nanargmax(prob / freq)] += 1
41+
42+
return [0] + np.cumsum(freq).tolist()
43+
44+
45+
def cum_freq_to_prob(cumFreq):
46+
"""
47+
Converts a cumulative frequency table into a probability distribution.
48+
49+
Parameters
50+
----------
51+
cumFreq : list
52+
Cumulative frequency table
53+
54+
Returns
55+
-------
56+
ndarray
57+
Probability distribution
58+
"""
59+
return np.diff(cumFreq).astype(np.float64) / cumFreq[-1]

python/src/module.cpp

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
#include <Python.h>
2+
#include "range_coder_interface.h"
3+
4+
static PyMethodDef RangeEncoder_methods[] = {
5+
{"encode",
6+
(PyCFunction)RangeEncoder_encode,
7+
METH_VARARGS | METH_KEYWORDS,
8+
RangeEncoder_encode_doc},
9+
{"close",
10+
(PyCFunction)RangeEncoder_close,
11+
METH_VARARGS | METH_KEYWORDS,
12+
RangeEncoder_close_doc},
13+
{0}
14+
};
15+
16+
17+
static PyGetSetDef RangeEncoder_getset[] = {
18+
{0}
19+
};
20+
21+
22+
PyTypeObject RangeEncoder_type = {
23+
PyVarObject_HEAD_INIT(0, 0)
24+
"range_coder.RangeEncoder", /*tp_name*/
25+
sizeof(RangeEncoderObject), /*tp_basicsize*/
26+
0, /*tp_itemsize*/
27+
(destructor)RangeEncoder_dealloc, /*tp_dealloc*/
28+
0, /*tp_print*/
29+
0, /*tp_getattr*/
30+
0, /*tp_setattr*/
31+
0, /*tp_compare*/
32+
0, /*tp_repr*/
33+
0, /*tp_as_number*/
34+
0, /*tp_as_sequence*/
35+
0, /*tp_as_mapping*/
36+
0, /*tp_hash */
37+
0, /*tp_call*/
38+
0, /*tp_str*/
39+
0, /*tp_getattro*/
40+
0, /*tp_setattro*/
41+
0, /*tp_as_buffer*/
42+
Py_TPFLAGS_DEFAULT, /*tp_flags*/
43+
RangeEncoder_doc, /*tp_doc*/
44+
0, /*tp_traverse*/
45+
0, /*tp_clear*/
46+
0, /*tp_richcompare*/
47+
0, /*tp_weaklistoffset*/
48+
0, /*tp_iter*/
49+
0, /*tp_iternext*/
50+
RangeEncoder_methods, /*tp_methods*/
51+
0, /*tp_members*/
52+
RangeEncoder_getset, /*tp_getset*/
53+
0, /*tp_base*/
54+
0, /*tp_dict*/
55+
0, /*tp_descr_get*/
56+
0, /*tp_descr_set*/
57+
0, /*tp_dictoffset*/
58+
(initproc)RangeEncoder_init, /*tp_init*/
59+
0, /*tp_alloc*/
60+
RangeEncoder_new, /*tp_new*/
61+
};
62+
63+
static PyMethodDef RangeDecoder_methods[] = {
64+
{"decode",
65+
(PyCFunction)RangeDecoder_decode,
66+
METH_VARARGS | METH_KEYWORDS,
67+
RangeDecoder_decode_doc},
68+
{"close",
69+
(PyCFunction)RangeDecoder_close,
70+
METH_VARARGS | METH_KEYWORDS,
71+
RangeDecoder_close_doc},
72+
{0}
73+
};
74+
75+
76+
static PyGetSetDef RangeDecoder_getset[] = {
77+
{0}
78+
};
79+
80+
81+
PyTypeObject RangeDecoder_type = {
82+
PyVarObject_HEAD_INIT(0, 0)
83+
"range_coder.RangeDecoder", /*tp_name*/
84+
sizeof(RangeDecoderObject), /*tp_basicsize*/
85+
0, /*tp_itemsize*/
86+
(destructor)RangeDecoder_dealloc, /*tp_dealloc*/
87+
0, /*tp_print*/
88+
0, /*tp_getattr*/
89+
0, /*tp_setattr*/
90+
0, /*tp_compare*/
91+
0, /*tp_repr*/
92+
0, /*tp_as_number*/
93+
0, /*tp_as_sequdece*/
94+
0, /*tp_as_mapping*/
95+
0, /*tp_hash */
96+
0, /*tp_call*/
97+
0, /*tp_str*/
98+
0, /*tp_getattro*/
99+
0, /*tp_setattro*/
100+
0, /*tp_as_buffer*/
101+
Py_TPFLAGS_DEFAULT, /*tp_flags*/
102+
RangeDecoder_doc, /*tp_doc*/
103+
0, /*tp_traverse*/
104+
0, /*tp_clear*/
105+
0, /*tp_richcompare*/
106+
0, /*tp_weaklistoffset*/
107+
0, /*tp_iter*/
108+
0, /*tp_iternext*/
109+
RangeDecoder_methods, /*tp_methods*/
110+
0, /*tp_members*/
111+
RangeDecoder_getset, /*tp_getset*/
112+
0, /*tp_base*/
113+
0, /*tp_dict*/
114+
0, /*tp_descr_get*/
115+
0, /*tp_descr_set*/
116+
0, /*tp_dictoffset*/
117+
(initproc)RangeDecoder_init, /*tp_init*/
118+
0, /*tp_alloc*/
119+
RangeDecoder_new, /*tp_new*/
120+
};
121+
122+
#if PY_MAJOR_VERSION >= 3
123+
static PyModuleDef range_coder_module = {
124+
PyModuleDef_HEAD_INIT,
125+
"_range_coder",
126+
"A fast implementation of a range encoder and decoder."
127+
-1, 0, 0, 0, 0, 0
128+
};
129+
#endif
130+
131+
132+
#if PY_MAJOR_VERSION >= 3
133+
PyMODINIT_FUNC PyInit__range_coder() {
134+
// create module object
135+
PyObject* module = PyModule_Create(&range_coder_module);
136+
#define RETVAL 0;
137+
#else
138+
PyMODINIT_FUNC init_range_coder() {
139+
PyObject* module = Py_InitModule3(
140+
"_range_coder", 0, "A fast implementation of a range encoder and decoder.");
141+
#define RETVAL void();
142+
#endif
143+
144+
if(!module)
145+
return RETVAL;
146+
147+
// initialize types
148+
if(PyType_Ready(&RangeEncoder_type) < 0)
149+
return RETVAL;
150+
if(PyType_Ready(&RangeDecoder_type) < 0)
151+
return RETVAL;
152+
153+
// add types to module
154+
Py_INCREF(&RangeEncoder_type);
155+
PyModule_AddObject(module, "RangeEncoder", reinterpret_cast<PyObject*>(&RangeEncoder_type));
156+
Py_INCREF(&RangeDecoder_type);
157+
PyModule_AddObject(module, "RangeDecoder", reinterpret_cast<PyObject*>(&RangeDecoder_type));
158+
159+
#if PY_MAJOR_VERSION >= 3
160+
return module;
161+
#endif
162+
}

0 commit comments

Comments
 (0)