Skip to content

Commit 48c127d

Browse files
koubaakoubaa
andauthored
add datatype enum to python module (#404)
* add datatype enum to python module Signed-off-by: koubaa <[email protected]> * add datatype enum to python module Signed-off-by: koubaa <[email protected]> * add doc Signed-off-by: koubaa <[email protected]> --------- Signed-off-by: koubaa <[email protected]> Co-authored-by: koubaa <[email protected]>
1 parent cb85d21 commit 48c127d

File tree

4 files changed

+51
-1
lines changed

4 files changed

+51
-1
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,7 @@ def kompute(shader):
187187
# Explicit type constructor supports uint32, int32, double, float and bool
188188
tensor_out_a = mgr.tensor_t(np.array([0, 0, 0], dtype=np.uint32))
189189
tensor_out_b = mgr.tensor_t(np.array([0, 0, 0], dtype=np.uint32))
190+
assert(t_data.data_type() == kp.DataTypes.uint)
190191
191192
params = [tensor_in_a, tensor_in_b, tensor_out_a, tensor_out_b]
192193

docs/overview/python-examples.rst

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,10 @@ Similarly you can find the same extended example as above:
6363
# Can be initialized with List[] or np.Array
6464
tensor_in_a = mgr.tensor([2, 2, 2])
6565
tensor_in_b = mgr.tensor([1, 2, 3])
66-
tensor_out = mgr.tensor([0, 0, 0])
66+
67+
# By default, tensors use a float type, but that can be explicitly specified
68+
tensor_out = mgr.tensor_t([0, 0, 0], dtype=np.float32)
69+
assert(tensor_out.data_type() == kp.DataTypes.float)
6770
6871
seq = mgr.sequence()
6972
seq.eval(kp.OpTensorSyncDevice([tensor_in_a, tensor_in_b, tensor_out]))

docs/overview/python-reference.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,16 @@ TensorType
3939
.. automodule:: kp
4040
:members:
4141

42+
MemoryTypes
43+
-------
44+
45+
.. automodule:: kp
46+
:members:
47+
48+
DataTypes
49+
-------
50+
51+
.. automodule:: kp
52+
:members:
53+
54+

python/src/main.cpp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,39 @@ PYBIND11_MODULE(kp, m)
6565

6666
py::module_ np = py::module_::import("numpy");
6767

68+
py::enum_<kp::Memory::DataTypes>(m, "DataTypes")
69+
.value("bool",
70+
kp::Memory::DataTypes::eBool,
71+
DOC(kp, Memory, DataTypes, eBool))
72+
.value("int",
73+
kp::Memory::DataTypes::eInt,
74+
DOC(kp, Memory, DataTypes, eInt))
75+
.value("uint",
76+
kp::Memory::DataTypes::eUnsignedInt,
77+
DOC(kp, Memory, DataTypes, eUnsignedInt))
78+
.value("float",
79+
kp::Memory::DataTypes::eFloat,
80+
DOC(kp, Memory, DataTypes, eFloat))
81+
.value("double",
82+
kp::Memory::DataTypes::eDouble,
83+
DOC(kp, Memory, DataTypes, eDouble))
84+
.value("custom",
85+
kp::Memory::DataTypes::eCustom,
86+
DOC(kp, Memory, DataTypes, eCustom))
87+
.value("short",
88+
kp::Memory::DataTypes::eShort,
89+
DOC(kp, Memory, DataTypes, eShort))
90+
.value("ushort",
91+
kp::Memory::DataTypes::eUnsignedShort,
92+
DOC(kp, Memory, DataTypes, eUnsignedShort))
93+
.value("char",
94+
kp::Memory::DataTypes::eChar,
95+
DOC(kp, Memory, DataTypes, eChar))
96+
.value("uchar",
97+
kp::Memory::DataTypes::eUnsignedChar,
98+
DOC(kp, Memory, DataTypes, eUnsignedChar))
99+
.export_values();
100+
68101
py::enum_<kp::Memory::MemoryTypes>(m, "MemoryTypes")
69102
.value("device",
70103
kp::Memory::MemoryTypes::eDevice,

0 commit comments

Comments
 (0)