Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 56 additions & 52 deletions keras_hub/src/utils/transformers/export/gemma.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import keras.ops as ops
import numpy as np


def get_gemma_config(backbone):
def get_gemma_config(backbone, include_lm_head=False):
token_embedding_layer = backbone.get_layer("token_embedding")
if include_lm_head:
architectures = ["GemmaForCausalLM"] # Full model with LM head
else:
architectures = ["GemmaForBackbone"] # Just backbone
hf_config = {
"architectures": architectures,
"vocab_size": backbone.vocabulary_size,
"num_hidden_layers": backbone.num_layers,
"num_attention_heads": backbone.num_query_heads,
Expand All @@ -22,79 +27,78 @@ def get_gemma_config(backbone):


def get_gemma_weights_map(backbone, include_lm_head=False):
weights_dict = {}

# Map token embedding
token_embedding_layer = backbone.get_layer("token_embedding")
weights_dict["model.embed_tokens.weight"] = token_embedding_layer.weights[0]

yield "model.embed_tokens.weight", token_embedding_layer.weights[0]
for i in range(backbone.num_layers):
decoder_layer = backbone.get_layer(f"decoder_block_{i}")

# Pre-attention normalization
weights_dict[f"model.layers.{i}.input_layernorm.weight"] = (
decoder_layer.pre_attention_norm.weights[0]
yield (
f"model.layers.{i}.input_layernorm.weight",
decoder_layer.pre_attention_norm.weights[0],
)

# Attention query projection
query_kernel = decoder_layer.attention.query_dense.weights[0]
query_kernel = ops.transpose(query_kernel, axes=(1, 0, 2))
query_kernel = ops.reshape(query_kernel, (-1, backbone.hidden_dim))
query_kernel = ops.transpose(query_kernel)
weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = query_kernel

yield f"model.layers.{i}.self_attn.q_proj.weight", query_kernel
# Attention key projection
key_kernel = decoder_layer.attention.key_dense.weights[0][0]
weights_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = (
ops.transpose(key_kernel)
)

key_kernel = decoder_layer.attention.key_dense.weights[0]
yield f"model.layers.{i}.self_attn.k_proj.weight", key_kernel
# Attention value projection
value_kernel = decoder_layer.attention.value_dense.weights[0][0]
weights_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = (
ops.transpose(value_kernel)
)

value_kernel = decoder_layer.attention.value_dense.weights[0]
yield f"model.layers.{i}.self_attn.v_proj.weight", value_kernel
# Attention output projection
out_kernel = decoder_layer.attention.output_dense.weights[0]
out_kernel = ops.transpose(out_kernel, axes=(2, 0, 1))
out_kernel = ops.reshape(out_kernel, (backbone.hidden_dim, -1))
weights_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = out_kernel

yield f"model.layers.{i}.self_attn.o_proj.weight", out_kernel
# Post-attention normalization
weights_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = (
decoder_layer.pre_ffw_norm.weights[0]
yield (
f"model.layers.{i}.post_attention_layernorm.weight",
decoder_layer.pre_ffw_norm.weights[0],
)

# MLP gate projection
gate_kernel = decoder_layer.gating_ffw.weights[0]
weights_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = ops.transpose(
gate_kernel
)

yield f"model.layers.{i}.mlp.gate_proj.weight", gate_kernel
# MLP up projection
up_kernel = decoder_layer.gating_ffw_2.weights[0]
weights_dict[f"model.layers.{i}.mlp.up_proj.weight"] = ops.transpose(
up_kernel
)

yield f"model.layers.{i}.mlp.up_proj.weight", up_kernel
# MLP down projection
down_kernel = decoder_layer.ffw_linear.weights[0]
weights_dict[f"model.layers.{i}.mlp.down_proj.weight"] = ops.transpose(
down_kernel
)

yield f"model.layers.{i}.mlp.down_proj.weight", down_kernel
# Map final normalization
weights_dict["model.norm.weight"] = backbone.get_layer(
"final_normalization"
).weights[0]

yield (
"model.norm.weight",
backbone.get_layer("final_normalization").weights[0],
)
# Map lm_head if embeddings are not tied
if include_lm_head and not token_embedding_layer.tie_weights:
weights_dict["lm_head.weight"] = ops.transpose(
token_embedding_layer.reverse_embeddings
)
return weights_dict
lm_head = token_embedding_layer.reverse_embeddings
yield "lm_head.weight", lm_head


def get_gemma_transform_fn(backbone):
"""Return a transform function for Gemma weights."""

def transform(name, np_tensor):
if name.endswith("q_proj.weight"):
np_tensor = np.transpose(np_tensor, axes=(1, 0, 2))
np_tensor = np.reshape(np_tensor, (-1, backbone.hidden_dim))
np_tensor = np.transpose(np_tensor)
elif name.endswith("k_proj.weight") or name.endswith("v_proj.weight"):
np_tensor = np.transpose(np_tensor, axes=(0, 2, 1))
np_tensor = np.reshape(np_tensor, (-1, backbone.hidden_dim))
elif name.endswith("o_proj.weight"):
np_tensor = np.transpose(np_tensor, axes=(2, 0, 1))
np_tensor = np.reshape(np_tensor, (backbone.hidden_dim, -1))
elif (
name.endswith("gate_proj.weight")
or name.endswith("up_proj.weight")
or name.endswith("down_proj.weight")
):
np_tensor = np.transpose(np_tensor)
elif name == "lm_head.weight":
np_tensor = np.transpose(np_tensor)
return np_tensor

return transform


def get_gemma_tokenizer_config(tokenizer):
Expand Down
122 changes: 91 additions & 31 deletions keras_hub/src/utils/transformers/export/hf_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,13 @@
import warnings

import keras
from safetensors.numpy import save_file

from keras_hub.src.utils.transformers.export.gemma import get_gemma_config
from keras_hub.src.utils.transformers.export.gemma import (
get_gemma_tokenizer_config,
)
from keras_hub.src.utils.transformers.export.gemma import get_gemma_transform_fn
from keras_hub.src.utils.transformers.export.gemma import get_gemma_weights_map

MODEL_CONFIGS = {
Expand All @@ -27,16 +29,21 @@
# get_mistral_tokenizer_config
}

MODEL_TRANSFORMERS = {
"GemmaBackbone": get_gemma_transform_fn,
# Add for future models, e.g., "MistralBackbone": get_mistral_transform_fn
}


def export_backbone(backbone, path, include_lm_head=False):
def export_backbone(backbone, path, include_lm_head=False, max_shard_size=2.0):
"""Export the backbone model to HuggingFace format.

Args:
backbone: The Keras backbone model to convert.
path: str. Path to save the exported model.
include_lm_head: bool. If True, include lm_head weights if applicable.
max_shard_size: float. Maximum size in GB for each shard.
"""
backend = keras.config.backend()
model_type = backbone.__class__.__name__
if model_type not in MODEL_CONFIGS:
raise ValueError(
Expand All @@ -46,41 +53,91 @@ def export_backbone(backbone, path, include_lm_head=False):
raise ValueError(
f"Export to Transformers format not implemented for {model_type}"
)
if model_type not in MODEL_TRANSFORMERS:
raise ValueError(f"Transformations not implemented for {model_type}")

def to_numpy(tensor):
return tensor.numpy()

# Get config
get_config_fn = MODEL_CONFIGS[model_type]
hf_config = get_config_fn(backbone)
# Get weights
get_weights_fn = MODEL_EXPORTERS[model_type]
weights_dict = get_weights_fn(backbone, include_lm_head=include_lm_head)
if not weights_dict:
raise ValueError("No weights to save.")
hf_config = get_config_fn(backbone, include_lm_head=include_lm_head)
# Save config
os.makedirs(path, exist_ok=True)
config_path = os.path.join(path, "config.json")
with open(config_path, "w") as f:
json.dump(hf_config, f)
# Save weights based on backend
weights_path = os.path.join(path, "model.safetensors")
if backend == "torch":
from safetensors.torch import save_file

weights_dict_contiguous = {
k: v.value.contiguous() if hasattr(v, "value") else v.contiguous()
for k, v in weights_dict.items()
}
save_file(
weights_dict_contiguous, weights_path, metadata={"format": "pt"}
)
elif backend == "tensorflow":
from safetensors.tensorflow import save_file
# Get model-specific transform function
get_transform_fn = MODEL_TRANSFORMERS[model_type]
transform_fn = get_transform_fn(backbone)
# Single pass: dynamic sharding based on actual NumPy sizes,
# processing one tensor at a time

save_file(weights_dict, weights_path, metadata={"format": "pt"})
elif backend == "jax":
from safetensors.flax import save_file

save_file(weights_dict, weights_path, metadata={"format": "pt"})
else:
raise ValueError(f"Unsupported backend: {backend}")
get_weights_fn = MODEL_EXPORTERS[model_type]
weights_generator = get_weights_fn(
backbone, include_lm_head=include_lm_head
)
shard_num = 1
current_shard_dict = {}
current_size_gb = 0.0
weight_map = {}
total_size_bytes = 0
temp_shard_files = []
current_temp_file = None
for name, backend_tensor in weights_generator:
np_tensor = to_numpy(backend_tensor)
np_tensor = transform_fn(name, np_tensor) # Model-specific transform
tensor_size_gb = np_tensor.nbytes / (1024**3)
total_size_bytes += np_tensor.nbytes
if (
current_size_gb + tensor_size_gb > max_shard_size
and current_shard_dict
):
# Save current shard as temp
current_temp_file = f"temp_shard_{shard_num}.safetensors"
weights_path = os.path.join(path, current_temp_file)
save_file(
current_shard_dict, weights_path, metadata={"format": "pt"}
)
temp_shard_files.append(
(current_temp_file, list(current_shard_dict.keys()))
)
del current_shard_dict
current_shard_dict = {}
current_size_gb = 0.0
shard_num += 1
current_shard_dict[name] = np_tensor
current_size_gb += tensor_size_gb
del np_tensor # Explicitly del to aid GC after adding to shard
# Save last shard
if current_shard_dict:
current_temp_file = f"temp_shard_{shard_num}.safetensors"
weights_path = os.path.join(path, current_temp_file)
save_file(current_shard_dict, weights_path, metadata={"format": "pt"})
temp_shard_files.append(
(current_temp_file, list(current_shard_dict.keys()))
)
del current_shard_dict
num_shards = shard_num
# Rename temp files to final format and build weight_map
for i, (temp_file, keys) in enumerate(temp_shard_files, 1):
if num_shards == 1:
final_file = "model.safetensors"
else:
final_file = f"model-{i:05d}-of-{num_shards:05d}.safetensors"
shutil.move(
os.path.join(path, temp_file), os.path.join(path, final_file)
)
for key in keys:
weight_map[key] = final_file
# Save index
index = {
"metadata": {"total_size": total_size_bytes},
"weight_map": weight_map,
}
index_path = os.path.join(path, "model.safetensors.index.json")
with open(index_path, "w") as f:
json.dump(index, f)


def export_tokenizer(tokenizer, path):
Expand Down Expand Up @@ -118,7 +175,7 @@ def export_tokenizer(tokenizer, path):
)


def export_to_safetensors(keras_model, path):
def export_to_safetensors(keras_model, path, max_shard_size=2.0):
"""Converts a Keras model to Hugging Face Transformers format.

It does the following:
Expand All @@ -129,9 +186,12 @@ def export_to_safetensors(keras_model, path):
keras_model: The Keras model to convert.
path: str. Path of the directory to which the safetensors file,
config and tokenizer will be saved.
max_shard_size: float. Maximum size in GB for each shard during export.
"""
backbone = keras_model.backbone
export_backbone(backbone, path, include_lm_head=True)
export_backbone(
backbone, path, include_lm_head=True, max_shard_size=max_shard_size
)
if (
keras_model.preprocessor is not None
and keras_model.preprocessor.tokenizer is None
Expand Down
Loading