@@ -3666,12 +3666,11 @@ int ggml_metal_op_cross_entropy_loss(ggml_metal_op_t ctx, int idx) {
36663666 GGML_TENSOR_LOCALS ( int32_t , ne1, src1, ne);
36673667 GGML_TENSOR_LOCALS (uint64_t , nb1, src1, nb);
36683668
3669- const int64_t nclasses = ne00;
3670- const int64_t nrows = ggml_nrows (src0);
3669+ const int32_t nrows = (int32_t ) ggml_nrows (src0);
36713670
36723671 ggml_metal_kargs_cross_entropy_loss args = {
3673- /* .n_classes =*/ ( int32_t ) nclasses ,
3674- /* .n_rows =*/ ( int32_t ) nrows,
3672+ /* .n_classes =*/ ne00 ,
3673+ /* .n_rows =*/ nrows,
36753674 };
36763675
36773676 ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cross_entropy_loss (lib, op);
@@ -3704,11 +3703,10 @@ int ggml_metal_op_cross_entropy_loss_back(ggml_metal_op_t ctx, int idx) {
37043703
37053704 GGML_TENSOR_LOCALS ( int32_t , ne0, src0, ne);
37063705
3707- const int64_t nclasses = ne00;
3708- const int64_t nrows = ggml_nrows (src0);
3706+ const int32_t nrows = (int32_t ) ggml_nrows (src0);
37093707
37103708 ggml_metal_kargs_cross_entropy_loss_back args = {
3711- /* .n_classes =*/ ( int32_t ) nclasses ,
3709+ /* .n_classes =*/ ne00 ,
37123710 };
37133711
37143712 ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cross_entropy_loss_back (lib, op);
0 commit comments