Skip to content

Commit fd7d7e4

Browse files
[sharktank] Refactor sharktank llm modeling (#1207)
- Refactor various sharktank llms into a single `llm.py` script - Remove all references to direct cache scripts and tests - Make `--attention-kernel="torch"` default for all llms and tests, except grok as softcap+sdpa isn't supported yet.
1 parent 41141b7 commit fd7d7e4

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+284
-2868
lines changed

docs/model_cookbook.md

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -269,14 +269,3 @@ python -m sharktank.examples.paged_llm_v1 \
269269
--dump-decode-steps=1 \
270270
--dump-path='/tmp'
271271
```
272-
273-
## Generating data for llama models
274-
275-
```bash
276-
set TURBINE_DEBUG=log_level=info
277-
python -m sharktank.models.llama.tools.generate_data \
278-
--tokenizer=openlm-research/open_llama_3b_v2 \
279-
--config=/tmp/open_llama_3b_v2/open-llama-3b-v2-f16.json \
280-
--output-dir=/tmp/open_llama_3b_v2/inputs \
281-
--prompt="What is the meaning of life?"
282-
```

sharktank/sharktank/evaluate/perplexity_iree.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,8 @@
2121
from torch.nn import CrossEntropyLoss
2222
import iree.runtime
2323

24-
from sharktank.models.llama.llama import *
25-
from sharktank.models.mixtral.mixtral import *
26-
from sharktank.models.grok.grok import *
27-
28-
from ..models.llama.sharding import shard_theta
24+
from sharktank.models.llm import *
25+
from sharktank.models.llama.sharding import shard_theta
2926

3027
from sharktank.layers import *
3128
from sharktank.types import *
@@ -187,13 +184,7 @@ def load_model(self, weight_path, tokenizer):
187184

188185
theta = weight_path.root_theta
189186

190-
if self.config.hp.expert_count:
191-
if self.config.hp.model_arch == "grok":
192-
model = PagedGrokModelV1(theta, self.config)
193-
else:
194-
model = PagedMixtralModelV1(theta, self.config)
195-
else:
196-
model = PagedLlamaModelV1(theta, self.config)
187+
model = PagedLlmModelV1(theta, self.config)
197188

198189
self.generator = TorchGenerator(model, tokenizer)
199190

sharktank/sharktank/evaluate/perplexity_prefill.py

Lines changed: 0 additions & 276 deletions
This file was deleted.

sharktank/sharktank/evaluate/perplexity_torch.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,8 @@
1818
from sharktank.layers import *
1919
from sharktank.types import *
2020

21-
from sharktank.models.llama.llama import *
22-
from sharktank.models.mixtral.mixtral import *
23-
from sharktank.models.grok.grok import *
24-
25-
from ..models.llama.sharding import shard_theta
21+
from sharktank.models.llm import *
22+
from sharktank.models.llama.sharding import shard_theta
2623

2724
from sharktank.utils import cli
2825
from sharktank.utils.load_llm import *
@@ -106,13 +103,7 @@ def load_model(self, dataset, tokenizer, tensor_parallelism_size, attention_kern
106103

107104
theta = dataset.root_theta
108105

109-
if self.config.hp.expert_count:
110-
if self.config.hp.model_arch == "grok":
111-
model = PagedGrokModelV1(theta, self.config)
112-
else:
113-
model = PagedMixtralModelV1(theta, self.config)
114-
else:
115-
model = PagedLlamaModelV1(theta, self.config)
106+
model = PagedLlmModelV1(theta, self.config)
116107

117108
self.generator = TorchGenerator(model, tokenizer)
118109

0 commit comments

Comments
 (0)