Skip to content

Commit 5c2b8d5

Browse files
committed
metal: remove redundant casting in cross entropy ops
1 parent d82759b commit 5c2b8d5

File tree

1 file changed

+5
-7
lines changed

1 file changed

+5
-7
lines changed

src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3666,12 +3666,11 @@ int ggml_metal_op_cross_entropy_loss(ggml_metal_op_t ctx, int idx) {
36663666
GGML_TENSOR_LOCALS( int32_t, ne1, src1, ne);
36673667
GGML_TENSOR_LOCALS(uint64_t, nb1, src1, nb);
36683668

3669-
const int64_t nclasses = ne00;
3670-
const int64_t nrows = ggml_nrows(src0);
3669+
const int32_t nrows = (int32_t) ggml_nrows(src0);
36713670

36723671
ggml_metal_kargs_cross_entropy_loss args = {
3673-
/*.n_classes =*/ (int32_t) nclasses,
3674-
/*.n_rows =*/ (int32_t) nrows,
3672+
/*.n_classes =*/ ne00,
3673+
/*.n_rows =*/ nrows,
36753674
};
36763675

36773676
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cross_entropy_loss(lib, op);
@@ -3704,11 +3703,10 @@ int ggml_metal_op_cross_entropy_loss_back(ggml_metal_op_t ctx, int idx) {
37043703

37053704
GGML_TENSOR_LOCALS( int32_t, ne0, src0, ne);
37063705

3707-
const int64_t nclasses = ne00;
3708-
const int64_t nrows = ggml_nrows(src0);
3706+
const int32_t nrows = (int32_t) ggml_nrows(src0);
37093707

37103708
ggml_metal_kargs_cross_entropy_loss_back args = {
3711-
/*.n_classes =*/ (int32_t) nclasses,
3709+
/*.n_classes =*/ ne00,
37123710
};
37133711

37143712
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cross_entropy_loss_back(lib, op);

0 commit comments

Comments
 (0)