11/*
2- * Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
2+ * Copyright (c) 2024-2025 , NVIDIA CORPORATION. All rights reserved.
33 *
44 * Licensed under the Apache License, Version 2.0 (the "License");
55 * you may not use this file except in compliance with the License.
@@ -49,57 +49,100 @@ __device__ __nv_bfloat16 negativeInfinity<__nv_bfloat16>()
4949}
5050
5151template <typename T, typename PackedT>
52+ __device__ PackedT packedNegativeInfinity ()
53+ {
54+ int constexpr kAlignment = sizeof (PackedT) / sizeof (T);
55+ T packed[kAlignment ];
56+ #pragma unroll
57+ for (int i = 0 ; i < kAlignment ; i++)
58+ {
59+ packed[i] = negativeInfinity<T>();
60+ }
61+ return *reinterpret_cast <PackedT*>(packed);
62+ }
63+
64+ template <typename T, typename PackedT, int32_t kBitsPerThread >
5265__global__ void __launch_bounds__ (kThreadsPerBlock ) logitsBitmaskKernel(
5366 T** __restrict__ logits, uint32_t const ** __restrict__ bitmask, int32_t vocabSizePadded, int32_t bitmaskSize)
5467{
5568 int constexpr kAlignment = sizeof (PackedT) / sizeof (T);
69+ uint32_t constexpr kPackedMask = (1 << kAlignment ) - 1 ;
70+
5671 int const batchIdx = blockIdx .y ;
5772
58- int const logitsGmemOffset = kThreadsPerBlock * blockIdx .x * kBitsPerMaskElement ;
59- T* logitsGmemPtr = logits[batchIdx] + logitsGmemOffset;
60- __shared__ T logitsSmem[kThreadsPerBlock * kBitsPerMaskElement ];
73+ int const blockOffset = blockIdx .x * kThreadsPerBlock * kBitsPerThread ;
74+ T* logitsGmemPtr = logits[batchIdx] + blockOffset;
75+
76+ uint32_t const * bitmaskGmemPtr = bitmask[batchIdx] + blockOffset / kBitsPerMaskElement ;
77+ int const bitmaskInnerIdx = threadIdx .x % (kBitsPerMaskElement / kAlignment );
78+ T logitsReg[kAlignment ];
6179
6280#pragma unroll
63- for (int offset = 0 ; offset < kThreadsPerBlock * kBitsPerMaskElement ; offset += kThreadsPerBlock * kAlignment )
81+ for (int offset = threadIdx .x * kAlignment ; offset < kThreadsPerBlock * kBitsPerThread ;
82+ offset += kThreadsPerBlock * kAlignment )
6483 {
65- int localOffset = offset + threadIdx .x * kAlignment ;
66- if (logitsGmemOffset + localOffset >= vocabSizePadded)
84+ if (blockOffset + offset >= vocabSizePadded)
6785 {
6886 break ;
6987 }
70- *reinterpret_cast <PackedT*>(logitsSmem + localOffset)
71- = *reinterpret_cast <PackedT*>(logitsGmemPtr + localOffset);
72- }
73- __syncthreads ();
7488
75- int const bitmaskIdx = kThreadsPerBlock * blockIdx . x + threadIdx . x ;
76- uint32_t const bitmaskVal = bitmask[batchIdx][bitmaskIdx] ;
89+ uint32_t const bitmaskVal
90+ = (~bitmaskGmemPtr[offset / kBitsPerMaskElement ] >> (bitmaskInnerIdx * kAlignment )) & kPackedMask ;
7791
78- #pragma unroll
79- for (int i = 0 ; i < kBitsPerMaskElement ; ++i)
80- {
81- int offset = (i + threadIdx .x ) % warpSize ;
82- if (bitmaskIdx * kBitsPerMaskElement + offset >= vocabSizePadded)
92+ if (bitmaskVal == 0 )
8393 {
8494 continue ;
8595 }
86- if (!((bitmaskVal >> offset) & 1 ))
96+
97+ if (bitmaskVal == kPackedMask )
8798 {
88- logitsSmem[threadIdx .x * kBitsPerMaskElement + offset] = negativeInfinity<T>();
99+ *reinterpret_cast <PackedT*>(logitsGmemPtr + offset) = packedNegativeInfinity<T, PackedT>();
100+ continue ;
89101 }
90- }
91- __syncthreads ();
92102
103+ *reinterpret_cast <PackedT*>(logitsReg) = *reinterpret_cast <PackedT*>(logitsGmemPtr + offset);
93104#pragma unroll
94- for (int offset = 0 ; offset < kThreadsPerBlock * kBitsPerMaskElement ; offset += kThreadsPerBlock * kAlignment )
95- {
96- int localOffset = offset + threadIdx .x * kAlignment ;
97- if (logitsGmemOffset + localOffset >= vocabSizePadded)
105+ for (int i = 0 ; i < kAlignment ; i++)
98106 {
99- break ;
107+ if (((bitmaskVal >> i) & 1 ))
108+ {
109+ logitsReg[i] = negativeInfinity<T>();
110+ }
100111 }
101- *reinterpret_cast <PackedT*>(logitsGmemPtr + localOffset)
102- = *reinterpret_cast <PackedT*>(logitsSmem + localOffset);
112+ *reinterpret_cast <PackedT*>(logitsGmemPtr + offset) = *reinterpret_cast <PackedT*>(logitsReg);
113+ }
114+ }
115+
116+ template <typename T, typename PackedT>
117+ void logitsBitmaskDispatchToBitsPerThread (
118+ T** logits, uint32_t const ** bitmask, int32_t batchSize, int32_t vocabSizePadded, cudaStream_t stream)
119+ {
120+ int constexpr kAlignment = sizeof (PackedT) / sizeof (T);
121+ int32_t const numBlocksPerRow = ceilDiv (2048 / kThreadsPerBlock * 128 , batchSize);
122+ int32_t const numBitsPerThread = ceilDiv (vocabSizePadded, kThreadsPerBlock * numBlocksPerRow);
123+ int32_t bitmaskSize = ceilDiv (vocabSizePadded, kBitsPerMaskElement );
124+
125+ dim3 const block (kThreadsPerBlock );
126+
127+ if (numBitsPerThread <= 4 && kAlignment <= 4 )
128+ {
129+ dim3 const grid (ceilDiv (vocabSizePadded, kThreadsPerBlock * 4 ), batchSize);
130+ logitsBitmaskKernel<T, PackedT, 4 ><<<grid, block, 0 , stream>>> (logits, bitmask, vocabSizePadded, bitmaskSize);
131+ }
132+ else if (numBitsPerThread <= 8 && kAlignment <= 8 )
133+ {
134+ dim3 const grid (ceilDiv (vocabSizePadded, kThreadsPerBlock * 8 ), batchSize);
135+ logitsBitmaskKernel<T, PackedT, 8 ><<<grid, block, 0 , stream>>> (logits, bitmask, vocabSizePadded, bitmaskSize);
136+ }
137+ else if (numBitsPerThread <= 16 && kAlignment <= 16 )
138+ {
139+ dim3 const grid (ceilDiv (vocabSizePadded, kThreadsPerBlock * 16 ), batchSize);
140+ logitsBitmaskKernel<T, PackedT, 16 ><<<grid, block, 0 , stream>>> (logits, bitmask, vocabSizePadded, bitmaskSize);
141+ }
142+ else
143+ {
144+ dim3 const grid (ceilDiv (vocabSizePadded, kThreadsPerBlock * 32 ), batchSize);
145+ logitsBitmaskKernel<T, PackedT, 32 ><<<grid, block, 0 , stream>>> (logits, bitmask, vocabSizePadded, bitmaskSize);
103146 }
104147}
105148} // namespace
@@ -108,25 +151,22 @@ template <typename T>
108151void invokeLogitsBitmask (
109152 T** logits, uint32_t const ** bitmask, int32_t batchSize, int32_t vocabSizePadded, cudaStream_t stream)
110153{
111- int bitmaskSize = ceilDiv (vocabSizePadded, kBitsPerMaskElement );
112- dim3 grid (ceilDiv (bitmaskSize, kThreadsPerBlock ), batchSize);
113- dim3 block (kThreadsPerBlock );
114-
154+ // Dispatch to PackedT
115155 if (vocabSizePadded % (sizeof (float4 ) / sizeof (T)) == 0 )
116156 {
117- logitsBitmaskKernel <T, float4 ><<<grid, block, 0 , stream>>> (logits, bitmask, vocabSizePadded, bitmaskSize );
157+ logitsBitmaskDispatchToBitsPerThread <T, float4 >(logits, bitmask, batchSize, vocabSizePadded, stream );
118158 }
119159 else if (vocabSizePadded % (sizeof (float2 ) / sizeof (T)) == 0 )
120160 {
121- logitsBitmaskKernel <T, float2 ><<<grid, block, 0 , stream>>> (logits, bitmask, vocabSizePadded, bitmaskSize );
161+ logitsBitmaskDispatchToBitsPerThread <T, float2 >(logits, bitmask, batchSize, vocabSizePadded, stream );
122162 }
123163 else if (vocabSizePadded % (sizeof (float ) / sizeof (T)) == 0 )
124164 {
125- logitsBitmaskKernel <T, float ><<<grid, block, 0 , stream>>> (logits, bitmask, vocabSizePadded, bitmaskSize );
165+ logitsBitmaskDispatchToBitsPerThread <T, float >(logits, bitmask, batchSize, vocabSizePadded, stream );
126166 }
127167 else
128168 {
129- logitsBitmaskKernel <T, T><<<grid, block, 0 , stream>>> (logits, bitmask, vocabSizePadded, bitmaskSize );
169+ logitsBitmaskDispatchToBitsPerThread <T, T>(logits, bitmask, batchSize, vocabSizePadded, stream );
130170 }
131171}
132172
0 commit comments