Skip to content

Commit 64c6dcb

Browse files
committed
metal : make the NSG a function constant in mul_mv kernels
ggml-ci
1 parent 320f029 commit 64c6dcb

File tree

5 files changed

+156
-104
lines changed

5 files changed

+156
-104
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -615,7 +615,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
615615
return res;
616616
}
617617

618-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
618+
ggml_metal_cv_t cv = ggml_metal_cv_init();
619+
620+
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
621+
622+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
623+
624+
ggml_metal_cv_free(cv);
619625

620626
ggml_metal_pipeline_set_nr0 (res, nr0);
621627
ggml_metal_pipeline_set_nr1 (res, nr1);
@@ -825,7 +831,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
825831
return res;
826832
}
827833

828-
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
834+
ggml_metal_cv_t cv = ggml_metal_cv_init();
835+
836+
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
837+
838+
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
839+
840+
ggml_metal_cv_free(cv);
829841

830842
ggml_metal_pipeline_set_nr0 (res, nr0);
831843
ggml_metal_pipeline_set_nr1 (res, nr1);

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ typedef struct ggml_metal_cv * ggml_metal_cv_t;
2222
ggml_metal_cv_t ggml_metal_cv_init(void);
2323
void ggml_metal_cv_free(ggml_metal_cv_t cv);
2424

25+
void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx);
2526
void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx);
2627
void ggml_metal_cv_set_bool (ggml_metal_cv_t cv, bool value, int32_t idx);
2728

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,10 @@ void ggml_metal_cv_free(ggml_metal_cv_t cv) {
5151
free(cv);
5252
}
5353

54+
void ggml_metal_cv_set_int16(ggml_metal_cv_t cv, int16_t value, int32_t idx) {
55+
[cv->obj setConstantValue:&value type:MTLDataTypeShort atIndex:idx];
56+
}
57+
5458
void ggml_metal_cv_set_int32(ggml_metal_cv_t cv, int32_t value, int32_t idx) {
5559
[cv->obj setConstantValue:&value type:MTLDataTypeInt atIndex:idx];
5660
}

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
#define FC_FLASH_ATTN_EXT 100
7676
#define FC_FLASH_ATTN_EXT_VEC 200
7777
#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300
78+
#define FC_MUL_MV 400
7879

7980
// kernel argument structs
8081
//

0 commit comments

Comments
 (0)