Skip to content

Commit 8c7c450

Browse files
committed
Update on "Remove sharded ckpt from export_llama"
Sharded checkpoint isn't used anymore; removing it and simplifying export_llama. Differential Revision: [D87828518](https://our.internmc.facebook.com/intern/diff/D87828518/) [ghstack-poisoned]
1 parent 243125f commit 8c7c450

File tree

1 file changed

+2
-6
lines changed

1 file changed

+2
-6
lines changed

examples/models/llama/model.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,7 @@
77
# pyre-unsafe
88

99
import json
10-
import os
11-
from typing import Dict, Optional, Tuple
10+
from typing import Optional
1211

1312
import torch
1413
from executorch.examples.models.checkpoint import (
@@ -18,7 +17,6 @@
1817

1918
from executorch.examples.models.llama.llama_transformer import construct_transformer
2019
from executorch.examples.models.llama.model_args import ModelArgs
21-
from executorch.examples.models.llama.rope import Rope
2220

2321
from executorch.extension.llm.export.config.llm_config import LlmConfig
2422
from torchao.utils import TorchAOBaseTensor
@@ -39,12 +37,9 @@ def convert_to_llama_checkpoint(**kwargs):
3937

4038
class Llama2Model(EagerModelBase):
4139
def __init__(self, llm_config: Optional[LlmConfig] = None):
42-
resource_dir = get_default_model_resource_dir(__file__)
43-
4440
self.llm_config = llm_config if llm_config else LlmConfig()
4541

4642
checkpoint_path = self.llm_config.base.checkpoint
47-
checkpoint_dir = self.llm_config.base.checkpoint_dir
4843
params_path = self.llm_config.base.params
4944

5045
# Adapter checkpoint and config.
@@ -71,6 +66,7 @@ def __init__(self, llm_config: Optional[LlmConfig] = None):
7166
# The example is using a dummy small model with random weights for demo purpose only.
7267
# Follow the instruction in https://github.com/facebookresearch/llama to download the model.
7368
device = "cpu"
69+
# flake8: noqa: TOR102
7470
if checkpoint_path:
7571
checkpoint = torch.load(checkpoint_path, map_location=device, mmap=True)
7672

0 commit comments

Comments
 (0)