Skip to content
4 changes: 2 additions & 2 deletions sharktank/sharktank/models/llama/toy_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def generate(
block_seq_stride = 16
max_blocks = 8
attention_head_count = 8
attn_head_dim = 32
attn_head_dim = 64
attention_head_count_kv = 4
rope_dimension_count = 32
rope_dimension_count = attn_head_dim
vocabulary_size = 256

config = LlamaModelConfig(
Expand Down
23 changes: 20 additions & 3 deletions sharktank/tests/models/llama/toy_llama_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,19 @@
TorchInstance,
llama_config_page_sizes,
)
from sharktank.utils.testing import is_cpu
from sharktank.utils.testing import is_mi300x


def get_iree_compile_flags(self):
flags = []

if self.iree_hal_target_device is not None:
flags.append(f"--iree-hal-target-device={self.iree_hal_target_device}")
if self.iree_hal_target_device == "hip":
flags.append(f"--iree-opt-level=O3")
flags.append(f"--iree-hal-indirect-command-buffers=true")
flags.append(f"--iree-stream-resource-memory-model=discrete")
flags.append(f"--iree-hal-memoization=true")

if self.iree_hal_target_device == "local":
flags.append("--iree-hal-local-target-device-backends=llvm-cpu")
Expand Down Expand Up @@ -89,12 +94,24 @@ def testDecodePerplexity(self):
torch.testing.assert_close(result.score, 0.583, atol=1e-2, rtol=1e-2)


# TODO:
# Verify why added iree flags are not being used in compile command in test:
# iree-compile /home/aramalin/shark-ai/3.11.venv/lib/python3.11/site-packages/iree/compiler/tools/../_mlir_libs/iree-compile - --iree-input-type=auto --iree-vm-bytecode-module-output-format=flatbuffer-binary --mlir-print-debuginfo --mlir-print-op-on-diagnostic=false --iree-hal-target-device=hip --iree-hip-target=gfx942


@pytest.mark.usefixtures("iree_flags")
@is_cpu
@is_mi300x
@pytest.mark.parametrize(
"use_extend_attention",
[
True,
pytest.param(
True,
marks=pytest.mark.xfail(
raises=iree.compiler.CompilerToolError,
strict=True,
reason="https://github.com/iree-org/iree/issues/22329",
),
),
False,
],
)
Expand Down
Loading