Skip to content

Commit 9ca50de

Browse files
authored
Clean up some broken references to fix turbine-llm scripts (#632)
Cleanup some broken references to fix turbine-llm scripts `gguf` used where `gguf_interop` should be
1 parent cf01650 commit 9ca50de

File tree

4 files changed

+11
-7
lines changed

4 files changed

+11
-7
lines changed

llm/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
gguf

llm/turbine_llm/examples/export_paged_llm_v1.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,11 @@
1010

1111
from shark_turbine.aot import *
1212

13-
from ..layers import *
13+
from turbine_llm.layers import *
14+
from turbine_llm.types import *
1415

1516
# TODO: Should be using a base class with the protocol supported.
16-
from ..models.llama.llama import PagedLlamaModelV1
17+
from ..models.llama.llama import LlamaModelConfig, PagedLlamaModelV1
1718

1819

1920
def main():
@@ -24,10 +25,10 @@ def main():
2425
args = cli.parse(parser)
2526

2627
data_files = cli.get_gguf_data_files(args)
27-
dataset = gguf.load_file(data_files["gguf"])
28+
dataset = gguf_interop.load_file(data_files["gguf"])
2829

2930
hp = configs.LlamaHParams.from_gguf_props(dataset.properties)
30-
model = PagedLlamaModelV1(dataset.root_theta, hp)
31+
model = PagedLlamaModelV1(dataset.root_theta, LlamaModelConfig(hp))
3132

3233
# Unrolling cache updates by batch row makes dynamo sad without an
3334
# override. There may be a better way to do this.

llm/turbine_llm/examples/validate_llama_ref_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,13 @@
1515
import torch
1616

1717
from turbine_llm.layers import *
18+
from turbine_llm.types import *
1819
from turbine_llm.models.llama.llama_ref import *
1920

2021

2122
def main(args: list[str]):
2223
torch.no_grad().__enter__()
23-
config = gguf.load_file(args[0])
24+
config = gguf_interop.load_file(args[0])
2425
hp = configs.LlamaHParams.from_gguf_props(config.properties)
2526
model = DirectCacheLlamaModelV1(config.root_theta, hp)
2627

llm/turbine_llm/examples/validate_paged_llama_model.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,15 @@
99
import torch
1010

1111
from turbine_llm.layers import *
12+
from turbine_llm.types import *
1213
from turbine_llm.models.llama.llama import *
1314

1415

1516
def main(args: list[str]):
1617
torch.no_grad().__enter__()
17-
config = gguf.load_file(args[0])
18+
config = gguf_interop.load_file(args[0])
1819
hp = configs.LlamaHParams.from_gguf_props(config.properties)
19-
model = PagedLlamaModelV1(config.root_theta, hp)
20+
model = PagedLlamaModelV1(config.root_theta, LlamaModelConfig(hp))
2021
cache_state = model.cache.paged.allocate(128, torch.float32)
2122
start_index = 0
2223
next_batch = torch.tensor(

0 commit comments

Comments
 (0)