Skip to content

Conversation

@danbev
Copy link
Member

@danbev danbev commented Sep 11, 2025

This commit applies the same changes as were applied in #15917 to the cuda, metal, opencl, sycl, and vulkan backends.

Refs: #15917

@danbev danbev requested a review from 0cc4m as a code owner September 11, 2025 04:25
@github-actions github-actions bot added Nvidia GPU Issues specific to Nvidia GPUs Vulkan Issues specific to the Vulkan backend ggml changes relating to the ggml tensor library for machine learning SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language Apple Metal https://en.wikipedia.org/wiki/Metal_(API) OpenCL Issues specific to the OpenCL backend labels Sep 11, 2025
Copy link
Member

@ggerganov ggerganov left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should add a comment to the declaration of this function to clarify that it will pad the tensor if it has odd number of dimensions. So here:

llama.cpp/ggml/include/ggml.h

Lines 2123 to 2130 in 28b5f19

// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
// timesteps: [N,]
// return: [N, dim]
GGML_API struct ggml_tensor * ggml_timestep_embedding(
struct ggml_context * ctx,
struct ggml_tensor * timesteps,
int dim,
int max_period);

Change to:

// return: [N, PAD(dim, 2)]

Btw, not sure what was the reasoning for this padding logic. It does not seem to be present in the reference Python implementation. Maybe @leejet can clarify?

@danbev
Copy link
Member Author

danbev commented Sep 11, 2025

I forgot to add [no ci] so I manually cancelled the workflows, hence the "Some checks haven't completed yet".

@leejet
Copy link
Contributor

leejet commented Sep 11, 2025

I think we should add a comment to the declaration of this function to clarify that it will pad the tensor if it has odd number of dimensions. So here:

llama.cpp/ggml/include/ggml.h

Lines 2123 to 2130 in 28b5f19

// Ref: https://github.com/CompVis/stable-diffusion/blob/main/ldm/modules/diffusionmodules/util.py#L151
// timesteps: [N,]
// return: [N, dim]
GGML_API struct ggml_tensor * ggml_timestep_embedding(
struct ggml_context * ctx,
struct ggml_tensor * timesteps,
int dim,
int max_period);

Change to:

// return: [N, PAD(dim, 2)]

Btw, not sure what was the reasoning for this padding logic. It does not seem to be present in the reference Python implementation. Maybe @leejet can clarify?

It seems that padding is incorrect behavior. I don’t really remember why I added it in the first place—it was too long ago. Given the current situation, I suggest removing the padding.

@danbev
Copy link
Member Author

danbev commented Sep 12, 2025

@leejet Thanks for the feedback!
I'll take a closer look at removing the padding next week.

@danbev
Copy link
Member Author

danbev commented Sep 15, 2025

@ggerganov @leejet I've taken a closer look at this and I think that the padding might be required after all.

Looking at the reference python implementation we have the following:

def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):    
    if not repeat_only:                                                            
        half = dim // 2                                                         
        freqs = torch.exp(                                                      
            -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
        ).to(device=timesteps.device)                                           
        args = timesteps[:, None].float() * freqs[None]                         
        embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)       
        if dim % 2:                                                             
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
    else:                                                                       
        embedding = repeat(timesteps, 'b -> b d', d=dim)                        
    return embedding                                                            

Specifically these lines:

        if dim % 2:                                                             
            embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)

if dim % 2 checks if the dimension is odd, and if so it adds an extra column of zeros to make the final embedding have the correct dimension.

This padding is necessary because sinusoidal embeddings produce pairs of cosine/sine values. For dim=15, you get 7 complete pairs (14 values) plus 1 padding zero to reach the target dimension. Without this padding, the GGML implementation leaves uninitialized memory that can contain garbage values, which is what we observed in the test failures. I believe the padding in GGML is required to match this reference behavior.

