Skip to content

Commit 3913f87

Browse files
authored
ggml : fix padding in timestep embedding kernels (#15932)
* ggml : remove adding extra dim timestep embedding This commit updates the ggml_timestep_embedding function to no longer add an extra dimension when the specified dimension is odd. The motivation for this change is that this introduces an unnecessary dimension when the dimension is odd, which caused an issue in the kernels which were not expecting this extra dimension and it resulted in uninitialized memory for the second to last dimension. * ggml-cuda : fix padding in timestep embedding kernel This commit removes the zeroing out of the last dimension now that we are not adding the extra padding dimension. * ggml-metal : fix padding in timestep embedding kernel This commit fixes the zero padding for odd dimensions in the timestep embedding kernel * ggml-opencl : fix padding in timestep embedding kernel This commit fixes the zero padding for odd dimensions in the timestep embedding kernel. * ggml-sycl : fix padding in timestep embedding kernel This commit fixes the zero padding for odd dimensions in the timestep embedding kernel. * ggml-vulkan : fix padding in timestep embedding kernel This commit fixes the zero padding for odd dimensions in the timestep embedding kernel. * ggml-cpu : fix padding in timestep embedding function This commit removes the zeroing out of the last dimension now that we are not adding the extra padding dimension.
1 parent 76888d2 commit 3913f87

File tree

7 files changed

+15
-18
lines changed

7 files changed

+15
-18
lines changed

ggml/src/ggml-cpu/ops.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8599,7 +8599,6 @@ static void ggml_compute_forward_timestep_embedding_f32(
85998599
}
86008600
if (dim % 2 != 0 && ith == 0) {
86018601
embed_data[2 * half] = 0.f;
8602-
embed_data[dim] = 0.f;
86038602
}
86048603
}
86058604
}

ggml/src/ggml-cuda/tsembd.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,11 @@ static __global__ void timestep_embedding_f32(const float * timesteps, float * d
77
int j = threadIdx.x + blockIdx.x * blockDim.x;
88
float * embed_data = (float *)((char *)dst + i*nb1);
99

10-
if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
11-
embed_data[dim] = 0.f;
10+
int half = dim / 2;
11+
if (dim % 2 != 0 && j == half) {
12+
embed_data[2 * half] = 0.f;
1213
}
1314

14-
int half = dim / 2;
1515
if (j >= half) {
1616
return;
1717
}

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4167,7 +4167,7 @@ kernel void kernel_timestep_embedding_f32(
41674167
}
41684168

41694169
if (args.dim % 2 != 0 && tpitg.x == 0) {
4170-
embed_data[args.dim] = 0.f;
4170+
embed_data[2 * half_] = 0.f;
41714171
}
41724172
}
41734173

ggml/src/ggml-opencl/kernels/tsembd.cl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ kernel void kernel_timestep_embedding(
2626
local_half_dim = logical_dim / 2;
2727
local_embed_data_ptr = (global float *)((global char *)local_dst_output_base_ptr + local_i * dst_nb1_bytes);
2828

29-
if (logical_dim % 2 != 0 && local_j == ((logical_dim + 1) / 2)) {
30-
local_embed_data_ptr[logical_dim] = 0.0f;
29+
if (logical_dim % 2 != 0 && local_j == local_half_dim) {
30+
local_embed_data_ptr[2 * local_half_dim] = 0.0f;
3131
}
3232

3333
if (local_j >= local_half_dim) {

ggml/src/ggml-sycl/tsembd.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,12 @@ static void timestep_embedding_f32(
2121
int j = item_ct1.get_local_id(2) + item_ct1.get_group(2) * item_ct1.get_local_range(2);
2222
float * embed_data = (float *)((char *)dst + i*nb1);
2323

24-
if (dim % 2 != 0 && j == ((dim + 1) / 2)) {
25-
embed_data[dim] = 0.f;
24+
int half = dim / 2;
25+
26+
if (dim % 2 != 0 && j == half) {
27+
embed_data[2 * half] = 0.f;
2628
}
2729

28-
int half = dim / 2;
2930
if (j >= half) {
3031
return;
3132
}

ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,12 @@ void main() {
2424
const uint j = gl_GlobalInvocationID.x;
2525
const uint d_offset = i * p.nb1;
2626

27-
if (p.dim % 2 != 0 && j == ((p.dim + 1) / 2)) {
28-
data_d[d_offset + p.dim] = 0.f;
27+
const uint half_dim = p.dim / 2;
28+
29+
if (p.dim % 2 != 0 && j == half_dim) {
30+
data_d[d_offset + 2 * half_dim] = 0.f;
2931
}
3032

31-
const uint half_dim = p.dim / 2;
3233
if (j >= half_dim) {
3334
return;
3435
}

ggml/src/ggml.c

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4923,12 +4923,8 @@ struct ggml_tensor * ggml_timestep_embedding(
49234923
struct ggml_tensor * timesteps,
49244924
int dim,
49254925
int max_period) {
4926-
int actual_dim = dim;
4927-
if (dim % 2 != 0) {
4928-
actual_dim = dim + 1;
4929-
}
49304926

4931-
struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, actual_dim, timesteps->ne[0]);
4927+
struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, dim, timesteps->ne[0]);
49324928

49334929
ggml_set_op_params_i32(result, 0, dim);
49344930
ggml_set_op_params_i32(result, 1, max_period);

0 commit comments

Comments
 (0)