diff --git a/python/src/main.cpp b/python/src/main.cpp index d94b2ecc..94e4b134 100644 --- a/python/src/main.cpp +++ b/python/src/main.cpp @@ -534,7 +534,7 @@ PYBIND11_MODULE(kp, m) if (spec_consts.dtype().is(py::dtype::of())) { std::vector specConstsVec( (float*)specInfo.ptr, ((float*)specInfo.ptr) + specInfo.size); - if (spec_consts.dtype().is(py::dtype::of())) { + if (push_consts.dtype().is(py::dtype::of())) { std::vector pushConstsVec((float*)pushInfo.ptr, ((float*)pushInfo.ptr) + pushInfo.size); @@ -543,7 +543,7 @@ PYBIND11_MODULE(kp, m) workgroup, specConstsVec, pushConstsVec); - } else if (spec_consts.dtype().is( + } else if (push_consts.dtype().is( py::dtype::of())) { std::vector pushConstsVec( (int32_t*)pushInfo.ptr, @@ -553,7 +553,7 @@ PYBIND11_MODULE(kp, m) workgroup, specConstsVec, pushConstsVec); - } else if (spec_consts.dtype().is( + } else if (push_consts.dtype().is( py::dtype::of())) { std::vector pushConstsVec( (uint32_t*)pushInfo.ptr, @@ -563,7 +563,7 @@ PYBIND11_MODULE(kp, m) workgroup, specConstsVec, pushConstsVec); - } else if (spec_consts.dtype().is( + } else if (push_consts.dtype().is( py::dtype::of())) { std::vector pushConstsVec((double*)pushInfo.ptr, ((double*)pushInfo.ptr) + @@ -578,7 +578,7 @@ PYBIND11_MODULE(kp, m) std::vector specconstsvec((int32_t*)specInfo.ptr, ((int32_t*)specInfo.ptr) + specInfo.size); - if (spec_consts.dtype().is(py::dtype::of())) { + if (push_consts.dtype().is(py::dtype::of())) { std::vector pushconstsvec((float*)pushInfo.ptr, ((float*)pushInfo.ptr) + pushInfo.size); @@ -587,7 +587,7 @@ PYBIND11_MODULE(kp, m) workgroup, specconstsvec, pushconstsvec); - } else if (spec_consts.dtype().is( + } else if (push_consts.dtype().is( py::dtype::of())) { std::vector pushconstsvec( (int32_t*)pushInfo.ptr, @@ -597,7 +597,7 @@ PYBIND11_MODULE(kp, m) workgroup, specconstsvec, pushconstsvec); - } else if (spec_consts.dtype().is( + } else if (push_consts.dtype().is( py::dtype::of())) { std::vector pushconstsvec( (uint32_t*)pushInfo.ptr, @@ -607,7 +607,7 @@ PYBIND11_MODULE(kp, m) workgroup, specconstsvec, pushconstsvec); - } else if (spec_consts.dtype().is( + } else if (push_consts.dtype().is( py::dtype::of())) { std::vector pushconstsvec((double*)pushInfo.ptr, ((double*)pushInfo.ptr) + @@ -622,7 +622,7 @@ PYBIND11_MODULE(kp, m) std::vector specconstsvec((uint32_t*)specInfo.ptr, ((uint32_t*)specInfo.ptr) + specInfo.size); - if (spec_consts.dtype().is(py::dtype::of())) { + if (push_consts.dtype().is(py::dtype::of())) { std::vector pushconstsvec((float*)pushInfo.ptr, ((float*)pushInfo.ptr) + pushInfo.size); @@ -631,7 +631,7 @@ PYBIND11_MODULE(kp, m) workgroup, specconstsvec, pushconstsvec); - } else if (spec_consts.dtype().is( + } else if (push_consts.dtype().is( py::dtype::of())) { std::vector pushconstsvec( (int32_t*)pushInfo.ptr, @@ -641,7 +641,7 @@ PYBIND11_MODULE(kp, m) workgroup, specconstsvec, pushconstsvec); - } else if (spec_consts.dtype().is( + } else if (push_consts.dtype().is( py::dtype::of())) { std::vector pushconstsvec( (uint32_t*)pushInfo.ptr, @@ -651,7 +651,7 @@ PYBIND11_MODULE(kp, m) workgroup, specconstsvec, pushconstsvec); - } else if (spec_consts.dtype().is( + } else if (push_consts.dtype().is( py::dtype::of())) { std::vector pushconstsvec((double*)pushInfo.ptr, ((double*)pushInfo.ptr) + @@ -666,7 +666,7 @@ PYBIND11_MODULE(kp, m) std::vector specconstsvec((double*)specInfo.ptr, ((double*)specInfo.ptr) + specInfo.size); - if (spec_consts.dtype().is(py::dtype::of())) { + if (push_consts.dtype().is(py::dtype::of())) { std::vector pushconstsvec((float*)pushInfo.ptr, ((float*)pushInfo.ptr) + pushInfo.size); @@ -675,31 +675,25 @@ PYBIND11_MODULE(kp, m) workgroup, specconstsvec, pushconstsvec); - } else if (spec_consts.dtype().is( - py::dtype::of())) { - std::vector pushconstsvec((int32_t*)pushInfo.ptr, - ((int32_t*)pushInfo.ptr) + - pushInfo.size); + } else if (push_consts.dtype().is(py::dtype::of())) { + std::vector pushconstsvec((int32_t*)pushInfo.ptr, + ((int32_t*)pushInfo.ptr) + pushInfo.size); return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec); - } else if (spec_consts.dtype().is( - py::dtype::of())) { - std::vector pushconstsvec((uint32_t*)pushInfo.ptr, - ((uint32_t*)pushInfo.ptr) + - pushInfo.size); + } else if (push_consts.dtype().is(py::dtype::of())) { + std::vector pushconstsvec((uint32_t*)pushInfo.ptr, + ((uint32_t*)pushInfo.ptr) + pushInfo.size); return self.algorithm(tensors, spirvVec, workgroup, specconstsvec, pushconstsvec); - } else if (spec_consts.dtype().is( - py::dtype::of())) { - std::vector pushconstsvec((double*)pushInfo.ptr, - ((double*)pushInfo.ptr) + - pushInfo.size); + } else if (push_consts.dtype().is(py::dtype::of())) { + std::vector pushconstsvec((double*)pushInfo.ptr, + ((double*)pushInfo.ptr) + pushInfo.size); return self.algorithm(tensors, spirvVec, workgroup, diff --git a/python/test/test_consts.py b/python/test/test_consts.py new file mode 100644 index 00000000..44783289 --- /dev/null +++ b/python/test/test_consts.py @@ -0,0 +1,259 @@ +import kp +import numpy as np +import pytest + +from .utils import compile_source + +workgroup = (3, 1, 1) + +def test_pushconsts_int32_spec_const_int32(): + shader = """ + #version 450 + layout(constant_id = 0) const int spec_val = 42; + layout(push_constant) uniform PushConsts { + int value[3]; + } pc; + layout(set = 0, binding = 0) buffer Output { + float outData[]; + }; + void main() { + uint idx = gl_GlobalInvocationID.x; + outData[idx] = float(pc.value[idx]) + float(spec_val); + } + """ + spirv = compile_source(shader) + mgr = kp.Manager() + arr_out = np.array([0.0, 0.0, 0.0], dtype=np.float32) + tensor_out = mgr.tensor_t(arr_out) + push_consts = np.array([1, 2, 3], dtype=np.int32) + spec_const = np.array([5], dtype=np.int32) # Will override spec_val to 5 + algo = mgr.algorithm([tensor_out], spirv, workgroup, spec_const, push_consts) + (mgr.sequence() + .record(kp.OpAlgoDispatch(algo)) + .record(kp.OpSyncLocal([tensor_out])) + .eval()) + assert np.array_equal(tensor_out.data(), push_consts.astype(np.float32) + spec_const.astype(np.float32)[0]) + +def test_pushconsts_int32_spec_const_uint32(): + shader = """ + #version 450 + layout(constant_id = 0) const uint spec_val = 0u; + layout(push_constant) uniform PushConsts { + int value[3]; + } pc; + layout(set = 0, binding = 0) buffer Output { + float outData[]; + }; + void main() { + uint idx = gl_GlobalInvocationID.x; + outData[idx] = float(pc.value[idx]) + float(spec_val); + } + """ + spirv = compile_source(shader) + mgr = kp.Manager() + arr_out = np.array([0.0, 0.0, 0.0], dtype=np.float32) + tensor_out = mgr.tensor_t(arr_out) + push_consts = np.array([4, 5, 6], dtype=np.int32) + spec_const = np.array([7], dtype=np.uint32) + algo = mgr.algorithm([tensor_out], spirv, workgroup, spec_const, push_consts) + (mgr.sequence() + .record(kp.OpAlgoDispatch(algo)) + .record(kp.OpSyncLocal([tensor_out])) + .eval()) + assert np.array_equal(tensor_out.data(), push_consts.astype(np.float32) + spec_const.astype(np.float32)[0]) + +def test_pushconsts_int32_spec_const_float(): + shader = """ + #version 450 + layout(constant_id = 0) const float spec_val = 0.0; + layout(push_constant) uniform PushConsts { + int value[3]; + } pc; + layout(set = 0, binding = 0) buffer Output { + float outData[]; + }; + void main() { + uint idx = gl_GlobalInvocationID.x; + outData[idx] = float(pc.value[idx]) + float(spec_val); + } + """ + spirv = compile_source(shader) + mgr = kp.Manager() + arr_out = np.array([0.0, 0.0, 0.0], dtype=np.float32) + tensor_out = mgr.tensor_t(arr_out) + push_consts = np.array([7, 8, 9], dtype=np.int32) + spec_const = np.array([3.3], dtype=np.float32) + algo = mgr.algorithm([tensor_out], spirv, workgroup, spec_const, push_consts) + (mgr.sequence() + .record(kp.OpAlgoDispatch(algo)) + .record(kp.OpSyncLocal([tensor_out])) + .eval()) + assert np.array_equal(tensor_out.data(), push_consts.astype(np.float32) + spec_const.astype(np.float32)[0]) + +def test_pushconsts_uint32_spec_const_int32(): + shader = """ + #version 450 + layout(constant_id = 0) const int spec_val = 0; + layout(push_constant) uniform PushConsts { + uint value[3]; + } pc; + layout(set = 0, binding = 0) buffer Output { + float outData[]; + }; + void main() { + uint idx = gl_GlobalInvocationID.x; + outData[idx] = float(pc.value[idx]) + float(spec_val); + } + """ + spirv = compile_source(shader) + mgr = kp.Manager() + arr_out = np.array([0.0, 0.0, 0.0], dtype=np.float32) + tensor_out = mgr.tensor_t(arr_out) + push_consts = np.array([1, 2, 3], dtype=np.uint32) + spec_const = np.array([4], dtype=np.int32) + algo = mgr.algorithm([tensor_out], spirv, workgroup, spec_const, push_consts) + (mgr.sequence() + .record(kp.OpAlgoDispatch(algo)) + .record(kp.OpSyncLocal([tensor_out])) + .eval()) + assert np.array_equal(tensor_out.data(), push_consts.astype(np.float32) + spec_const.astype(np.float32)[0]) + +def test_pushconsts_uint32_spec_const_uint32(): + shader = """ + #version 450 + layout(constant_id = 0) const uint spec_val = 0u; + layout(push_constant) uniform PushConsts { + uint value[3]; + } pc; + layout(set = 0, binding = 0) buffer Output { + float outData[]; + }; + void main() { + uint idx = gl_GlobalInvocationID.x; + outData[idx] = float(pc.value[idx]) + float(spec_val); + } + """ + spirv = compile_source(shader) + mgr = kp.Manager() + arr_out = np.array([0.0, 0.0, 0.0], dtype=np.float32) + tensor_out = mgr.tensor_t(arr_out) + push_consts = np.array([4, 5, 6], dtype=np.uint32) + spec_const = np.array([8], dtype=np.uint32) + algo = mgr.algorithm([tensor_out], spirv, workgroup, spec_const, push_consts) + (mgr.sequence() + .record(kp.OpAlgoDispatch(algo)) + .record(kp.OpSyncLocal([tensor_out])) + .eval()) + assert np.array_equal(tensor_out.data(), push_consts.astype(np.float32) + spec_const.astype(np.float32)[0]) + +def test_pushconsts_uint32_spec_const_float(): + shader = """ + #version 450 + layout(constant_id = 0) const float spec_val = 0.0; + layout(push_constant) uniform PushConsts { + uint value[3]; + } pc; + layout(set = 0, binding = 0) buffer Output { + float outData[]; + }; + void main() { + uint idx = gl_GlobalInvocationID.x; + outData[idx] = float(pc.value[idx]) + float(spec_val); + } + """ + spirv = compile_source(shader) + mgr = kp.Manager() + arr_out = np.array([0.0, 0.0, 0.0], dtype=np.float32) + tensor_out = mgr.tensor_t(arr_out) + push_consts = np.array([7, 8, 9], dtype=np.uint32) + spec_const = np.array([3.3], dtype=np.float32) + algo = mgr.algorithm([tensor_out], spirv, workgroup, spec_const, push_consts) + (mgr.sequence() + .record(kp.OpAlgoDispatch(algo)) + .record(kp.OpSyncLocal([tensor_out])) + .eval()) + assert np.array_equal(tensor_out.data(), push_consts.astype(np.float32) + spec_const.astype(np.float32)[0]) + +def test_pushconsts_float_spec_const_int32(): + shader = """ + #version 450 + layout(constant_id = 0) const int spec_val = 42; + layout(push_constant) uniform PushConsts { + float value[3]; + } pc; + layout(set = 0, binding = 0) buffer Output { + float outData[]; + }; + void main() { + uint idx = gl_GlobalInvocationID.x; + outData[idx] = float(pc.value[idx]) + float(spec_val); + } + """ + spirv = compile_source(shader) + mgr = kp.Manager() + arr_out = np.array([0.0, 0.0, 0.0], dtype=np.float32) + tensor_out = mgr.tensor_t(arr_out) + push_consts = np.array([1.1, 2.2, 3.3], dtype=np.float32) + spec_const = np.array([11], dtype=np.int32) + algo = mgr.algorithm([tensor_out], spirv, workgroup, spec_const, push_consts) + (mgr.sequence() + .record(kp.OpAlgoDispatch(algo)) + .record(kp.OpSyncLocal([tensor_out])) + .eval()) + assert np.array_equal(tensor_out.data(), push_consts.astype(np.float32) + spec_const.astype(np.float32)[0]) + +def test_pushconsts_float_spec_const_uint32(): + shader = """ + #version 450 + layout(constant_id = 0) const uint spec_val = 0u; + layout(push_constant) uniform PushConsts { + float value[3]; + } pc; + layout(set = 0, binding = 0) buffer Output { + float outData[]; + }; + void main() { + uint idx = gl_GlobalInvocationID.x; + outData[idx] = float(pc.value[idx]) + float(spec_val); + } + """ + spirv = compile_source(shader) + mgr = kp.Manager() + arr_out = np.array([0.0, 0.0, 0.0], dtype=np.float32) + tensor_out = mgr.tensor_t(arr_out) + push_consts = np.array([4.4, 5.5, 6.6], dtype=np.float32) + spec_const = np.array([2], dtype=np.uint32) + algo = mgr.algorithm([tensor_out], spirv, workgroup, spec_const, push_consts) + (mgr.sequence() + .record(kp.OpAlgoDispatch(algo)) + .record(kp.OpSyncLocal([tensor_out])) + .eval()) + assert np.array_equal(tensor_out.data(), push_consts.astype(np.float32) + spec_const.astype(np.float32)[0]) + +def test_pushconsts_float_spec_const_float(): + shader = """ + #version 450 + layout(constant_id = 0) const float spec_val = 0.0; + layout(push_constant) uniform PushConsts { + float value[3]; + } pc; + layout(set = 0, binding = 0) buffer Output { + float outData[]; + }; + void main() { + uint idx = gl_GlobalInvocationID.x; + outData[idx] = float(pc.value[idx]) + float(spec_val); + } + """ + spirv = compile_source(shader) + mgr = kp.Manager() + arr_out = np.array([0.0, 0.0, 0.0], dtype=np.float32) + tensor_out = mgr.tensor_t(arr_out) + push_consts = np.array([7.7, 8.8, 9.9], dtype=np.float32) + spec_const = np.array([1.1], dtype=np.float32) + algo = mgr.algorithm([tensor_out], spirv, workgroup, spec_const, push_consts) + (mgr.sequence() + .record(kp.OpAlgoDispatch(algo)) + .record(kp.OpSyncLocal([tensor_out])) + .eval()) + assert np.array_equal(tensor_out.data(), push_consts.astype(np.float32) + spec_const.astype(np.float32)[0])