metal: implement cross-entropy and count-equal ops for MNIST#1390
metal: implement cross-entropy and count-equal ops for MNIST#1390xi-guo-0 wants to merge 1 commit intoggml-org:masterfrom
Conversation
iliailmer
left a comment
There was a problem hiding this comment.
Hey!
I am working on a similar PR, I was pointed towards your work, I wanted to ask a couple of questions to understand our codes' differences/similarities better. Maybe we can join forces and add cross entropy together. I'm relatively new to Metal dev work.
I left a couple of comments.
| const ggml_tensor * src1 = op->src[1]; | ||
|
|
||
| GGML_TENSOR_LOCALS( int32_t, ne0, src0, ne); | ||
| GGML_TENSOR_LOCALS(uint64_t, nb0, src0, nb); |
There was a problem hiding this comment.
Curious, why not use int32_t here directly instead of casting later?
There was a problem hiding this comment.
Thanks for the review!
You are absolutely right. ne00 was already int32_t from the macro, so promoting it to int64_t just to cast it back was redundant. I've cleaned it up
src/ggml-metal/ggml-metal.metal
Outdated
| template<typename T> | ||
| kernel void kernel_cross_entropy_loss( | ||
| constant ggml_metal_kargs_cross_entropy_loss & args, | ||
| device const char * logits_ptr, |
There was a problem hiding this comment.
Why use char instead of accepting float type argument for logits?
There was a problem hiding this comment.
I'm sticking with char* to support the template type-erasure pattern used in ggml-metal. Since I'm using decltype to enforce a single function signature across all template specializations (f32, f16, i32), the arguments must match exactly. device const char* allows me to bind the same kernel signature regardless of the underlying data type.
5c2b8d5 to
3f83fa0
Compare
iliailmer
left a comment
There was a problem hiding this comment.
Thanks for your responses! I left a couple more clarifications and questions.
src/ggml-metal/ggml-metal.metal
Outdated
| if (tgpig.x == 0 && sgitg == 0 && tiisg == 0) { | ||
| float total_loss = 0.0f; | ||
| for (int i = 0; i < nrows; i++) { | ||
| total_loss += dst[i]; |
There was a problem hiding this comment.
if i am not mistaken this part tries to collect the loss from all positions of dst array on a 0 thread, but I couldn't find where we are writing into dst? also, would this line cause race condition where 0th thread is summing incomplete locations?
There was a problem hiding this comment.
Good catch. You are absolutely right—the previous logic caused a race condition and failed to write to dst correctly. I have removed the global accumulation loop. The kernel now computes the loss for each row and writes it directly to dst[tgpig.x] (row index).
src/ggml-metal/ggml-metal-device.cpp
Outdated
| char base[256]; | ||
| char name[256]; | ||
|
|
||
| snprintf(base, 256, "kernel_cross_entropy_loss_back_%s", ggml_type_name(op->src[0]->type)); |
There was a problem hiding this comment.
| snprintf(base, 256, "kernel_cross_entropy_loss_back_%s", ggml_type_name(op->src[0]->type)); | |
| snprintf(base, 256, "kernel_cross_entropy_loss_back_%s", ggml_type_name(op->src[1]->type)); |
just replacing the grad type with logits type
There was a problem hiding this comment.
Fixed. I switched it to op->src[1]->type to correctly use the logits type.
src/ggml-metal/ggml-metal.metal
Outdated
|
|
||
| template [[host_name("kernel_cross_entropy_loss_f32")]] kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss<float>; | ||
| template [[host_name("kernel_cross_entropy_loss_f16")]] kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss<half>; | ||
| template [[host_name("kernel_cross_entropy_loss_i32")]] kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss<int>; |
There was a problem hiding this comment.
Can you help me understand what is the use case for i32 here? Would that mean we are casting the logits_ptr into int *? I think it should only be f32 and f16 since logits are always floats
There was a problem hiding this comment.
You are right. I have removed the i32 and i16 templates since logits are always floating-point numbers (f32 or f16).
src/ggml-metal/ggml-metal.metal
Outdated
| x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid]; | ||
| } | ||
|
|
||
| typedef decltype(kernel_cross_entropy_loss<int>) kernel_cross_entropy_loss_t; |
There was a problem hiding this comment.
| typedef decltype(kernel_cross_entropy_loss<int>) kernel_cross_entropy_loss_t; | |
| typedef decltype(kernel_cross_entropy_loss<float>) kernel_cross_entropy_loss_t; |
There was a problem hiding this comment.
Fixed. I updated the typedefs to use as the base type for both the forward and backward kernels.
src/ggml-metal/ggml-metal.metal
Outdated
| template [[host_name("kernel_cross_entropy_loss_i32")]] kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss<int>; | ||
| template [[host_name("kernel_cross_entropy_loss_i16")]] kernel kernel_cross_entropy_loss_t kernel_cross_entropy_loss<short>; | ||
|
|
||
| typedef decltype(kernel_cross_entropy_loss_back<int>) kernel_cross_entropy_loss_back_t; |
There was a problem hiding this comment.
| typedef decltype(kernel_cross_entropy_loss_back<int>) kernel_cross_entropy_loss_back_t; | |
| typedef decltype(kernel_cross_entropy_loss_back<float>) kernel_cross_entropy_loss_back_t; |
|
Note that the count_equal op has just been merged. |
3f83fa0 to
d7beb52
Compare
src/ggml-metal/ggml-metal.metal
Outdated
| } | ||
|
|
||
| if (sgitg == 0 && tiisg == 0) { | ||
| dst[tgpig.x] = -row_loss; |
There was a problem hiding this comment.
i think this is still out of bounds potentially
There was a problem hiding this comment.
Thanks for pointing that out. I've added a boundary check (if (tgpig.x >= nrows) return;) at the beginning of the kernel. This ensures that any extra threadgroups dispatched beyond the number of rows will return early, preventing the out-of-bounds access on dst.
There was a problem hiding this comment.
i think this check won't work because tgpig.x is at between 0 and nrows-1. i'm trying to see what can be done to help fix this, might need a different reduction strategy?
d7beb52 to
a3520fb
Compare
40b8484 to
48fb04b
Compare
48fb04b to
a613156
Compare
This PR implements the missing Metal backend operations required to run examples/mnist