Skip to content

Commit 6f8baef

Browse files
committed
shark migration update
1 parent 038a000 commit 6f8baef

File tree

7 files changed

+33
-29
lines changed

7 files changed

+33
-29
lines changed

amdsharktank/amdsharktank/utils/export.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from typing import Callable, Optional, Any
88
from os import PathLike
9+
from pathlib import Path
910
import functools
1011

1112
import torch
@@ -16,7 +17,13 @@
1617
from torch.utils._pytree import tree_structure, tree_unflatten, tree_flatten
1718
from amdsharktank.types.tensors import ShardedTensor
1819
from amdsharktank.types.theta import mark_export_external_theta
19-
from amdsharktank.layers import BaseLayer, ThetaLayer
20+
21+
from typing import TYPE_CHECKING
22+
23+
if TYPE_CHECKING:
24+
from amdsharktank.layers import BaseLayer, ThetaLayer
25+
26+
# from amdsharktank.layers import BaseLayer, ThetaLayer
2027

2128

2229
def flatten_signature(
@@ -180,7 +187,7 @@ def flat_fn(*args, **kwargs):
180187

181188

182189
def export_model_mlir(
183-
model: BaseLayer,
190+
model: "BaseLayer",
184191
output_path: PathLike,
185192
*,
186193
function_batch_sizes_map: Optional[dict[Optional[str], list[int]]] = None,
@@ -202,7 +209,7 @@ def export_model_mlir(
202209

203210
assert not (function_batch_sizes_map is not None and batch_sizes is not None)
204211

205-
if isinstance(model, ThetaLayer):
212+
if isinstance(model, "ThetaLayer"):
206213
mark_export_external_theta(model.theta)
207214

208215
if batch_sizes is not None:
@@ -317,7 +324,7 @@ def export_torch_module_to_mlir_file(
317324
def _(module, *fn_args):
318325
return module.forward(*fn_args)
319326

320-
export_output = export(fxb, import_symbolic_shape_expressions=True)
327+
export_output = aot.export(fxb, import_symbolic_shape_expressions=True)
321328
export_output.save_mlir(mlir_path)
322329

323330
return export_output

amdsharktank/amdsharktank/utils/iree.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,6 @@
3535
from iree.runtime import FileHandle
3636
import iree.runtime
3737

38-
from iree.turbine import aot
39-
from iree.turbine.aot import export
40-
4138

4239
if TYPE_CHECKING:
4340
from ..layers import ModelConfig

amdsharktank/tests/layers/ffn_with_iree_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
import torch
77
import pytest
8-
from sharktank.utils.iree import run_iree_vs_torch_eager
9-
from sharktank.utils._iree_compile_flags_config import LLM_HIP_COMPILE_FLAGS
10-
from sharktank.utils.testing import is_hip_condition
8+
from amdsharktank.utils.iree import run_iree_vs_torch_eager
9+
from amdsharktank.utils._iree_compile_flags_config import LLM_HIP_COMPILE_FLAGS
10+
from amdsharktank.utils.testing import is_hip_condition
1111

1212

1313
class FFN(torch.nn.Module):

amdsharktank/tests/layers/linear_with_iree_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import torch
88
import pytest
99
import gc
10-
from sharktank.utils.iree import run_iree_vs_torch_eager
11-
from sharktank.utils._iree_compile_flags_config import LLM_HIP_COMPILE_FLAGS
12-
from sharktank.utils.testing import is_hip_condition
10+
from amdsharktank.utils.iree import run_iree_vs_torch_eager
11+
from amdsharktank.utils._iree_compile_flags_config import LLM_HIP_COMPILE_FLAGS
12+
from amdsharktank.utils.testing import is_hip_condition
1313

1414

1515
class Linear(torch.nn.Module):

amdsharktank/tests/layers/output_lm_head_with_iree_test.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,14 @@
77
import torch
88
import pytest
99
from pathlib import Path
10-
from sharktank.utils.iree import (
10+
from amdsharktank.utils.iree import (
1111
run_iree_vs_torch_eager,
1212
)
13-
from sharktank.layers import LinearLayer, RMSNormLayer
14-
from sharktank.types import Dataset, Theta
15-
from sharktank.layers.configs import LlamaModelConfig
16-
from sharktank.utils._iree_compile_flags_config import LLM_HIP_COMPILE_FLAGS
17-
from sharktank.utils.testing import is_hip_condition, validate_and_get_irpa_path
13+
from amdsharktank.layers import LinearLayer, RMSNormLayer
14+
from amdsharktank.types import Dataset, Theta
15+
from amdsharktank.layers.configs import LlamaModelConfig
16+
from amdsharktank.utils._iree_compile_flags_config import LLM_HIP_COMPILE_FLAGS
17+
from amdsharktank.utils.testing import is_hip_condition, validate_and_get_irpa_path
1818

1919

2020
class OutputLMHead(torch.nn.Module):
@@ -136,7 +136,7 @@ def test_output_lm_head_mock():
136136
torch.manual_seed(42)
137137

138138
# Mock configuration - provide all required parameters
139-
from sharktank.layers.configs import LlamaHParams
139+
from amdsharktank.layers.configs import LlamaHParams
140140

141141
# Create LlamaHParams with all required parameters
142142
hp = LlamaHParams(
@@ -160,7 +160,7 @@ def test_output_lm_head_mock():
160160
)
161161

162162
# Create mock theta with synthetic weights
163-
from sharktank.types import DefaultPrimitiveTensor
163+
from amdsharktank.types import DefaultPrimitiveTensor
164164

165165
# Mock output_norm weights
166166
output_norm_weight = torch.randn(hp.embedding_length, dtype=torch.float32)

amdsharktank/tests/layers/rms_norm_with_iree_test.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
import torch
88
import pytest
99
import gc
10-
from sharktank.utils.iree import run_iree_vs_torch_eager
11-
from sharktank.utils._iree_compile_flags_config import LLM_HIP_COMPILE_FLAGS
12-
from sharktank.utils.testing import is_hip_condition
10+
from amdsharktank.utils.iree import run_iree_vs_torch_eager
11+
from amdsharktank.utils._iree_compile_flags_config import LLM_HIP_COMPILE_FLAGS
12+
from amdsharktank.utils.testing import is_hip_condition
1313

1414

1515
class RMSNorm(torch.nn.Module):

amdsharktank/tests/layers/token_embedding_with_iree_test.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,13 @@
77
import torch
88
import pytest
99
from pathlib import Path
10-
from sharktank.layers.token_embedding import TokenEmbeddingLayer
11-
from sharktank.types.theta import Dataset
12-
from sharktank.utils.iree import (
10+
from amdsharktank.layers.token_embedding import TokenEmbeddingLayer
11+
from amdsharktank.types.theta import Dataset
12+
from amdsharktank.utils.iree import (
1313
run_iree_vs_torch_eager,
1414
)
15-
from sharktank.utils._iree_compile_flags_config import LLM_HIP_COMPILE_FLAGS
16-
from sharktank.utils.testing import is_hip_condition, validate_and_get_irpa_path
15+
from amdsharktank.utils._iree_compile_flags_config import LLM_HIP_COMPILE_FLAGS
16+
from amdsharktank.utils.testing import is_hip_condition, validate_and_get_irpa_path
1717

1818

1919
class TokenEmbeddingSmall(torch.nn.Module):

0 commit comments

Comments
 (0)