-
Notifications
You must be signed in to change notification settings - Fork 70
LLM Layer testcases #2289
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
PhaneeshB
wants to merge
39
commits into
main
Choose a base branch
from
layers_test
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
LLM Layer testcases #2289
Changes from all commits
Commits
Show all changes
39 commits
Select commit
Hold shift + click to select a range
eaf2738
add torchVsIREE tests for llama layers
PhaneeshB ba858db
add output lm head test
PhaneeshB 685234a
update compile flags and irpa path
PhaneeshB 0ceef71
finalise test
PhaneeshB 822c3a4
refactor iree vs torch method
PhaneeshB db9604f
add copyright txt
PhaneeshB 58bddab
use parameters option from test root conftest
PhaneeshB 5ba4829
move manual seed to test files
PhaneeshB 9116900
remove old tests and debug prints
PhaneeshB c7cd2d6
update token embedding test
PhaneeshB b393838
refactor helpers
PhaneeshB 91a1b1c
move flags to config file and use in all test
PhaneeshB 2a2d253
fix pre-commit
PhaneeshB 8ec0a61
layers only run for hip devices
PhaneeshB 0aa8684
add ci for layerstest
PhaneeshB fbd2834
xfail token emb mock test
PhaneeshB acd9d59
refactor helpers
PhaneeshB 05bc4cc
xfail test with issues
PhaneeshB 2433066
Update xfails, add garbage collection to some tests, and split test s…
eagarvey-amd cc39cd8
Protect iree device arrays from escaping their device's lifetime.
eagarvey-amd d927971
Xfail lm head test with numerics issue.
eagarvey-amd ec7c786
Fix parameter overwriting issue with conv2d test and update lm_head x…
eagarvey-amd 1556450
Pass in temporary directory path for parameter write.
eagarvey-amd a7eb4a5
Add waits in fragile rotary tests, more garbage collection control
eagarvey-amd 6030be4
Fixups to garbage collection/device array destruction
eagarvey-amd 0ee3c7e
Fixup rotary embedding test utils usage
eagarvey-amd ad92d1a
Use sharded parameters path on conv2d layer module load.
eagarvey-amd 2164d0b
Clone torch tensors before exiting device context.
eagarvey-amd 4f0b26e
Move garbage collection outside of device contexts
eagarvey-amd 2ff6786
Require HIP device for rotary IREE test.
eagarvey-amd 9cde8f9
Update with_iree_device_context docs.
eagarvey-amd 3c1cb39
Revert changes to linear quant test.
eagarvey-amd d6bebd8
Preserve iree_to_torch prior functionality.
eagarvey-amd 6590ded
Simplify conv2d param path args.
eagarvey-amd 49cafeb
small iree_to_torch rework and fixes
eagarvey-amd aff894c
Rename default func name, small fixes
eagarvey-amd 038a000
refactor ref and export
PhaneeshB 6f8baef
shark migration update
PhaneeshB 2ee8ba4
lint fixes
PhaneeshB File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
24 changes: 24 additions & 0 deletions
24
amdsharktank/amdsharktank/utils/_iree_compile_flags_config.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,24 @@ | ||
| # Copyright 2024 Advanced Micro Devices, Inc. | ||
| # | ||
| # Licensed under the Apache License v2.0 with LLVM Exceptions. | ||
| # See https://llvm.org/LICENSE.txt for license information. | ||
| # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception | ||
|
|
||
| """ | ||
| IREE compilation flags for specific usecases. | ||
| """ | ||
|
|
||
| LLM_HIP_COMPILE_FLAGS = [ | ||
| "--iree-hal-target-device=hip", | ||
| "--iree-hip-target=gfx942", # MI300 example; adjust to your GPU if needed | ||
| "--iree-execution-model=async-external", | ||
| "--iree-opt-strip-assertions=true", | ||
| "--iree-opt-level=O3", | ||
| "--iree-dispatch-creation-propagate-collapse-across-expands=true", | ||
| "--iree-stream-affinity-solver-max-iterations=1024", | ||
| "--iree-hal-indirect-command-buffers=true", | ||
| "--iree-stream-resource-memory-model=discrete", | ||
| "--iree-hip-specialize-dispatches", | ||
| "--iree-hal-memoization=true", | ||
| "--iree-codegen-enable-default-tuning-specs=true", | ||
| ] |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| import json | ||
| from copy import deepcopy | ||
| from pathlib import Path | ||
| import tempfile | ||
|
|
||
| import numpy as np | ||
| import collections.abc | ||
|
|
@@ -29,10 +30,16 @@ | |
| torch_tree_flatten, | ||
| ) | ||
| from amdsharktank.utils import verify_exactly_one_is_not_none | ||
| from amdsharktank.utils.export import ( | ||
| export_torch_module_to_mlir_file, | ||
| get_torch_eager_output, | ||
| _as_tuple, | ||
| ) | ||
| from .tree import Tree | ||
| from iree.runtime import FileHandle | ||
| import iree.runtime | ||
|
|
||
|
|
||
| if TYPE_CHECKING: | ||
| from ..layers import ModelConfig | ||
|
|
||
|
|
@@ -252,7 +259,8 @@ def f(): | |
| ``` | ||
| Although the dev variable will be deleted after all other variables, in practice | ||
| with the various object wrappings with numpy and torch, the underlying HalBuffer | ||
| may get destroyed after the device. | ||
| may get destroyed after the device. For torch, use tensor.clone().detach() to break the ref to an IREE HalBuffer | ||
| and return the cloned tensor. With this process, usually garbage collection will do the right thing. | ||
| """ | ||
| res = fn(devices) | ||
| gc.collect() | ||
|
|
@@ -615,8 +623,14 @@ def call_torch_module_function( | |
| return res | ||
|
|
||
|
|
||
| def iree_to_torch(*tensors: iree.runtime.DeviceArray) -> List[torch.Tensor]: | ||
| return [device_array_to_host(tensor) for tensor in tensors] | ||
| def iree_to_torch( | ||
| *tensors: iree.runtime.DeviceArray, to_host: bool = False | ||
| ) -> List[torch.Tensor]: | ||
| res_torch = [device_array_to_host(tensor) for tensor in tensors] | ||
| if to_host: | ||
| res_torch = [tensor.clone().detach() for tensor in res_torch] | ||
| del tensors | ||
| return res_torch | ||
|
|
||
|
|
||
| def make_hal_buffer_view_trace_default_callback( | ||
|
|
@@ -753,3 +767,159 @@ def run_model_with_iree_run_module( | |
| input_args = [f"--input={arg.strip()}" for arg in input_args] | ||
| cmd += input_args | ||
| subprocess.check_call(cmd, **subprocess_run_kwargs) | ||
|
|
||
|
|
||
| def run_iree_module_from_vmfb( | ||
| vmfb_path: Path, | ||
| devices: list[iree.runtime.HalDevice] | None = None, | ||
| args=(), | ||
| *, | ||
| entrypoint="forward", | ||
| parameters_path=None, | ||
| driver="hip", | ||
| device_count=1, | ||
| ): | ||
| """ | ||
| Load VMFB and run with IREE. | ||
|
|
||
| Args: | ||
| vmfb_path: Path to the VMFB file | ||
| args: Input arguments for the module | ||
| entrypoint: Name of the function to run | ||
| parameters_path: Optional path to parameters file | ||
| driver: IREE driver to use | ||
| device_count: Number of devices | ||
|
|
||
| Returns: | ||
| IREE module output | ||
| """ | ||
| args = _as_tuple(args) | ||
|
|
||
| # Load & run with IREE | ||
| if devices is None: | ||
| devices = get_iree_devices(driver=driver, device_count=device_count) | ||
|
|
||
| def run_with_devices(devices): | ||
| iree_module, vm_context, _ = load_iree_module( | ||
| module_path=str(vmfb_path), | ||
| devices=devices, | ||
| parameters_path=parameters_path, | ||
| ) | ||
| iree_args = prepare_iree_module_function_args(args=args, devices=devices) | ||
|
|
||
| iree_out = run_iree_module_function( | ||
| module=iree_module, | ||
| vm_context=vm_context, | ||
| args=iree_args, | ||
| device=devices[0], | ||
| function_name=entrypoint, | ||
| ) | ||
| results_host = iree_to_torch(*iree_out, to_host=True) | ||
| return results_host | ||
|
|
||
| return with_iree_device_context(run_with_devices, devices) | ||
|
|
||
|
|
||
| def assert_iree_torch_outputs_close( | ||
| iree_output: torch.Tensor | tuple[torch.Tensor, ...], | ||
| torch_output: torch.Tensor | tuple[torch.Tensor, ...], | ||
| *, | ||
| atol: Optional[float] = None, | ||
| rtol: Optional[float] = None, | ||
| ): | ||
| """ | ||
| Compare IREE output with torch eager reference and assert closeness. | ||
|
|
||
| Args: | ||
| iree_output: Output from IREE module as a tuple of np.ndarrays. | ||
| torch_output: Output from torch eager execution | ||
| atol/rtol: tolerances passed to torch.testing.assert_close | ||
| """ | ||
| # Convert and compare | ||
| expected = torch_output | ||
| actual = iree_output | ||
|
|
||
| if isinstance(expected, torch.Tensor): | ||
| expected = (expected,) | ||
| if isinstance(actual, torch.Tensor): | ||
| actual = (actual,) | ||
|
|
||
| torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol) | ||
|
|
||
|
|
||
| def run_iree_vs_torch_eager( | ||
| module: torch.nn.Module, | ||
| input_args=(), | ||
| kwargs=None, | ||
| *, | ||
| atol=1e-4, | ||
| rtol=0.0, | ||
| entrypoint="forward", | ||
| parameters_path=None, | ||
| compile_flags: list[str] | None = None, | ||
| driver="hip", | ||
| device_count=1, | ||
| directory=".", | ||
| ): | ||
| """ | ||
| Wrapper for MLIR export via FxProgramsBuilder(model) and IREE vs Torch eager comparison. | ||
|
|
||
| Args: | ||
| module: torch.nn.Module under test | ||
| input_args: example positional inputs (tuple required) | ||
| kwargs: example kwargs | ||
| atol/rtol: tolerances passed to torch.testing.assert_close | ||
| entrypoint: the method name exported/invoked ("forward" by default) | ||
| parameters_path: Optional path to parameters file | ||
| compile_flags: List of compilation flags for iree | ||
| driver: IREE driver to use | ||
| device_count: Number of devices | ||
| """ | ||
| with tempfile.TemporaryDirectory() as td: | ||
| td = Path(td) | ||
| mlir_path = td / "module.mlir" | ||
| vmfb_path = td / "module.vmfb" | ||
|
|
||
| # Get torch reference output | ||
| torch_output = get_torch_eager_output( | ||
| module=module, | ||
| input_args=input_args, | ||
| kwargs=kwargs, | ||
| ) | ||
|
|
||
| # Export to MLIR | ||
| export_torch_module_to_mlir_file( | ||
| module=module, | ||
| input_args=input_args, | ||
| kwargs=kwargs, | ||
| mlir_path=mlir_path, | ||
| target_fn=entrypoint, | ||
| ) | ||
|
|
||
| # Compile MLIR to VMFB | ||
| if compile_flags is None: | ||
| raise ValueError("compile_flags must be provided") | ||
|
|
||
| iree.compiler.compile_file( | ||
| str(mlir_path), | ||
| output_file=str(vmfb_path), | ||
| extra_args=compile_flags, | ||
| ) | ||
| iree_devices = get_iree_devices(driver=driver, device_count=device_count) | ||
| # Run with IREE | ||
| iree_output = run_iree_module_from_vmfb( | ||
| vmfb_path=vmfb_path, | ||
| devices=iree_devices, | ||
| args=input_args, | ||
| entrypoint=entrypoint, | ||
| parameters_path=parameters_path, | ||
| ) | ||
| # Compare outputs | ||
| assert_iree_torch_outputs_close( | ||
| iree_output=iree_output, | ||
| torch_output=torch_output, | ||
| atol=atol, | ||
| rtol=rtol, | ||
| ) | ||
| del iree_devices | ||
| gc.collect() | ||
|
Comment on lines
+924
to
+925
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What kind of problems did you run into that requires manual GC? |
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The PR #2526 does very similar things.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
#2606
yes, it does propose utilities similar in nature and useful, since we didn't have those when we started, we developed our version of it
After #2606 is merged, for future tests, we can make it a point to use those utilities as needed.