@@ -584,14 +584,7 @@ void mini_jit::TensorOperation::execute(void const *tensor_in0, void const *tens
584584 char const *ptr_in1 = static_cast <char const *>(tensor_in1);
585585 char *ptr_out = static_cast <char *>(tensor_out);
586586
587- if (isParallel)
588- {
589- execute_dimension_parallel (0 , ptr_in0, ptr_in1, ptr_out, true , true );
590- }
591- else
592- {
593- execute_dimension (0 , ptr_in0, ptr_in1, ptr_out, true , true );
594- }
587+ execute_dimension (0 , ptr_in0, ptr_in1, ptr_out, true , true );
595588}
596589
597590void mini_jit::TensorOperation::execute_dimension (int64_t index_dim, char const *ptr_in0, char const *ptr_in1, char *ptr_out,
@@ -603,102 +596,6 @@ void mini_jit::TensorOperation::execute_dimension(int64_t index_dim, char const
603596 int64_t stride_in1 = isUnary (prim_main) ? 1 : strides_in1[index_dim];
604597 int64_t stride_out = strides_out[index_dim];
605598
606- if (exec_types[index_dim] == TensorConfig::exec_t ::seq)
607- {
608- release_assert (exec_types[index_dim] == TensorConfig::exec_t ::seq, " Expected a sequential loop" );
609-
610- bool is_first = first_access;
611- bool is_last = last_access;
612-
613- for (int64_t iDim = 0 ; iDim < dim_size; iDim++)
614- {
615- if (dim_types[index_dim] == TensorConfig::dim_t ::k)
616- {
617- is_first = first_access && (iDim == 0 );
618- is_last = last_access && (iDim == (dim_size - 1 ));
619- }
620-
621- char const *rec_ptr_in0 = ptr_in0 + iDim * stride_in0 * dtype_bytes;
622- char const *rec_ptr_in1 = ptr_in1 + iDim * stride_in1 * dtype_bytes;
623- char *rec_ptr_out = ptr_out + iDim * stride_out * dtype_bytes;
624- execute_dimension (index_dim + 1 , rec_ptr_in0, rec_ptr_in1, rec_ptr_out, is_first, is_last);
625- }
626- }
627- else
628- {
629- release_assert (exec_types[index_dim] == TensorConfig::exec_t ::prim, " Expected a primitive loop" );
630-
631- // call first touch kernel if necessary
632- if (first_access && prim_first != TensorConfig::prim_t ::none)
633- {
634- if (std::holds_alternative<Unary>(first_touch))
635- {
636- Unary::kernel_t kernel = std::get<Unary>(first_touch).get_kernel ();
637- kernel (ptr_out, ptr_out, strides_out[indexPrimN], strides_out[indexPrimN]);
638- }
639- else
640- {
641- release_assert (false , " Unexpected first touch primitive" );
642- }
643- }
644-
645- // call main_kernel kernel
646- if (prim_main != TensorConfig::prim_t ::none)
647- {
648- if (std::holds_alternative<Unary>(main_kernel))
649- {
650- Unary::kernel_t kernel = std::get<Unary>(main_kernel).get_kernel ();
651- kernel (ptr_in0, ptr_out, strides_in0[indexPrimN], strides_out[indexPrimN]);
652- }
653- else if (std::holds_alternative<Brgemm>(main_kernel))
654- {
655- Brgemm::kernel_t kernel = std::get<Brgemm>(main_kernel).get_kernel ();
656-
657- if (prim_main == TensorConfig::prim_t ::gemm)
658- {
659- kernel (ptr_in0, ptr_in1, ptr_out, strides_in0[indexPrimK], strides_in1[indexPrimN], strides_out[indexPrimN], 1 , 1 );
660- }
661- else if (prim_main == TensorConfig::prim_t ::brgemm)
662- {
663- kernel (ptr_in0, ptr_in1, ptr_out, strides_in0[indexPrimK], strides_in1[indexPrimN], strides_out[indexPrimN],
664- strides_in0[indexPrimBatch], strides_in1[indexPrimBatch]);
665- }
666- else
667- {
668- release_assert (false , " Unexpected Brgemm primitive." );
669- }
670- }
671- else
672- {
673- release_assert (false , " Unexpected main primitive." );
674- }
675- }
676-
677- // call last touch kernel if necessary
678- if (last_access && prim_last != TensorConfig::prim_t ::none)
679- {
680- if (std::holds_alternative<Unary>(last_touch))
681- {
682- Unary::kernel_t kernel = std::get<Unary>(last_touch).get_kernel ();
683- kernel (ptr_out, ptr_out, strides_out[indexPrimN], strides_out[indexPrimN]);
684- }
685- else
686- {
687- release_assert (false , " Unexpected last touch primitive" );
688- }
689- }
690- }
691- }
692-
693- void mini_jit::TensorOperation::execute_dimension_parallel (int64_t index_dim, char const *ptr_in0, char const *ptr_in1, char *ptr_out,
694- bool first_access, bool last_access)
695- {
696- uint32_t dtype_bytes = 4 ;
697- int64_t dim_size = dim_sizes[index_dim];
698- int64_t stride_in0 = strides_in0[index_dim];
699- int64_t stride_in1 = isUnary (prim_main) ? 1 : strides_in1[index_dim];
700- int64_t stride_out = strides_out[index_dim];
701-
702599 if (exec_types[index_dim] == TensorConfig::exec_t ::shared)
703600 {
704601 // Parallel execution with OpenMP
@@ -719,7 +616,7 @@ void mini_jit::TensorOperation::execute_dimension_parallel(int64_t index_dim, ch
719616 char const *rec_ptr_in0 = ptr_in0 + iDim * stride_in0 * dtype_bytes;
720617 char const *rec_ptr_in1 = ptr_in1 + iDim * stride_in1 * dtype_bytes;
721618 char *rec_ptr_out = ptr_out + iDim * stride_out * dtype_bytes;
722- execute_dimension_parallel (index_dim + 1 , rec_ptr_in0, rec_ptr_in1, rec_ptr_out, is_first, is_last);
619+ execute_dimension (index_dim + 1 , rec_ptr_in0, rec_ptr_in1, rec_ptr_out, is_first, is_last);
723620 }
724621 }
725622 else if (exec_types[index_dim] == TensorConfig::exec_t ::seq)
@@ -739,7 +636,7 @@ void mini_jit::TensorOperation::execute_dimension_parallel(int64_t index_dim, ch
739636 char const *rec_ptr_in0 = ptr_in0 + iDim * stride_in0 * dtype_bytes;
740637 char const *rec_ptr_in1 = ptr_in1 + iDim * stride_in1 * dtype_bytes;
741638 char *rec_ptr_out = ptr_out + iDim * stride_out * dtype_bytes;
742- execute_dimension_parallel (index_dim + 1 , rec_ptr_in0, rec_ptr_in1, rec_ptr_out, is_first, is_last);
639+ execute_dimension (index_dim + 1 , rec_ptr_in0, rec_ptr_in1, rec_ptr_out, is_first, is_last);
743640 }
744641 }
745642 else
0 commit comments