Skip to content

Commit b3ec875

Browse files
committed
Cosmos convert logic
1 parent a2b7597 commit b3ec875

File tree

3 files changed

+66
-22
lines changed

3 files changed

+66
-22
lines changed

loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from .ops import GGMLTensor
88
from .dequant import is_quantized, dequantize_tensor
99

10-
IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "ltxv", "hyvid", "wan"}
10+
IMG_ARCH_LIST = {"flux", "sd1", "sdxl", "sd3", "aura", "hidream", "cosmos", "ltxv", "hyvid", "wan"}
1111
TXT_ARCH_LIST = {"t5", "t5encoder", "llama"}
1212

1313
def get_orig_shape(reader, tensor_name):

tools/convert.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class ModelTemplate:
1818
keys_detect = [] # list of lists to match in state dict
1919
keys_banned = [] # list of keys that should mark model as invalid for conversion
2020
keys_hiprec = [] # list of keys that need to be kept in fp32 for some reason
21+
keys_ignore = [] # list of strings to ignore keys by when found
2122

2223
def handle_nd_tensor(self, key, data):
2324
raise NotImplementedError(f"Tensor detected that exceeds dims supported by C++ code! ({key} @ {data.shape})")
@@ -60,6 +61,17 @@ class ModelHiDream(ModelTemplate):
6061
"img_emb.emb_pos"
6162
]
6263

64+
class CosmosPredict2(ModelTemplate):
65+
arch = "cosmos"
66+
keys_detect = [
67+
(
68+
"blocks.0.mlp.layer1.weight",
69+
"blocks.0.adaln_modulation_cross_attn.1.weight",
70+
)
71+
]
72+
keys_hiprec = ["pos_embedder"]
73+
keys_ignore = ["_extra_state", "accum_"]
74+
6375
class ModelHyVid(ModelTemplate):
6476
arch = "hyvid"
6577
keys_detect = [
@@ -128,7 +140,7 @@ class ModelSD1(ModelTemplate):
128140
]
129141

130142
# The architectures are checked in order and the first successful match terminates the search.
131-
arch_list = [ModelFlux, ModelSD3, ModelAura, ModelHiDream, ModelLTXV, ModelHyVid, ModelWan, ModelSDXL, ModelSD1]
143+
arch_list = [ModelFlux, ModelSD3, ModelAura, ModelHiDream, CosmosPredict2, ModelLTXV, ModelHyVid, ModelWan, ModelSDXL, ModelSD1]
132144

133145
def is_model_arch(model, state_dict):
134146
# check if model is correct
@@ -163,20 +175,32 @@ def parse_args():
163175
return args
164176

165177
def strip_prefix(state_dict):
166-
# only keep unet with no prefix!
178+
# prefix for mixed state dict
167179
prefix = None
168180
for pfx in ["model.diffusion_model.", "model."]:
169181
if any([x.startswith(pfx) for x in state_dict.keys()]):
170182
prefix = pfx
171183
break
172184

173-
sd = {}
174-
for k, v in state_dict.items():
175-
if prefix and prefix not in k:
176-
continue
177-
if prefix:
185+
# prefix for uniform state dict
186+
if prefix is None:
187+
for pfx in ["net."]:
188+
if all([x.startswith(pfx) for x in state_dict.keys()]):
189+
prefix = pfx
190+
break
191+
192+
# strip prefix if found
193+
if prefix is not None:
194+
logging.info(f"State dict prefix found: '{prefix}'")
195+
sd = {}
196+
for k, v in state_dict.items():
197+
if prefix not in k:
198+
continue
178199
k = k.replace(prefix, "")
179-
sd[k] = v
200+
sd[k] = v
201+
else:
202+
logging.debug("State dict has no prefix")
203+
sd = state_dict
180204

181205
return sd
182206

@@ -209,6 +233,10 @@ def handle_tensors(writer, state_dict, model_arch):
209233
for key, data in tqdm(state_dict.items()):
210234
old_dtype = data.dtype
211235

236+
if any(x in key for x in model_arch.keys_ignore):
237+
tqdm.write(f"Filtering ignored key: '{key}'")
238+
continue
239+
212240
if data.dtype == torch.bfloat16:
213241
data = data.to(torch.float32).numpy()
214242
# this is so we don't break torch 2.0.X

tools/lcpp.patch

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,10 @@ index b16c462f..6d1568f1 100644
3939
const int idx = gguf_find_tensor(ctx, name);
4040
if (idx < 0) {
4141
diff --git a/src/llama.cpp b/src/llama.cpp
42-
index 24e1f1f0..39045ca5 100644
42+
index 24e1f1f0..9957ea30 100644
4343
--- a/src/llama.cpp
4444
+++ b/src/llama.cpp
45-
@@ -205,6 +205,15 @@ enum llm_arch {
45+
@@ -205,6 +205,16 @@ enum llm_arch {
4646
LLM_ARCH_GRANITE,
4747
LLM_ARCH_GRANITE_MOE,
4848
LLM_ARCH_CHAMELEON,
@@ -55,10 +55,11 @@ index 24e1f1f0..39045ca5 100644
5555
+ LLM_ARCH_HYVID,
5656
+ LLM_ARCH_WAN,
5757
+ LLM_ARCH_HIDREAM,
58+
+ LLM_ARCH_COSMOS,
5859
LLM_ARCH_UNKNOWN,
5960
};
6061

61-
@@ -258,6 +267,15 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
62+
@@ -258,6 +268,16 @@ static const std::map<llm_arch, const char *> LLM_ARCH_NAMES = {
6263
{ LLM_ARCH_GRANITE, "granite" },
6364
{ LLM_ARCH_GRANITE_MOE, "granitemoe" },
6465
{ LLM_ARCH_CHAMELEON, "chameleon" },
@@ -71,10 +72,11 @@ index 24e1f1f0..39045ca5 100644
7172
+ { LLM_ARCH_HYVID, "hyvid" },
7273
+ { LLM_ARCH_WAN, "wan" },
7374
+ { LLM_ARCH_HIDREAM, "hidream" },
75+
+ { LLM_ARCH_COSMOS, "cosmos" },
7476
{ LLM_ARCH_UNKNOWN, "(unknown)" },
7577
};
7678

77-
@@ -1531,6 +1549,15 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
79+
@@ -1531,6 +1551,16 @@ static const std::map<llm_arch, std::map<llm_tensor, const char *>> LLM_TENSOR_N
7880
{ LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" },
7981
},
8082
},
@@ -87,10 +89,11 @@ index 24e1f1f0..39045ca5 100644
8789
+ { LLM_ARCH_HYVID, {}},
8890
+ { LLM_ARCH_WAN, {}},
8991
+ { LLM_ARCH_HIDREAM, {}},
92+
+ { LLM_ARCH_COSMOS, {}},
9093
{
9194
LLM_ARCH_UNKNOWN,
9295
{
93-
@@ -5403,6 +5430,23 @@ static void llm_load_hparams(
96+
@@ -5403,6 +5433,24 @@ static void llm_load_hparams(
9497
// get general kv
9598
ml.get_key(LLM_KV_GENERAL_NAME, model.name, false);
9699

@@ -105,6 +108,7 @@ index 24e1f1f0..39045ca5 100644
105108
+ case LLM_ARCH_HYVID:
106109
+ case LLM_ARCH_WAN:
107110
+ case LLM_ARCH_HIDREAM:
111+
+ case LLM_ARCH_COSMOS:
108112
+ model.ftype = ml.ftype;
109113
+ return;
110114
+ default:
@@ -114,7 +118,7 @@ index 24e1f1f0..39045ca5 100644
114118
// get hparams kv
115119
ml.get_key(LLM_KV_VOCAB_SIZE, hparams.n_vocab, false) || ml.get_arr_n(LLM_KV_TOKENIZER_LIST, hparams.n_vocab);
116120

117-
@@ -18016,6 +18060,129 @@ static void llama_tensor_dequantize_internal(
121+
@@ -18016,6 +18064,132 @@ static void llama_tensor_dequantize_internal(
118122
workers.clear();
119123
}
120124

@@ -149,7 +153,8 @@ index 24e1f1f0..39045ca5 100644
149153
+ (name.find(".to_v.weight") != std::string::npos) ||
150154
+ (name.find(".v.weight") != std::string::npos) ||
151155
+ (name.find(".attn.w1v.weight") != std::string::npos) ||
152-
+ (name.find(".attn.w2v.weight") != std::string::npos)
156+
+ (name.find(".attn.w2v.weight") != std::string::npos) ||
157+
+ (name.find("_attn.v_proj.weight") != std::string::npos)
153158
+ ){
154159
+ if (ftype == LLAMA_FTYPE_MOSTLY_Q2_K) {
155160
+ new_type = GGML_TYPE_Q3_K;
@@ -184,7 +189,9 @@ index 24e1f1f0..39045ca5 100644
184189
+ (name.find("ffn_down") != std::string::npos) ||
185190
+ ((name.find("experts.") != std::string::npos) && (name.find(".w2.weight") != std::string::npos)) ||
186191
+ (name.find(".ffn.2.weight") != std::string::npos) || // is this even the right way around?
187-
+ (name.find(".ff.net.2.weight") != std::string::npos)
192+
+ (name.find(".ff.net.2.weight") != std::string::npos) ||
193+
+ (name.find(".mlp.layer2.weight") != std::string::npos) ||
194+
+ (name.find(".adaln_modulation_mlp.2.weight") != std::string::npos)
188195
+ ) {
189196
+ // TODO: add back `layer_info` with some model specific logic + logic further down
190197
+ if (ftype == LLAMA_FTYPE_MOSTLY_Q3_K_M) {
@@ -244,7 +251,7 @@ index 24e1f1f0..39045ca5 100644
244251
static ggml_type llama_tensor_get_type(quantize_state_internal & qs, ggml_type new_type, const ggml_tensor * tensor, llama_ftype ftype) {
245252
const std::string name = ggml_get_name(tensor);
246253

247-
@@ -18513,7 +18680,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
254+
@@ -18513,7 +18687,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
248255
if (llama_model_has_encoder(&model)) {
249256
n_attn_layer *= 3;
250257
}
@@ -255,7 +262,7 @@ index 24e1f1f0..39045ca5 100644
255262
}
256263

257264
size_t total_size_org = 0;
258-
@@ -18547,6 +18716,51 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
265+
@@ -18547,6 +18723,51 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
259266
ctx_outs[i_split] = gguf_init_empty();
260267
}
261268
gguf_add_tensor(ctx_outs[i_split], tensor);
@@ -307,7 +314,7 @@ index 24e1f1f0..39045ca5 100644
307314
}
308315

309316
// Set split info if needed
310-
@@ -18647,6 +18861,92 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
317+
@@ -18647,6 +18868,101 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
311318
// do not quantize relative position bias (T5)
312319
quantize &= name.find("attn_rel_b.weight") == std::string::npos;
313320

@@ -392,6 +399,15 @@ index 24e1f1f0..39045ca5 100644
392399
+ quantize &= name.find(".ff_i.gate.weight") == std::string::npos;
393400
+ quantize &= name.find("caption_projection.") == std::string::npos;
394401
+ }
402+
+ if (model.arch == LLM_ARCH_COSMOS) {
403+
+ image_model = true;
404+
+ quantize &= name.find("p_embedder.") == std::string::npos;
405+
+ quantize &= name.find("t_embedder.") == std::string::npos;
406+
+ quantize &= name.find("t_embedding_norm.") == std::string::npos;
407+
+ quantize &= name.find("x_embedder.") == std::string::npos;
408+
+ quantize &= name.find("pos_embedder.") == std::string::npos;
409+
+ quantize &= name.find("final_layer.") == std::string::npos;
410+
+ }
395411
+ // ignore 3D/4D tensors for image models as the code was never meant to handle these
396412
+ if (image_model) {
397413
+ quantize &= ggml_n_dims(tensor) == 2;
@@ -400,7 +416,7 @@ index 24e1f1f0..39045ca5 100644
400416
enum ggml_type new_type;
401417
void * new_data;
402418
size_t new_size;
403-
@@ -18655,6 +18955,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
419+
@@ -18655,6 +18971,9 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
404420
new_type = default_type;
405421

406422
// get more optimal quantization type based on the tensor shape, layer, etc.
@@ -410,7 +426,7 @@ index 24e1f1f0..39045ca5 100644
410426
if (!params->pure && ggml_is_quantized(default_type)) {
411427
new_type = llama_tensor_get_type(qs, new_type, tensor, ftype);
412428
}
413-
@@ -18664,6 +18967,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
429+
@@ -18664,6 +18983,7 @@ static void llama_model_quantize_internal(const std::string & fname_inp, const s
414430
if (params->output_tensor_type < GGML_TYPE_COUNT && strcmp(tensor->name, "output.weight") == 0) {
415431
new_type = params->output_tensor_type;
416432
}

0 commit comments

Comments
 (0)