@danbev danbev force-pushed the ggml-timestep-embedding-backends branch from 004eaba to d304f6b Compare September 16, 2025 04:20
@ggerganov
Copy link
Member

@danbev I think the problem is that for input dim=15 the Python code produces output dimension of 15, while ggml produces 16. At least this is what I understand from reading the implementation in ggml.c:

llama.cpp/ggml/src/ggml.c

Lines 4919 to 4932 in 261e6a2

// ggml_timestep_embedding
struct ggml_tensor * ggml_timestep_embedding(
struct ggml_context * ctx,
struct ggml_tensor * timesteps,
int dim,
int max_period) {
int actual_dim = dim;
if (dim % 2 != 0) {
actual_dim = dim + 1;
}
struct ggml_tensor * result = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, actual_dim, timesteps->ne[0]);

@danbev
Copy link
Member Author

danbev commented Sep 16, 2025

@ggerganov Oh right, I misunderstood the comment about removing the padding to mean in the kernels. It indeed looks like adding an extra dimension here would not be needed. I'll update this and test it.

This commit updates the ggml_timestep_embedding function to no longer
add an extra dimension when the specified dimension is odd.

The motivation for this change is that this introduces an unnecessary
dimension when the dimension is odd, which caused an issue in the
kernels which were not expecting this extra dimension and it resulted in
uninitialized memory for the second to last dimension.
This commit removes the zeroing out of the last dimension now that we
are not adding the extra padding dimension.
This commit fixes the zero padding for odd dimensions in
the timestep embedding kernel
This commit fixes the zero padding for odd dimensions in
the timestep embedding kernel.
This commit fixes the zero padding for odd dimensions in
the timestep embedding kernel.
This commit fixes the zero padding for odd dimensions in
the timestep embedding kernel.
This commit removes the zeroing out of the last dimension now that we
are not adding the extra padding dimension.
@danbev danbev force-pushed the ggml-timestep-embedding-backends branch from d304f6b to aa2c30c Compare September 16, 2025 09:42
@danbev danbev requested review from ggerganov and removed request for 0cc4m September 16, 2025 09:43
@danbev danbev merged commit 3913f87 into ggml-org:master Sep 16, 2025
47 of 48 checks passed
angt pushed a commit to angt/llama.cpp that referenced this pull request Sep 17, 2025
* ggml : remove adding extra dim timestep embedding

This commit updates the ggml_timestep_embedding function to no longer
add an extra dimension when the specified dimension is odd.

The motivation for this change is that this introduces an unnecessary
dimension when the dimension is odd, which caused an issue in the
kernels which were not expecting this extra dimension and it resulted in
uninitialized memory for the second to last dimension.

* ggml-cuda : fix padding in timestep embedding kernel

This commit removes the zeroing out of the last dimension now that we
are not adding the extra padding dimension.

* ggml-metal : fix padding in timestep embedding kernel

This commit fixes the zero padding for odd dimensions in
the timestep embedding kernel

* ggml-opencl : fix padding in timestep embedding kernel

This commit fixes the zero padding for odd dimensions in
the timestep embedding kernel.

* ggml-sycl : fix padding in timestep embedding kernel

This commit fixes the zero padding for odd dimensions in
the timestep embedding kernel.

* ggml-vulkan : fix padding in timestep embedding kernel

This commit fixes the zero padding for odd dimensions in
the timestep embedding kernel.

* ggml-cpu : fix padding in timestep embedding function

This commit removes the zeroing out of the last dimension now that we
are not adding the extra padding dimension.
@danbev danbev deleted the ggml-timestep-embedding-backends branch September 24, 2025 06:19
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Apple Metal https://en.wikipedia.org/wiki/Metal_(API) ggml changes relating to the ggml tensor library for machine learning Nvidia GPU Issues specific to Nvidia GPUs OpenCL Issues specific to the OpenCL backend SYCL https://en.wikipedia.org/wiki/SYCL - GPU programming language Vulkan Issues specific to the Vulkan backend

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants