From 5734de30d2ad92e961732c96c16a9ab2087770a5 Mon Sep 17 00:00:00 2001 From: Dhiraj BM Date: Sat, 13 Sep 2025 00:56:35 +0530 Subject: [PATCH 1/3] Optimize ram --- .../src/utils/transformers/export/gemma.py | 103 +++++++------- .../utils/transformers/export/hf_exporter.py | 128 ++++++++++++++---- 2 files changed, 151 insertions(+), 80 deletions(-) diff --git a/keras_hub/src/utils/transformers/export/gemma.py b/keras_hub/src/utils/transformers/export/gemma.py index 846e391937..d9a5f698b3 100644 --- a/keras_hub/src/utils/transformers/export/gemma.py +++ b/keras_hub/src/utils/transformers/export/gemma.py @@ -1,9 +1,14 @@ -import keras.ops as ops +import numpy as np -def get_gemma_config(backbone): +def get_gemma_config(backbone, include_lm_head=False): token_embedding_layer = backbone.get_layer("token_embedding") + if include_lm_head: + architectures = ["GemmaForCausalLM"] # Full model with LM head + else: + architectures = ["GemmaForBackbone"] # Just backbone hf_config = { + "architectures": architectures, "vocab_size": backbone.vocabulary_size, "num_hidden_layers": backbone.num_layers, "num_attention_heads": backbone.num_query_heads, @@ -22,79 +27,77 @@ def get_gemma_config(backbone): def get_gemma_weights_map(backbone, include_lm_head=False): - weights_dict = {} - # Map token embedding token_embedding_layer = backbone.get_layer("token_embedding") - weights_dict["model.embed_tokens.weight"] = token_embedding_layer.weights[0] - + yield "model.embed_tokens.weight", token_embedding_layer.weights[0] for i in range(backbone.num_layers): decoder_layer = backbone.get_layer(f"decoder_block_{i}") - # Pre-attention normalization - weights_dict[f"model.layers.{i}.input_layernorm.weight"] = ( - decoder_layer.pre_attention_norm.weights[0] + yield ( + f"model.layers.{i}.input_layernorm.weight", + decoder_layer.pre_attention_norm.weights[0], ) - # Attention query projection query_kernel = decoder_layer.attention.query_dense.weights[0] - query_kernel = ops.transpose(query_kernel, axes=(1, 0, 2)) - query_kernel = ops.reshape(query_kernel, (-1, backbone.hidden_dim)) - query_kernel = ops.transpose(query_kernel) - weights_dict[f"model.layers.{i}.self_attn.q_proj.weight"] = query_kernel - + yield f"model.layers.{i}.self_attn.q_proj.weight", query_kernel # Attention key projection key_kernel = decoder_layer.attention.key_dense.weights[0][0] - weights_dict[f"model.layers.{i}.self_attn.k_proj.weight"] = ( - ops.transpose(key_kernel) - ) - + yield f"model.layers.{i}.self_attn.k_proj.weight", key_kernel # Attention value projection value_kernel = decoder_layer.attention.value_dense.weights[0][0] - weights_dict[f"model.layers.{i}.self_attn.v_proj.weight"] = ( - ops.transpose(value_kernel) - ) - + yield f"model.layers.{i}.self_attn.v_proj.weight", value_kernel # Attention output projection out_kernel = decoder_layer.attention.output_dense.weights[0] - out_kernel = ops.transpose(out_kernel, axes=(2, 0, 1)) - out_kernel = ops.reshape(out_kernel, (backbone.hidden_dim, -1)) - weights_dict[f"model.layers.{i}.self_attn.o_proj.weight"] = out_kernel - + yield f"model.layers.{i}.self_attn.o_proj.weight", out_kernel # Post-attention normalization - weights_dict[f"model.layers.{i}.post_attention_layernorm.weight"] = ( - decoder_layer.pre_ffw_norm.weights[0] + yield ( + f"model.layers.{i}.post_attention_layernorm.weight", + decoder_layer.pre_ffw_norm.weights[0], ) - # MLP gate projection gate_kernel = decoder_layer.gating_ffw.weights[0] - weights_dict[f"model.layers.{i}.mlp.gate_proj.weight"] = ops.transpose( - gate_kernel - ) - + yield f"model.layers.{i}.mlp.gate_proj.weight", gate_kernel # MLP up projection up_kernel = decoder_layer.gating_ffw_2.weights[0] - weights_dict[f"model.layers.{i}.mlp.up_proj.weight"] = ops.transpose( - up_kernel - ) - + yield f"model.layers.{i}.mlp.up_proj.weight", up_kernel # MLP down projection down_kernel = decoder_layer.ffw_linear.weights[0] - weights_dict[f"model.layers.{i}.mlp.down_proj.weight"] = ops.transpose( - down_kernel - ) - + yield f"model.layers.{i}.mlp.down_proj.weight", down_kernel # Map final normalization - weights_dict["model.norm.weight"] = backbone.get_layer( - "final_normalization" - ).weights[0] - + yield ( + "model.norm.weight", + backbone.get_layer("final_normalization").weights[0], + ) # Map lm_head if embeddings are not tied if include_lm_head and not token_embedding_layer.tie_weights: - weights_dict["lm_head.weight"] = ops.transpose( - token_embedding_layer.reverse_embeddings - ) - return weights_dict + lm_head = token_embedding_layer.reverse_embeddings + yield "lm_head.weight", lm_head + + +def get_gemma_transform_fn(backbone): + """Return a transform function for Gemma weights.""" + + def transform(name, np_tensor): + if name.endswith("q_proj.weight"): + np_tensor = np.transpose(np_tensor, axes=(1, 0, 2)) + np_tensor = np.reshape(np_tensor, (-1, backbone.hidden_dim)) + np_tensor = np.transpose(np_tensor) + elif name.endswith("k_proj.weight") or name.endswith("v_proj.weight"): + np_tensor = np.transpose(np_tensor) + elif name.endswith("o_proj.weight"): + np_tensor = np.transpose(np_tensor, axes=(2, 0, 1)) + np_tensor = np.reshape(np_tensor, (backbone.hidden_dim, -1)) + elif ( + name.endswith("gate_proj.weight") + or name.endswith("up_proj.weight") + or name.endswith("down_proj.weight") + ): + np_tensor = np.transpose(np_tensor) + elif name == "lm_head.weight": + np_tensor = np.transpose(np_tensor) + return np_tensor + + return transform def get_gemma_tokenizer_config(tokenizer): diff --git a/keras_hub/src/utils/transformers/export/hf_exporter.py b/keras_hub/src/utils/transformers/export/hf_exporter.py index 1593987ca9..b071b0fca0 100644 --- a/keras_hub/src/utils/transformers/export/hf_exporter.py +++ b/keras_hub/src/utils/transformers/export/hf_exporter.py @@ -4,11 +4,13 @@ import warnings import keras +from safetensors.numpy import save_file from keras_hub.src.utils.transformers.export.gemma import get_gemma_config from keras_hub.src.utils.transformers.export.gemma import ( get_gemma_tokenizer_config, ) +from keras_hub.src.utils.transformers.export.gemma import get_gemma_transform_fn from keras_hub.src.utils.transformers.export.gemma import get_gemma_weights_map MODEL_CONFIGS = { @@ -27,14 +29,20 @@ # get_mistral_tokenizer_config } +MODEL_TRANSFORMERS = { + "GemmaBackbone": get_gemma_transform_fn, + # Add for future models, e.g., "MistralBackbone": get_mistral_transform_fn +} + -def export_backbone(backbone, path, include_lm_head=False): +def export_backbone(backbone, path, include_lm_head=False, max_shard_size=2.0): """Export the backbone model to HuggingFace format. Args: backbone: The Keras backbone model to convert. path: str. Path to save the exported model. include_lm_head: bool. If True, include lm_head weights if applicable. + max_shard_size: float. Maximum size in GB for each shard. """ backend = keras.config.backend() model_type = backbone.__class__.__name__ @@ -46,41 +54,98 @@ def export_backbone(backbone, path, include_lm_head=False): raise ValueError( f"Export to Transformers format not implemented for {model_type}" ) + if model_type not in MODEL_TRANSFORMERS: + raise ValueError(f"Transformations not implemented for {model_type}") + + def to_numpy(tensor): + if backend == "jax": + return tensor.numpy() + elif backend == "torch": + return tensor.detach().cpu().numpy() + elif backend == "tensorflow": + return tensor.numpy() + else: + raise ValueError(f"Unsupported backend: {backend}") + # Get config get_config_fn = MODEL_CONFIGS[model_type] - hf_config = get_config_fn(backbone) - # Get weights - get_weights_fn = MODEL_EXPORTERS[model_type] - weights_dict = get_weights_fn(backbone, include_lm_head=include_lm_head) - if not weights_dict: - raise ValueError("No weights to save.") + hf_config = get_config_fn(backbone, include_lm_head=include_lm_head) # Save config os.makedirs(path, exist_ok=True) config_path = os.path.join(path, "config.json") with open(config_path, "w") as f: json.dump(hf_config, f) - # Save weights based on backend - weights_path = os.path.join(path, "model.safetensors") - if backend == "torch": - from safetensors.torch import save_file - - weights_dict_contiguous = { - k: v.value.contiguous() if hasattr(v, "value") else v.contiguous() - for k, v in weights_dict.items() - } - save_file( - weights_dict_contiguous, weights_path, metadata={"format": "pt"} - ) - elif backend == "tensorflow": - from safetensors.tensorflow import save_file + # Get model-specific transform function + get_transform_fn = MODEL_TRANSFORMERS[model_type] + transform_fn = get_transform_fn(backbone) + # Single pass: dynamic sharding based on actual NumPy sizes, + # processing one tensor at a time - save_file(weights_dict, weights_path, metadata={"format": "pt"}) - elif backend == "jax": - from safetensors.flax import save_file - - save_file(weights_dict, weights_path, metadata={"format": "pt"}) - else: - raise ValueError(f"Unsupported backend: {backend}") + get_weights_fn = MODEL_EXPORTERS[model_type] + weights_generator = get_weights_fn( + backbone, include_lm_head=include_lm_head + ) + shard_num = 1 + current_shard_dict = {} + current_size_gb = 0.0 + weight_map = {} + total_size_bytes = 0 + temp_shard_files = [] + current_temp_file = None + for name, backend_tensor in weights_generator: + np_tensor = to_numpy(backend_tensor) + np_tensor = transform_fn(name, np_tensor) # Model-specific transform + tensor_size_gb = np_tensor.nbytes / (1024**3) + total_size_bytes += np_tensor.nbytes + if ( + current_size_gb + tensor_size_gb > max_shard_size + and current_shard_dict + ): + # Save current shard as temp + current_temp_file = f"temp_shard_{shard_num}.safetensors" + weights_path = os.path.join(path, current_temp_file) + save_file( + current_shard_dict, weights_path, metadata={"format": "pt"} + ) + temp_shard_files.append( + (current_temp_file, list(current_shard_dict.keys())) + ) + del current_shard_dict + current_shard_dict = {} + current_size_gb = 0.0 + shard_num += 1 + current_shard_dict[name] = np_tensor + current_size_gb += tensor_size_gb + del np_tensor # Explicitly del to aid GC after adding to shard + # Save last shard + if current_shard_dict: + current_temp_file = f"temp_shard_{shard_num}.safetensors" + weights_path = os.path.join(path, current_temp_file) + save_file(current_shard_dict, weights_path, metadata={"format": "pt"}) + temp_shard_files.append( + (current_temp_file, list(current_shard_dict.keys())) + ) + del current_shard_dict + num_shards = shard_num + # Rename temp files to final format and build weight_map + for i, (temp_file, keys) in enumerate(temp_shard_files, 1): + if num_shards == 1: + final_file = "model.safetensors" + else: + final_file = f"model-{i:05d}-of-{num_shards:05d}.safetensors" + shutil.move( + os.path.join(path, temp_file), os.path.join(path, final_file) + ) + for key in keys: + weight_map[key] = final_file + # Save index + index = { + "metadata": {"total_size": total_size_bytes}, + "weight_map": weight_map, + } + index_path = os.path.join(path, "model.safetensors.index.json") + with open(index_path, "w") as f: + json.dump(index, f) def export_tokenizer(tokenizer, path): @@ -118,7 +183,7 @@ def export_tokenizer(tokenizer, path): ) -def export_to_safetensors(keras_model, path): +def export_to_safetensors(keras_model, path, max_shard_size=2.0): """Converts a Keras model to Hugging Face Transformers format. It does the following: @@ -129,9 +194,12 @@ def export_to_safetensors(keras_model, path): keras_model: The Keras model to convert. path: str. Path of the directory to which the safetensors file, config and tokenizer will be saved. + max_shard_size: float. Maximum size in GB for each shard during export. """ backbone = keras_model.backbone - export_backbone(backbone, path, include_lm_head=True) + export_backbone( + backbone, path, include_lm_head=True, max_shard_size=max_shard_size + ) if ( keras_model.preprocessor is not None and keras_model.preprocessor.tokenizer is None From 792620b17ceebaae69b17fffec931092dac2bf16 Mon Sep 17 00:00:00 2001 From: Dhiraj BM Date: Sat, 13 Sep 2025 01:21:30 +0530 Subject: [PATCH 2/3] address comments --- keras_hub/src/utils/transformers/export/gemma.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/keras_hub/src/utils/transformers/export/gemma.py b/keras_hub/src/utils/transformers/export/gemma.py index d9a5f698b3..52c66931e4 100644 --- a/keras_hub/src/utils/transformers/export/gemma.py +++ b/keras_hub/src/utils/transformers/export/gemma.py @@ -41,10 +41,10 @@ def get_gemma_weights_map(backbone, include_lm_head=False): query_kernel = decoder_layer.attention.query_dense.weights[0] yield f"model.layers.{i}.self_attn.q_proj.weight", query_kernel # Attention key projection - key_kernel = decoder_layer.attention.key_dense.weights[0][0] + key_kernel = decoder_layer.attention.key_dense.weights[0] yield f"model.layers.{i}.self_attn.k_proj.weight", key_kernel # Attention value projection - value_kernel = decoder_layer.attention.value_dense.weights[0][0] + value_kernel = decoder_layer.attention.value_dense.weights[0] yield f"model.layers.{i}.self_attn.v_proj.weight", value_kernel # Attention output projection out_kernel = decoder_layer.attention.output_dense.weights[0] @@ -83,7 +83,8 @@ def transform(name, np_tensor): np_tensor = np.reshape(np_tensor, (-1, backbone.hidden_dim)) np_tensor = np.transpose(np_tensor) elif name.endswith("k_proj.weight") or name.endswith("v_proj.weight"): - np_tensor = np.transpose(np_tensor) + np_tensor = np.transpose(np_tensor, axes=(0, 2, 1)) + np_tensor = np.reshape(np_tensor, (-1, backbone.hidden_dim)) elif name.endswith("o_proj.weight"): np_tensor = np.transpose(np_tensor, axes=(2, 0, 1)) np_tensor = np.reshape(np_tensor, (backbone.hidden_dim, -1)) From 3d2d59592d31cf474f9331f97858a9a39175947c Mon Sep 17 00:00:00 2001 From: Dhiraj BM Date: Sat, 13 Sep 2025 02:58:04 +0530 Subject: [PATCH 3/3] compatible with all the backends --- keras_hub/src/utils/transformers/export/hf_exporter.py | 10 +--------- 1 file changed, 1 insertion(+), 9 deletions(-) diff --git a/keras_hub/src/utils/transformers/export/hf_exporter.py b/keras_hub/src/utils/transformers/export/hf_exporter.py index b071b0fca0..e77ecc524c 100644 --- a/keras_hub/src/utils/transformers/export/hf_exporter.py +++ b/keras_hub/src/utils/transformers/export/hf_exporter.py @@ -44,7 +44,6 @@ def export_backbone(backbone, path, include_lm_head=False, max_shard_size=2.0): include_lm_head: bool. If True, include lm_head weights if applicable. max_shard_size: float. Maximum size in GB for each shard. """ - backend = keras.config.backend() model_type = backbone.__class__.__name__ if model_type not in MODEL_CONFIGS: raise ValueError( @@ -58,14 +57,7 @@ def export_backbone(backbone, path, include_lm_head=False, max_shard_size=2.0): raise ValueError(f"Transformations not implemented for {model_type}") def to_numpy(tensor): - if backend == "jax": - return tensor.numpy() - elif backend == "torch": - return tensor.detach().cpu().numpy() - elif backend == "tensorflow": - return tensor.numpy() - else: - raise ValueError(f"Unsupported backend: {backend}") + return tensor.numpy() # Get config get_config_fn = MODEL_CONFIGS[model_type]