Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
eaf2738
add torchVsIREE tests for llama layers
PhaneeshB Sep 1, 2025
ba858db
add output lm head test
PhaneeshB Sep 19, 2025
685234a
update compile flags and irpa path
PhaneeshB Sep 19, 2025
0ceef71
finalise test
PhaneeshB Sep 19, 2025
822c3a4
refactor iree vs torch method
PhaneeshB Sep 24, 2025
db9604f
add copyright txt
PhaneeshB Sep 25, 2025
58bddab
use parameters option from test root conftest
PhaneeshB Sep 25, 2025
5ba4829
move manual seed to test files
PhaneeshB Sep 25, 2025
9116900
remove old tests and debug prints
PhaneeshB Sep 26, 2025
c7cd2d6
update token embedding test
PhaneeshB Sep 26, 2025
b393838
refactor helpers
PhaneeshB Sep 26, 2025
91a1b1c
move flags to config file and use in all test
PhaneeshB Sep 29, 2025
2a2d253
fix pre-commit
PhaneeshB Sep 30, 2025
8ec0a61
layers only run for hip devices
PhaneeshB Sep 30, 2025
0aa8684
add ci for layerstest
PhaneeshB Sep 30, 2025
fbd2834
xfail token emb mock test
PhaneeshB Sep 30, 2025
acd9d59
refactor helpers
PhaneeshB Oct 3, 2025
05bc4cc
xfail test with issues
PhaneeshB Oct 3, 2025
2433066
Update xfails, add garbage collection to some tests, and split test s…
eagarvey-amd Oct 6, 2025
cc39cd8
Protect iree device arrays from escaping their device's lifetime.
eagarvey-amd Oct 8, 2025
d927971
Xfail lm head test with numerics issue.
eagarvey-amd Oct 8, 2025
ec7c786
Fix parameter overwriting issue with conv2d test and update lm_head x…
eagarvey-amd Oct 9, 2025
1556450
Pass in temporary directory path for parameter write.
eagarvey-amd Oct 9, 2025
a7eb4a5
Add waits in fragile rotary tests, more garbage collection control
eagarvey-amd Oct 14, 2025
6030be4
Fixups to garbage collection/device array destruction
eagarvey-amd Oct 14, 2025
0ee3c7e
Fixup rotary embedding test utils usage
eagarvey-amd Oct 14, 2025
ad92d1a
Use sharded parameters path on conv2d layer module load.
eagarvey-amd Oct 15, 2025
2164d0b
Clone torch tensors before exiting device context.
eagarvey-amd Oct 15, 2025
4f0b26e
Move garbage collection outside of device contexts
eagarvey-amd Oct 15, 2025
2ff6786
Require HIP device for rotary IREE test.
eagarvey-amd Oct 16, 2025
9cde8f9
Update with_iree_device_context docs.
eagarvey-amd Oct 16, 2025
3c1cb39
Revert changes to linear quant test.
eagarvey-amd Oct 20, 2025
d6bebd8
Preserve iree_to_torch prior functionality.
eagarvey-amd Oct 20, 2025
6590ded
Simplify conv2d param path args.
eagarvey-amd Oct 21, 2025
49cafeb
small iree_to_torch rework and fixes
eagarvey-amd Oct 21, 2025
aff894c
Rename default func name, small fixes
eagarvey-amd Oct 21, 2025
038a000
refactor ref and export
PhaneeshB Nov 24, 2025
6f8baef
shark migration update
PhaneeshB Nov 24, 2025
2ee8ba4
lint fixes
PhaneeshB Nov 25, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions .github/workflows/ci-amdsharktank.yml
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,24 @@ jobs:
--iree-hal-target-device=hip \
--iree-hip-target=gfx942 \
--iree-device=hip://0 \
--ignore=amdsharktank/tests/layers \
--device=cuda:0

- name: Run amdsharktank layers tests
if: ${{ !cancelled() }}
run: |
pytest amdsharktank/tests/layers \
--durations=10 \
--capture=no \
--log-cli-level=info \
-v \
--iree-hal-target-device=hip \
--iree-hip-target=gfx942 \
--iree-device=hip://0 \
--parameters=/amdshark-dev/ossci-models/llama_3_1/405b/fp4/fp4_preshuffled_2025_09_12.irpa \
--device=cuda:0 \
-n 4

test_with_data:
name: "Data-dependent Tests"
strategy:
Expand Down
24 changes: 24 additions & 0 deletions amdsharktank/amdsharktank/utils/_iree_compile_flags_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Copyright 2024 Advanced Micro Devices, Inc.
#
# Licensed under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

"""
IREE compilation flags for specific usecases.
"""

