Skip to content

Commit 480b404

Browse files
CISCggerganov
andauthored
Apply suggestions from code review
Co-authored-by: Georgi Gerganov <[email protected]>
1 parent 7657ec3 commit 480b404

File tree

4 files changed

+6
-6
lines changed

4 files changed

+6
-6
lines changed

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows(ggml_metal_librar
142142
return res;
143143
}
144144

145-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tdst, ggml_type tidx) {
145+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows(ggml_metal_library_t lib, ggml_type tidx, ggml_type tdst) {
146146
char base[256];
147147
char name[256];
148148

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_base (ggml_me
105105
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_cpy (ggml_metal_library_t lib, enum ggml_type tsrc, enum ggml_type tdst);
106106
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pool_2d (ggml_metal_library_t lib, const struct ggml_tensor * op, enum ggml_op_pool op_pool);
107107
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_get_rows (ggml_metal_library_t lib, enum ggml_type tsrc);
108-
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tdst, enum ggml_type tidx);
108+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_metal_library_t lib, enum ggml_type tidx, enum ggml_type tdst);
109109
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc);
110110
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op);
111111
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op);

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -892,7 +892,7 @@ int ggml_metal_op_set_rows(ggml_metal_op_t ctx, int idx) {
892892
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
893893
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
894894

895-
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->type, op->src[1]->type);
895+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_set_rows(lib, op->src[1]->type, op->type);
896896

897897
const int32_t nk0 = ne0/ggml_blck_size(op->type);
898898

ggml/src/ggml-metal/ggml-metal.metal

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7743,7 +7743,7 @@ kernel void kernel_get_rows_i32(
77437743
}
77447744
}
77457745

7746-
template<typename idx_t, typename block_q, void (*quantize_func)(device const float *, device block_q &)>
7746+
template<typename TI, typename block_q, void (*quantize_func)(device const float *, device block_q &)>
77477747
kernel void kernel_set_rows_q32(
77487748
constant ggml_metal_kargs_set_rows & args,
77497749
device const void * src0,
@@ -7774,7 +7774,7 @@ kernel void kernel_set_rows_q32(
77747774
}
77757775
}
77767776

7777-
template<typename T, typename idx_t>
7777+
template<typename T, typename TI>
77787778
kernel void kernel_set_rows_f(
77797779
constant ggml_metal_kargs_set_rows & args,
77807780
device const void * src0,
@@ -7795,7 +7795,7 @@ kernel void kernel_set_rows_f(
77957795
}
77967796

77977797
const int32_t i10 = i01;
7798-
const idx_t i1 = ((const device idx_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
7798+
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
77997799

78007800
device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
78017801
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);

0 commit comments

Comments
 (0)