@@ -22,14 +22,17 @@ __global__ void quant_symm_row(
2222#if TURBOMIND_ARCH_SM90
2323 static_assert (group_size % vec_size == 0 );
2424 constexpr int threads = group_size / vec_size;
25+ const int dim1 = round_up (dim, WARP_SIZE * vec_size);
2526 for (int ti = blockIdx .x ; ti < num; ti += gridDim .x ) {
26- for (int di = threadIdx .x * vec_size; di < dim; di += blockDim .x * vec_size) {
27- Array<T, vec_size> vec;
28- Ldg (vec, src + ti * src_ld + di);
27+ for (int di = threadIdx .x * vec_size; di < dim1; di += blockDim .x * vec_size) {
28+ Array<T, vec_size> vec{};
29+ if (di < dim) {
30+ Ldg (vec, src + ti * src_ld + di);
31+ }
2932 auto absmax = fmaxf (static_cast <Tscale>(find_absmax<threads>(vec)), 1e-8f );
3033 const Tscale scale = absmax / qmax;
3134 const Tscale inv_scale = qmax / absmax;
32- if (threadIdx .x % threads == 0 ) {
35+ if (threadIdx .x % threads == 0 && di < dim ) {
3336 // column-major
3437 scales[(di / group_size) * scales_ld + ti] = scale;
3538 }
@@ -38,7 +41,9 @@ __global__ void quant_symm_row(
3841 for (int c = 0 ; c < vec_size; ++c) {
3942 tmp[c] = Tout (static_cast <Tscale>(vec[c]) * inv_scale);
4043 }
41- Store (out + ti * out_ld + di, tmp);
44+ if (di < dim) {
45+ Store (out + ti * out_ld + di, tmp);
46+ }
4247 }
4348 }
4449#endif
@@ -69,11 +74,13 @@ void QuantizeSymm(Tensor& out, Tensor& scale, const Tensor& src, cudaStream_t st
6974
7075 const int aligned_num = round_up<int >(num, alignment);
7176
77+ const int s_dim = cdiv<ssize_t >(dim, group_size);
78+
7279 if (!scale) {
73- scale = Tensor_<Tscale>({{dim / group_size , num}, {aligned_num, 1 }}, kDEVICE );
80+ scale = Tensor_<Tscale>({{s_dim , num}, {aligned_num, 1 }}, kDEVICE );
7481 }
7582 else {
76- TM_CHECK (std::make_tuple (dim / group_size , num) == scale.shapes (0 , 1 ));
83+ TM_CHECK (std::make_tuple (s_dim , num) == scale.shapes (0 , 1 ));
7784 TM_CHECK (scale.stride (1 ) == 1 );
7885 TM_CHECK (scale.stride (0 ) % alignment == 0 );
7986 }
@@ -159,17 +166,17 @@ __global__ void quant_symm_block(Tout* out, Tscale* scales, const T* src, Tscale
159166 __shared__ typename BlockReduce::TempStorage temp_storage;
160167 __shared__ T shared_inv_scale;
161168
162- const int ti = blockIdx .x * block_size;
163- const int di = blockIdx .y * block_size;
164- const int col = threadIdx .x % threads;
165169 const int row = threadIdx .x / threads;
170+ const int col = threadIdx .x % threads;
171+ const int ti = blockIdx .x * block_size;
172+ const int di = blockIdx .y * block_size + col * vec_size;
166173
167174 T absmax{};
168175 Array<T, vec_size> xs[S]{};
169176 PRAGMA_UNROLL
170177 for (int s = 0 ; s < S; ++s) {
171- if (auto r = ti + s * rows + row; r < num) {
172- Ldg (xs[s], src + (int64_t )r * dim + di + col * vec_size );
178+ if (auto r = ti + s * rows + row; r < num && di < dim ) {
179+ Ldg (xs[s], src + (int64_t )r * dim + di);
173180 }
174181 PRAGMA_UNROLL
175182 for (int i = 0 ; i < vec_size; ++i) {
@@ -193,14 +200,14 @@ __global__ void quant_symm_block(Tout* out, Tscale* scales, const T* src, Tscale
193200 for (int i = 0 ; i < vec_size; ++i) {
194201 ys[s][i] = Tout (static_cast <Tscale>(xs[s][i]) * inv_scale);
195202 }
196- if (auto r = ti + s * rows + row; r < num) {
197- Store (out + (int64_t )r * dim + di + col * vec_size , ys[s]);
203+ if (auto r = ti + s * rows + row; r < num && di < dim ) {
204+ Store (out + (int64_t )r * dim + di, ys[s]);
198205 }
199206 }
200207#endif
201208}
202209
203- void QuantizeSymmBlock (Tensor& out, Tensor& scale , const Tensor& src, cudaStream_t st)
210+ void QuantizeSymmBlock (Ref< Tensor> out_, Ref< Tensor> scale_ , const Tensor& src, cudaStream_t st)
204211{
205212 TM_CHECK (src.is_contiguous ());
206213 TM_CHECK_EQ (src.ndim (), 2 );
@@ -220,6 +227,9 @@ void QuantizeSymmBlock(Tensor& out, Tensor& scale, const Tensor& src, cudaStream
220227 constexpr int cta_size = 1024 ;
221228 const dim3 grid (bnum, bdim);
222229
230+ auto & out = out_.get ();
231+ auto & scale = scale_.get ();
232+
223233 if (!out) {
224234 out = Tensor_<Tout>{src.layout (), kDEVICE };
225235 }
@@ -259,7 +269,7 @@ __global__ void dequant_symm_block(Tout* out, const T* src, const Tscale* scales
259269 PRAGMA_UNROLL
260270 for (int s = 0 ; s < S; ++s) {
261271 const auto ti = blockIdx .x * block_size + s * rows + row;
262- if (ti < num) {
272+ if (ti < num && di < dim ) {
263273 Array<T, vec_size> x;
264274 Ldg (x, src + (int64_t )ti * dim + di);
265275 Array<Tout, vec_size> y;
@@ -273,7 +283,7 @@ __global__ void dequant_symm_block(Tout* out, const T* src, const Tscale* scales
273283#endif
274284}
275285
276- void DequantizeSymmBlock (Tensor& out, const Tensor& src , const Tensor& scale, cudaStream_t st)
286+ void DequantizeSymmBlock (Ref< Tensor> out_, Ref< Tensor> src_ , const Tensor& scale, cudaStream_t st)
277287{
278288 using T = fp8_e4m3_t ;
279289 using Tout = bfloat16_t ;
@@ -282,6 +292,9 @@ void DequantizeSymmBlock(Tensor& out, const Tensor& src, const Tensor& scale, cu
282292 constexpr int block_size = 128 ;
283293 constexpr int vec_size = 8 ;
284294
295+ auto & out = out_.get ();
296+ auto & src = src_.get ();
297+
285298 if (!out) {
286299 out = Tensor_<Tout>{src.layout (), kDEVICE };
287300 }
0 commit comments