@@ -43,7 +43,7 @@ __device__ inline float sigmoidf(const float x) { return __frcp_rn(1.0f + __expf
4343
4444template <bool IS_DGATED, typename ParamOP, float (*ActOP)(float , const ParamOP &),
4545 float (*DActOP)(float , const ParamOP &), typename IType, typename OType,
46- size_t SCALE_DIM_Y, size_t SCALE_DIM_X>
46+ size_t SCALE_DIM_Y, size_t SCALE_DIM_X, bool IS_ALIGNED >
4747__global__ void __launch_bounds__ (THREADS_PER_CHUNK)
4848 cast_mxfp8_gated_kernel(const IType *grad_ptr,
4949 const IType *input_act,
@@ -76,7 +76,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
7676 const int tid_Y = threadIdx .x / THREADS_PER_CHUNK_X;
7777 const int tid_X = threadIdx .x % THREADS_PER_CHUNK_X;
7878
79- constexpr size_t VECTOR_WIDTH = 16 / sizeof (OType);
79+ constexpr size_t VECTOR_WIDTH = (IS_ALIGNED ?: 2 ) * 8 / sizeof (OType);
8080
8181 const int thread_offset_Y = tid_Y;
8282 const int thread_offset_X = tid_X;
@@ -136,16 +136,16 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
136136
137137 // Initiate bulk tensor copy
138138 if constexpr (IS_DGATED) {
139- copy_2d_to_shared<IType, VECTOR_WIDTH, false >(&in_grad_sh[0 ], grad_ptr, chunk_it_offset_x, chunk_it_offset_y,
139+ copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED >(&in_grad_sh[0 ], grad_ptr, chunk_it_offset_x, chunk_it_offset_y,
140140 cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
141141 }
142142
143143 // Act
144- copy_2d_to_shared<IType, VECTOR_WIDTH, false >(&in_act_sh[0 ], input_act, chunk_it_offset_x, chunk_it_offset_y,
144+ copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED >(&in_act_sh[0 ], input_act, chunk_it_offset_x, chunk_it_offset_y,
145145 2 *cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
146146
147147 // Gate
148- copy_2d_to_shared<IType, VECTOR_WIDTH, false >(&in_gate_sh[0 ], input_gate, chunk_it_offset_x, chunk_it_offset_y,
148+ copy_2d_to_shared<IType, VECTOR_WIDTH, IS_ALIGNED >(&in_gate_sh[0 ], input_gate, chunk_it_offset_x, chunk_it_offset_y,
149149 2 *cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
150150
151151 __syncthreads ();
@@ -347,19 +347,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK)
347347 __syncthreads ();
348348
349349 if constexpr (USE_ROWWISE_SCALING) {
350- bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, false >(&out_act_rowwise_sh[0 ], output_act_rowwise, chunk_it_offset_x,
350+ bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, IS_ALIGNED >(&out_act_rowwise_sh[0 ], output_act_rowwise, chunk_it_offset_x,
351351 chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
352352 if constexpr (IS_DGATED) {
353- bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, false >(&out_gate_rowwise_sh[0 ], output_gate_rowwise, chunk_it_offset_x,
353+ bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, IS_ALIGNED >(&out_gate_rowwise_sh[0 ], output_gate_rowwise, chunk_it_offset_x,
354354 chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
355355 }
356356 }
357357
358358 if constexpr (USE_COLWISE_SCALING) {
359- bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, false >(&out_act_colwise_sh[0 ], output_act_colwise, chunk_it_offset_x,
359+ bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, IS_ALIGNED >(&out_act_colwise_sh[0 ], output_act_colwise, chunk_it_offset_x,
360360 chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
361361 if constexpr (IS_DGATED) {
362- bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, false >(&out_gate_colwise_sh[0 ], output_gate_colwise, chunk_it_offset_x,
362+ bulk_tensor_2d_shared_to_global<OType, VECTOR_WIDTH, IS_ALIGNED >(&out_gate_colwise_sh[0 ], output_gate_colwise, chunk_it_offset_x,
363363 chunk_it_offset_y, output_cols, SHMEM_DIM_Y, SHMEM_DIM_X, rows, cols);
364364 }
365365 }
0 commit comments