88Functions for downloading pre-trained DiT models
99"""
1010import os
11-
11+ import json
1212import torch
1313from torchvision .datasets .utils import download_url
1414
@@ -22,11 +22,39 @@ def find_model(model_name):
2222 if model_name in pretrained_models : # Find/download our pre-trained DiT checkpoints
2323 return download_model (model_name )
2424 else : # Load a custom DiT checkpoint:
25- assert os .path .isfile (model_name ), f"Could not find DiT checkpoint at { model_name } "
26- checkpoint = torch .load (model_name , map_location = lambda storage , loc : storage )
27- if "ema" in checkpoint : # supports checkpoints from train.py
28- checkpoint = checkpoint ["ema" ]
29- return checkpoint
25+ if not os .path .isfile (model_name ):
26+ # if the model_name is a directory, then we assume we should load it in the Hugging Face manner
27+ # i.e. the model weights are sharded into multiple files and there is an index.json file
28+ # walk through the files in the directory and find the index.json file
29+ index_file = [os .path .join (model_name , f ) for f in os .listdir (model_name ) if "index.json" in f ]
30+ assert len (index_file ) == 1 , f"Could not find index.json in { model_name } "
31+
32+ # process index json
33+ with open (index_file [0 ], "r" ) as f :
34+ index_data = json .load (f )
35+
36+ bin_to_weight_mapping = dict ()
37+ for k , v in index_data ['weight_map' ].items ():
38+ if v in bin_to_weight_mapping :
39+ bin_to_weight_mapping [v ].append (k )
40+ else :
41+ bin_to_weight_mapping [v ] = [k ]
42+
43+ # make state dict
44+ state_dict = dict ()
45+ for bin_name , weight_list in bin_to_weight_mapping .items ():
46+ bin_path = os .path .join (model_name , bin_name )
47+ bin_state_dict = torch .load (bin_path , map_location = lambda storage , loc : storage )
48+ for weight in weight_list :
49+ state_dict [weight ] = bin_state_dict [weight ]
50+ return state_dict
51+ else :
52+ # if it is a file, we just load it directly in the typical PyTorch manner
53+ assert os .path .exists (model_name ), f"Could not find DiT checkpoint at { model_name } "
54+ checkpoint = torch .load (model_name , map_location = lambda storage , loc : storage )
55+ if "ema" in checkpoint : # supports checkpoints from train.py
56+ checkpoint = checkpoint ["ema" ]
57+ return checkpoint
3058
3159
3260def download_model (model_name ):
0 commit comments