-
Notifications
You must be signed in to change notification settings - Fork 13.5k
ggml : fix padding in timestep embedding kernels #15932
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ggml : fix padding in timestep embedding kernels #15932
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we should add a comment to the declaration of this function to clarify that it will pad the tensor if it has odd number of dimensions. So here:
Lines 2123 to 2130 in 28b5f19
| // Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151 | |
| // timesteps: [N,] | |
| // return: [N, dim] | |
| GGML_API struct ggml_tensor * ggml_timestep_embedding( | |
| struct ggml_context * ctx, | |
| struct ggml_tensor * timesteps, | |
| int dim, | |
| int max_period); |
Change to:
// return: [N, PAD(dim, 2)]Btw, not sure what was the reasoning for this padding logic. It does not seem to be present in the reference Python implementation. Maybe @leejet can clarify?
|
I forgot to add |
It seems that padding is incorrect behavior. I don’t really remember why I added it in the first place—it was too long ago. Given the current situation, I suggest removing the padding. |
|
@leejet Thanks for the feedback! |
|
@ggerganov @leejet I've taken a closer look at this and I think that the padding might be required after all. Looking at the reference python implementation we have the following: def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
if not repeat_only:
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
else:
embedding = repeat(timesteps, 'b -> b d', d=dim)
return embedding Specifically these lines: if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
This padding is necessary because sinusoidal embeddings produce pairs of cosine/sine values. For dim=15, you get 7 complete pairs (14 values) plus 1 padding zero to reach the target dimension. Without this padding, the GGML implementation leaves uninitialized memory that can contain garbage values, which is what we observed in the test failures. I believe the padding in GGML is required to match this reference behavior. |
004eaba to
d304f6b
Compare
|
@danbev I think the problem is that for input Lines 4919 to 4932 in 261e6a2
|
|
@ggerganov Oh right, I misunderstood the comment about removing the padding to mean in the kernels. It indeed looks like adding an extra dimension here would not be needed. I'll update this and test it. |
This commit updates the ggml_timestep_embedding function to no longer add an extra dimension when the specified dimension is odd. The motivation for this change is that this introduces an unnecessary dimension when the dimension is odd, which caused an issue in the kernels which were not expecting this extra dimension and it resulted in uninitialized memory for the second to last dimension.
This commit removes the zeroing out of the last dimension now that we are not adding the extra padding dimension.
This commit fixes the zero padding for odd dimensions in the timestep embedding kernel
This commit fixes the zero padding for odd dimensions in the timestep embedding kernel.
This commit fixes the zero padding for odd dimensions in the timestep embedding kernel.
This commit fixes the zero padding for odd dimensions in the timestep embedding kernel.
This commit removes the zeroing out of the last dimension now that we are not adding the extra padding dimension.
d304f6b to
aa2c30c
Compare
* ggml : remove adding extra dim timestep embedding This commit updates the ggml_timestep_embedding function to no longer add an extra dimension when the specified dimension is odd. The motivation for this change is that this introduces an unnecessary dimension when the dimension is odd, which caused an issue in the kernels which were not expecting this extra dimension and it resulted in uninitialized memory for the second to last dimension. * ggml-cuda : fix padding in timestep embedding kernel This commit removes the zeroing out of the last dimension now that we are not adding the extra padding dimension. * ggml-metal : fix padding in timestep embedding kernel This commit fixes the zero padding for odd dimensions in the timestep embedding kernel * ggml-opencl : fix padding in timestep embedding kernel This commit fixes the zero padding for odd dimensions in the timestep embedding kernel. * ggml-sycl : fix padding in timestep embedding kernel This commit fixes the zero padding for odd dimensions in the timestep embedding kernel. * ggml-vulkan : fix padding in timestep embedding kernel This commit fixes the zero padding for odd dimensions in the timestep embedding kernel. * ggml-cpu : fix padding in timestep embedding function This commit removes the zeroing out of the last dimension now that we are not adding the extra padding dimension.
This commit applies the same changes as were applied in #15917 to the cuda, metal, opencl, sycl, and vulkan backends.
Refs: #15917