Skip to content

Commit cf76817

Browse files
joerundemaxdebaysernjhill
authored
Also convert .index.json file (IBM#35)
The conversion to safetensors wasn't taking into account the pytorch_model.bin.index.json file that needs to be converted to model.safetensors.index.json. This also converts `bin.index.json` files Signed-off-by: Joe Runde <[email protected]> Co-authored-by: Maximilien Philippe Marie de Bayser <[email protected]> Co-authored-by: Nick Hill <[email protected]>
1 parent 6c56f8f commit cf76817

File tree

5 files changed

+42
-4
lines changed

5 files changed

+42
-4
lines changed

server/tests/utils/test_convert.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ def test_convert_files():
1313
local_pt_files = download_weights(model_id, extension=".bin")
1414
local_pt_files = [Path(p) for p in local_pt_files]
1515
local_st_files = [
16-
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors" for p in local_pt_files
16+
p.parent / f"{p.stem.removeprefix('pytorch_')}.safetensors" for p in local_pt_files
1717
]
1818
convert_files(local_pt_files, local_st_files, discard_names=[])
1919

server/text_generation_server/cli.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,16 +138,21 @@ def convert_to_safetensors(
138138
# Get local pytorch file paths
139139
model_path = utils.get_model_path(model_name, revision)
140140
local_pt_files = utils.local_weight_files(model_path, ".bin")
141+
local_pt_index_files = utils.local_index_files(model_path, ".bin")
142+
if len(local_pt_index_files) > 1:
143+
print(f"Found more than one .bin.index.json file: {local_pt_index_files}")
144+
return
141145

142146
if not local_pt_files:
143147
print("No pytorch .bin files found to convert")
144148
return
145149

146150
local_pt_files = [Path(f) for f in local_pt_files]
151+
local_pt_index_file = local_pt_index_files[0] if local_pt_index_files else None
147152

148153
# Safetensors final filenames
149154
local_st_files = [
150-
p.parent / f"{p.stem.lstrip('pytorch_')}.safetensors"
155+
p.parent / f"{p.stem.removeprefix('pytorch_')}.safetensors"
151156
for p in local_pt_files
152157
]
153158

@@ -173,6 +178,16 @@ def convert_to_safetensors(
173178
except Exception:
174179
discard_names = []
175180

181+
if local_pt_index_file:
182+
local_pt_index_file = Path(local_pt_index_file)
183+
local_st_index_file = local_pt_index_file.parent / f"{local_pt_index_file.stem.removeprefix('pytorch_').rstrip('.bin.index')}.safetensors.index.json"
184+
185+
if os.path.exists(local_st_index_file):
186+
print("Existing .safetensors.index.json file found, remove it first to reconvert")
187+
return
188+
189+
utils.convert_index_file(local_pt_index_file, local_st_index_file, local_pt_files, local_st_files)
190+
176191
# Convert pytorch weights to safetensors
177192
utils.convert_files(local_pt_files, local_st_files, discard_names)
178193

server/text_generation_server/utils/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from text_generation_server.utils.convert import convert_file, convert_files
1+
from text_generation_server.utils.convert import convert_file, convert_files, convert_index_file
22
from text_generation_server.utils.dist import (
33
initialize_torch_distributed,
44
run_rank_n,
@@ -10,6 +10,7 @@
1010
from text_generation_server.utils.hub import (
1111
get_model_path,
1212
local_weight_files,
13+
local_index_files,
1314
weight_files,
1415
weight_hub_files,
1516
download_weights,
@@ -41,13 +42,15 @@
4142
__all__ = [
4243
"convert_file",
4344
"convert_files",
45+
"convert_index_file",
4446
"initialize_torch_distributed",
4547
"run_rank_n",
4648
"print_rank_n",
4749
"get_torch_dtype",
4850
"RANK",
4951
"get_model_path",
5052
"local_weight_files",
53+
"local_index_files",
5154
"weight_files",
5255
"weight_hub_files",
5356
"download_weights",

server/text_generation_server/utils/convert.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import datetime
22
import torch
33
import os
4+
import json
45

56
from loguru import logger
67
from pathlib import Path
@@ -87,12 +88,25 @@ def convert_file(pt_file: Path, sf_file: Path, discard_names: List[str]):
8788
raise RuntimeError(f"The output tensors do not match for key {k}")
8889

8990

91+
def convert_index_file(source_file: Path, dest_file: Path, pt_files: List[Path], sf_files: List[Path]):
92+
weight_file_map = {s.name: d.name for s, d in zip(pt_files, sf_files)}
93+
94+
logger.info(f"Converting pytorch .bin.index.json files to .safetensors.index.json")
95+
with open(source_file, "r") as f:
96+
index = json.load(f)
97+
98+
index["weight_map"] = {k: weight_file_map[v] for k, v in index["weight_map"].items()}
99+
100+
with open(dest_file, "w") as f:
101+
json.dump(index, f, indent=4)
102+
103+
90104
def convert_files(pt_files: List[Path], sf_files: List[Path], discard_names: List[str] = None):
91105
assert len(pt_files) == len(sf_files)
92106

93107
# Filter non-inference files
94108
pairs = [p for p in zip(pt_files, sf_files) if not any(
95-
s in p[0].name for s in ["arguments", "args", "training", "optimizer", "scheduler"]
109+
s in p[0].name for s in ["arguments", "args", "training", "optimizer", "scheduler", "index"]
96110
)]
97111

98112
N = len(pairs)

server/text_generation_server/utils/hub.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,3 +100,9 @@ def local_weight_files(model_path: str, extension=".safetensors"):
100100
"""Get the local safetensors filenames"""
101101
ext = "" if extension is None else extension
102102
return glob.glob(f"{model_path}/*{ext}")
103+
104+
105+
def local_index_files(model_path: str, extension=".safetensors"):
106+
"""Get the local .index.json filename"""
107+
ext = "" if extension is None else extension
108+
return glob.glob(f"{model_path}/*{ext}.index.json")

0 commit comments

Comments
 (0)