Skip to content

Commit 05fcaa9

Browse files
[sharktank] Refactor sharktank - import cleanup (#1220)
Refactor sharktank - Expand relative imports - Cleanup unused imports - Correct the source module of import (eg: Import `Theta` from `types` instead of `layers`) Requires #1207 to merge
1 parent 977d804 commit 05fcaa9

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

84 files changed

+302
-346
lines changed

sharktank/sharktank/examples/export_paged_llm_v1.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
"""Export support for the PagedLLMV1 protocol of models."""
88

9+
import os
910
import json
1011
from typing import Any, Dict
1112
import torch
@@ -16,14 +17,13 @@
1617
from sharktank.types import *
1718
from sharktank.utils.math import ceildiv
1819
from sharktank import ops
20+
from sharktank.utils import cli
1921

2022
# TODO: Should be using a base class with the protocol supported.
2123
from sharktank.models.llm import *
2224

2325

2426
def main():
25-
from ..utils import cli
26-
import os
2727

2828
parser = cli.create_parser()
2929
cli.add_input_dataset_options(parser)

sharktank/sharktank/examples/paged_llm_v1.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,17 @@
1212
from typing import Optional
1313
import torch
1414
import numpy as np
15-
from ..layers import *
16-
from ..types import *
1715

18-
from ..ops import replicate, unshard
16+
from sharktank.layers import *
17+
from sharktank.types import *
18+
from sharktank.ops import replicate, unshard
1919

2020
# TODO: Should be using a base class with the protocol supported.
21-
from ..models.llm import *
22-
from ..models.llama.sharding import shard_theta
23-
from ..utils.debugging import trace_tensor
24-
from ..utils.tokenizer import InferenceTokenizer
25-
from ..utils import cli
21+
from sharktank.models.llm import *
22+
from sharktank.models.llama.sharding import shard_theta
23+
from sharktank.utils.debugging import trace_tensor
24+
from sharktank.utils.tokenizer import InferenceTokenizer
25+
from sharktank.utils import cli
2626

2727

2828
class TorchGenerator:
@@ -428,7 +428,7 @@ def main():
428428
model = PagedLlmModelV1(dataset.root_theta, config)
429429

430430
if args.save_intermediates_path:
431-
from ..utils.patching import SaveModuleResultTensorsPatch
431+
from sharktank.utils.patching import SaveModuleResultTensorsPatch
432432

433433
intermediates_saver = SaveModuleResultTensorsPatch()
434434
intermediates_saver.patch_child_modules(model)

sharktank/sharktank/examples/pipeline/export_ppffn_net.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@
1414
--output-irpa-file=/tmp/ffn.irpa /tmp/ffn.mlir
1515
"""
1616

17+
import os
1718
import math
1819

1920
import torch
2021

21-
from ...layers import *
22-
from ... import ops
23-
from ...types import *
22+
from sharktank.utils import cli
23+
from sharktank.layers import *
24+
from sharktank import ops
25+
from sharktank.types import *
2426

2527
from iree.turbine.aot import DeviceAffinity, DeviceTensorTrait, export
2628

@@ -69,9 +71,6 @@ def forward(self, x: torch.Tensor):
6971

7072

7173
def main(raw_args=None):
72-
from ...utils import cli
73-
import os
74-
7574
parser = cli.create_parser()
7675
parser.add_argument(
7776
"output_file",

sharktank/sharktank/examples/sharding/export_ffn_net.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@
1313
python -m sharktank.examples.sharding.export_ffn_net \
1414
--output-irpa-file=/tmp/ffn.irpa /tmp/ffn.mlir
1515
"""
16+
import os
1617

1718
import torch
18-
import torch.nn as nn
1919

20-
from ...layers import *
21-
from ... import ops
22-
from ...types import *
20+
from sharktank.layers import *
21+
from sharktank import ops
22+
from sharktank.types import *
23+
from sharktank.utils import cli
2324

2425

2526
def create_theta(
@@ -63,8 +64,6 @@ def forward(self, x: torch.Tensor):
6364

6465

6566
def main(raw_args=None):
66-
from ...utils import cli
67-
import os
6867

6968
parser = cli.create_parser()
7069
parser.add_argument(

sharktank/sharktank/examples/sharding/shard_llm_dataset.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@
1010
weights of an LLM by converting the RHS of all eligible layers to a sharded
1111
form.
1212
"""
13-
from ...models.llama.sharding import shard_theta
14-
from ...layers import LlamaHParams, LlamaModelConfig
15-
from ...types import *
13+
from sharktank.models.llama.sharding import shard_theta
14+
from sharktank.layers import LlamaHParams, LlamaModelConfig
15+
from sharktank.types import *
16+
from sharktank.utils import cli
1617

1718

1819
def main(raw_args=None):
19-
from ...utils import cli
2020

2121
parser = cli.create_parser()
2222
cli.add_input_dataset_options(parser)

sharktank/sharktank/examples/validate_paged_llama_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111
from sharktank.layers import *
1212
from sharktank.types import *
1313
from sharktank.models.llm import *
14+
from sharktank.utils import cli
1415

1516

1617
def main(args: list[str]):
17-
from ..utils import cli
1818

1919
torch.no_grad().__enter__()
2020

sharktank/sharktank/export_layer/export_kv_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from sharktank.types import SplitPrimitiveTensor
1212
from sharktank.ops import reshard_split, replicate
1313
from sharktank.layers.paged_attention import PagedAttention
14-
from ..utils import cli
14+
from sharktank.utils import cli
1515

1616

1717
def main():

sharktank/sharktank/export_layer/export_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from sharktank.models.llama.testing import make_moe_block_theta, make_rand_torch
1313
from sharktank.layers.mixture_of_experts_block import MoeBlock
14-
from ..utils import cli
14+
from sharktank.utils import cli
1515

1616

1717
def main():

sharktank/sharktank/kernels/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434

3535
from iree.turbine.transforms.merger import Merger
3636

37-
from ..utils.logging import get_logger
37+
from sharktank.utils.logging import get_logger
3838

3939
LIBRARY = def_library("sharktank")
4040
TEMPLATES_DIR = Path(__file__).parent / "templates"

sharktank/sharktank/layers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,9 @@
1414
from .token_embedding import TokenEmbeddingLayer
1515
from .paged_llama_attention_block import PagedLlamaAttentionBlock
1616
from .ffn_block import FFN
17-
from .ffn_moe_block import FFNMOE
17+
from .ffn_moe_block import FFNMOE, PreGatherFFNMOE
1818
from .mixture_of_experts_block import MoeBlock
1919
from .mmdit import MMDITDoubleBlock, MMDITSingleBlock
20+
from .modulation import ModulationLayer
2021

2122
from .configs import *

0 commit comments

Comments
 (0)