Skip to content

Commit c55e233

Browse files
tianshubThe tunix Authors
authored andcommitted
add llama3 70 & 405b
PiperOrigin-RevId: 820873916
1 parent 05093d9 commit c55e233

File tree

1 file changed

+28
-0
lines changed

1 file changed

+28
-0
lines changed

tunix/models/llama3/model.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,34 @@ def llama3_1_8b(cls):
141141
weight_tying=False,
142142
)
143143

144+
@classmethod
145+
def llama3_70b(cls):
146+
return cls(
147+
num_layers=80,
148+
vocab_size=128256,
149+
embed_dim=8192,
150+
hidden_dim=28672,
151+
num_heads=64,
152+
head_dim=128,
153+
num_kv_heads=8,
154+
norm_eps=1e-05,
155+
rope_theta=500_000,
156+
)
157+
158+
@classmethod
159+
def llama3_405b(cls):
160+
return cls(
161+
num_layers=126,
162+
vocab_size=128256,
163+
embed_dim=16384,
164+
hidden_dim=53248,
165+
num_heads=128,
166+
head_dim=128,
167+
num_kv_heads=8,
168+
norm_eps=1e-05,
169+
rope_theta=500_000,
170+
)
171+
144172

145173
def shard(x: jnp.ndarray, s: Tuple[str, ...]):
146174
mesh = pxla.thread_resources.env.physical_mesh

0 commit comments

Comments
 (0)