Skip to content

eturchenkov/llama3-in-jax

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

1 Commit
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Llama3 implementation in jax with annotated shapes of all tensors in the code.

Shape parameters:

  • B: batch size
  • T: sequence length
  • C: model dimension/embedding size
  • n_heads: number of attention heads
  • n_kv_heads: number of key-value heads (used in grouped query attention)
  • head_dim: dimension of each attention head (= C / n_heads)

About

llama3 in jax with annotated shapes

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors

Languages