Skip to content

Commit fb826c6

Browse files
committed
ggml-cuda : fix padding in timestep embedding kernel
This commit fixes the zero padding for odd dimensions in the timestep embedding kernel similar to the fix that was applied to the cpu backend in Commit 9de447d ("ggml-cpu : fix padding in ggml_timestep_embedding (#15917)"). The motivation for this is that currently if an odd dimension is used, the padding check incorrectly uses the dimension value for indexing. For example, with dim=15: Elements 0-6 are set to cosine values Elements 7-13 are set to sine values Element 14 is left uninitialized (contains garbage) Element 15 is correctly set to zero This fix changes embed_data[dim] to embed_data[2 * half] so that element 14 (the first unused element) is properly set to zero as well as the last element.
1 parent 51abc96 commit fb826c6

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

ggml/src/ggml-cuda/tsembd.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,12 @@ static __global__ void timestep_embedding_f32(const float * timesteps, float * d
77
int j = threadIdx.x + blockIdx.x * blockDim.x;
88
float * embed_data = (float *)((char *)dst + i*nb1);
99

10+
int half = dim / 2;
1011
if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
12+
embed_data[2 * half] = 0.f;
1113
embed_data[dim] = 0.f;
1214
}
1315

14-
int half = dim / 2;
1516
if (j >= half) {
1617
return;
1718
}

0 commit comments

Comments
 (0)