Skip to content

Commit 5b8558b

Browse files
authored
Fix python SpecConst and PushConst bugs (#430)
* Fix push const inconsistent type bug Signed-off-by: Alejandro Saucedo <[email protected]> * Fix push const bug Signed-off-by: Alejandro Saucedo <[email protected]> * Added test to consts Signed-off-by: Alejandro Saucedo <[email protected]> * Added tests for const types@ Signed-off-by: Alejandro Saucedo <[email protected]> * Added test to consts Signed-off-by: Alejandro Saucedo <[email protected]> * Added test to consts Signed-off-by: Alejandro Saucedo <[email protected]> * Added test to consts Signed-off-by: Alejandro Saucedo <[email protected]> --------- Signed-off-by: Alejandro Saucedo <[email protected]>
1 parent b22d4a2 commit 5b8558b

File tree

2 files changed

+281
-28
lines changed

2 files changed

+281
-28
lines changed

python/src/main.cpp

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ PYBIND11_MODULE(kp, m)
534534
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
535535
std::vector<float> specConstsVec(
536536
(float*)specInfo.ptr, ((float*)specInfo.ptr) + specInfo.size);
537-
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
537+
if (push_consts.dtype().is(py::dtype::of<std::float_t>())) {
538538
std::vector<float> pushConstsVec((float*)pushInfo.ptr,
539539
((float*)pushInfo.ptr) +
540540
pushInfo.size);
@@ -543,7 +543,7 @@ PYBIND11_MODULE(kp, m)
543543
workgroup,
544544
specConstsVec,
545545
pushConstsVec);
546-
} else if (spec_consts.dtype().is(
546+
} else if (push_consts.dtype().is(
547547
py::dtype::of<std::int32_t>())) {
548548
std::vector<int32_t> pushConstsVec(
549549
(int32_t*)pushInfo.ptr,
@@ -553,7 +553,7 @@ PYBIND11_MODULE(kp, m)
553553
workgroup,
554554
specConstsVec,
555555
pushConstsVec);
556-
} else if (spec_consts.dtype().is(
556+
} else if (push_consts.dtype().is(
557557
py::dtype::of<std::uint32_t>())) {
558558
std::vector<uint32_t> pushConstsVec(
559559
(uint32_t*)pushInfo.ptr,
@@ -563,7 +563,7 @@ PYBIND11_MODULE(kp, m)
563563
workgroup,
564564
specConstsVec,
565565
pushConstsVec);
566-
} else if (spec_consts.dtype().is(
566+
} else if (push_consts.dtype().is(
567567
py::dtype::of<std::double_t>())) {
568568
std::vector<double> pushConstsVec((double*)pushInfo.ptr,
569569
((double*)pushInfo.ptr) +
@@ -578,7 +578,7 @@ PYBIND11_MODULE(kp, m)
578578
std::vector<int32_t> specconstsvec((int32_t*)specInfo.ptr,
579579
((int32_t*)specInfo.ptr) +
580580
specInfo.size);
581-
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
581+
if (push_consts.dtype().is(py::dtype::of<std::float_t>())) {
582582
std::vector<float> pushconstsvec((float*)pushInfo.ptr,
583583
((float*)pushInfo.ptr) +
584584
pushInfo.size);
@@ -587,7 +587,7 @@ PYBIND11_MODULE(kp, m)
587587
workgroup,
588588
specconstsvec,
589589
pushconstsvec);
590-
} else if (spec_consts.dtype().is(
590+
} else if (push_consts.dtype().is(
591591
py::dtype::of<std::int32_t>())) {
592592
std::vector<int32_t> pushconstsvec(
593593
(int32_t*)pushInfo.ptr,
@@ -597,7 +597,7 @@ PYBIND11_MODULE(kp, m)
597597
workgroup,
598598
specconstsvec,
599599
pushconstsvec);
600-
} else if (spec_consts.dtype().is(
600+
} else if (push_consts.dtype().is(
601601
py::dtype::of<std::uint32_t>())) {
602602
std::vector<uint32_t> pushconstsvec(
603603
(uint32_t*)pushInfo.ptr,
@@ -607,7 +607,7 @@ PYBIND11_MODULE(kp, m)
607607
workgroup,
608608
specconstsvec,
609609
pushconstsvec);
610-
} else if (spec_consts.dtype().is(
610+
} else if (push_consts.dtype().is(
611611
py::dtype::of<std::double_t>())) {
612612
std::vector<double> pushconstsvec((double*)pushInfo.ptr,
613613
((double*)pushInfo.ptr) +
@@ -622,7 +622,7 @@ PYBIND11_MODULE(kp, m)
622622
std::vector<uint32_t> specconstsvec((uint32_t*)specInfo.ptr,
623623
((uint32_t*)specInfo.ptr) +
624624
specInfo.size);
625-
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
625+
if (push_consts.dtype().is(py::dtype::of<std::float_t>())) {
626626
std::vector<float> pushconstsvec((float*)pushInfo.ptr,
627627
((float*)pushInfo.ptr) +
628628
pushInfo.size);
@@ -631,7 +631,7 @@ PYBIND11_MODULE(kp, m)
631631
workgroup,
632632
specconstsvec,
633633
pushconstsvec);
634-
} else if (spec_consts.dtype().is(
634+
} else if (push_consts.dtype().is(
635635
py::dtype::of<std::int32_t>())) {
636636
std::vector<int32_t> pushconstsvec(
637637
(int32_t*)pushInfo.ptr,
@@ -641,7 +641,7 @@ PYBIND11_MODULE(kp, m)
641641
workgroup,
642642
specconstsvec,
643643
pushconstsvec);
644-
} else if (spec_consts.dtype().is(
644+
} else if (push_consts.dtype().is(
645645
py::dtype::of<std::uint32_t>())) {
646646
std::vector<uint32_t> pushconstsvec(
647647
(uint32_t*)pushInfo.ptr,
@@ -651,7 +651,7 @@ PYBIND11_MODULE(kp, m)
651651
workgroup,
652652
specconstsvec,
653653
pushconstsvec);
654-
} else if (spec_consts.dtype().is(
654+
} else if (push_consts.dtype().is(
655655
py::dtype::of<std::double_t>())) {
656656
std::vector<double> pushconstsvec((double*)pushInfo.ptr,
657657
((double*)pushInfo.ptr) +
@@ -666,7 +666,7 @@ PYBIND11_MODULE(kp, m)
666666
std::vector<double> specconstsvec((double*)specInfo.ptr,
667667
((double*)specInfo.ptr) +
668668
specInfo.size);
669-
if (spec_consts.dtype().is(py::dtype::of<std::float_t>())) {
669+
if (push_consts.dtype().is(py::dtype::of<std::float_t>())) {
670670
std::vector<float> pushconstsvec((float*)pushInfo.ptr,
671671
((float*)pushInfo.ptr) +
672672
pushInfo.size);
@@ -675,31 +675,25 @@ PYBIND11_MODULE(kp, m)
675675
workgroup,
676676
specconstsvec,
677677
pushconstsvec);
678-
} else if (spec_consts.dtype().is(
679-
py::dtype::of<std::int32_t>())) {
680-
std::vector<float> pushconstsvec((int32_t*)pushInfo.ptr,
681-
((int32_t*)pushInfo.ptr) +
682-
pushInfo.size);
678+
} else if (push_consts.dtype().is(py::dtype::of<std::int32_t>())) {
679+
std::vector<int32_t> pushconstsvec((int32_t*)pushInfo.ptr,
680+
((int32_t*)pushInfo.ptr) + pushInfo.size);
683681
return self.algorithm(tensors,
684682
spirvVec,
685683
workgroup,
686684
specconstsvec,
687685
pushconstsvec);
688-
} else if (spec_consts.dtype().is(
689-
py::dtype::of<std::uint32_t>())) {
690-
std::vector<float> pushconstsvec((uint32_t*)pushInfo.ptr,
691-
((uint32_t*)pushInfo.ptr) +
692-
pushInfo.size);
686+
} else if (push_consts.dtype().is(py::dtype::of<std::uint32_t>())) {
687+
std::vector<uint32_t> pushconstsvec((uint32_t*)pushInfo.ptr,
688+
((uint32_t*)pushInfo.ptr) + pushInfo.size);
693689
return self.algorithm(tensors,
694690
spirvVec,
695691
workgroup,
696692
specconstsvec,
697693
pushconstsvec);
698-
} else if (spec_consts.dtype().is(
699-
py::dtype::of<std::double_t>())) {
700-
std::vector<float> pushconstsvec((double*)pushInfo.ptr,
701-
((double*)pushInfo.ptr) +
702-
pushInfo.size);
694+
} else if (push_consts.dtype().is(py::dtype::of<std::double_t>())) {
695+
std::vector<double> pushconstsvec((double*)pushInfo.ptr,
696+
((double*)pushInfo.ptr) + pushInfo.size);
703697
return self.algorithm(tensors,
704698
spirvVec,
705699
workgroup,

0 commit comments

Comments
 (0)