Skip to content

Commit 701de1f

Browse files
committed
fix pre-commit
1 parent 3c2e58f commit 701de1f

File tree

7 files changed

+110
-77
lines changed

7 files changed

+110
-77
lines changed

sharktank/sharktank/utils/_helpers.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def _as_tuple(x):
2525
return tuple(x)
2626
return (x,)
2727

28+
2829
def export_torch_module_to_mlir(
2930
module: torch.nn.Module,
3031
input_args=(),
@@ -55,25 +56,22 @@ def export_torch_module_to_mlir(
5556
expected = module(*input_args, **kwargs)
5657

5758
fxb = FxProgramsBuilder(module)
58-
59+
5960
# empty tensors for export input
6061
# there needs to be one corresponding to each arg
6162
# NOTE: assuming args are not nested.
62-
empty_args = tuple([
63-
torch.empty(arg.shape, dtype=arg.dtype) for arg in input_args
64-
])
63+
empty_args = tuple([torch.empty(arg.shape, dtype=arg.dtype) for arg in input_args])
6564

6665
# need to get this info from the test, currently only for static shapes
6766
# one corresponding to each arg
6867
dynamic_shapes = tuple([dict() for _ in input_args])
6968

70-
7169
@fxb.export_program(
7270
name=target_fn,
7371
args=empty_args,
7472
dynamic_shapes=(dynamic_shapes,),
75-
strict=False,
76-
)
73+
strict=False,
74+
)
7775
def _(module, *fn_args):
7876
return module.forward(*fn_args)
7977

@@ -175,36 +173,38 @@ def compare_iree_torch_outputs(
175173
actual = (actual,)
176174

177175
# Match dtypes to be safe (IREE may produce f32 by default in some paths)
178-
actual = tuple(a.to(e.dtype) if hasattr(a, "dtype") else a for a, e in zip(actual, expected))
176+
actual = tuple(
177+
a.to(e.dtype) if hasattr(a, "dtype") else a for a, e in zip(actual, expected)
178+
)
179179
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
180180

181181

182182
def validate_and_get_irpa_path(request):
183183
"""
184184
Validate and get IRPA path from pytest request configuration.
185-
185+
186186
Args:
187187
request: pytest request fixture
188-
188+
189189
Returns:
190190
str: Path to the IRPA file
191-
191+
192192
Raises:
193193
pytest.skip: If IRPA path is not provided or file doesn't exist
194194
"""
195195
from pytest import skip
196-
196+
197197
# Get IRPA path from command line argument
198198
irpa_path = request.config.getoption("--parameters")
199-
199+
200200
# Skip test if no IRPA path provided
201201
if irpa_path is None:
202202
skip("No IRPA path provided. Use --parameters to specify the IRPA file.")
203-
203+
204204
# Skip test if IRPA file doesn't exist
205205
if not Path(irpa_path).exists():
206206
skip(f"IRPA file not found: {irpa_path}")
207-
207+
208208
return irpa_path
209209

210210

@@ -217,7 +217,7 @@ def run_iree_vs_torch_fx(
217217
rtol=0.0,
218218
entrypoint="run_forward",
219219
parameters_path=None,
220-
compile_flags: list[str]|None=None,
220+
compile_flags: list[str] | None = None,
221221
driver="hip",
222222
device_count=1,
223223
):

sharktank/sharktank/utils/_iree_compile_flags_config.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
LLM_HIP_COMPILE_FLAGS = [
1212
"--iree-hal-target-device=hip",
13-
"--iree-hip-target=gfx942", # MI300 example; adjust to your GPU if needed
13+
"--iree-hip-target=gfx942", # MI300 example; adjust to your GPU if needed
1414
"--iree-execution-model=async-external",
1515
"--iree-opt-strip-assertions=true",
1616
"--iree-opt-level=O3",
@@ -20,5 +20,5 @@
2020
"--iree-stream-resource-memory-model=discrete",
2121
"--iree-hip-specialize-dispatches",
2222
"--iree-hal-memoization=true",
23-
"--iree-codegen-enable-default-tuning-specs=true"
24-
]
23+
"--iree-codegen-enable-default-tuning-specs=true",
24+
]

sharktank/tests/layers/ffn_with_iree_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sharktank.utils._helpers import run_iree_vs_torch_fx
99
from sharktank.utils._iree_compile_flags_config import LLM_HIP_COMPILE_FLAGS
1010

11+
1112
class FFN(torch.nn.Module):
1213
def __init__(self, hidden=64, inter=128, dtype=torch.float32, activation="silu"):
1314
super().__init__()
@@ -22,9 +23,12 @@ def forward(self, x):
2223
else:
2324
return self.w_down(torch.nn.functional.gelu(self.w_up(x)))
2425

26+
2527
@pytest.mark.parametrize("dtype,atol", [(torch.float32, 1e-4), (torch.float16, 1e-4)])
2628
def test_ffn_iree_vs_eager(dtype, atol):
2729
torch.manual_seed(42)
2830
m = FFN(hidden=64, inter=128, dtype=dtype, activation="silu")
2931
x = torch.randn(2, 8, 64, dtype=dtype)
30-
run_iree_vs_torch_fx(m, input_args=(x,), atol=atol, rtol=0, compile_flags=LLM_HIP_COMPILE_FLAGS)
32+
run_iree_vs_torch_fx(
33+
m, input_args=(x,), atol=atol, rtol=0, compile_flags=LLM_HIP_COMPILE_FLAGS
34+
)

sharktank/tests/layers/linear_with_iree_test.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,12 @@ def __init__(self, in_f, out_f, bias=False, dtype=torch.float32):
1818
def forward(self, x):
1919
return self.lin(x)
2020

21+
2122
@pytest.mark.parametrize("dtype,atol", [(torch.float32, 1e-4), (torch.float16, 1e-4)])
2223
def test_linear_iree_vs_eager(dtype, atol):
2324
torch.manual_seed(42)
2425
m = Linear(64, 64, bias=False, dtype=dtype)
2526
x = torch.randn(2, 8, 64, dtype=dtype)
26-
run_iree_vs_torch_fx(m, input_args=(x,), atol=atol, rtol=0, compile_flags=LLM_HIP_COMPILE_FLAGS)
27+
run_iree_vs_torch_fx(
28+
m, input_args=(x,), atol=atol, rtol=0, compile_flags=LLM_HIP_COMPILE_FLAGS
29+
)

sharktank/tests/layers/output_lm_test_with_iree.py

Lines changed: 58 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -16,100 +16,105 @@
1616

1717
class OutputLMHead(torch.nn.Module):
1818
"""Standalone output_lm_head block extracted from PagedLlmModelV1"""
19-
19+
2020
def __init__(self, theta: Theta, config: LlamaModelConfig):
2121
super().__init__()
2222
self.config = config
2323
self.hp = config.hp
24-
24+
2525
# Output normalization layer
2626
self.output_norm = RMSNormLayer(
27-
theta("output_norm"),
28-
epsilon=self.hp.attention_layer_norm_rms_epsilon
27+
theta("output_norm"), epsilon=self.hp.attention_layer_norm_rms_epsilon
2928
)
30-
29+
3130
# Output linear layer (language model head)
3231
self.output_lm_head = LinearLayer(
3332
theta("output"),
3433
matmul_kernel=config.matmul_kernel,
3534
)
36-
35+
3736
def forward(self, h: torch.Tensor) -> torch.Tensor:
3837
# Apply normalization
39-
h_norm = self.output_norm(h) # output fp16 && wieghts float32
40-
38+
h_norm = self.output_norm(h) # output fp16 && wieghts float32
39+
4140
# Apply final linear transformation
42-
logits = self.output_lm_head(h_norm) # output && weights fp16
43-
41+
logits = self.output_lm_head(h_norm) # output && weights fp16
42+
4443
return logits
4544

4645

47-
def create_output_lm_head_from_irpa(irpa_path: str) -> tuple[OutputLMHead, torch.Tensor]:
46+
def create_output_lm_head_from_irpa(
47+
irpa_path: str,
48+
) -> tuple[OutputLMHead, torch.Tensor]:
4849
"""
4950
Create OutputLMHead module from IRPA file and generate sample input.
50-
51+
5152
Args:
5253
irpa_path: Path to the IRPA file
53-
54+
5455
Returns:
5556
Tuple of (OutputLMHead module, sample input tensor)
5657
"""
5758
# Load dataset from IRPA file
5859
dataset = Dataset.load(Path(irpa_path))
59-
60+
6061
# Create model config from dataset
6162
llama_config = LlamaModelConfig.from_dataset(
6263
dataset=dataset,
6364
attention_kernel="torch",
6465
matmul_kernel="sharktank.asm;*",
6566
activation_dtype=torch.float16,
6667
)
67-
68+
6869
# Create the output LM head module
6970
output_lm_head = OutputLMHead(dataset.root_theta, llama_config)
70-
71+
7172
# Generate sample input tensor matching expected dimensions
7273
# Typical shape: [batch_size, seq_len, hidden_dim]
7374
# TODO: Check if there are other more suitable sizes to test.
7475
batch_size = 2
7576
seq_len = 8
76-
hidden_dim = llama_config.hp.embedding_length # Use embedding_length instead of model_dim
77-
77+
hidden_dim = (
78+
llama_config.hp.embedding_length
79+
) # Use embedding_length instead of model_dim
80+
7881
sample_input = torch.randn(
79-
batch_size, seq_len, hidden_dim,
80-
dtype=llama_config.activation_dtype
82+
batch_size, seq_len, hidden_dim, dtype=llama_config.activation_dtype
8183
)
82-
84+
8385
return output_lm_head, sample_input
8486

8587

8688
# Test cases
87-
@pytest.mark.parametrize("dtype,atol", [
88-
(torch.float16, 1e-4)
89-
])
89+
@pytest.mark.parametrize("dtype,atol", [(torch.float16, 1e-4)])
9090
def test_output_lm_head_iree_vs_eager(request, dtype, atol):
9191
"""
9292
Test OutputLMHead module comparing IREE vs PyTorch eager execution.
93-
93+
9494
Use --parameters command line argument to specify the IRPA file path.
9595
"""
9696
# Validate and get IRPA path
9797
irpa_path = validate_and_get_irpa_path(request)
98-
98+
9999
try:
100100
# Create module and sample input from IRPA
101-
module, sample_input = create_output_lm_head_from_irpa(irpa_path)
101+
module, sample_input = create_output_lm_head_from_irpa(irpa_path)
102102
except Exception as e:
103103
pytest.skip(f"Failed to load model from IRPA: {e}")
104104

105105
# Convert to desired dtype
106106
# module = module.to(dtype)
107107
sample_input = sample_input.to(dtype)
108-
108+
109109
# Run IREE vs torch comparison
110-
run_iree_vs_torch_fx(module, input_args=(sample_input,), atol=atol, rtol=0,
111-
compile_flags=LLM_HIP_COMPILE_FLAGS,
112-
parameters_path=irpa_path)
110+
run_iree_vs_torch_fx(
111+
module,
112+
input_args=(sample_input,),
113+
atol=atol,
114+
rtol=0,
115+
compile_flags=LLM_HIP_COMPILE_FLAGS,
116+
parameters_path=irpa_path,
117+
)
113118

114119

115120
def test_output_lm_head_mock():
@@ -118,10 +123,10 @@ def test_output_lm_head_mock():
118123
Adding this test to work without requiring an IRPA file.
119124
"""
120125
torch.manual_seed(42)
121-
126+
122127
# Mock configuration - provide all required parameters
123128
from sharktank.layers.configs import LlamaHParams
124-
129+
125130
# Create LlamaHParams with all required parameters
126131
hp = LlamaHParams(
127132
model_arch="llama",
@@ -135,41 +140,48 @@ def test_output_lm_head_mock():
135140
attention_head_count_kv=8,
136141
vocab_size=32000,
137142
)
138-
143+
139144
# Create mock config
140145
config = LlamaModelConfig(
141146
hp=hp,
142147
activation_dtype=torch.float16,
143148
# attention_dtype=torch.float32,
144149
)
145-
150+
146151
# Create mock theta with synthetic weights
147152
from sharktank.types import DefaultPrimitiveTensor
148-
153+
149154
# Mock output_norm weights
150155
output_norm_weight = torch.randn(hp.embedding_length, dtype=torch.float32)
151-
152-
# Mock output (lm_head) weights
156+
157+
# Mock output (lm_head) weights
153158
output_weight = torch.randn(hp.vocab_size, hp.embedding_length, dtype=torch.float16)
154-
159+
155160
# Create theta structure
156161
theta_dict = {
157162
"output_norm": {"weight": DefaultPrimitiveTensor(data=output_norm_weight)},
158163
"output": {"weight": DefaultPrimitiveTensor(data=output_weight)},
159164
}
160-
165+
161166
theta = Theta(theta_dict)
162-
167+
163168
# Create module
164169
module = OutputLMHead(theta, config)
165-
170+
166171
# Create sample input
167172
batch_size, seq_len = 2, 8
168-
sample_input = torch.randn(batch_size, seq_len, hp.embedding_length, dtype=torch.float32)
169-
173+
sample_input = torch.randn(
174+
batch_size, seq_len, hp.embedding_length, dtype=torch.float32
175+
)
176+
170177
# Run IREE vs torch comparison
171-
run_iree_vs_torch_fx(module, input_args=(sample_input,), atol=1e-4, rtol=0,
172-
compile_flags=LLM_HIP_COMPILE_FLAGS,)
178+
run_iree_vs_torch_fx(
179+
module,
180+
input_args=(sample_input,),
181+
atol=1e-4,
182+
rtol=0,
183+
compile_flags=LLM_HIP_COMPILE_FLAGS,
184+
)
173185

174186

175187
if __name__ == "__main__":

sharktank/tests/layers/rms_norm_with_iree_test.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,14 @@ def forward(self, x):
2222
var = (x.to(torch.float32) ** 2).mean(dim=-1, keepdim=True)
2323
inv = torch.rsqrt(var + self.eps)
2424
y = x * inv
25-
return (y * self.weight) # broadcast over last dim
25+
return y * self.weight # broadcast over last dim
26+
2627

2728
@pytest.mark.parametrize("dtype,atol", [(torch.float32, 1e-4), (torch.bfloat16, 1e-2)])
2829
def test_rms_norm_iree_vs_eager(dtype, atol):
2930
torch.manual_seed(42)
3031
m = RMSNorm(hidden=64, dtype=dtype)
3132
x = torch.randn(2, 8, 64, dtype=dtype)
32-
run_iree_vs_torch_fx(m, input_args=(x,), atol=atol, rtol=0,
33-
compile_flags=LLM_HIP_COMPILE_FLAGS)
33+
run_iree_vs_torch_fx(
34+
m, input_args=(x,), atol=atol, rtol=0, compile_flags=LLM_HIP_COMPILE_FLAGS
35+
)

0 commit comments

Comments
 (0)