Skip to content

Commit 1dcdc24

Browse files
authored
fix performence bug (#73501)
1 parent 3ba4aa0 commit 1dcdc24

File tree

2 files changed

+16
-16
lines changed

2 files changed

+16
-16
lines changed

paddle/phi/kernels/funcs/broadcast_function.h

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -292,10 +292,10 @@ __device__ void VectorizedBroadcastKernelImpl(
292292
const Array<const _ptr_ char *__restrict__, Arity> &ins,
293293
Array<_ptr_ OutT *, NumOuts> outs,
294294
const Array<bool, Arity> &use_broadcast,
295-
const int64_t numel,
295+
const uint32_t numel,
296296
const Array<kps::details::BroadcastConfig, Arity> &configs,
297-
int64_t num,
298-
int64_t block_offset,
297+
int num,
298+
int block_offset,
299299
int read_lens,
300300
Functor func) {
301301
using Traits = phi::funcs::FunctionTraits<Functor>;
@@ -310,10 +310,10 @@ __device__ void VectorizedBroadcastKernelImpl(
310310
if (LoadType == kBroadcast) {
311311
uint32_t index_bc[Arity][VecSize] = {0};
312312
Unroller<BroadcastDataInit, VecSize, Arity>::step(args);
313-
int64_t thread_offset = block_offset + threadIdx.x * VecSize;
313+
uint32_t thread_offset = block_offset + threadIdx.x * VecSize;
314314
#pragma unroll
315315
for (int k = 0; k < VecSize; ++k) {
316-
int64_t idx = thread_offset + k;
316+
uint32_t idx = thread_offset + k;
317317
if (IsBoundary && idx == numel) break;
318318
#pragma unroll
319319
for (int i = 0; i < phi::DDim::kMaxRank; ++i) {
@@ -352,10 +352,10 @@ __global__ void VectorizedBroadcastKernel(
352352
Array<const _ptr_ char *__restrict__, Arity> ins,
353353
Array<_ptr_ OutT *, NumOuts> outs,
354354
Array<bool, Arity> use_broadcast,
355-
int64_t numel,
355+
uint32_t numel,
356356
Array<kps::details::BroadcastConfig, Arity> configs,
357-
int64_t main_offset,
358-
int64_t tail_tid,
357+
int main_offset,
358+
int tail_tid,
359359
int read_lens,
360360
Functor func) {
361361
#ifdef PADDLE_WITH_XPU_KP
@@ -440,13 +440,13 @@ void LaunchBroadcastKernel(
440440
const BroadcastTypeClassifier<OutT, Functor, Arity, NumOuts> &classifier,
441441
Functor func) {
442442
#ifdef PADDLE_WITH_XPU_KP
443-
int64_t numel = classifier.numel;
443+
int numel = classifier.numel;
444444
const int threads = 64;
445445
const int blocks = 8;
446446
int read_lens = configs[0].buf_len;
447447
auto stream = ctx.x_context()->xpu_stream;
448-
int64_t main_offset = (numel / (read_lens * threads)) * read_lens * threads;
449-
int64_t tail_tid = numel % (read_lens * threads);
448+
int main_offset = (numel / (read_lens * threads)) * read_lens * threads;
449+
int tail_tid = numel % (read_lens * threads);
450450

451451
VectorizedBroadcastKernel<Functor, OutT, Arity, NumOuts, VecSize, false>
452452
<<<blocks, threads, 0, stream>>>(classifier.ins_data,
@@ -459,14 +459,14 @@ void LaunchBroadcastKernel(
459459
read_lens,
460460
func);
461461
#else
462-
const auto &numel = classifier.numel;
462+
const int &numel = classifier.numel;
463463
auto gpu_config =
464464
phi::backends::gpu::GetGpuLaunchConfig1D(ctx, numel, VecSize);
465465
auto stream = ctx.stream();
466466
auto threads = gpu_config.GetBlockSize();
467467
auto blocks = gpu_config.block_per_grid;
468-
int64_t main_offset = (numel / (VecSize * threads)) * VecSize * threads;
469-
int64_t tail_tid = numel % (VecSize * threads);
468+
int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
469+
int tail_tid = numel % (VecSize * threads);
470470

471471
if (classifier.all_elementwise) {
472472
VectorizedBroadcastKernel<Functor,

paddle/phi/kernels/primitive/datamover_primitives.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ struct alignas(sizeof(T) * VecSize) VectorType {
4040
* must be [dim1, dim0].
4141
*/
4242
struct BroadcastConfig {
43-
phi::funcs::FastDivMod<int64_t> divmoders[phi::DDim::kMaxRank];
43+
phi::funcs::FastDivMod<int> divmoders[phi::DDim::kMaxRank];
4444
uint64_t strides[phi::DDim::kMaxRank];
4545
int rank{0};
4646

@@ -51,7 +51,7 @@ struct BroadcastConfig {
5151
const std::vector<int64_t>& in_dims,
5252
int dim_size) {
5353
for (int i = 0; i < dim_size; ++i) {
54-
divmoders[i] = phi::funcs::FastDivMod<int64_t>(out_dims[i]);
54+
divmoders[i] = phi::funcs::FastDivMod<int>(out_dims[i]);
5555
}
5656

5757
for (int i = 0; i < dim_size; ++i) {

0 commit comments

Comments
 (0)