Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 46 additions & 1 deletion convert_flux_to_gguf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import argparse
import contextlib
import json
import safetensors.torch
import os
import re
import sys
Expand Down Expand Up @@ -177,6 +178,31 @@ def write(self) -> None:
self.gguf_writer.write_tensors_to_file(progress=True)
self.gguf_writer.close()

# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
def _merge_sharded_checkpoints(folder: Path):
with open(folder / "diffusion_pytorch_model.safetensors.index.json", "r") as f:
ckpt_metadata = json.load(f)
weight_map = ckpt_metadata.get("weight_map", None)
if weight_map is None:
raise KeyError("'weight_map' key not found in the shard index file.")

# Collect all unique safetensors files from weight_map
files_to_load = set(weight_map.values())
merged_state_dict = {}

# Load tensors from each unique file
for file_name in files_to_load:
part_file_path = folder / file_name
if not os.path.exists(part_file_path):
raise FileNotFoundError(f"Part file {file_name} not found.")

with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
for tensor_key in f.keys():
if tensor_key in weight_map:
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)

return merged_state_dict


def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -216,10 +242,29 @@ def main() -> None:
else:
logging.basicConfig(level=logging.INFO)

if not args.model.is_file():
if not args.model.is_dir() and not args.model.is_file():
logging.error(f"Model path {args.model} does not exist.")
sys.exit(1)

if args.model.is_dir():
logging.info("Supplied a directory.")
merged_state_dict = None
files = list(args.model.glob('*.safetensors'))
n = len(files)
if n == 0:
logging.error("No safetensors files found.")
sys.exit(1)
if n == 1:
logging.info(f"Assinging {files[0]} to `args.model`")
args.model = files[0]
if n > 1:
assert args.model / "diffusion_pytorch_model.safetensors.index.json" in list(args.model.glob("*.*"))
merged_state_dict = _merge_sharded_checkpoints(args.model)
filepath = "merged_state_dict.safetensors"
safetensors.torch.save_file(merged_state_dict, filepath)
logging.info(f"Serialized merged state dict to {filepath}")
args.model = Path(filepath)

if args.model.suffix != ".safetensors":
logging.error(f"Model path {args.model} is not a safetensors file.")
sys.exit(1)
Expand Down