We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 05093d9 commit c55e233Copy full SHA for c55e233
tunix/models/llama3/model.py
@@ -141,6 +141,34 @@ def llama3_1_8b(cls):
141
weight_tying=False,
142
)
143
144
+ @classmethod
145
+ def llama3_70b(cls):
146
+ return cls(
147
+ num_layers=80,
148
+ vocab_size=128256,
149
+ embed_dim=8192,
150
+ hidden_dim=28672,
151
+ num_heads=64,
152
+ head_dim=128,
153
+ num_kv_heads=8,
154
+ norm_eps=1e-05,
155
+ rope_theta=500_000,
156
+ )
157
+
158
159
+ def llama3_405b(cls):
160
161
+ num_layers=126,
162
163
+ embed_dim=16384,
164
+ hidden_dim=53248,
165
+ num_heads=128,
166
167
168
169
170
171
172
173
def shard(x: jnp.ndarray, s: Tuple[str, ...]):
174
mesh = pxla.thread_resources.env.physical_mesh
0 commit comments