LLM_HIP_COMPILE_FLAGS = [
"--iree-hal-target-device=hip",
"--iree-hip-target=gfx942", # MI300 example; adjust to your GPU if needed
"--iree-execution-model=async-external",
"--iree-opt-strip-assertions=true",
"--iree-opt-level=O3",
"--iree-dispatch-creation-propagate-collapse-across-expands=true",
"--iree-stream-affinity-solver-max-iterations=1024",
"--iree-hal-indirect-command-buffers=true",
"--iree-stream-resource-memory-model=discrete",
"--iree-hip-specialize-dispatches",
"--iree-hal-memoization=true",
"--iree-codegen-enable-default-tuning-specs=true",
]
102 changes: 99 additions & 3 deletions amdsharktank/amdsharktank/utils/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from typing import Callable, Optional, Any
from os import PathLike
from pathlib import Path
import functools

import torch
Expand All @@ -16,7 +17,13 @@
from torch.utils._pytree import tree_structure, tree_unflatten, tree_flatten
from amdsharktank.types.tensors import ShardedTensor
from amdsharktank.types.theta import mark_export_external_theta
from amdsharktank.layers import BaseLayer, ThetaLayer

from typing import TYPE_CHECKING

if TYPE_CHECKING:
from amdsharktank.layers import BaseLayer, ThetaLayer

# from amdsharktank.layers import BaseLayer, ThetaLayer


