Skip to content

Commit 87fece2

Browse files
authored
Enable aligned vectorized memory ops for MXFP8 cast (#342)
* Enable aligned vectorized memory ops for MXFP8 cast * Optimized vector sizes and alignment conditions
1 parent e9c7361 commit 87fece2

File tree

8 files changed

+64
-44
lines changed

8 files changed

+64
-44
lines changed

transformer_engine/common/normalization/common.h

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -459,15 +459,20 @@ void rocm_norm_mxfp8_quantize(LaunchParams<ForwardKernelParams> &launch_params)
459459
scale_dim_Y_colwise, SCALE_DIM_Y,
460460
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
461461
launch_params.z_tensor->dtype(), OType,
462-
cast_mxfp8_2D_kernel<false, false, false, Empty, {}, compute_t, OType,
463-
SCALE_DIM_Y, scale_dim_X_rowwise, true><<<grid, block, 0, launch_params.stream>>>(
464-
reinterpret_cast<const compute_t*>(launch_params.params.z),
465-
nullptr,
466-
reinterpret_cast<OType *>(launch_params.z_tensor->data.dptr),
467-
reinterpret_cast<OType *>(launch_params.z_tensor->columnwise_data.dptr),
468-
scales_rowwise_ptr, scales_colwise_ptr,
469-
nullptr, nullptr, nullptr,
470-
rows, cols, scale_stride_rowwise, scale_stride_colwise);););
462+
TRANSFORMER_ENGINE_SWITCH_CONDITION(
463+
!(cols % (32 * sizeof(compute_t))), IS_ALIGNED,
464+
cast_mxfp8_2D_kernel<false, false, false, Empty, {}, compute_t, OType,
465+
SCALE_DIM_Y, scale_dim_X_rowwise, IS_ALIGNED, true><<<grid, block, 0, launch_params.stream>>>(
466+
reinterpret_cast<const compute_t*>(launch_params.params.z),
467+
nullptr,
468+
reinterpret_cast<OType *>(launch_params.z_tensor->data.dptr),
469+
reinterpret_cast<OType *>(launch_params.z_tensor->columnwise_data.dptr),
470+
scales_rowwise_ptr, scales_colwise_ptr,
471+
nullptr, nullptr, nullptr,
472+
rows, cols, scale_stride_rowwise, scale_stride_colwise);
473+
);
474+
);
475+
);
471476
}
472477
#endif
473478

transformer_engine/common/util/cast_gated_kernels.cuh

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -847,8 +847,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
847847
gated_input.dtype(), IType,
848848
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
849849
output->dtype(), OType,
850-
851850
#ifdef __HIP_PLATFORM_AMD__
851+
TRANSFORMER_ENGINE_SWITCH_CONDITION(
852+
!(cols % (32 * sizeof(IType))), IS_ALIGNED,
852853
const IType *tensor_map_grad = IS_DGATED ? reinterpret_cast<const IType *>(grad.data.dptr) : nullptr;
853854
const IType *tensor_map_input_act = reinterpret_cast<const IType *>(gated_input.data.dptr);
854855
const IType *tensor_map_input_gate = reinterpret_cast<const IType *>(gated_input.data.dptr) + cols;
@@ -918,11 +919,19 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
918919

919920
NVTE_CHECK_CUDA(cudaFuncSetAttribute(
920921
(const void*)cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
921-
SCALE_DIM_Y, SCALE_DIM_X>,
922+
SCALE_DIM_Y, SCALE_DIM_X
923+
#ifdef __HIP_PLATFORM_AMD__
924+
, IS_ALIGNED
925+
#endif
926+
>,
922927
cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size));
923928

924929
cast_mxfp8_gated_kernel<IS_DGATED, ParamOP, ActOP, DActOP, IType, OType,
925-
SCALE_DIM_Y, SCALE_DIM_X>
930+
SCALE_DIM_Y, SCALE_DIM_X
931+
#ifdef __HIP_PLATFORM_AMD__
932+
, IS_ALIGNED
933+
#endif
934+
>
926935
<<<grid_dim, block_dim, shmem_size, stream>>>(
927936
tensor_map_grad, tensor_map_input_act, tensor_map_input_gate,
928937
tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise,
@@ -932,6 +941,9 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out
932941
); // NOLINT(*)
933942
); // NOLINT(*)
934943
); // NOLINT(*)
944+
#ifdef __HIP_PLATFORM_AMD__
945+
); // NOLINT(*)
946+
#endif
935947
}
936948

937949
template <typename ParamOP, float (*ActOP)(float, const ParamOP &)>

