Skip to content

Commit b9dc7c1

Browse files
authored
[BIG tensor]fix crossentropy (#73529)
1 parent 7226f3f commit b9dc7c1

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

paddle/phi/kernels/gpu/cross_entropy_kernel.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -536,10 +536,10 @@ __global__ void WarpSoftmaxForwardSoftLabel(T* loss,
536536
constexpr int kIterations = kDimCeil / kWarpSize;
537537
constexpr int kIterationsV =
538538
(kIterations >= kVSize) ? (kIterations / kVSize) : 1;
539-
constexpr int kBatchSize = (kDimCeil <= 128) ? 2 : 1;
539+
constexpr int64_t kBatchSize = (kDimCeil <= 128) ? 2 : 1;
540540

541-
int first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
542-
int local_batches = batch_size - first_batch;
541+
int64_t first_batch = (blockDim.y * blockIdx.x + threadIdx.y) * kBatchSize;
542+
int64_t local_batches = batch_size - first_batch;
543543
if (local_batches > kBatchSize) {
544544
local_batches = kBatchSize;
545545
}

0 commit comments

Comments
 (0)