Skip to content

Commit d7942e7

Browse files
authored
fixed sharded weight loading (#75)
* fixed sharded weight loading * polish
1 parent 4e8beca commit d7942e7

File tree

1 file changed

+34
-6
lines changed

1 file changed

+34
-6
lines changed

opendit/utils/download.py

Lines changed: 34 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
Functions for downloading pre-trained DiT models
99
"""
1010
import os
11-
11+
import json
1212
import torch
1313
from torchvision.datasets.utils import download_url
1414

@@ -22,11 +22,39 @@ def find_model(model_name):
2222
if model_name in pretrained_models: # Find/download our pre-trained DiT checkpoints
2323
return download_model(model_name)
2424
else: # Load a custom DiT checkpoint:
25-
assert os.path.isfile(model_name), f"Could not find DiT checkpoint at {model_name}"
26-
checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
27-
if "ema" in checkpoint: # supports checkpoints from train.py
28-
checkpoint = checkpoint["ema"]
29-
return checkpoint
25+
if not os.path.isfile(model_name):
26+
# if the model_name is a directory, then we assume we should load it in the Hugging Face manner
27+
# i.e. the model weights are sharded into multiple files and there is an index.json file
28+
# walk through the files in the directory and find the index.json file
29+
index_file = [os.path.join(model_name, f) for f in os.listdir(model_name) if "index.json" in f]
30+
assert len(index_file) == 1, f"Could not find index.json in {model_name}"
31+
32+
# process index json
33+
with open (index_file[0], "r") as f:
34+
index_data = json.load(f)
35+
36+
bin_to_weight_mapping = dict()
37+
for k, v in index_data['weight_map'].items():
38+
if v in bin_to_weight_mapping:
39+
bin_to_weight_mapping[v].append(k)
40+
else:
41+
bin_to_weight_mapping[v] = [k]
42+
43+
# make state dict
44+
state_dict = dict()
45+
for bin_name, weight_list in bin_to_weight_mapping.items():
46+
bin_path = os.path.join(model_name, bin_name)
47+
bin_state_dict = torch.load(bin_path, map_location=lambda storage, loc: storage)
48+
for weight in weight_list:
49+
state_dict[weight] = bin_state_dict[weight]
50+
return state_dict
51+
else:
52+
# if it is a file, we just load it directly in the typical PyTorch manner
53+
assert os.path.exists(model_name), f"Could not find DiT checkpoint at {model_name}"
54+
checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
55+
if "ema" in checkpoint: # supports checkpoints from train.py
56+
checkpoint = checkpoint["ema"]
57+
return checkpoint
3058

3159

3260
def download_model(model_name):

0 commit comments

Comments
 (0)