transformer_engine/common/util/cast_kernels.cuh

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -999,8 +999,10 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
999999
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
10001000
output->dtype(), OType,
10011001
#ifdef __HIP_PLATFORM_AMD__
1002-
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType,
1003-
SCALE_DIM_Y, SCALE_DIM_X><<<grid, block, 0, stream>>>(
1002+
TRANSFORMER_ENGINE_SWITCH_CONDITION(
1003+
!(cols % (32 * sizeof(IType))), IS_ALIGNED,
1004+
cast_mxfp8_2D_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType,
1005+
SCALE_DIM_Y, SCALE_DIM_X, IS_ALIGNED><<<grid, block, 0, stream>>>(
10041006
reinterpret_cast<const IType *>(input.data.dptr),
10051007
(IS_DACT) ? reinterpret_cast<const IType *>(act_input->data.dptr) : nullptr,
10061008
reinterpret_cast<OType *>(output->data.dptr),
@@ -1051,6 +1053,9 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
10511053
); // NOLINT(*)
10521054
); // NOLINT(*)
10531055
); // NOLINT(*)
1056+
#ifdef __HIP_PLATFORM_AMD__
1057+
); // NOLINT(*)
1058+
#endif
10541059
}
10551060

10561061
namespace detail {

transformer_engine/common/util/dequantize_kernels.cuh

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -310,8 +310,10 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
310310
TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(
311311
output->dtype(), OType,
312312
#ifdef __HIP_PLATFORM_AMD__
313-
dequantize_mxfp8_kernel<IType, OType, SCALE_DIM_Y, SCALE_DIM_X>
314-
<<<grid, block, 0, stream>>>(reinterpret_cast<const IType *>(input_data.dptr), reinterpret_cast<OType *>(output->data.dptr), scales_ptr,
313+
TRANSFORMER_ENGINE_SWITCH_CONDITION(
314+
!(cols % (32 * sizeof(OType))), IS_ALIGNED,
315+
dequantize_mxfp8_kernel<IType, OType, SCALE_DIM_Y, SCALE_DIM_X, IS_ALIGNED>
316+
<<<grid, block, 0, stream>>>(reinterpret_cast<const IType *>(input_data.dptr), reinterpret_cast<OType *>(output->data.dptr), scales_ptr,
315317
rows, cols, scales_stride);); // NOLINT(*)
316318
#else // #ifdef __HIP_PLATFORM_AMD__
317319
alignas(64) CUtensorMap tensor_map_input{};
@@ -329,6 +331,9 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s
329331
); // NOLINT(*)
330332
); // NOLINT(*)
331333
); // NOLINT(*)
334+
#ifdef __HIP_PLATFORM_AMD__
335+
); // NOLINT(*)
336+
#endif
332337
}
333338
} // namespace dequantization
334339

transformer_engine/common/util/rocm_cast_gated_kernels.cuh

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ __device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf
4343

