Skip to content

Commit 6c97db4

Browse files
committed
Add multi-axis layernorm cuda implementation
1 parent 9506b57 commit 6c97db4

File tree

1 file changed

+13
-6
lines changed

1 file changed

+13
-6
lines changed

src/ops/layer_norm_gpu.cu

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ namespace at {
99
// Forward declaration of the CUDA kernels.
1010
template <typename T, typename SizeT>
1111
__global__ void LayerNormForwardCUDAKernel(SizeT N,
12+
SizeT axis_size,
1213
float eps,
1314
const T* X,
1415
const T* gamma,
@@ -30,13 +31,15 @@ namespace ctranslate2 {
3031
const dim_t axis,
3132
const dim_t outer_size,
3233
const dim_t axis_size,
33-
const dim_t,
34+
const dim_t inner_size,
35+
const bool multi_axis,
3436
StorageView& output) const {
35-
if (axis != input.rank() - 1 || !beta || !gamma)
37+
if (!multi_axis && axis != input.rank() - 1 || !beta || !gamma)
3638
throw std::invalid_argument("Generalized LayerNorm is currently not implemented on GPU");
3739

3840
at::native::LayerNormForwardCUDAKernel<cuda::device_type<T>, cuda::index_t>
3941
<<<outer_size, CUDA_NUM_THREADS, 0, cuda::get_cuda_stream()>>>(
42+
inner_size * axis_size,
4043
axis_size,
4144
_epsilon,
4245
cuda::device_cast(input.data<T>()),
@@ -54,6 +57,7 @@ namespace ctranslate2 {
5457
const dim_t outer_size, \
5558
const dim_t axis_size, \
5659
const dim_t inner_size, \
60+
const bool multi_axis, \
5761
StorageView& output) const;
5862

5963
DECLARE_IMPL(float)
@@ -147,6 +151,7 @@ namespace at {
147151

148152
template <typename T, typename SizeT>
149153
__global__ void LayerNormForwardCUDAKernel(SizeT N,
154+
SizeT axis_size,
150155
float eps,
151156
const T* X,
152157
const T* gamma,
@@ -179,11 +184,13 @@ namespace at {
179184

180185
__syncthreads();
181186

182-
for (SizeT j = threadIdx.x; j < N; j += blockDim.x) {
183-
const SizeT index = i * N + j;
184-
Y[index] = (float(X[index]) - s_mean) * s_variance * float(gamma[j]) + float(beta[j]);
187+
SizeT inner_dim = N / axis_size;
188+
for (SizeT j = 0; j < inner_dim; j++) {
189+
for (SizeT k = threadIdx.x; k < axis_size; k += blockDim.x) {
190+
const SizeT index = i * N + k * inner_dim + j;
191+
Y[index] = (float(X[index]) - s_mean) * s_variance * float(gamma[k]) + float(beta[k]);
192+
}
185193
}
186194
}
187-
188195
}
189196
}

0 commit comments

Comments
 (0)