Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions keras_hub/api/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,9 @@
from keras_hub.src.models.llama3.llama3_backbone import (
Llama3Backbone as Llama3Backbone,
)
from keras_hub.src.models.llama3.llama3_backbone import (
Llama3BackboneConfig as Llama3BackboneConfig,
)
from keras_hub.src.models.llama3.llama3_causal_lm import (
Llama3CausalLM as Llama3CausalLM,
)
Expand All @@ -399,6 +402,27 @@
from keras_hub.src.models.llama3.llama3_tokenizer import (
Llama3Tokenizer as Llama3Tokenizer,
)
from keras_hub.src.models.llama3.llama3_vision_backbone import (
Llama3VisionBackbone as Llama3VisionBackbone,
)
from keras_hub.src.models.llama3.llama3_vision_causal_lm import (
Llama3VisionCausalLM as Llama3VisionCausalLM,
)
from keras_hub.src.models.llama3.llama3_vision_cross_attention import (
Llama3VisionCrossAttention as Llama3VisionCrossAttention,
)
from keras_hub.src.models.llama3.llama3_vision_encoder import (
Llama3VisionEncoder as Llama3VisionEncoder,
)
from keras_hub.src.models.llama3.llama3_vision_image_converter import (
Llama3VisionImageConverter as Llama3VisionImageConverter,
)
from keras_hub.src.models.llama3.llama3_vision_preprocessor import (
Llama3VisionPreprocessor as Llama3VisionPreprocessor,
)
from keras_hub.src.models.llama3.llama3_vision_projector import (
Llama3VisionProjector as Llama3VisionProjector,
)
from keras_hub.src.models.masked_lm import MaskedLM as MaskedLM
from keras_hub.src.models.masked_lm_preprocessor import (
MaskedLMPreprocessor as MaskedLMPreprocessor,
Expand Down
15 changes: 15 additions & 0 deletions keras_hub/src/models/llama3/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
from keras_hub.src.models.llama3.llama3_backbone import Llama3Backbone
from keras_hub.src.models.llama3.llama3_presets import backbone_presets
from keras_hub.src.models.llama3.llama3_vision_backbone import (
Llama3VisionBackbone,
)
from keras_hub.src.models.llama3.llama3_vision_causal_lm import (
Llama3VisionCausalLM,
)
from keras_hub.src.models.llama3.llama3_vision_cross_attention import (
Llama3VisionCrossAttention,
)
from keras_hub.src.models.llama3.llama3_vision_encoder import (
Llama3VisionEncoder,
)
from keras_hub.src.models.llama3.llama3_vision_projector import (
Llama3VisionProjector,
)
from keras_hub.src.utils.preset_utils import register_presets

register_presets(backbone_presets, Llama3Backbone)
64 changes: 64 additions & 0 deletions keras_hub/src/models/llama3/llama3_backbone.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,70 @@
from keras_hub.src.models.llama.llama_backbone import LlamaBackbone


# Config class for Llama3Backbone
@keras_hub_export("keras_hub.models.Llama3BackboneConfig")
class Llama3BackboneConfig:
"""Configuration for Llama3Backbone.

Args:
vocabulary_size: int. Size of the token vocabulary.
num_layers: int. Number of transformer layers.
num_query_heads: int. Number of query attention heads.
hidden_dim: int. Size of the transformer encoding layers.
intermediate_dim: int. Output dimension of feedforward network.
num_key_value_heads: int. Number of key/value attention heads.
rope_max_wavelength: int. Maximum angular wavelength for RoPE.
rope_scaling_factor: float. Scaling factor for RoPE.
layer_norm_epsilon: float. Epsilon for layer normalization.
dtype: str or DTypePolicy. Dtype for computations and weights.
"""

def __init__(
self,
vocabulary_size=128256,
num_layers=32,
num_query_heads=32,
hidden_dim=4096,
intermediate_dim=14336,
num_key_value_heads=8,
rope_max_wavelength=500000,
rope_scaling_factor=8.0,
layer_norm_epsilon=1e-5,
dtype=None,
**kwargs,
):
self.vocabulary_size = vocabulary_size
self.num_layers = num_layers
self.num_query_heads = num_query_heads
self.hidden_dim = hidden_dim
self.intermediate_dim = intermediate_dim
self.num_key_value_heads = num_key_value_heads
self.rope_max_wavelength = rope_max_wavelength
self.rope_scaling_factor = rope_scaling_factor
self.layer_norm_epsilon = layer_norm_epsilon
self.dtype = dtype
# Store any extra kwargs
self._kwargs = kwargs
for k, v in kwargs.items():
setattr(self, k, v)

def get_config(self):
config = {
"vocabulary_size": self.vocabulary_size,
"num_layers": self.num_layers,
"num_query_heads": self.num_query_heads,
"hidden_dim": self.hidden_dim,
"intermediate_dim": self.intermediate_dim,
"num_key_value_heads": self.num_key_value_heads,
"rope_max_wavelength": self.rope_max_wavelength,
"rope_scaling_factor": self.rope_scaling_factor,
"layer_norm_epsilon": self.layer_norm_epsilon,
"dtype": self.dtype,
}
config.update(self._kwargs)
return config


# LLaMA 3 shares the same architecture as its predecessors
# So, we simply create an alias for API consistency
@keras_hub_export("keras_hub.models.Llama3Backbone")
Expand Down
Loading
Loading