@@ -534,7 +534,7 @@ PYBIND11_MODULE(kp, m)
534534 if (spec_consts.dtype ().is (py::dtype::of<std::float_t >())) {
535535 std::vector<float > specConstsVec (
536536 (float *)specInfo.ptr , ((float *)specInfo.ptr ) + specInfo.size );
537- if (spec_consts .dtype ().is (py::dtype::of<std::float_t >())) {
537+ if (push_consts .dtype ().is (py::dtype::of<std::float_t >())) {
538538 std::vector<float > pushConstsVec ((float *)pushInfo.ptr ,
539539 ((float *)pushInfo.ptr ) +
540540 pushInfo.size );
@@ -543,7 +543,7 @@ PYBIND11_MODULE(kp, m)
543543 workgroup,
544544 specConstsVec,
545545 pushConstsVec);
546- } else if (spec_consts .dtype ().is (
546+ } else if (push_consts .dtype ().is (
547547 py::dtype::of<std::int32_t >())) {
548548 std::vector<int32_t > pushConstsVec (
549549 (int32_t *)pushInfo.ptr ,
@@ -553,7 +553,7 @@ PYBIND11_MODULE(kp, m)
553553 workgroup,
554554 specConstsVec,
555555 pushConstsVec);
556- } else if (spec_consts .dtype ().is (
556+ } else if (push_consts .dtype ().is (
557557 py::dtype::of<std::uint32_t >())) {
558558 std::vector<uint32_t > pushConstsVec (
559559 (uint32_t *)pushInfo.ptr ,
@@ -563,7 +563,7 @@ PYBIND11_MODULE(kp, m)
563563 workgroup,
564564 specConstsVec,
565565 pushConstsVec);
566- } else if (spec_consts .dtype ().is (
566+ } else if (push_consts .dtype ().is (
567567 py::dtype::of<std::double_t >())) {
568568 std::vector<double > pushConstsVec ((double *)pushInfo.ptr ,
569569 ((double *)pushInfo.ptr ) +
@@ -578,7 +578,7 @@ PYBIND11_MODULE(kp, m)
578578 std::vector<int32_t > specconstsvec ((int32_t *)specInfo.ptr ,
579579 ((int32_t *)specInfo.ptr ) +
580580 specInfo.size );
581- if (spec_consts .dtype ().is (py::dtype::of<std::float_t >())) {
581+ if (push_consts .dtype ().is (py::dtype::of<std::float_t >())) {
582582 std::vector<float > pushconstsvec ((float *)pushInfo.ptr ,
583583 ((float *)pushInfo.ptr ) +
584584 pushInfo.size );
@@ -587,7 +587,7 @@ PYBIND11_MODULE(kp, m)
587587 workgroup,
588588 specconstsvec,
589589 pushconstsvec);
590- } else if (spec_consts .dtype ().is (
590+ } else if (push_consts .dtype ().is (
591591 py::dtype::of<std::int32_t >())) {
592592 std::vector<int32_t > pushconstsvec (
593593 (int32_t *)pushInfo.ptr ,
@@ -597,7 +597,7 @@ PYBIND11_MODULE(kp, m)
597597 workgroup,
598598 specconstsvec,
599599 pushconstsvec);
600- } else if (spec_consts .dtype ().is (
600+ } else if (push_consts .dtype ().is (
601601 py::dtype::of<std::uint32_t >())) {
602602 std::vector<uint32_t > pushconstsvec (
603603 (uint32_t *)pushInfo.ptr ,
@@ -607,7 +607,7 @@ PYBIND11_MODULE(kp, m)
607607 workgroup,
608608 specconstsvec,
609609 pushconstsvec);
610- } else if (spec_consts .dtype ().is (
610+ } else if (push_consts .dtype ().is (
611611 py::dtype::of<std::double_t >())) {
612612 std::vector<double > pushconstsvec ((double *)pushInfo.ptr ,
613613 ((double *)pushInfo.ptr ) +
@@ -622,7 +622,7 @@ PYBIND11_MODULE(kp, m)
622622 std::vector<uint32_t > specconstsvec ((uint32_t *)specInfo.ptr ,
623623 ((uint32_t *)specInfo.ptr ) +
624624 specInfo.size );
625- if (spec_consts .dtype ().is (py::dtype::of<std::float_t >())) {
625+ if (push_consts .dtype ().is (py::dtype::of<std::float_t >())) {
626626 std::vector<float > pushconstsvec ((float *)pushInfo.ptr ,
627627 ((float *)pushInfo.ptr ) +
628628 pushInfo.size );
@@ -631,7 +631,7 @@ PYBIND11_MODULE(kp, m)
631631 workgroup,
632632 specconstsvec,
633633 pushconstsvec);
634- } else if (spec_consts .dtype ().is (
634+ } else if (push_consts .dtype ().is (
635635 py::dtype::of<std::int32_t >())) {
636636 std::vector<int32_t > pushconstsvec (
637637 (int32_t *)pushInfo.ptr ,
@@ -641,7 +641,7 @@ PYBIND11_MODULE(kp, m)
641641 workgroup,
642642 specconstsvec,
643643 pushconstsvec);
644- } else if (spec_consts .dtype ().is (
644+ } else if (push_consts .dtype ().is (
645645 py::dtype::of<std::uint32_t >())) {
646646 std::vector<uint32_t > pushconstsvec (
647647 (uint32_t *)pushInfo.ptr ,
@@ -651,7 +651,7 @@ PYBIND11_MODULE(kp, m)
651651 workgroup,
652652 specconstsvec,
653653 pushconstsvec);
654- } else if (spec_consts .dtype ().is (
654+ } else if (push_consts .dtype ().is (
655655 py::dtype::of<std::double_t >())) {
656656 std::vector<double > pushconstsvec ((double *)pushInfo.ptr ,
657657 ((double *)pushInfo.ptr ) +
@@ -666,7 +666,7 @@ PYBIND11_MODULE(kp, m)
666666 std::vector<double > specconstsvec ((double *)specInfo.ptr ,
667667 ((double *)specInfo.ptr ) +
668668 specInfo.size );
669- if (spec_consts .dtype ().is (py::dtype::of<std::float_t >())) {
669+ if (push_consts .dtype ().is (py::dtype::of<std::float_t >())) {
670670 std::vector<float > pushconstsvec ((float *)pushInfo.ptr ,
671671 ((float *)pushInfo.ptr ) +
672672 pushInfo.size );
@@ -675,31 +675,25 @@ PYBIND11_MODULE(kp, m)
675675 workgroup,
676676 specconstsvec,
677677 pushconstsvec);
678- } else if (spec_consts.dtype ().is (
679- py::dtype::of<std::int32_t >())) {
680- std::vector<float > pushconstsvec ((int32_t *)pushInfo.ptr ,
681- ((int32_t *)pushInfo.ptr ) +
682- pushInfo.size );
678+ } else if (push_consts.dtype ().is (py::dtype::of<std::int32_t >())) {
679+ std::vector<int32_t > pushconstsvec ((int32_t *)pushInfo.ptr ,
680+ ((int32_t *)pushInfo.ptr ) + pushInfo.size );
683681 return self.algorithm (tensors,
684682 spirvVec,
685683 workgroup,
686684 specconstsvec,
687685 pushconstsvec);
688- } else if (spec_consts.dtype ().is (
689- py::dtype::of<std::uint32_t >())) {
690- std::vector<float > pushconstsvec ((uint32_t *)pushInfo.ptr ,
691- ((uint32_t *)pushInfo.ptr ) +
692- pushInfo.size );
686+ } else if (push_consts.dtype ().is (py::dtype::of<std::uint32_t >())) {
687+ std::vector<uint32_t > pushconstsvec ((uint32_t *)pushInfo.ptr ,
688+ ((uint32_t *)pushInfo.ptr ) + pushInfo.size );
693689 return self.algorithm (tensors,
694690 spirvVec,
695691 workgroup,
696692 specconstsvec,
697693 pushconstsvec);
698- } else if (spec_consts.dtype ().is (
699- py::dtype::of<std::double_t >())) {
700- std::vector<float > pushconstsvec ((double *)pushInfo.ptr ,
701- ((double *)pushInfo.ptr ) +
702- pushInfo.size );
694+ } else if (push_consts.dtype ().is (py::dtype::of<std::double_t >())) {
695+ std::vector<double > pushconstsvec ((double *)pushInfo.ptr ,
696+ ((double *)pushInfo.ptr ) + pushInfo.size );
703697 return self.algorithm (tensors,
704698 spirvVec,
705699 workgroup,
0 commit comments