Skip to content

Commit 207c3c2

Browse files
committed
refactor iree vs torch method
1 parent 0a179a2 commit 207c3c2

File tree

3 files changed

+179
-72
lines changed

3 files changed

+179
-72
lines changed

sharktank/sharktank/utils/_helpers.py

Lines changed: 167 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
)
1616

1717
from iree.turbine.aot import *
18+
# from iree.turbine.aot import FxProgramsBuilder, export
1819

1920
DEFAULT_COMPILE_FLAGS = [
2021
"--iree-hal-target-device=hip", # change to your backend (e.g., local, cuda, vulkan)
@@ -38,40 +39,38 @@ def _as_tuple(x):
3839
return tuple(x)
3940
return (x,)
4041

41-
def run_iree_vs_torch_fx(
42+
def export_torch_module_to_mlir(
4243
module: torch.nn.Module,
4344
args=(),
4445
kwargs=None,
4546
*,
46-
atol=1e-4,
47-
rtol=0.0,
48-
entrypoint="forward",
49-
parameters_path=None,
47+
mlir_path: Path,
48+
target_fn="run_forward",
5049
):
5150
"""
52-
Exports MLIR via FxProgramsBuilder(model) and compares IREE vs Torch eager.
51+
Export torch module to MLIR and get torch eager reference output.
5352
5453
Args:
55-
module: torch.nn.Module under test
56-
args: example positional inputs (tuple required)
57-
kwargs: example kwargs
58-
atol/rtol: tolerances passed to torch.testing.assert_close
59-
entrypoint: the method name exported/invoked ("forward" by default)
54+
module: torch.nn.Module under test
55+
args: example positional inputs (tuple required)
56+
kwargs: example kwargs
57+
mlir_path: Path where to save the MLIR file
58+
target_fn: name of the exported function
59+
60+
Returns:
61+
Tuple of (torch_eager_output, export_output)
6062
"""
6163
kwargs = kwargs or {}
6264
args = _as_tuple(args)
6365
torch.manual_seed(1234)
64-
target_fn = "run_forward"
65-
entrypoint = target_fn
6666

67-
# ---- 1) Torch eager reference ----
67+
# ---- Torch eager reference ----
6868
module.eval()
6969
with torch.no_grad():
7070
expected = module(*args, **kwargs)
7171

7272
fxb = FxProgramsBuilder(module)
7373

74-
7574
# empty tensors for export input
7675
# there needs to be one corresponding to each arg
7776
# NOTE: assuming args are not nested.
@@ -95,55 +94,99 @@ def run_iree_vs_torch_fx(
9594
def _(module, *fn_args):
9695
return module.forward(*fn_args)
9796

98-
# Export the selected entry point (callable) from the instance `module`.
99-
# We pass a bound method so export() can trace that entry.
100-
# target_fn = getattr(type(module), entrypoint)
101-
10297
export_output = export(fxb, import_symbolic_shape_expressions=True)
98+
export_output.save_mlir(mlir_path)
10399

104-
# The turbine builder attaches a Torch-MLIR operation on the exported program.
105-
# Retrieve MLIR text and compile it with iree-compile.
106-
# Note: sharktank's exporter uses the same fx-builder object to drive MLIR generation.
107-
# See export_paged_llm_v1.py (fxb usage).
108-
# mlir_text = ep.mlir_module_operation.get_asm(enable_debug_info=False)
100+
return expected, export_output
109101

110-
# Compile MLIR -> VMFB
111-
with tempfile.TemporaryDirectory() as td:
112-
td = Path(td)
113-
mlir_path = td / "module.mlir"
114-
vmfb_path = td / "module.vmfb"
115-
export_output.save_mlir(mlir_path)
116102

117-
iree.compiler.compile_file(
118-
str(mlir_path),
119-
output_file=str(vmfb_path),
120-
extra_args=DEFAULT_COMPILE_FLAGS,
121-
)
103+
def compile_mlir_to_vmfb(
104+
mlir_path: Path,
105+
vmfb_path: Path,
106+
*,
107+
compile_flags=None,
108+
):
109+
"""
110+
Compile MLIR file to VMFB.
111+
112+
Args:
113+
mlir_path: Path to the MLIR file
114+
vmfb_path: Path where to save the VMFB file
115+
compile_flags: List of compilation flags (uses DEFAULT_COMPILE_FLAGS if None)
116+
"""
117+
compile_flags = compile_flags or DEFAULT_COMPILE_FLAGS
118+
119+
iree.compiler.compile_file(
120+
str(mlir_path),
121+
output_file=str(vmfb_path),
122+
extra_args=compile_flags,
123+
)
124+
125+
126+
def run_iree_module_from_vmfb(
127+
vmfb_path: Path,
128+
args=(),
129+
*,
130+
entrypoint="run_forward",
131+
parameters_path=None,
132+
driver="hip",
133+
device_count=1,
134+
):
135+
"""
136+
Load VMFB and run with IREE.
137+
138+
Args:
139+
vmfb_path: Path to the VMFB file
140+
args: Input arguments for the module
141+
entrypoint: Name of the function to run
142+
parameters_path: Optional path to parameters file
143+
driver: IREE driver to use
144+
device_count: Number of devices
145+
146+
Returns:
147+
IREE module output
148+
"""
149+
args = _as_tuple(args)
150+
151+
# Load & run with IREE
152+
devices = get_iree_devices(driver=driver, device_count=device_count)
153+
iree_module, vm_context, _ = load_iree_module(
154+
module_path=str(vmfb_path),
155+
devices=devices,
156+
parameters_path=parameters_path,
157+
)
158+
iree_args = prepare_iree_module_function_args(args=args, devices=devices)
159+
160+
iree_out = run_iree_module_function(
161+
module=iree_module,
162+
vm_context=vm_context,
163+
args=iree_args,
164+
device=devices[0],
165+
function_name=entrypoint,
166+
)
167+
168+
return iree_out
122169

123-
# Load & run with IREE
124-
devices = get_iree_devices(driver="hip", device_count=1) # adjust driver
125-
iree_module, vm_context, _ = load_iree_module(
126-
module_path=str(vmfb_path),
127-
devices=devices,
128-
parameters_path=parameters_path,
129-
)
130-
iree_args = prepare_iree_module_function_args(args=args, devices=devices)
131-
132-
# For FxProgramsBuilder export, the function name is typically "forward".
133-
# If you exported a different method, pass entrypoint=<that name>.
134-
# do we need logic to identify the correct entrypoints, will we have multi entry point executions in these pytests?
135-
136-
iree_out = run_iree_module_function(
137-
module=iree_module,
138-
vm_context=vm_context,
139-
args=iree_args,
140-
device=devices[0],
141-
function_name=entrypoint,
142-
)
143170

144-
# TODO: refactor to separate it from iree compile and run
171+
def compare_iree_torch_outputs(
172+
iree_output,
173+
torch_output,
174+
*,
175+
atol=1e-4,
176+
rtol=0.0,
177+
):
178+
"""
179+
Compare IREE output with torch eager reference and assert closeness.
180+
181+
Args:
182+
iree_output: Output from IREE module
183+
torch_output: Output from torch eager execution
184+
atol/rtol: tolerances passed to torch.testing.assert_close
185+
"""
145186
# Convert and compare
146-
actual = iree_to_torch(*iree_out)
187+
actual = iree_to_torch(*iree_output)
188+
expected = torch_output
189+
147190
if isinstance(expected, torch.Tensor):
148191
expected = (expected,)
149192
if isinstance(actual, torch.Tensor):
@@ -154,3 +197,70 @@ def _(module, *fn_args):
154197
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
155198
print(f"actual : {actual}")
156199
print(f"expected : {expected}")
200+
201+
202+
def run_iree_vs_torch_fx(
203+
module: torch.nn.Module,
204+
args=(),
205+
kwargs=None,
206+
*,
207+
atol=1e-4,
208+
rtol=0.0,
209+
entrypoint="run_forward",
210+
parameters_path=None,
211+
compile_flags=None,
212+
driver="hip",
213+
device_count=1,
214+
):
215+
"""
216+
Wrapper for MLIR export via FxProgramsBuilder(model) and IREE vs Torch eager comparison.
217+
218+
Args:
219+
module: torch.nn.Module under test
220+
args: example positional inputs (tuple required)
221+
kwargs: example kwargs
222+
atol/rtol: tolerances passed to torch.testing.assert_close
223+
entrypoint: the method name exported/invoked ("run_forward" by default)
224+
parameters_path: Optional path to parameters file
225+
compile_flags: List of compilation flags (uses DEFAULT_COMPILE_FLAGS if None)
226+
driver: IREE driver to use
227+
device_count: Number of devices
228+
"""
229+
with tempfile.TemporaryDirectory() as td:
230+
td = Path(td)
231+
mlir_path = td / "module.mlir"
232+
vmfb_path = td / "module.vmfb"
233+
234+
# Export to MLIR and get torch reference
235+
torch_output, _ = export_torch_module_to_mlir(
236+
module=module,
237+
args=args,
238+
kwargs=kwargs,
239+
mlir_path=mlir_path,
240+
target_fn=entrypoint,
241+
)
242+
243+
# Compile MLIR to VMFB
244+
compile_mlir_to_vmfb(
245+
mlir_path=mlir_path,
246+
vmfb_path=vmfb_path,
247+
compile_flags=compile_flags,
248+
)
249+
250+
# Run with IREE
251+
iree_output = run_iree_module_from_vmfb(
252+
vmfb_path=vmfb_path,
253+
args=args,
254+
entrypoint=entrypoint,
255+
parameters_path=parameters_path,
256+
driver=driver,
257+
device_count=device_count,
258+
)
259+
260+
# Compare outputs
261+
compare_iree_torch_outputs(
262+
iree_output=iree_output,
263+
torch_output=torch_output,
264+
atol=atol,
265+
rtol=rtol,
266+
)

sharktank/tests/layers/output_lm_test_with_iree.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from sharktank.layers import LinearLayer, RMSNormLayer
66
from sharktank.types import Dataset, Theta
77
from sharktank.layers.configs import LlamaModelConfig
8-
from sharktank.utils import cli
98

109

1110
class OutputLMHead(torch.nn.Module):
@@ -54,7 +53,6 @@ def create_output_lm_head_from_irpa(irpa_path: str) -> tuple[OutputLMHead, torch
5453
# Create model config from dataset
5554
llama_config = LlamaModelConfig.from_dataset(
5655
dataset=dataset,
57-
use_hf=True, # or False depending on your model
5856
attention_kernel="torch",
5957
matmul_kernel="sharktank.asm;*",
6058
activation_dtype=torch.float16,
@@ -80,14 +78,13 @@ def create_output_lm_head_from_irpa(irpa_path: str) -> tuple[OutputLMHead, torch
8078

8179
# Test cases
8280
@pytest.mark.parametrize("dtype,atol", [
83-
(torch.float16, 1e-3)
81+
(torch.float16, 1e-4)
8482
])
8583
def test_output_lm_head_iree_vs_eager(request, dtype, atol):
8684
"""
8785
Test OutputLMHead module comparing IREE vs PyTorch eager execution.
8886
8987
Use --irpa-path command line argument to specify the IRPA file path.
90-
Example: pytest tests/layers/output_lm_test_with_iree.py::test_output_lm_head_iree_vs_eager --irpa-path /path/to/model.irpa
9188
"""
9289
# Get IRPA path from command line argument
9390
irpa_path = request.config.getoption("--irpa-path")
@@ -142,7 +139,7 @@ def test_output_lm_head_mock():
142139
config = LlamaModelConfig(
143140
hp=hp,
144141
activation_dtype=torch.float16,
145-
attention_dtype=torch.float32,
142+
# attention_dtype=torch.float32,
146143
)
147144

148145
# Create mock theta with synthetic weights
@@ -152,7 +149,7 @@ def test_output_lm_head_mock():
152149
output_norm_weight = torch.randn(hp.embedding_length, dtype=torch.float32)
153150

154151
# Mock output (lm_head) weights
155-
output_weight = torch.randn(hp.vocab_size, hp.embedding_length, dtype=torch.float32)
152+
output_weight = torch.randn(hp.vocab_size, hp.embedding_length, dtype=torch.float16)
156153

157154
# Create theta structure
158155
theta_dict = {
Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
# sharktank/tests/layers/token_embedding_with_iree_test.py
22
import torch
33
import pytest
4-
from sharktank.utils._helpers import run_iree_vs_torch_fx
4+
from pathlib import Path
55
from sharktank.layers.token_embedding import TokenEmbeddingLayer
6+
from sharktank.types.theta import Dataset
7+
from sharktank.utils._helpers import run_iree_vs_torch_fx
68

79
class TokenEmbeddingSmall(torch.nn.Module):
810
def __init__(self, vocab_size=128, hidden=64, dtype=torch.float32):
@@ -12,20 +14,18 @@ def __init__(self, vocab_size=128, hidden=64, dtype=torch.float32):
1214
def forward(self, ids: torch.Tensor):
1315
return self.weight[ids]
1416

15-
@pytest.mark.parametrize("dtype,atol", [(torch.float32, 1e-4), (torch.float16, 1e-4)])
17+
@pytest.mark.parametrize("dtype,atol", [(torch.float16, 1e-4)])
1618
def test_token_embedding_iree_vs_eager(dtype, atol):
1719
torch.manual_seed(0)
1820

1921
# Each test assumes all inputs are in the correct dtype
2022
# as that information is required to export the model
21-
m = TokenEmbeddingSmall(vocab_size=128, hidden=64, dtype=dtype)
22-
23-
from pathlib import Path
24-
irpa_path = Path('/shark-dev/8b/instruct/weights/llama3.1_8b_instruct_fp16.irpa')
25-
from sharktank.types.theta import Dataset
23+
# Example usage of dummy Token Embedding Layer
24+
# m = TokenEmbeddingSmall(vocab_size=128, hidden=64, dtype=dtype)
25+
26+
irpa_path = Path('/shark-dev/llama3.1/405b/instruct/weights/fp4/fp4_2025_07_10_fn.irpa')
2627
dataset = Dataset.load(irpa_path)
2728

2829
m = TokenEmbeddingLayer(dataset.root_theta("token_embd"), dtype=dtype)
29-
# ids = torch.randint(0, 128, (2, 8), dtype=torch.long)
3030
inp_tensors = torch.randint(0, 128, (2, 8), dtype=torch.long)
31-
run_iree_vs_torch_fx(m, args=(inp_tensors,), atol=atol, rtol=0.0)
31+
run_iree_vs_torch_fx(m, args=(inp_tensors,), atol=atol, rtol=0.0, parameters_path=irpa_path)

0 commit comments

Comments
 (0)