def flatten_signature(
Expand Down Expand Up @@ -180,7 +187,7 @@ def flat_fn(*args, **kwargs):


def export_model_mlir(
model: BaseLayer,
model: "BaseLayer",
output_path: PathLike,
*,
function_batch_sizes_map: Optional[dict[Optional[str], list[int]]] = None,
Expand All @@ -202,7 +209,7 @@ def export_model_mlir(

assert not (function_batch_sizes_map is not None and batch_sizes is not None)

if isinstance(model, ThetaLayer):
if isinstance(model, "ThetaLayer"):
mark_export_external_theta(model.theta)

if batch_sizes is not None:
Expand Down Expand Up @@ -232,3 +239,92 @@ def _(model, **kwargs):

output = aot.export(fxb)
output.save_mlir(output_path)


def _as_tuple(x):
if isinstance(x, tuple):
return x
if isinstance(x, list):
return tuple(x)
return (x,)


def get_torch_eager_output(
module: torch.nn.Module,
input_args=(),
kwargs=None,
):
"""
Get torch eager reference output from a module.

Args:
module: torch.nn.Module to execute
input_args: example positional inputs (tuple required)
kwargs: example kwargs

Returns:
Output from torch eager execution
"""
kwargs = kwargs or {}
input_args = _as_tuple(input_args)

module.eval()
with torch.no_grad():
output = module(*input_args, **kwargs)

return output


def export_torch_module_to_mlir_file(
module: torch.nn.Module,
input_args=(),
kwargs=None,
*,
mlir_path: Path,
target_fn="forward",
):
"""
Export torch module to MLIR and save to file.

Uses iree-turbine's aot.export functionality to export a torch module
to MLIR format and saves it to disk.

Args:
module: torch.nn.Module to export
input_args: example positional inputs (tuple required)
kwargs: example kwargs
mlir_path: Path where to save the MLIR file
target_fn: name of the exported function

Returns:
ExportOutput from iree-turbine containing the exported program
"""
kwargs = kwargs or {}
input_args = _as_tuple(input_args)

from iree.turbine.aot import FxProgramsBuilder

fxb = FxProgramsBuilder(module)

# empty tensors for export input
# there needs to be one corresponding to each arg
# NOTE: assuming args are not nested.
empty_args = tuple([torch.empty(arg.shape, dtype=arg.dtype) for arg in input_args])

# need to get this info from the test, currently only for static shapes
# one corresponding to each arg
dynamic_shapes = tuple([dict() for _ in input_args])

@fxb.export_program(
name=target_fn,
args=empty_args,
dynamic_shapes=(dynamic_shapes,),
strict=False,
)
def _(module, *fn_args):
return module.forward(*fn_args)

export_output = aot.export(fxb, import_symbolic_shape_expressions=True)
export_output.save_mlir(mlir_path)

return export_output
176 changes: 173 additions & 3 deletions amdsharktank/amdsharktank/utils/iree.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import json
from copy import deepcopy
from pathlib import Path
import tempfile

import numpy as np
import collections.abc
Expand All @@ -29,10 +30,16 @@
torch_tree_flatten,
)
from amdsharktank.utils import verify_exactly_one_is_not_none
from amdsharktank.utils.export import (
export_torch_module_to_mlir_file,
get_torch_eager_output,
_as_tuple,
)
from .tree import Tree
from iree.runtime import FileHandle
import iree.runtime


if TYPE_CHECKING:
from ..layers import ModelConfig

Expand Down Expand Up @@ -252,7 +259,8 @@ def f():
```
Although the dev variable will be deleted after all other variables, in practice
with the various object wrappings with numpy and torch, the underlying HalBuffer
may get destroyed after the device.
may get destroyed after the device. For torch, use tensor.clone().detach() to break the ref to an IREE HalBuffer
and return the cloned tensor. With this process, usually garbage collection will do the right thing.
"""
res = fn(devices)
gc.collect()
Expand Down Expand Up @@ -615,8 +623,14 @@ def call_torch_module_function(
return res


def iree_to_torch(*tensors: iree.runtime.DeviceArray) -> List[torch.Tensor]:
return [device_array_to_host(tensor) for tensor in tensors]
def iree_to_torch(
*tensors: iree.runtime.DeviceArray, to_host: bool = False
) -> List[torch.Tensor]:
res_torch = [device_array_to_host(tensor) for tensor in tensors]
if to_host:
res_torch = [tensor.clone().detach() for tensor in res_torch]
del tensors
return res_torch


def make_hal_buffer_view_trace_default_callback(
Expand Down Expand Up @@ -753,3 +767,159 @@ def run_model_with_iree_run_module(
input_args = [f"--input={arg.strip()}" for arg in input_args]
cmd += input_args
subprocess.check_call(cmd, **subprocess_run_kwargs)


def run_iree_module_from_vmfb(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PR #2526 does very similar things.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#2606
yes, it does propose utilities similar in nature and useful, since we didn't have those when we started, we developed our version of it
After #2606 is merged, for future tests, we can make it a point to use those utilities as needed.

vmfb_path: Path,
devices: list[iree.runtime.HalDevice] | None = None,
args=(),
*,
entrypoint="forward",
parameters_path=None,
driver="hip",
device_count=1,
):
"""
Load VMFB and run with IREE.

Args:
vmfb_path: Path to the VMFB file
args: Input arguments for the module
entrypoint: Name of the function to run
parameters_path: Optional path to parameters file
driver: IREE driver to use
device_count: Number of devices

Returns:
IREE module output
"""
args = _as_tuple(args)

# Load & run with IREE
if devices is None:
devices = get_iree_devices(driver=driver, device_count=device_count)

def run_with_devices(devices):
iree_module, vm_context, _ = load_iree_module(
module_path=str(vmfb_path),
devices=devices,
parameters_path=parameters_path,
)
iree_args = prepare_iree_module_function_args(args=args, devices=devices)

iree_out = run_iree_module_function(
module=iree_module,
vm_context=vm_context,
args=iree_args,
device=devices[0],
function_name=entrypoint,
)
results_host = iree_to_torch(*iree_out, to_host=True)
return results_host

return with_iree_device_context(run_with_devices, devices)


def assert_iree_torch_outputs_close(
iree_output: torch.Tensor | tuple[torch.Tensor, ...],
torch_output: torch.Tensor | tuple[torch.Tensor, ...],
*,
atol: Optional[float] = None,
rtol: Optional[float] = None,
):
"""
Compare IREE output with torch eager reference and assert closeness.

Args:
iree_output: Output from IREE module as a tuple of np.ndarrays.
torch_output: Output from torch eager execution
atol/rtol: tolerances passed to torch.testing.assert_close
"""
# Convert and compare
expected = torch_output
actual = iree_output

if isinstance(expected, torch.Tensor):
expected = (expected,)
if isinstance(actual, torch.Tensor):
actual = (actual,)

torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)


def run_iree_vs_torch_eager(
module: torch.nn.Module,
input_args=(),
kwargs=None,
*,
atol=1e-4,
rtol=0.0,
entrypoint="forward",
parameters_path=None,
compile_flags: list[str] | None = None,
driver="hip",
device_count=1,
directory=".",
):
"""
Wrapper for MLIR export via FxProgramsBuilder(model) and IREE vs Torch eager comparison.

Args:
module: torch.nn.Module under test
input_args: example positional inputs (tuple required)
kwargs: example kwargs
atol/rtol: tolerances passed to torch.testing.assert_close
entrypoint: the method name exported/invoked ("forward" by default)
parameters_path: Optional path to parameters file
compile_flags: List of compilation flags for iree
driver: IREE driver to use
device_count: Number of devices
"""
with tempfile.TemporaryDirectory() as td:
td = Path(td)
mlir_path = td / "module.mlir"
vmfb_path = td / "module.vmfb"

# Get torch reference output
torch_output = get_torch_eager_output(
module=module,
input_args=input_args,
kwargs=kwargs,
)

# Export to MLIR
export_torch_module_to_mlir_file(
module=module,
input_args=input_args,
kwargs=kwargs,
mlir_path=mlir_path,
target_fn=entrypoint,
)

# Compile MLIR to VMFB
if compile_flags is None:
raise ValueError("compile_flags must be provided")

iree.compiler.compile_file(
str(mlir_path),
output_file=str(vmfb_path),
extra_args=compile_flags,
)
iree_devices = get_iree_devices(driver=driver, device_count=device_count)
# Run with IREE
iree_output = run_iree_module_from_vmfb(
vmfb_path=vmfb_path,
devices=iree_devices,
args=input_args,
entrypoint=entrypoint,
parameters_path=parameters_path,
)
# Compare outputs
assert_iree_torch_outputs_close(
iree_output=iree_output,
torch_output=torch_output,
atol=atol,
rtol=rtol,
)
del iree_devices
gc.collect()
Comment on lines +924 to +925
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What kind of problems did you run into that requires manual GC?
Is it something related to the explanation in with_iree_device_context.

Loading
Loading