Skip to content

Commit 1d863bf

Browse files
authored
Merge branch 'ggml-org:master' into master
2 parents 1528dd3 + 5016b72 commit 1d863bf

File tree

19 files changed

+349
-42
lines changed

19 files changed

+349
-42
lines changed

ggml/src/ggml-cann/aclnn_ops.cpp

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -146,9 +146,7 @@ void ggml_cann_op_unary_gated(
146146
unary_op(ctx, acl_src0, acl_dst);
147147
GGML_CANN_CALL_ACLNN_OP(ctx, InplaceMul, acl_dst, acl_src1);
148148

149-
ggml_cann_release_resources(ctx, acl_src0, acl_dst);
150-
if(src1)
151-
ggml_cann_release_resources(ctx, acl_src1);
149+
ggml_cann_release_resources(ctx, acl_src0, acl_src1, acl_dst);
152150
}
153151

154152
/**
@@ -1851,7 +1849,7 @@ void ggml_cann_get_rows(ggml_backend_cann_context& ctx, ggml_tensor* dst) {
18511849
dst->data, dst->ne, dst->nb,
18521850
src1, dst->type);
18531851

1854-
ggml_cann_release_resources(ctx, dequant_tensor);
1852+
ggml_cann_release_resources(ctx, acl_weight_tensor, acl_scale_tensor, dequant_tensor);
18551853
break;
18561854
}
18571855
default:
@@ -3290,8 +3288,8 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
32903288
aclTensor* acl_q_tensor = acl_src0_f16_tensor;
32913289
aclTensor* acl_k_tensors[] = {acl_src1_f16_tensor};
32923290
aclTensor* acl_v_tensors[] = {acl_src2_f16_tensor};
3293-
auto acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
3294-
auto acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
3291+
aclTensorList* acl_k_tensor_list = aclCreateTensorList(acl_k_tensors, kvTensorNum);
3292+
aclTensorList* acl_v_tensor_list = aclCreateTensorList(acl_v_tensors, kvTensorNum);
32953293

32963294
int64_t numHeads = src0->ne[2]; // N
32973295
int64_t numKeyValueHeads = src1->ne[2];
@@ -3362,8 +3360,8 @@ void ggml_cann_flash_attn_ext(ggml_backend_cann_context& ctx, ggml_tensor* dst){
33623360
}
33633361

33643362
ggml_cann_release_resources(ctx, acl_src0_f16_tensor,
3365-
acl_src1_f16_tensor,
3366-
acl_src2_f16_tensor,
3363+
acl_k_tensor_list,
3364+
acl_v_tensor_list,
33673365
fa_dst_tensor,
33683366
acl_dst_tensor,
33693367
bcast_pse_tensor);

ggml/src/ggml-cpu/ggml-cpu-impl.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ struct ggml_compute_params {
6868
#endif // __VXE2__
6969
#endif // __s390x__ && __VEC__
7070

71-
#if defined(__ARM_FEATURE_SVE)
71+
#if defined(__ARM_FEATURE_SVE) && defined(__linux__)
7272
#include <sys/prctl.h>
7373
#endif
7474

ggml/src/ggml-cpu/ggml-cpu.c

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -689,8 +689,13 @@ bool ggml_is_numa(void) {
689689
#endif
690690

691691
static void ggml_init_arm_arch_features(void) {
692-
#if defined(__linux__) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
692+
#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE)
693+
#if defined(__linux__)
693694
ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL);
695+
#else
696+
// TODO: add support of SVE for non-linux systems
697+
#error "TODO: SVE is not supported on this platform. To use SVE, sve_cnt needs to be initialized here."
698+
#endif
694699
#endif
695700
}
696701

ggml/src/ggml-cpu/vec.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -463,9 +463,9 @@ ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const floa
463463
#endif
464464
for (; i < n; ++i) {
465465
float val = x[i] - mean;
466+
y[i] = val;
466467
val *= val;
467468
sum += (ggml_float)val;
468-
y[i] = val;
469469
}
470470
return sum/n;
471471
}

ggml/src/ggml-cuda/fattn-tile.cuh

Lines changed: 17 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -540,10 +540,12 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
540540
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
541541
}
542542

543-
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) && (!oob_check || i_KQ < k_VKQ_sup) ?
544-
slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
543+
if (!oob_check || i_KQ < k_VKQ_sup) {
544+
KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ?
545+
slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f;
545546

546-
KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
547+
KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]);
548+
}
547549
}
548550

549551
KQ_max_new[jc0] = warp_reduce_max<warp_size>(KQ_max_new[jc0]);
@@ -581,10 +583,9 @@ static __device__ __forceinline__ void flash_attn_tile_iter(
581583
float KQ_sum_add = 0.0f;
582584
#pragma unroll
583585
for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) {
584-
const float val = expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]);
585-
if (!oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup) {
586-
KQ_sum_add += val;
587-
}
586+
const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ?
587+
expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f;
588+
KQ_sum_add += val;
588589
tmp[i0/(np*warp_size)][jc1] = val;
589590
}
590591
KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add;
@@ -975,26 +976,6 @@ static __global__ void flash_attn_tile(
975976
}
976977
}
977978

978-
if (gridDim.y == 1) {
979-
#pragma unroll
980-
for (int jc0 = 0; jc0 < cpw; ++jc0) {
981-
#ifdef FAST_FP16_AVAILABLE
982-
const half2 KQ_sum_jc_inv = make_half2(1.0f/KQ_sum[jc0], 1.0f/KQ_sum[jc0]);
983-
#pragma unroll
984-
for (int i = 0; i < (DVp/2)/warp_size; ++i) {
985-
VKQ[jc0*((DVp/2)/warp_size) + i] *= KQ_sum_jc_inv;
986-
}
987-
#else
988-
const float KQ_sum_jc_inv = 1.0f/KQ_sum[jc0];
989-
#pragma unroll
990-
for (int i = 0; i < (DVp/2)/warp_size; ++i) {
991-
VKQ[jc0*((DVp/2)/warp_size) + i].x *= KQ_sum_jc_inv;
992-
VKQ[jc0*((DVp/2)/warp_size) + i].y *= KQ_sum_jc_inv;
993-
}
994-
#endif // FAST_FP16_AVAILABLE
995-
}
996-
}
997-
998979
// Write back results:
999980
#pragma unroll
1000981
for (int jc0 = 0; jc0 < cpw; ++jc0) {
@@ -1007,6 +988,8 @@ static __global__ void flash_attn_tile(
1007988
return;
1008989
}
1009990

991+
const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f;
992+
1010993
const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y;
1011994

1012995
#ifdef FAST_FP16_AVAILABLE
@@ -1017,6 +1000,8 @@ static __global__ void flash_attn_tile(
10171000
#pragma unroll
10181001
for (int i1 = 0; i1 < cpy_ne_D; ++i1) {
10191002
tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]);
1003+
tmp[i1].x *= scale;
1004+
tmp[i1].y *= scale;
10201005
}
10211006
if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) {
10221007
ggml_cuda_memcpy_1<sizeof(tmp)>(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp);
@@ -1027,6 +1012,11 @@ static __global__ void flash_attn_tile(
10271012
#pragma unroll
10281013
for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) {
10291014
if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) {
1015+
#pragma unroll
1016+
for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) {
1017+
VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale;
1018+
VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale;
1019+
}
10301020
ggml_cuda_memcpy_1<cpy_ne_D*4>(
10311021
&dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D],
10321022
&VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]);

ggml/src/ggml-metal/ggml-metal-device.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1519,3 +1519,22 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_
15191519

15201520
return res;
15211521
}
1522+
1523+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) {
1524+
assert(op->op == GGML_OP_OPT_STEP_SGD);
1525+
1526+
char base[256];
1527+
char name[256];
1528+
1529+
snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type));
1530+
snprintf(name, 256, "%s", base);
1531+
1532+
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
1533+
if (res) {
1534+
return res;
1535+
}
1536+
1537+
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
1538+
1539+
return res;
1540+
}

ggml/src/ggml-metal/ggml-metal-device.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_me
136136
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op);
137137
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op);
138138
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op);
139+
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op);
139140

140141
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad(
141142
ggml_metal_library_t lib,

ggml/src/ggml-metal/ggml-metal-device.m

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -800,6 +800,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
800800
};
801801
}
802802
case GGML_OP_OPT_STEP_ADAMW:
803+
case GGML_OP_OPT_STEP_SGD:
803804
return has_simdgroup_reduction;
804805
default:
805806
return false;

ggml/src/ggml-metal/ggml-metal-impl.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -781,4 +781,8 @@ typedef struct {
781781
int64_t np;
782782
} ggml_metal_kargs_opt_step_adamw;
783783

784+
typedef struct {
785+
int64_t np;
786+
} ggml_metal_kargs_opt_step_sgd;
787+
784788
#endif // GGML_METAL_IMPL

ggml/src/ggml-metal/ggml-metal-ops.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -418,6 +418,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) {
418418
{
419419
n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx);
420420
} break;
421+
case GGML_OP_OPT_STEP_SGD:
422+
{
423+
n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx);
424+
} break;
421425
default:
422426
{
423427
GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op));
@@ -3469,3 +3473,37 @@ int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) {
34693473

34703474
return 1;
34713475
}
3476+
3477+
int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) {
3478+
ggml_tensor * op = ctx->node(idx);
3479+
3480+
ggml_metal_library_t lib = ctx->lib;
3481+
ggml_metal_encoder_t enc = ctx->enc;
3482+
3483+
GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne);
3484+
GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb);
3485+
GGML_TENSOR_LOCALS( int32_t, ne, op, ne);
3486+
GGML_TENSOR_LOCALS(uint32_t, nb, op, nb);
3487+
3488+
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op);
3489+
3490+
const int64_t np = ggml_nelements(op->src[0]);
3491+
ggml_metal_kargs_opt_step_sgd args = {
3492+
/*.np =*/ np,
3493+
};
3494+
3495+
int ida = 0;
3496+
3497+
ggml_metal_encoder_set_pipeline(enc, pipeline);
3498+
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++);
3499+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++);
3500+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++);
3501+
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++);
3502+
3503+
const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0);
3504+
const int64_t n = (np + nth - 1) / nth;
3505+
3506+
ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1);
3507+
3508+
return 1;
3509+
}

0 commit comments

Comments
 (0)