Skip to content

Commit 94f4f29

Browse files
committed
Merge branch 'master' into qwen_image
2 parents 178a415 + 35843c7 commit 94f4f29

File tree

2 files changed

+15
-5
lines changed

2 files changed

+15
-5
lines changed

clip.hpp

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -553,10 +553,9 @@ class CLIPEmbeddings : public GGMLBlock {
553553
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types = {}, const std::string prefix = "") {
554554
enum ggml_type token_wtype = GGML_TYPE_F32;
555555
if (!force_clip_f32) {
556-
auto tensor_type = tensor_types.find(prefix + "token_embedding.weight");
557-
std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0};
558-
if (tensor_type != tensor_types.end() && allow_types.find(tensor_type->second) != allow_types.end()) {
559-
token_wtype = tensor_type->second;
556+
token_wtype = get_type(prefix + "token_embedding.weight", tensor_types, GGML_TYPE_F32);
557+
if (!support_get_rows(token_wtype)) {
558+
token_wtype = GGML_TYPE_F32;
560559
}
561560
}
562561
enum ggml_type position_wtype = GGML_TYPE_F32;

ggml_extend.hpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1980,13 +1980,24 @@ class Linear : public UnaryBlock {
19801980
}
19811981
};
19821982

1983+
__STATIC_INLINE__ bool support_get_rows(ggml_type wtype) {
1984+
std::set<ggml_type> allow_types = {GGML_TYPE_F16, GGML_TYPE_Q8_0, GGML_TYPE_Q5_1, GGML_TYPE_Q5_0, GGML_TYPE_Q4_1, GGML_TYPE_Q4_0};
1985+
if (allow_types.find(wtype) != allow_types.end()) {
1986+
return true;
1987+
}
1988+
return false;
1989+
}
1990+
19831991
class Embedding : public UnaryBlock {
19841992
protected:
19851993
int64_t embedding_dim;
19861994
int64_t num_embeddings;
19871995
void init_params(struct ggml_context* ctx, const String2GGMLType& tensor_types, const std::string prefix = "") {
19881996
enum ggml_type wtype = get_type(prefix + "weight", tensor_types, GGML_TYPE_F32);
1989-
params["weight"] = ggml_new_tensor_2d(ctx, wtype, embedding_dim, num_embeddings);
1997+
if (!support_get_rows(wtype)) {
1998+
wtype = GGML_TYPE_F32;
1999+
}
2000+
params["weight"] = ggml_new_tensor_2d(ctx, wtype, embedding_dim, num_embeddings);
19902001
}
19912002

19922003
public:

0 commit comments

Comments
 (0)