Skip to content

Commit 733c74d

Browse files
authored
Merge branch 'ggml-org:master' into mradermacher
2 parents ed75d79 + 4c32832 commit 733c74d

File tree

6 files changed

+22
-4
lines changed

6 files changed

+22
-4
lines changed

ggml/src/ggml-cuda/fattn-vec-f16.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ static __global__ void flash_attn_vec_ext_f16(
212212
}
213213
}
214214
if (__all_sync(0xFFFFFFFF, skip)) {
215+
__syncthreads();
215216
continue;
216217
}
217218
#endif // GGML_USE_HIP

ggml/src/ggml-cuda/fattn-vec-f32.cuh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,7 @@ static __global__ void flash_attn_vec_ext_f32(
217217
}
218218
}
219219
if (__all_sync(0xFFFFFFFF, skip)) {
220+
__syncthreads();
220221
continue;
221222
}
222223
#endif // GGML_USE_HIP

ggml/src/ggml-cuda/ggml-cuda.cu

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2207,6 +2207,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
22072207
case GGML_UNARY_OP_SILU:
22082208
ggml_cuda_op_silu(ctx, dst);
22092209
break;
2210+
case GGML_UNARY_OP_GELU_ERF:
2211+
ggml_cuda_op_gelu_erf(ctx, dst);
2212+
break;
22102213
case GGML_UNARY_OP_GELU_QUICK:
22112214
ggml_cuda_op_gelu_quick(ctx, dst);
22122215
break;
@@ -2992,6 +2995,7 @@ static bool ggml_backend_cuda_device_supports_op(ggml_backend_dev_t dev, const g
29922995
case GGML_UNARY_OP_SIGMOID:
29932996
case GGML_UNARY_OP_HARDSIGMOID:
29942997
case GGML_UNARY_OP_HARDSWISH:
2998+
case GGML_UNARY_OP_GELU_ERF:
29952999
case GGML_UNARY_OP_GELU_QUICK:
29963000
case GGML_UNARY_OP_TANH:
29973001
case GGML_UNARY_OP_EXP:

ggml/src/ggml-cuda/unary.cu

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ static __device__ __forceinline__ float op_gelu(float x) {
2323
return 0.5f*x*(1.0f + tanhf(SQRT_2_OVER_PI*x*(1.0f + GELU_COEF_A*x*x)));
2424
}
2525

26+
static __device__ __forceinline__ float op_gelu_erf(float x) {
27+
const float SQRT_2_INV = 0.70710678118654752440084436210484f;
28+
29+
return 0.5f*x*(1.0f + erff(x*SQRT_2_INV));
30+
}
31+
2632
static __device__ __forceinline__ float op_gelu_quick(float x) {
2733
const float GELU_QUICK_COEF = -1.702f;
2834

@@ -134,6 +140,10 @@ void ggml_cuda_op_gelu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
134140
ggml_cuda_op_unary<op_gelu>(ctx, dst);
135141
}
136142

143+
void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
144+
ggml_cuda_op_unary<op_gelu_erf>(ctx, dst);
145+
}
146+
137147
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
138148
ggml_cuda_op_unary<op_gelu_quick>(ctx, dst);
139149
}

ggml/src/ggml-cuda/unary.cuh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ void ggml_cuda_op_silu(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
3030

3131
void ggml_cuda_op_silu_back(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
3232

33+
void ggml_cuda_op_gelu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
34+
3335
void ggml_cuda_op_gelu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
3436

3537
void ggml_cuda_op_tanh(ggml_backend_cuda_context & ctx, ggml_tensor * dst);

src/llama-vocab.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -835,7 +835,7 @@ struct llm_tokenizer_ugm_session {
835835
}
836836

837837
// initialize score_sum to -FLT_MAX so it will be always lower than sums of token scores
838-
std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.token_unk(), 0, -FLT_MAX});
838+
std::vector<struct best_tokenization> tokenization_results(input_len + 1, {vocab.token_unk(), 0, -DBL_MAX});
839839
// at the beginning tokenization score is zero
840840
tokenization_results[0] = { vocab.token_unk(), 0, 0 };
841841

@@ -867,7 +867,7 @@ struct llm_tokenizer_ugm_session {
867867
const double challenger_score = current_best.score_sum + token_score;
868868
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
869869
if (challenger_score > current_champ.score_sum) {
870-
struct best_tokenization challenger = { token_id, input_offset, (float) challenger_score };
870+
struct best_tokenization challenger = { token_id, input_offset, challenger_score };
871871
current_champ = challenger;
872872
}
873873
}
@@ -881,7 +881,7 @@ struct llm_tokenizer_ugm_session {
881881
prefix_offset = input_offset + n_utf8_code_units;
882882
struct best_tokenization & current_champ = tokenization_results[prefix_offset];
883883
if (challenger_score > current_champ.score_sum) {
884-
struct best_tokenization challenger = { vocab.token_unk(), input_offset, (float) challenger_score };
884+
struct best_tokenization challenger = { vocab.token_unk(), input_offset, challenger_score };
885885
current_champ = challenger;
886886
}
887887
}
@@ -1007,7 +1007,7 @@ struct llm_tokenizer_ugm_session {
10071007
struct best_tokenization {
10081008
llama_token token_id;
10091009
size_t input_offset;
1010-
float score_sum;
1010+
double score_sum;
10111011
};
10121012

10131013
struct normalization_result normalize_prefix(const std::string & input, size_t input_offset) {

0 commit comments

Comments
 (0)