diff --git a/src/ATen/native/xpu/sycl/Indexing.cpp b/src/ATen/native/xpu/sycl/Indexing.cpp index 47a7f76e5b..0a5a31cbd5 100644 --- a/src/ATen/native/xpu/sycl/Indexing.cpp +++ b/src/ATen/native/xpu/sycl/Indexing.cpp @@ -77,6 +77,47 @@ class IndexSelectScalarFunctor { } }; +template +static inline void _embedding( + scalar_t* output, + const scalar_t* weight, + const index_t* index, + int64_t num_embeddings, + int64_t embedding_dim, + int64_t indices_length) { + using KernelClass = EmbeddingKernelFunctor; + using SLMKernelClass = EmbeddingKernelSLMFunctor; + int64_t work_group_size = syclDeviceMaxWorkGroupSize(); + int64_t num_xe_core = syclGpuEuCount() / syclGpuEUCountPerSubslice(); + + // 2 work group on 1 xe core to reach 100% occupancy + int64_t num_work_group = std::min( + num_xe_core * 2, + ceil_div( + static_cast(indices_length * embedding_dim), + static_cast(work_group_size))); + auto kfn = KernelClass( + output, weight, index, num_embeddings, embedding_dim, indices_length); + auto slmkfn = SLMKernelClass( + output, weight, index, num_embeddings, embedding_dim, indices_length); + // 2 work group share 1 Xe core, so slm is 64KB + if (static_cast(num_embeddings) * + static_cast(embedding_dim) * + static_cast(sizeof(scalar_t)) <= + static_cast(syclLocalMemSize() / 2)) + sycl_kernel_submit( + num_work_group * work_group_size, + work_group_size, + getCurrentSYCLQueue(), + slmkfn); + else + sycl_kernel_submit( + num_work_group * work_group_size, + work_group_size, + getCurrentSYCLQueue(), + kfn); +} + template < class SrcInfo, class DstInfo, @@ -202,14 +243,24 @@ void index_select_kernel( // Improve efficiency of generated native instructions for contiguous. // See comm/TensorInfo.h - if (dst.is_contiguous() && indices.is_contiguous()) - _index_select_kernel< - SrcInfo, - DstInfo, - IdxInfo, - /* TrivialOffCal */ true>( - src_info, dst_info, index_info, new_indexing_dim); - else + if (dst.is_contiguous() && indices.is_contiguous()) { + if (src.dim() == 2 && indices.dim() == 1 && src.is_contiguous()) { + _embedding( + dst.mutable_data_ptr(), + src.const_data_ptr(), + indices.const_data_ptr(), + src.size(0), + src.size(1), + indices.size(0)); + } else { + _index_select_kernel< + SrcInfo, + DstInfo, + IdxInfo, + /* TrivialOffCal */ true>( + src_info, dst_info, index_info, new_indexing_dim); + } + } else _index_select_kernel< SrcInfo, DstInfo, diff --git a/src/ATen/native/xpu/sycl/Indexing.h b/src/ATen/native/xpu/sycl/Indexing.h index acc277ec16..89d644b68b 100644 --- a/src/ATen/native/xpu/sycl/Indexing.h +++ b/src/ATen/native/xpu/sycl/Indexing.h @@ -27,6 +27,88 @@ TensorInfo tensorInfoIfScalar(TensorInfo ti) { return ti; } +template +struct EmbeddingKernelFunctor { + void operator()(sycl::nd_item<1> item) const { + for (auto thread_id = item.get_global_linear_id(); + thread_id < indices_length_ * embedding_dim_; + thread_id += item.get_local_range(0) * item.get_group_range(0)) { + SYCL_KERNEL_ASSERT(index_[thread_id / embedding_dim_] < num_embeddings_); + output_[thread_id] = weight_ + [index_[thread_id / embedding_dim_] * embedding_dim_ + + thread_id % embedding_dim_]; + } + } + EmbeddingKernelFunctor( + scalar_t* output, + const scalar_t* weight, + const index_t* index, + int64_t num_embeddings, + int64_t embedding_dim, + int64_t indices_length) + : output_(output), + weight_(weight), + index_(index), + num_embeddings_(num_embeddings), + embedding_dim_(embedding_dim), + indices_length_(indices_length) {} + + private: + scalar_t* output_; + const scalar_t* weight_; + const index_t* index_; + int64_t num_embeddings_; + int64_t embedding_dim_; + int64_t indices_length_; +}; + +template +struct EmbeddingKernelSLMFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { + void operator()(sycl::nd_item<1> item) const { + for (auto local_id = item.get_local_id(0); + local_id < num_embeddings_ * embedding_dim_; + local_id += item.get_local_range(0)) { + cached_weight_[local_id] = weight_[local_id]; + } + item.barrier(sycl_local_fence); + for (auto thread_id = item.get_global_linear_id(); + thread_id < indices_length_ * embedding_dim_; + thread_id += item.get_local_range(0) * item.get_group_range(0)) { + SYCL_KERNEL_ASSERT(index_[thread_id / embedding_dim_] < num_embeddings_); + output_[thread_id] = cached_weight_ + [index_[thread_id / embedding_dim_] * embedding_dim_ + + thread_id % embedding_dim_]; + } + } + void sycl_ker_config_convention(sycl::handler& cgh) { + cached_weight_ = + sycl_local_acc_t(num_embeddings_ * embedding_dim_, cgh); + } + EmbeddingKernelSLMFunctor( + scalar_t* output, + const scalar_t* weight, + const index_t* index, + int64_t num_embeddings, + int64_t embedding_dim, + int64_t indices_length) + : output_(output), + weight_(weight), + index_(index), + num_embeddings_(num_embeddings), + embedding_dim_(embedding_dim), + indices_length_(indices_length), + cached_weight_() {} + + private: + scalar_t* output_; + const scalar_t* weight_; + const index_t* index_; + int64_t num_embeddings_; + int64_t embedding_dim_; + int64_t indices_length_; + sycl_local_acc_t cached_weight_; +}; + template class IndexKernelConfig : public BatchKernelConfig { public: