diff --git a/convert_flux_to_gguf.py b/convert_flux_to_gguf.py index 28a3019..aef8eb6 100644 --- a/convert_flux_to_gguf.py +++ b/convert_flux_to_gguf.py @@ -8,6 +8,7 @@ import argparse import contextlib import json +import safetensors.torch import os import re import sys @@ -177,6 +178,31 @@ def write(self) -> None: self.gguf_writer.write_tensors_to_file(progress=True) self.gguf_writer.close() +# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64 +def _merge_sharded_checkpoints(folder: Path): + with open(folder / "diffusion_pytorch_model.safetensors.index.json", "r") as f: + ckpt_metadata = json.load(f) + weight_map = ckpt_metadata.get("weight_map", None) + if weight_map is None: + raise KeyError("'weight_map' key not found in the shard index file.") + + # Collect all unique safetensors files from weight_map + files_to_load = set(weight_map.values()) + merged_state_dict = {} + + # Load tensors from each unique file + for file_name in files_to_load: + part_file_path = folder / file_name + if not os.path.exists(part_file_path): + raise FileNotFoundError(f"Part file {file_name} not found.") + + with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f: + for tensor_key in f.keys(): + if tensor_key in weight_map: + merged_state_dict[tensor_key] = f.get_tensor(tensor_key) + + return merged_state_dict + def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser( @@ -216,10 +242,29 @@ def main() -> None: else: logging.basicConfig(level=logging.INFO) - if not args.model.is_file(): + if not args.model.is_dir() and not args.model.is_file(): logging.error(f"Model path {args.model} does not exist.") sys.exit(1) + if args.model.is_dir(): + logging.info("Supplied a directory.") + merged_state_dict = None + files = list(args.model.glob('*.safetensors')) + n = len(files) + if n == 0: + logging.error("No safetensors files found.") + sys.exit(1) + if n == 1: + logging.info(f"Assinging {files[0]} to `args.model`") + args.model = files[0] + if n > 1: + assert args.model / "diffusion_pytorch_model.safetensors.index.json" in list(args.model.glob("*.*")) + merged_state_dict = _merge_sharded_checkpoints(args.model) + filepath = "merged_state_dict.safetensors" + safetensors.torch.save_file(merged_state_dict, filepath) + logging.info(f"Serialized merged state dict to {filepath}") + args.model = Path(filepath) + if args.model.suffix != ".safetensors": logging.error(f"Model path {args.model} is not a safetensors file.") sys.exit(1)