Skip to content

[feat](kt-kernel): Add utility script to merge loose layer weights to safetensors#1886

Open
DocShotgun wants to merge 3 commits intokvcache-ai:mainfrom
DocShotgun:merge_cpu_weights
Open

[feat](kt-kernel): Add utility script to merge loose layer weights to safetensors#1886
DocShotgun wants to merge 3 commits intokvcache-ai:mainfrom
DocShotgun:merge_cpu_weights

Conversation

@DocShotgun
Copy link
Contributor

merge_cpu_weights.py

This is a simple utility script that merges CPU weights currently in the loose layers format (created using convert_cpu_weights.py with --no-merge-safetensor for memory savings) into sharded safetensors for more convenient upload.

The relevant args are:

  • --input-path: Input directory with nested _layer_* folders
  • --output: Output directory for merged safetensors
  • --original-path: Original model folder with config.json and tokenizer files to copy
  • --max-tensors: Maximum tensors per safetensors shard (default: 3000)

I decided to keep the default shard size as 3000 tensors to match the existing behavior of convert_cpu_weights.py.

Before submitting

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a new Python utility script designed to streamline the process of consolidating CPU-optimized model weights. Previously, weights could be saved in a fragmented, loose layer format for memory efficiency during conversion. This new script addresses the subsequent need to merge these fragmented weights into a more manageable and upload-friendly sharded safetensors format, improving the overall workflow for model deployment and sharing.

Highlights

  • New Utility Script: Added a new utility script, merge_cpu_weights.py, to consolidate CPU-optimized model weights.
  • Weight Merging: Enabled merging of weights from a loose layer folder structure (generated with --no-merge-safetensor) into sharded safetensors.
  • Quantization Method Detection: Implemented detection for various quantization methods (INT4, INT8, MOE_INT4, MOE_INT8) to correctly process weight files.
  • Safetensor Sharding: Provided functionality to shard output safetensors based on a configurable maximum number of tensors per shard.
  • Configuration File Copying: Included an option to copy essential configuration and tokenizer files from an original model directory to the output.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • kt-kernel/scripts/merge_cpu_weights.py
    • Implemented functions to discover layer and NUMA folders containing loose weights.
    • Added logic to detect the quantization method (e.g., INT4, INT8) from file names.
    • Developed a function to load binary tensor files (.kt format) into PyTorch tensors.
    • Created a core processing function to iterate through layers and NUMA folders, collecting all tensors.
    • Implemented sharding logic to write accumulated tensors into multiple safetensor files based on a maximum tensor count.
    • Included a utility to copy model configuration and tokenizer files to the output directory.
    • Provided command-line argument parsing for input/output paths, original model path, and max tensors per shard.
Activity
  • No human activity has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for GitHub and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a new utility script, merge_cpu_weights.py, to merge loose layer weights into sharded safetensors. The script is well-structured and handles file discovery, processing, and sharding. My review focuses on improving memory efficiency, error handling practices, and overall code clarity. I've identified a high-severity memory issue in the sharding logic and a medium-severity issue regarding error message output. The suggested changes will make the script more robust and efficient, especially for large models.

Comment on lines +125 to +179
def write_shards(accumulated_tensors: dict, output_path: str, shard_counter: dict, keep_remainder: bool = True):
"""Write accumulated tensors to one or more shard files.

Args:
accumulated_tensors: Dict of tensors to write
output_path: Output directory
shard_counter: Dict with 'shard' and 'max_tensors' keys
keep_remainder: If True, keep leftover tensors in accumulator for next batch
"""
if not accumulated_tensors:
return

max_tensors = shard_counter["max_tensors"]
current_shard = shard_counter["shard"]
total_tensors = len(accumulated_tensors)

if total_tensors <= max_tensors:
if not keep_remainder:
output_file = os.path.join(output_path, f"model-{current_shard:05d}.safetensors")
save_file(accumulated_tensors, output_file)
print(f" Saved {total_tensors} tensors to {output_file}")
shard_counter["shard"] = current_shard + 1
accumulated_tensors.clear()
else:
pass # Keep accumulating until we hit max_tensors
else:
full_shards = total_tensors // max_tensors
remainder = total_tensors % max_tensors

items = list(accumulated_tensors.items())

# Write full shards
for i in range(full_shards):
batch = dict(items[i * max_tensors : (i + 1) * max_tensors])
output_file = os.path.join(output_path, f"model-{current_shard:05d}.safetensors")
save_file(batch, output_file)
print(f" Saved {len(batch)} tensors to {output_file}")
current_shard += 1

