@@ -292,10 +292,10 @@ __device__ void VectorizedBroadcastKernelImpl(
292
292
const Array<const _ptr_ char *__restrict__, Arity> &ins,
293
293
Array<_ptr_ OutT *, NumOuts> outs,
294
294
const Array<bool , Arity> &use_broadcast,
295
- const int64_t numel,
295
+ const uint32_t numel,
296
296
const Array<kps::details::BroadcastConfig, Arity> &configs,
297
- int64_t num,
298
- int64_t block_offset,
297
+ int num,
298
+ int block_offset,
299
299
int read_lens,
300
300
Functor func) {
301
301
using Traits = phi::funcs::FunctionTraits<Functor>;
@@ -310,10 +310,10 @@ __device__ void VectorizedBroadcastKernelImpl(
310
310
if (LoadType == kBroadcast ) {
311
311
uint32_t index_bc[Arity][VecSize] = {0 };
312
312
Unroller<BroadcastDataInit, VecSize, Arity>::step (args);
313
- int64_t thread_offset = block_offset + threadIdx.x * VecSize;
313
+ uint32_t thread_offset = block_offset + threadIdx.x * VecSize;
314
314
#pragma unroll
315
315
for (int k = 0 ; k < VecSize; ++k) {
316
- int64_t idx = thread_offset + k;
316
+ uint32_t idx = thread_offset + k;
317
317
if (IsBoundary && idx == numel) break ;
318
318
#pragma unroll
319
319
for (int i = 0 ; i < phi::DDim::kMaxRank ; ++i) {
@@ -352,10 +352,10 @@ __global__ void VectorizedBroadcastKernel(
352
352
Array<const _ptr_ char *__restrict__, Arity> ins,
353
353
Array<_ptr_ OutT *, NumOuts> outs,
354
354
Array<bool , Arity> use_broadcast,
355
- int64_t numel,
355
+ uint32_t numel,
356
356
Array<kps::details::BroadcastConfig, Arity> configs,
357
- int64_t main_offset,
358
- int64_t tail_tid,
357
+ int main_offset,
358
+ int tail_tid,
359
359
int read_lens,
360
360
Functor func) {
361
361
#ifdef PADDLE_WITH_XPU_KP
@@ -440,13 +440,13 @@ void LaunchBroadcastKernel(
440
440
const BroadcastTypeClassifier<OutT, Functor, Arity, NumOuts> &classifier,
441
441
Functor func) {
442
442
#ifdef PADDLE_WITH_XPU_KP
443
- int64_t numel = classifier.numel ;
443
+ int numel = classifier.numel ;
444
444
const int threads = 64 ;
445
445
const int blocks = 8 ;
446
446
int read_lens = configs[0 ].buf_len ;
447
447
auto stream = ctx.x_context ()->xpu_stream ;
448
- int64_t main_offset = (numel / (read_lens * threads)) * read_lens * threads;
449
- int64_t tail_tid = numel % (read_lens * threads);
448
+ int main_offset = (numel / (read_lens * threads)) * read_lens * threads;
449
+ int tail_tid = numel % (read_lens * threads);
450
450
451
451
VectorizedBroadcastKernel<Functor, OutT, Arity, NumOuts, VecSize, false >
452
452
<<<blocks, threads, 0 , stream>>>(classifier.ins_data ,
@@ -459,14 +459,14 @@ void LaunchBroadcastKernel(
459
459
read_lens,
460
460
func);
461
461
#else
462
- const auto &numel = classifier.numel ;
462
+ const int &numel = classifier.numel ;
463
463
auto gpu_config =
464
464
phi::backends::gpu::GetGpuLaunchConfig1D (ctx, numel, VecSize);
465
465
auto stream = ctx.stream ();
466
466
auto threads = gpu_config.GetBlockSize ();
467
467
auto blocks = gpu_config.block_per_grid ;
468
- int64_t main_offset = (numel / (VecSize * threads)) * VecSize * threads;
469
- int64_t tail_tid = numel % (VecSize * threads);
468
+ int main_offset = (numel / (VecSize * threads)) * VecSize * threads;
469
+ int tail_tid = numel % (VecSize * threads);
470
470
471
471
if (classifier.all_elementwise ) {
472
472
VectorizedBroadcastKernel<Functor,
0 commit comments