4444
template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float, const ParamOP &),
4545
float (*DActOP)(float, const ParamOP &), typename IType, typename OType,
46-
size_t SCALE_DIM_Y, size_t SCALE_DIM_X>
46+
size_t SCALE_DIM_Y, size_t SCALE_DIM_X, bool IS_ALIGNED>
4747
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
4848
cast_mxfp8_gated_kernel(const IType *grad_ptr,
4949
const IType *input_act,
@@ -76,7 +76,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
7676
const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X;
7777
const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X;
7878

79-
constexpr size_t VECTOR_WIDTH = 16 / sizeof(OType);
79+
constexpr size_t VECTOR_WIDTH = (IS_ALIGNED ?: 2) * 8 / sizeof(OType);
8080

8181
const int thread_offset_Y = tid_Y;
8282
const int thread_offset_X = tid_X;
@@ -136,16 +136,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
136136

137137
// Initiate bulk tensor copy
138138
if constexpr (IS_DGATED) {
139-
copy_2d_to_shared<IType, VECTOR_WIDTH, false>(&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y,
139+
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&in_grad_sh[0], grad_ptr, chunk_it_offset_x, chunk_it_offset_y,
140140
cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
141141
}
142142

143143
// Act
144-
copy_2d_to_shared<IType, VECTOR_WIDTH, false>(&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y,
144+
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&in_act_sh[0], input_act, chunk_it_offset_x, chunk_it_offset_y,
145145
2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
146146

147147
// Gate
148-
copy_2d_to_shared<IType, VECTOR_WIDTH, false>(&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y,
148+
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&in_gate_sh[0], input_gate, chunk_it_offset_x, chunk_it_offset_y,
149149
2*cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
150150

151151
__syncthreads();
@@ -347,19 +347,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
347347
__syncthreads();
348348

349349
if constexpr (USE_ROWWISE_SCALING) {
350-
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, false>(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x,
350+
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, IS_ALIGNED>(&out_act_rowwise_sh[0], output_act_rowwise, chunk_it_offset_x,
351351
chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
352352
if constexpr (IS_DGATED) {
353-
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, false>(&out_gate_rowwise_sh[0], output_gate_rowwise, chunk_it_offset_x,
353+
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, IS_ALIGNED>(&out_gate_rowwise_sh[0], output_gate_rowwise, chunk_it_offset_x,
354354
chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
355355
}
356356
}
357357

358358
if constexpr (USE_COLWISE_SCALING) {
359-
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, false>(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x,
359+
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, IS_ALIGNED>(&out_act_colwise_sh[0], output_act_colwise, chunk_it_offset_x,
360360
chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
361361
if constexpr (IS_DGATED) {
362-
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, false>(&out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x,
362+
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, IS_ALIGNED>(&out_gate_colwise_sh[0], output_gate_colwise, chunk_it_offset_x,
363363
chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
364364
}
365365
}

transformer_engine/common/util/rocm_cast_kernels.cuh

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,6 @@ constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1;
2727
constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1;
2828
constexpr size_t MXFP8_CHUNKS_PER_BLOCK = MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X;
2929
constexpr size_t MXFP8_THREADS_PER_CHUNK = 64;
30-
constexpr size_t MXFP8_BUFFERS_NUM = 2;
31-
constexpr size_t MXFP8_PREFETCH_BUFFERS_NUM = 1;
32-
static_assert(MXFP8_PREFETCH_BUFFERS_NUM < MXFP8_BUFFERS_NUM);
3330

3431
constexpr size_t ELEMS_PER_THREAD = 16;
3532
constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported
@@ -45,11 +42,10 @@ constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64
4542
constexpr size_t MXFP8_BUFF_STAGES_NUM =
4643
MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16
4744
constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32
48-
static_assert(MXFP8_ITERATIONS >= MXFP8_PREFETCH_BUFFERS_NUM);
4945

5046
template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
5147
float (*OP)(float, const ParamOP &), typename IType, typename OType, size_t SCALE_DIM_Y,
52-
size_t SCALE_DIM_X, bool IS_NORM = false>
48+
size_t SCALE_DIM_X, bool IS_ALIGNED, bool IS_NORM = false>
5349
__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
5450
cast_mxfp8_2D_kernel(const IType *input_ptr,
5551
const IType *act_input_ptr,
@@ -83,7 +79,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
8379
constexpr size_t THREADS_PER_SCALE_X_ROWWISE =
8480
DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16
8581
constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2
86-
constexpr size_t VECTOR_WIDTH = 16 / sizeof(OType);
82+
constexpr size_t VECTOR_WIDTH = (IS_ALIGNED ?: 2) * 8 / sizeof(OType);
8783

8884
const int block_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y;
8985
const int block_offset_X = blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X;
@@ -161,11 +157,11 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
161157
const int chunk_it_offset_x = chunk_offset_X;
162158
const size_t row_base = chunk_it_offset_y;
163159
if constexpr (IS_DACT) {
164-
copy_2d_to_shared<IType, VECTOR_WIDTH, false>(&act_in_sh[0][0], act_input_ptr,
160+
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&act_in_sh[0][0], act_input_ptr,
165161
chunk_it_offset_x, chunk_it_offset_y, cols,
166162
MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, rows, cols);
167163
}
168-
copy_2d_to_shared<IType, VECTOR_WIDTH, false>(&in_sh[0][0], input_ptr, chunk_it_offset_x,
164+
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&in_sh[0][0], input_ptr, chunk_it_offset_x,
169165
chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y,
170166
MXFP8_SHMEM_DIM_X, rows, cols);
171167
__syncthreads();
@@ -301,12 +297,12 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK)
301297
__syncthreads();
302298

303299
if constexpr (USE_ROWWISE_SCALING) {
304-
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, false>(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x,
300+
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, IS_ALIGNED>(&out_rowwise_sh[0][0], output_rowwise, chunk_it_offset_x,
305301
chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y,
306302
MXFP8_SHMEM_DIM_X, rows, cols);
307303
}
308304
if constexpr (USE_COLWISE_SCALING) {
309-
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, false>(&out_colwise_sh[0][0], output_colwise, chunk_it_offset_x,
305+
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, IS_ALIGNED>(&out_colwise_sh[0][0], output_colwise, chunk_it_offset_x,
310306
chunk_it_offset_y, cols, MXFP8_SHMEM_DIM_Y,
311307
MXFP8_SHMEM_DIM_X, rows, cols);
312308
}

transformer_engine/common/util/rocm_dequantize_kernels.cuh

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ constexpr size_t THREADS_PER_CHUNK_X_COLWISE = CHUNK_DIM_X;
4242
constexpr size_t ITERATIONS = CHUNK_DIM_Y / BUFFER_DIM_Y; // 8 = 128 / 16
4343
static_assert(ITERATIONS >= 1);
4444

45-
template <typename IType, typename OType, size_t SCALE_DIM_Y, size_t SCALE_DIM_X>
45+
template <typename IType, typename OType, size_t SCALE_DIM_Y, size_t SCALE_DIM_X, bool IS_ALIGNED>
4646
__global__ void __launch_bounds__(THREADS_PER_CHUNK)
4747
dequantize_mxfp8_kernel(const IType *input_ptr,
4848
OType *output_ptr,
@@ -59,7 +59,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
5959

6060
constexpr size_t THREADS_PER_SCALE_X_ROWWISE =
6161
DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16
62-
constexpr size_t VECTOR_WIDTH = 16 / sizeof(OType);
62+
constexpr size_t VECTOR_WIDTH = (IS_ALIGNED ?: 2) * 8 / sizeof(IType);
6363

6464
const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y;
6565
const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X;
@@ -86,7 +86,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
8686
const int chunk_it_offset_y = chunk_offset_Y + iter * BUFFER_DIM_Y;
8787
const int chunk_it_offset_x = chunk_offset_X;
8888

89-
copy_2d_to_shared<IType, VECTOR_WIDTH, false>(&in_sh[0][0], input_ptr, chunk_it_offset_x,
89+
copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED>(&in_sh[0][0], input_ptr, chunk_it_offset_x,
9090
chunk_it_offset_y, cols, SHMEM_DIM_Y,
9191
SHMEM_DIM_X, rows, cols);
9292
__syncthreads();
@@ -127,7 +127,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
127127

128128
__syncthreads();
129129

130-
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, false>(&out_sh[0][0], output_ptr, chunk_it_offset_x,
130+
bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, IS_ALIGNED>(&out_sh[0][0], output_ptr, chunk_it_offset_x,
131131
chunk_it_offset_y, cols, SHMEM_DIM_Y,
132132
SHMEM_DIM_X, rows, cols);
133133

transformer_engine/common/util/rocm_vectorized_2d.cuh

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,11 @@
1010

1111
namespace transformer_engine {
1212
// These 2d copy functions replace TMA tensormap async copies for AMD GPUs.
13-
template <typename T, int N_VEC, bool aligned = false>
13+
template <typename T, int N_VEC, bool ALIGNED_ACCESS>
1414
__device__ inline void copy_2d_to_shared(T *sh_ptr_base, const T *g_ptr, size_t g_start_col,
1515
size_t g_start_row, size_t g_stride, size_t chunk_dim_y,
1616
size_t chunk_dim_x, size_t total_rows,
1717
size_t total_cols) {
18-
// TODO: Manage edge cases where "aligned = true" causes into issues
19-
constexpr bool ALIGNED_ACCESS = aligned;
2018
size_t chunk_dim_x_vec_elements = (chunk_dim_x + N_VEC - 1) / N_VEC;
2119
const size_t l_idx = threadIdx.x;
2220

@@ -51,12 +49,11 @@ __device__ inline void copy_2d_to_shared(T *sh_ptr_base, const T *g_ptr, size_t
5149
}
5250
}
5351

54-
template <typename T, int N_VEC, bool aligned = false>
52+
template <typename T, int N_VEC, bool ALIGNED_ACCESS>
5553
__device__ inline void bulk_tensor_2d_shared_to_global(const T *sh_ptr_base, T *g_ptr, size_t g_start_col,
5654
size_t g_start_row, size_t g_stride, size_t chunk_dim_y,
5755
size_t chunk_dim_x, size_t total_rows,
5856
size_t total_cols) {
59-
constexpr bool ALIGNED_ACCESS = aligned;
6057
const size_t chunk_dim_x_vec_elements = (chunk_dim_x + N_VEC - 1) / N_VEC;
6158
const size_t l_idx = threadIdx.x;
6259

0 commit comments

Comments
 (0)