# Keep remainder for next batch if enabled
if keep_remainder and remainder > 0:
remainder_items = dict(items[full_shards * max_tensors:])
accumulated_tensors.clear()
accumulated_tensors.update(remainder_items)
print(f" Rolled over {remainder} tensors to next batch")
elif remainder > 0:
# Write remainder as final shard
batch = dict(items[full_shards * max_tensors:])
output_file = os.path.join(output_path, f"model-{current_shard:05d}.safetensors")
save_file(batch, output_file)
print(f" Saved {len(batch)} tensors to {output_file}")
current_shard += 1
accumulated_tensors.clear()

shard_counter["shard"] = current_shard
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation of write_shards converts the entire accumulated_tensors dictionary to a list using list(accumulated_tensors.items()). For large models, this can consume a significant amount of memory by duplicating the tensor references and keys, which seems to contradict the memory-saving goal of this script. This could lead to an OutOfMemoryError for very large models.

Additionally, the logic for handling sharding and remainders is quite complex and could be simplified for better readability and maintainability.

I suggest refactoring the function to process tensors in chunks without creating a large intermediate list of all items. This will make the script more memory-efficient and robust.

def write_shards(accumulated_tensors: dict, output_path: str, shard_counter: dict, keep_remainder: bool = True):
    """Write accumulated tensors to one or more shard files.
    
    Args:
        accumulated_tensors: Dict of tensors to write
        output_path: Output directory
        shard_counter: Dict with 'shard' and 'max_tensors' keys
        keep_remainder: If True, keep leftover tensors in accumulator for next batch
    """
    if not accumulated_tensors:
        return

    max_tensors = shard_counter["max_tensors"]
    
    # In keep_remainder mode, only write if we have at least one full shard.
    if keep_remainder and len(accumulated_tensors) < max_tensors:
        return

    # Process in chunks without creating a giant list of all items
    while len(accumulated_tensors) >= max_tensors:
        keys_to_write = []
        for i, key in enumerate(accumulated_tensors):
            if i >= max_tensors:
                break
            keys_to_write.append(key)
        
        batch = {key: accumulated_tensors.pop(key) for key in keys_to_write}
        
        output_file = os.path.join(output_path, f"model-{shard_counter['shard']:05d}.safetensors")
        save_file(batch, output_file)
        print(f"  Saved {len(batch)} tensors to {output_file}")
        shard_counter["shard"] += 1

    # Handle remainder
    if not keep_remainder and accumulated_tensors:
        output_file = os.path.join(output_path, f"model-{shard_counter['shard']:05d}.safetensors")
        save_file(accumulated_tensors, output_file)
        print(f"  Saved {len(accumulated_tensors)} tensors to {output_file}")
        shard_counter["shard"] += 1
        accumulated_tensors.clear()
    elif keep_remainder and accumulated_tensors:
        print(f"  Rolled over {len(accumulated_tensors)} tensors to next batch")

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My understanding is that list(accumulated_tensors.items()) creates references to existing tensors, not copies, so memory duplication is minimal. The keys themselves are small compared to the tensor data which is already stored in memory.

@ErvinXie
Copy link
Collaborator

Bug: MOE prefix causes expert index parsing to fail

In process_layer, the expert index is parsed by splitting the filename on _ and taking parts[2]:

parts = filename.split("_")
expert_idx = int(parts[2])

This works for INT4 / INT8 prefixes but breaks for MOE_INT4 / MOE_INT8 because the prefix itself contains an underscore, shifting all indices by one:

Quant method Filename parts[2]
INT4 INT4_down_0_4Byte_quant_.kt "0"
MOE_INT4 MOE_INT4_down_0_4Byte_quant_.kt "down"

int("down") raises ValueError, which gets silently caught by the except block — so all MOE tensors are skipped and the output is empty.

Suggested fix

Since amx_prefix and proj_name are already known in the loop, strip them from the filename and parse the remainder:

# Instead of:
parts = filename.split("_")
expert_idx = int(parts[2])

# Use:
remainder = filename[len(f"{amx_prefix}_{proj_name}_"):]
expert_idx = int(remainder.split("_")[0])

This works regardless of whether the prefix is INT4 or MOE_INT4. The same fix applies to both the quant and scale file loops.

@ErvinXie
Copy link
Collaborator

Thanks for the contribution! This is a useful utility for the two-stage conversion workflow.

The MOE prefix parsing issue mentioned above is the only blocker — once that's fixed, we'll merge this in. 🙏

@DocShotgun
Copy link
Contributor Author

I pushed the suggested change. Tried to test it locally, but I don't think my native build of kt-kernel is compiled with support for moe_int4 or moe_int8.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants