Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def from_name(cls, name: str):


transformer_configs = {
"gemma-2b": dict(dim=2048, vocab_size=256000, n_layer=18, n_head=8, n_local_heads=1, intermediate_size=16384),
"CodeLlama-7b-Python-hf": dict(block_size=16384, vocab_size=32000, n_layer=32, dim = 4096, rope_base=1000000),
"7B": dict(n_layer=32, n_head=32, dim=4096),
"13B": dict(n_layer=40, n_head=40, dim=5120),
Expand Down Expand Up @@ -109,6 +110,7 @@ def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
mask = self.causal_mask[None, None, input_pos]
freqs_cis = self.freqs_cis[input_pos]
x = self.tok_embeddings(idx)
x = (self.config.dim ** 0.5) * x

for i, layer in enumerate(self.layers):
x = layer(x, input_pos, freqs_cis, mask)
Expand Down Expand Up @@ -195,7 +197,7 @@ def __init__(self, config: ModelArgs) -> None:
self.w2 = nn.Linear(config.intermediate_size, config.dim, bias=False)

def forward(self, x: Tensor) -> Tensor:
return self.w2(F.silu(self.w1(x)) * self.w3(x))
return self.w2(F.gelu(self.w1(x)) * self.w3(x))


class RMSNorm(nn.Module):
Expand All @@ -209,7 +211,7 @@ def _norm(self, x):

def forward(self, x: Tensor) -> Tensor:
output = self._norm(x.float()).type_as(x)
return output * self.weight
return output * (1 + self.weight)


def precompute_freqs_cis(
Expand Down
10 changes: 8 additions & 2 deletions scripts/convert_hf_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,10 @@ def convert_hf_checkpoint(
config = ModelArgs.from_name(model_name)
print(f"Model config {config.__dict__}")

from safetensors import safe_open

# Load the json file containing weight mapping
model_map_json = checkpoint_dir / "pytorch_model.bin.index.json"
model_map_json = checkpoint_dir / "model.safetensors.index.json"

assert model_map_json.is_file()

Expand Down Expand Up @@ -65,7 +67,8 @@ def permute(w, n_head):

merged_result = {}
for file in sorted(bin_files):
state_dict = torch.load(str(file), map_location="cpu", mmap=True, weights_only=True)
state_dict = safe_open(str(file), framework="pt", device='cpu')
state_dict = {k: state_dict.get_tensor(k) for k in state_dict.keys()}
merged_result.update(state_dict)
final_result = {}
for key, value in merged_result.items():
Expand All @@ -92,6 +95,9 @@ def permute(w, n_head):
del final_result[key]
del final_result[key.replace("wq", "wk")]
del final_result[key.replace("wq", "wv")]
if "output.weight" not in final_result:
final_result["output.weight"] = final_result["tok_embeddings.weight"]

print(f"Saving checkpoint to {checkpoint_dir / 'model.pth'}")
torch.save(final_result, checkpoint_dir / "model.pth")

Expand Down