Skip to content

Commit 4f0b012

Browse files
committed
xfail token emb mock test
1 parent 245b200 commit 4f0b012

File tree

3 files changed

+195
-2
lines changed

3 files changed

+195
-2
lines changed

.github/workflows/ci-sharktank.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,8 @@ jobs:
185185
--iree-hal-target-device=hip \
186186
--iree-hip-target=gfx942 \
187187
--iree-device=hip://0 \
188-
--device=cuda:0 \
189-
--ignore=sharktank/tests/layers
188+
--ignore=sharktank/tests/layers \
189+
--device=cuda:0
190190
191191
- name: Run sharktank layers tests
192192
if: ${{ !cancelled() }}
Lines changed: 192 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,192 @@
1+
# Copyright 2024 Advanced Micro Devices, Inc.
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
import torch
8+
import pytest
9+
from pathlib import Path
10+
from sharktank.utils._helpers import run_iree_vs_torch_fx, validate_and_get_irpa_path
11+
from sharktank.layers import LinearLayer, RMSNormLayer
12+
from sharktank.types import Dataset, Theta
13+
from sharktank.layers.configs import LlamaModelConfig
14+
from sharktank.utils._iree_compile_flags_config import LLM_HIP_COMPILE_FLAGS
15+
from sharktank.utils.testing import is_hip_condition
16+
17+
18+
class OutputLMHead(torch.nn.Module):
19+
"""Standalone output_lm_head block extracted from PagedLlmModelV1"""
20+
21+
def __init__(self, theta: Theta, config: LlamaModelConfig):
22+
super().__init__()
23+
self.config = config
24+
self.hp = config.hp
25+
26+
# Output normalization layer
27+
self.output_norm = RMSNormLayer(
28+
theta("output_norm"), epsilon=self.hp.attention_layer_norm_rms_epsilon
29+
)
30+
31+
# Output linear layer (language model head)
32+
self.output_lm_head = LinearLayer(
33+
theta("output"),
34+
matmul_kernel=config.matmul_kernel,
35+
)
36+
37+
def forward(self, h: torch.Tensor) -> torch.Tensor:
38+
# Apply normalization
39+
h_norm = self.output_norm(h) # output fp16 && wieghts float32
40+
41+
# Apply final linear transformation
42+
logits = self.output_lm_head(h_norm) # output && weights fp16
43+
44+
return logits
45+
46+
47+
def create_output_lm_head_from_irpa(
48+
irpa_path: str,
49+
) -> tuple[OutputLMHead, torch.Tensor]:
50+
"""
51+
Create OutputLMHead module from IRPA file and generate sample input.
52+
53+
Args:
54+
irpa_path: Path to the IRPA file
55+
56+
Returns:
57+
Tuple of (OutputLMHead module, sample input tensor)
58+
"""
59+
# Load dataset from IRPA file
60+
dataset = Dataset.load(Path(irpa_path))
61+
62+
# Create model config from dataset
63+
llama_config = LlamaModelConfig.from_dataset(
64+
dataset=dataset,
65+
attention_kernel="torch",
66+
matmul_kernel="sharktank.asm;*",
67+
activation_dtype=torch.float16,
68+
)
69+
70+
# Create the output LM head module
71+
output_lm_head = OutputLMHead(dataset.root_theta, llama_config)
72+
73+
# Generate sample input tensor matching expected dimensions
74+
# Typical shape: [batch_size, seq_len, hidden_dim]
75+
# TODO: Check if there are other more suitable sizes to test.
76+
batch_size = 2
77+
seq_len = 8
78+
hidden_dim = (
79+
llama_config.hp.embedding_length
80+
) # Use embedding_length instead of model_dim
81+
82+
sample_input = torch.randn(
83+
batch_size, seq_len, hidden_dim, dtype=llama_config.activation_dtype
84+
)
85+
86+
return output_lm_head, sample_input
87+
88+
89+
# Test cases
90+
@pytest.mark.skipif(f"not ({is_hip_condition})", reason="Test requires HIP device")
91+
@pytest.mark.parametrize("dtype,atol", [(torch.float16, 1e-4)])
92+
def test_output_lm_head_iree_vs_eager(request, dtype, atol):
93+
"""
94+
Test OutputLMHead module comparing IREE vs PyTorch eager execution.
95+
96+
Use --parameters command line argument to specify the IRPA file path.
97+
"""
98+
# Validate and get IRPA path
99+
irpa_path = validate_and_get_irpa_path(request)
100+
101+
try:
102+
# Create module and sample input from IRPA
103+
module, sample_input = create_output_lm_head_from_irpa(irpa_path)
104+
except Exception as e:
105+
pytest.skip(f"Failed to load model from IRPA: {e}")
106+
107+
# Convert to desired dtype
108+
# module = module.to(dtype)
109+
sample_input = sample_input.to(dtype)
110+
111+
# Run IREE vs torch comparison
112+
run_iree_vs_torch_fx(
113+
module,
114+
input_args=(sample_input,),
115+
atol=atol,
116+
rtol=0,
117+
compile_flags=LLM_HIP_COMPILE_FLAGS,
118+
parameters_path=irpa_path,
119+
)
120+
121+
122+
@pytest.mark.skipif(f"not ({is_hip_condition})", reason="Test requires HIP device")
123+
def test_output_lm_head_mock():
124+
"""
125+
Mock test with synthetic weights for OutputLMHead functionality.
126+
Adding this test to work without requiring an IRPA file.
127+
"""
128+
torch.manual_seed(42)
129+
130+
# Mock configuration - provide all required parameters
131+
from sharktank.layers.configs import LlamaHParams
132+
133+
# Create LlamaHParams with all required parameters
134+
hp = LlamaHParams(
135+
model_arch="llama",
136+
context_length=2048,
137+
embedding_length=512, # hidden dimension
138+
block_count=6,
139+
feed_forward_length=2048,
140+
attention_head_count=8,
141+
attn_head_dim=64,
142+
attention_layer_norm_rms_epsilon=1e-6,
143+
attention_head_count_kv=8,
144+
vocab_size=32000,
145+
)
146+
147+
# Create mock config
148+
config = LlamaModelConfig(
149+
hp=hp,
150+
activation_dtype=torch.float16,
151+
# attention_dtype=torch.float32,
152+
)
153+
154+
# Create mock theta with synthetic weights
155+
from sharktank.types import DefaultPrimitiveTensor
156+
157+
# Mock output_norm weights
158+
output_norm_weight = torch.randn(hp.embedding_length, dtype=torch.float32)
159+
160+
# Mock output (lm_head) weights
161+
output_weight = torch.randn(hp.vocab_size, hp.embedding_length, dtype=torch.float16)
162+
163+
# Create theta structure
164+
theta_dict = {
165+
"output_norm": {"weight": DefaultPrimitiveTensor(data=output_norm_weight)},
166+
"output": {"weight": DefaultPrimitiveTensor(data=output_weight)},
167+
}
168+
169+
theta = Theta(theta_dict)
170+
171+
# Create module
172+
module = OutputLMHead(theta, config)
173+
174+
# Create sample input
175+
batch_size, seq_len = 2, 8
176+
sample_input = torch.randn(
177+
batch_size, seq_len, hp.embedding_length, dtype=torch.float32
178+
)
179+
180+
# Run IREE vs torch comparison
181+
run_iree_vs_torch_fx(
182+
module,
183+
input_args=(sample_input,),
184+
atol=1e-4,
185+
rtol=0,
186+
compile_flags=LLM_HIP_COMPILE_FLAGS,
187+
)
188+
189+
190+
if __name__ == "__main__":
191+
test_output_lm_head_mock()
192+
print("OutputLMHead mock test complete!")

sharktank/tests/layers/token_embedding_with_iree_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def test_token_embedding_iree_vs_eager(request, dtype, atol):
4646

4747

4848
@pytest.mark.skipif(f"not ({is_hip_condition})", reason="Test requires HIP device")
49+
@pytest.mark.xfail(reason="Test fails on execution with fatal python error")
4950
@pytest.mark.parametrize("dtype,atol", [(torch.float16, 1e-4)])
5051
def test_token_embedding_mock_iree_vs_eager(dtype, atol):
5152
torch.manual_seed(42)

0 commit comments

Comments
 (0)