Skip to content

Comments

metal: implement cross-entropy and count-equal ops for MNIST#1390

Open
xi-guo-0 wants to merge 1 commit intoggml-org:masterfrom
xi-guo-0:feat/metal-cross-entropy
Open

metal: implement cross-entropy and count-equal ops for MNIST#1390
xi-guo-0 wants to merge 1 commit intoggml-org:masterfrom
xi-guo-0:feat/metal-cross-entropy

Conversation

@xi-guo-0
Copy link

This PR implements the missing Metal backend operations required to run examples/mnist

Copy link
Contributor

@iliailmer iliailmer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey!

I am working on a similar PR, I was pointed towards your work, I wanted to ask a couple of questions to understand our codes' differences/similarities better. Maybe we can join forces and add cross entropy together. I'm relatively new to Metal dev work.

I left a couple of comments.

const ggml_tensor * src1 = op->src[1];

GGML_TENSOR_LOCALS( int32_t, ne0, src0, ne);
GGML_TENSOR_LOCALS(uint64_t, nb0, src0, nb);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, why not use int32_t here directly instead of casting later?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the review!

You are absolutely right. ne00 was already int32_t from the macro, so promoting it to int64_t just to cast it back was redundant. I've cleaned it up

template<typename T>
kernel void kernel_cross_entropy_loss(
constant ggml_metal_kargs_cross_entropy_loss & args,
device const char * logits_ptr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why use char instead of accepting float type argument for logits?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm sticking with char* to support the template type-erasure pattern used in ggml-metal. Since I'm using decltype to enforce a single function signature across all template specializations (f32, f16, i32), the arguments must match exactly. device const char* allows me to bind the same kernel signature regardless of the underlying data type.

@xi-guo-0 xi-guo-0 force-pushed the feat/metal-cross-entropy branch from 5c2b8d5 to 3f83fa0 Compare December 20, 2025 06:46
Copy link
Contributor

@iliailmer iliailmer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your responses! I left a couple more clarifications and questions.

if (tgpig.x == 0 && sgitg == 0 && tiisg == 0) {
float total_loss = 0.0f;
for (int i = 0; i < nrows; i++) {
total_loss += dst[i];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if i am not mistaken this part tries to collect the loss from all positions of dst array on a 0 thread, but I couldn't find where we are writing into dst? also, would this line cause race condition where 0th thread is summing incomplete locations?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch. You are absolutely right—the previous logic caused a race condition and failed to write to dst correctly. I have removed the global accumulation loop. The kernel now computes the loss for each row and writes it directly to dst[tgpig.x] (row index).

char base[256];
char name[256];

snprintf(base, 256, "kernel_cross_entropy_loss_back_%s", ggml_type_name(op->src[0]->type));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
snprintf(base, 256, "kernel_cross_entropy_loss_back_%s", ggml_type_name(op->src[0]->type));
snprintf(base, 256, "kernel_cross_entropy_loss_back_%s", ggml_type_name(op->src[1]->type));

just replacing the grad type with logits type

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. I switched it to op->src[1]->type to correctly use the logits type.


template [[host_name("kernel_cross_entropy_loss_f32")]] kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss<float>;
template [[host_name("kernel_cross_entropy_loss_f16")]] kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss<half>;
template [[host_name("kernel_cross_entropy_loss_i32")]] kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss<int>;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you help me understand what is the use case for i32 here? Would that mean we are casting the logits_ptr into int *? I think it should only be f32 and f16 since logits are always floats

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right. I have removed the i32 and i16 templates since logits are always floating-point numbers (f32 or f16).

x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid];
}

typedef decltype(kernel_cross_entropy_loss<int>) kernel_cross_entropy_loss_t;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
typedef decltype(kernel_cross_entropy_loss<int>) kernel_cross_entropy_loss_t;
typedef decltype(kernel_cross_entropy_loss<float>) kernel_cross_entropy_loss_t;

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed. I updated the typedefs to use as the base type for both the forward and backward kernels.

template [[host_name("kernel_cross_entropy_loss_i32")]] kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss<int>;
template [[host_name("kernel_cross_entropy_loss_i16")]] kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss<short>;

typedef decltype(kernel_cross_entropy_loss_back<int>) kernel_cross_entropy_loss_back_t;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
typedef decltype(kernel_cross_entropy_loss_back<int>) kernel_cross_entropy_loss_back_t;
typedef decltype(kernel_cross_entropy_loss_back<float>) kernel_cross_entropy_loss_back_t;

@ggerganov
Copy link
Member

Note that the count_equal op has just been merged.

@xi-guo-0 xi-guo-0 force-pushed the feat/metal-cross-entropy branch from 3f83fa0 to d7beb52 Compare January 1, 2026 08:28
}

if (sgitg == 0 && tiisg == 0) {
dst[tgpig.x] = -row_loss;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this is still out of bounds potentially

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing that out. I've added a boundary check (if (tgpig.x >= nrows) return;) at the beginning of the kernel. This ensures that any extra threadgroups dispatched beyond the number of rows will return early, preventing the out-of-bounds access on dst.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think this check won't work because tgpig.x is at between 0 and nrows-1. i'm trying to see what can be done to help fix this, might need a different reduction strategy?

@xi-guo-0 xi-guo-0 force-pushed the feat/metal-cross-entropy branch from d7beb52 to a3520fb Compare January 17, 2026 03:00
@xi-guo-0 xi-guo-0 force-pushed the feat/metal-cross-entropy branch 2 times, most recently from 40b8484 to 48fb04b Compare January 30, 2026 14:20
@xi-guo-0 xi-guo-0 force-pushed the feat/metal-cross-entropy branch from 48fb04b to a613156 Compare January 31, 2026 08:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants