|
| 1 | +import os |
| 2 | +os.environ['HF_HUB_CACHE'] = '/share/shitao/downloaded_models2' |
| 3 | + |
| 4 | +from huggingface_hub import snapshot_download |
| 5 | + |
| 6 | +from diffusers.models import OmniGenTransformerModel |
| 7 | +from transformers import Phi3Model, Phi3Config |
| 8 | + |
| 9 | + |
| 10 | +from safetensors.torch import load_file |
| 11 | + |
| 12 | +model_name = "Shitao/OmniGen-v1" |
| 13 | +config = Phi3Config.from_pretrained("Shitao/OmniGen-v1") |
| 14 | +model = OmniGenTransformerModel(transformer_config=config) |
| 15 | +cache_folder = os.getenv('HF_HUB_CACHE') |
| 16 | +model_name = snapshot_download(repo_id=model_name, |
| 17 | + cache_dir=cache_folder, |
| 18 | + ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']) |
| 19 | +print(model_name) |
| 20 | +model_path = os.path.join(model_name, 'model.safetensors') |
| 21 | +ckpt = load_file(model_path, 'cpu') |
| 22 | + |
| 23 | + |
| 24 | +mapping_dict = { |
| 25 | + "pos_embed": "patch_embedding.pos_embed", |
| 26 | + "x_embedder.proj.weight": "patch_embedding.output_image_proj.weight", |
| 27 | + "x_embedder.proj.bias": "patch_embedding.output_image_proj.bias", |
| 28 | + "input_x_embedder.proj.weight": "patch_embedding.input_image_proj.weight", |
| 29 | + "input_x_embedder.proj.bias": "patch_embedding.input_image_proj.bias", |
| 30 | + "final_layer.adaLN_modulation.1.weight": "norm_out.linear.weight", |
| 31 | + "final_layer.adaLN_modulation.1.bias": "norm_out.linear.bias", |
| 32 | + "final_layer.linear.weight": "proj_out.weight", |
| 33 | + "final_layer.linear.bias": "proj_out.bias", |
| 34 | + |
| 35 | +} |
| 36 | + |
| 37 | +new_ckpt = {} |
| 38 | +for k, v in ckpt.items(): |
| 39 | + # new_ckpt[k] = v |
| 40 | + if k in mapping_dict: |
| 41 | + new_ckpt[mapping_dict[k]] = v |
| 42 | + else: |
| 43 | + new_ckpt[k] = v |
| 44 | + |
| 45 | + |
| 46 | + |
| 47 | +model.load_state_dict(new_ckpt) |
| 48 | + |
| 49 | + |
| 50 | + |
| 51 | + |
| 52 | + |
0 commit comments