diff --git a/backends/iluvatar_gpu/CMakeLists.txt b/backends/iluvatar_gpu/CMakeLists.txt index d71fa59857b..974231c247e 100644 --- a/backends/iluvatar_gpu/CMakeLists.txt +++ b/backends/iluvatar_gpu/CMakeLists.txt @@ -219,6 +219,7 @@ file( ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/tril_triu_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/unbind_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/uniform_kernel.cu + ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/unique_consecutive_kernel.cu ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/gpu/where_kernel.cu # kernels/selected_rows ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/selected_rows/gpu/adamw_kernel.cu @@ -932,7 +933,7 @@ file( ${PADDLE_SOURCE_DIR}/paddle/phi/kernels/array_kernel.cc) set(CUDA_SRCS ${CUDA_SRCS1} ${CUDA_SRCS2}) -list(REMOVE_DUPLICATES CUDA_SRCS1) +list(REMOVE_DUPLICATES CUDA_SRCS) list( REMOVE_ITEM diff --git a/backends/iluvatar_gpu/kernels/cuda_kernels/cross_entropy_grad_kernel.cu b/backends/iluvatar_gpu/kernels/cuda_kernels/cross_entropy_grad_kernel.cu index cc8ddf5a4e7..032b23ac5bd 100644 --- a/backends/iluvatar_gpu/kernels/cuda_kernels/cross_entropy_grad_kernel.cu +++ b/backends/iluvatar_gpu/kernels/cuda_kernels/cross_entropy_grad_kernel.cu @@ -22,7 +22,7 @@ limitations under the License. */ namespace cub = hipcub; #endif -#include "../gpudnn/softmax_gpudnn.h" +#include "kernels/gpudnn/softmax_gpudnn.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_dnn.h" #include "paddle/phi/common/amp_type_traits.h" @@ -277,4 +277,5 @@ PD_REGISTER_PLUGIN_KERNEL(cross_entropy_with_softmax_grad, ALL_LAYOUT, phi::CrossEntropyWithSoftmaxGradKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/backends/iluvatar_gpu/kernels/cuda_kernels/cross_entropy_kernel.cu b/backends/iluvatar_gpu/kernels/cuda_kernels/cross_entropy_kernel.cu index 3438e48549e..aacfe924ab5 100644 --- a/backends/iluvatar_gpu/kernels/cuda_kernels/cross_entropy_kernel.cu +++ b/backends/iluvatar_gpu/kernels/cuda_kernels/cross_entropy_kernel.cu @@ -23,7 +23,7 @@ limitations under the License. */ namespace cub = hipcub; #endif -#include "../gpudnn/softmax_gpudnn.h" +#include "kernels/gpudnn/softmax_gpudnn.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_dnn.h" #include "paddle/phi/common/amp_type_traits.h" @@ -1412,4 +1412,5 @@ PD_REGISTER_PLUGIN_KERNEL(cross_entropy_with_softmax, ALL_LAYOUT, phi::CrossEntropyWithSoftmaxKernel, float, - phi::dtype::float16) {} + phi::dtype::float16, + phi::dtype::bfloat16) {} diff --git a/backends/iluvatar_gpu/kernels/cuda_kernels/index_elementwise_put_grad_kernel_register.cu b/backends/iluvatar_gpu/kernels/cuda_kernels/index_elementwise_put_grad_kernel_register.cu index c84b650803b..c600f784dfe 100644 --- a/backends/iluvatar_gpu/kernels/cuda_kernels/index_elementwise_put_grad_kernel_register.cu +++ b/backends/iluvatar_gpu/kernels/cuda_kernels/index_elementwise_put_grad_kernel_register.cu @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gpu/index_elementwise_put_grad_kernel.cu" // NOLINT #include "paddle/phi/kernels/index_elementwise_put_grad_kernel.h" PD_CUSTOM_KERNEL_REGISTER(index_elementwise_put_grad, @@ -21,13 +22,26 @@ PD_CUSTOM_KERNEL_REGISTER(index_elementwise_put_grad, phi::IndexElementwisePutGradKernel, bool, float, - double, int, int8_t, int64_t, int16_t, uint8_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64) {} + +PD_CUSTOM_KERNEL_REGISTER(index_elementwise_put_with_tensor_grad, + iluvatar_gpu, + ALL_LAYOUT, + phi::IndexElementwisePutWithTensorGradKernel, + bool, + float, + int, + int8_t, + int64_t, + int16_t, + uint8_t, + phi::float16, + phi::bfloat16, + phi::complex64) {} diff --git a/backends/iluvatar_gpu/kernels/cuda_kernels/index_elementwise_put_kernel_register.cu b/backends/iluvatar_gpu/kernels/cuda_kernels/index_elementwise_put_kernel_register.cu index eac81613553..750d5ef102f 100644 --- a/backends/iluvatar_gpu/kernels/cuda_kernels/index_elementwise_put_kernel_register.cu +++ b/backends/iluvatar_gpu/kernels/cuda_kernels/index_elementwise_put_kernel_register.cu @@ -13,6 +13,7 @@ // limitations under the License. #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/gpu/index_elementwise_put_kernel.cu" // NOLINT #include "paddle/phi/kernels/index_elementwise_put_kernel.h" PD_CUSTOM_KERNEL_REGISTER(index_elementwise_put, @@ -21,13 +22,26 @@ PD_CUSTOM_KERNEL_REGISTER(index_elementwise_put, phi::IndexElementwisePutKernel, bool, float, - double, int, int8_t, int64_t, int16_t, uint8_t, - phi::dtype::float16, - phi::dtype::bfloat16, - phi::dtype::complex, - phi::dtype::complex) {} + phi::float16, + phi::bfloat16, + phi::complex64) {} + +PD_CUSTOM_KERNEL_REGISTER(index_elementwise_put_with_tensor, + iluvatar_gpu, + ALL_LAYOUT, + phi::IndexElementwisePutWithTensorKernel, + bool, + float, + int, + int8_t, + int64_t, + int16_t, + uint8_t, + phi::float16, + phi::bfloat16, + phi::complex64) {} diff --git a/backends/iluvatar_gpu/kernels/cuda_kernels/log_softmax_grad_kernel.cu b/backends/iluvatar_gpu/kernels/cuda_kernels/log_softmax_grad_kernel.cu index 2f2b4a302ac..8535f257b4e 100644 --- a/backends/iluvatar_gpu/kernels/cuda_kernels/log_softmax_grad_kernel.cu +++ b/backends/iluvatar_gpu/kernels/cuda_kernels/log_softmax_grad_kernel.cu @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "../gpudnn/softmax_gpudnn.h" +#include "kernels/gpudnn/softmax_gpudnn.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/backends/iluvatar_gpu/kernels/cuda_kernels/log_softmax_kernel.cu b/backends/iluvatar_gpu/kernels/cuda_kernels/log_softmax_kernel.cu index 6347bfb75c2..4c8fff808b7 100644 --- a/backends/iluvatar_gpu/kernels/cuda_kernels/log_softmax_kernel.cu +++ b/backends/iluvatar_gpu/kernels/cuda_kernels/log_softmax_kernel.cu @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "../gpudnn/softmax_gpudnn.h" +#include "kernels/gpudnn/softmax_gpudnn.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/backends/iluvatar_gpu/kernels/cuda_kernels/softmax_grad_kernel.cu b/backends/iluvatar_gpu/kernels/cuda_kernels/softmax_grad_kernel.cu index af391c7cd98..28f46bb24b2 100644 --- a/backends/iluvatar_gpu/kernels/cuda_kernels/softmax_grad_kernel.cu +++ b/backends/iluvatar_gpu/kernels/cuda_kernels/softmax_grad_kernel.cu @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "../gpudnn/softmax_gpudnn.h" +#include "kernels/gpudnn/softmax_gpudnn.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/backends/iluvatar_gpu/kernels/cuda_kernels/softmax_kernel.cu b/backends/iluvatar_gpu/kernels/cuda_kernels/softmax_kernel.cu index 9658fd01e23..5aad43f7e34 100644 --- a/backends/iluvatar_gpu/kernels/cuda_kernels/softmax_kernel.cu +++ b/backends/iluvatar_gpu/kernels/cuda_kernels/softmax_kernel.cu @@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "../gpudnn/softmax_gpudnn.h" +#include "kernels/gpudnn/softmax_gpudnn.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/backends/iluvatar_gpu/kernels/cuda_kernels/unique_consecutive_kernel_register.cc b/backends/iluvatar_gpu/kernels/cuda_kernels/unique_consecutive_kernel_register.cc new file mode 100644 index 00000000000..e1be85609d6 --- /dev/null +++ b/backends/iluvatar_gpu/kernels/cuda_kernels/unique_consecutive_kernel_register.cc @@ -0,0 +1,29 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/unique_consecutive_kernel.h" + +PD_CUSTOM_KERNEL_REGISTER(unique_consecutive, + iluvatar_gpu, + ALL_LAYOUT, + phi::UniqueConsecutiveKernel, + float, + int32_t, + int64_t) { + kernel->OutputAt(1).SetDataType(kernel_key.dtype()); + kernel->OutputAt(2).SetDataType(kernel_key.dtype()); +} diff --git a/backends/iluvatar_gpu/kernels/gpudnn/softmax_gpudnn.h b/backends/iluvatar_gpu/kernels/gpudnn/softmax_gpudnn.h index 559ac826aae..1fadee8ed62 100644 --- a/backends/iluvatar_gpu/kernels/gpudnn/softmax_gpudnn.h +++ b/backends/iluvatar_gpu/kernels/gpudnn/softmax_gpudnn.h @@ -30,6 +30,9 @@ limitations under the License. */ #define MATRIX_SOFTMAX_ALIGN_BYTES 16 #define MATRIX_SOFTMAX_THRESHOLD 100000 +#ifdef PADDLE_WITH_COREX +#define MAX_YZ_DIM_SIZE 65535 +#endif namespace phi { @@ -845,6 +848,10 @@ static void GetGridDim( grid_x = std::min(grid_x, max_num_blocks); int grid_y = (max_num_blocks + grid_x - 1) / grid_x; grid_y = std::min(grid_y, high_dim); +#ifdef PADDLE_WITH_COREX + grid_y = std::min(grid_y, + std::max(MAX_YZ_DIM_SIZE / static_cast(block.y), 1)); +#endif grid->x = grid_x; grid->y = grid_y; } @@ -1211,7 +1218,7 @@ void SoftmaxForwardCUDAKernelDriverImpl(const GPUContext& dev_ctx, IndexType dim = tensor_dims[1]; int D = tensor_dims[2]; - if (D == 1) { + if (D == 1 && x.dtype() != phi::DataType::BFLOAT16) { if (!UseCudnnSoftmax(dev_ctx, dim, true)) { int dim_log2 = static_cast(Log2Ceil(dim)); IndexType dim_ceil = 1 << dim_log2; @@ -1278,7 +1285,7 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx, int dim = tensor_dims[1]; int D = tensor_dims[2]; - if (D == 1) { + if (D == 1 && out.dtype() != phi::DataType::BFLOAT16) { if (!UseCudnnSoftmax(dev_ctx, dim, true)) { int dim_log2 = Log2Ceil(dim); int dim_ceil = 1 << dim_log2; diff --git a/backends/iluvatar_gpu/runtime/runtime.cc b/backends/iluvatar_gpu/runtime/runtime.cc index 9d08d8e82e8..665a15c3756 100644 --- a/backends/iluvatar_gpu/runtime/runtime.cc +++ b/backends/iluvatar_gpu/runtime/runtime.cc @@ -555,6 +555,10 @@ C_Status Allocate(const C_Device device, void **ptr, size_t size) { err = cudaMalloc(ptr, size); if (err != cudaSuccess) { *ptr = NULL; + if (err == cudaErrorMemoryAllocation) { + VLOG(0) << "[RUNTIME] Failed to alloc hbm, size: " << size + << ", out of memory."; + } return C_ERROR; }