@@ -9,6 +9,7 @@ namespace at {
99 // Forward declaration of the CUDA kernels.
1010 template <typename T, typename SizeT>
1111 __global__ void LayerNormForwardCUDAKernel (SizeT N,
12+ SizeT axis_size,
1213 float eps,
1314 const T* X,
1415 const T* gamma,
@@ -30,13 +31,15 @@ namespace ctranslate2 {
3031 const dim_t axis,
3132 const dim_t outer_size,
3233 const dim_t axis_size,
33- const dim_t ,
34+ const dim_t inner_size,
35+ const bool multi_axis,
3436 StorageView& output) const {
35- if (axis != input.rank () - 1 || !beta || !gamma)
37+ if (!multi_axis && axis != input.rank () - 1 || !beta || !gamma)
3638 throw std::invalid_argument (" Generalized LayerNorm is currently not implemented on GPU" );
3739
3840 at::native::LayerNormForwardCUDAKernel<cuda::device_type<T>, cuda::index_t >
3941 <<<outer_size, CUDA_NUM_THREADS, 0 , cuda::get_cuda_stream()>>> (
42+ inner_size * axis_size,
4043 axis_size,
4144 _epsilon,
4245 cuda::device_cast (input.data <T>()),
@@ -54,6 +57,7 @@ namespace ctranslate2 {
5457 const dim_t outer_size, \
5558 const dim_t axis_size, \
5659 const dim_t inner_size, \
60+ const bool multi_axis, \
5761 StorageView& output) const ;
5862
5963 DECLARE_IMPL (float )
@@ -147,6 +151,7 @@ namespace at {
147151
148152 template <typename T, typename SizeT>
149153 __global__ void LayerNormForwardCUDAKernel (SizeT N,
154+ SizeT axis_size,
150155 float eps,
151156 const T* X,
152157 const T* gamma,
@@ -179,11 +184,13 @@ namespace at {
179184
180185 __syncthreads ();
181186
182- for (SizeT j = threadIdx .x ; j < N; j += blockDim .x ) {
183- const SizeT index = i * N + j;
184- Y[index] = (float (X[index]) - s_mean) * s_variance * float (gamma[j]) + float (beta[j]);
187+ SizeT inner_dim = N / axis_size;
188+ for (SizeT j = 0 ; j < inner_dim; j++) {
189+ for (SizeT k = threadIdx .x ; k < axis_size; k += blockDim .x ) {
190+ const SizeT index = i * N + k * inner_dim + j;
191+ Y[index] = (float (X[index]) - s_mean) * s_variance * float (gamma[k]) + float (beta[k]);
192+ }
185193 }
186194 }
187-
188195 }
189196}
0 commit comments