Skip to content

Commit 922f751

Browse files
authored
Adding --iree-rocm-container-type= flag. (#20902)
This should almost always be 'auto' but when performing executable compilation (`--compile-mode=hal-executable`) it can be explicitly set to `hsaco` to have the output of the compile command be the raw HSACO ELF image. Example: ``` iree-compile \ --compile-mode=hal-executable \ --iree-hal-target-device=hip --iree-hip-target=gfx942 \ --iree-rocm-container-type=hsaco \ tools/test/iree-benchmark-executable.mlir \ -o=hsaco.elf ```
1 parent 26e6e97 commit 922f751

File tree

4 files changed

+97
-4
lines changed

4 files changed

+97
-4
lines changed

compiler/plugins/target/ROCM/ROCMTarget.cpp

Lines changed: 65 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,10 +61,22 @@ namespace mlir::iree_compiler::IREE::HAL {
6161

6262
namespace {
6363

64+
enum class ContainerType {
65+
// Automatically detect the container type from the target ABI attribute.
66+
Auto,
67+
// HIP ExecutableDef flatbuffer.
68+
HIP,
69+
// AMDGPU ExecutableDef flatbuffer.
70+
AMDGPU,
71+
// Raw HSACO image (ELF).
72+
HSACO,
73+
};
74+
6475
// TODO(#18792): rename flags back to iree-rocm- as they are not HIP-specific.
6576
struct ROCMOptions {
6677
std::string target = "";
6778
std::string targetFeatures = "";
79+
ContainerType containerType = ContainerType::Auto;
6880
std::string bitcodeDirectory = getDefaultBitcodeDirectory();
6981
int wavesPerEu = 0;
7082
std::string enableROCMUkernels = "none";
@@ -80,6 +92,7 @@ struct ROCMOptions {
8092
void bindOptions(OptionsBinder &binder) {
8193
using namespace llvm;
8294
static cl::OptionCategory category("HIP HAL Target");
95+
8396
binder.opt<std::string>(
8497
"iree-hip-target", target, cl::cat(category),
8598
cl::desc(
@@ -93,16 +106,34 @@ struct ROCMOptions {
93106
"for more details."
94107
// clang-format on
95108
));
109+
96110
binder.opt<std::string>(
97111
"iree-hip-target-features", targetFeatures, cl::cat(category),
98112
cl::desc("HIP target features as expected by LLVM AMDGPU backend; "
99113
"e.g., '+sramecc,+xnack'."));
114+
115+
binder.opt<ContainerType>(
116+
"iree-rocm-container-type", containerType,
117+
llvm::cl::desc("Serialized executable container type."),
118+
llvm::cl::cat(category),
119+
llvm::cl::values(clEnumValN(ContainerType::Auto, "auto",
120+
"Automatically detect the container type "
121+
"from the target ABI attribute."),
122+
clEnumValN(ContainerType::HIP, "hip",
123+
"HIP ExecutableDef flatbuffer."),
124+
clEnumValN(ContainerType::AMDGPU, "amdgpu",
125+
"AMDGPU ExecutableDef flatbuffer."),
126+
clEnumValN(ContainerType::HSACO, "hsaco",
127+
"Raw HSACO image (ELF).")));
128+
100129
binder.opt<std::string>("iree-hip-bc-dir", bitcodeDirectory,
101130
cl::cat(category),
102131
cl::desc("Directory of HIP Bitcode."));
132+
103133
binder.opt<int>("iree-hip-waves-per-eu", wavesPerEu, cl::cat(category),
104134
cl::desc("Optimization hint specifying minimum "
105135
"number of waves per execution unit."));
136+
106137
binder.opt<std::string>(
107138
"iree-hip-enable-ukernels", enableROCMUkernels, cl::cat(category),
108139
cl::desc("Enables microkernels in the HIP compiler backend. May be "
@@ -124,6 +155,7 @@ struct ROCMOptions {
124155
"to be passed to the target backend compiler during HIP "
125156
"executable serialization"),
126157
cl::ZeroOrMore, cl::cat(category));
158+
127159
binder.opt<bool>("iree-hip-llvm-slp-vec", slpVectorization,
128160
cl::cat(category),
129161
cl::desc("Enable slp vectorization in llvm opt."));
@@ -673,14 +705,44 @@ class ROCMTargetBackend final : public TargetBackend {
673705
".hsaco", targetHSACO);
674706
}
675707

676-
// Wrap the HSACO ELF binary in a Flatbuffers container.
708+
// Determine container type from the target ABI attribute.
709+
ContainerType containerType = options.containerType;
710+
if (containerType == ContainerType::Auto) {
711+
if (getABI(targetAttr) == "amdgpu") {
712+
containerType = ContainerType::AMDGPU;
713+
} else {
714+
containerType = ContainerType::HIP;
715+
}
716+
}
717+
718+
// Wrap the HSACO ELF binary in the requested container type (if any).
677719
FailureOr<DenseIntElementsAttr> binaryContainer;
678-
if (getABI(targetAttr) == "amdgpu") {
720+
switch (containerType) {
721+
case ContainerType::Auto: {
722+
// Resolved above; unreachable. Fall-through to error case.
723+
assert(false && "auto container type must have resolved earlier");
724+
break;
725+
}
726+
case ContainerType::AMDGPU: {
679727
binaryContainer = serializeAMDGPUBinaryContainer(
680728
serializationOptions, variantOp, exportOps, targetHSACO);
681-
} else {
729+
break;
730+
}
731+
case ContainerType::HIP: {
682732
binaryContainer = serializeHIPBinaryContainer(
683733
serializationOptions, variantOp, exportOps, targetHSACO);
734+
break;
735+
}
736+
case ContainerType::HSACO: {
737+
SmallVector<uint8_t> image;
738+
image.resize(targetHSACO.size());
739+
std::memcpy(image.data(), targetHSACO.data(), image.size());
740+
binaryContainer = DenseIntElementsAttr::get(
741+
VectorType::get({static_cast<int64_t>(targetHSACO.size())},
742+
executableBuilder.getI8Type()),
743+
image);
744+
break;
745+
}
684746
}
685747
if (failed(binaryContainer) || !binaryContainer.value()) {
686748
return failure();

compiler/plugins/target/ROCM/test/BUILD.bazel

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7-
load("//build_tools/bazel:build_defs.oss.bzl", "iree_cmake_extra_content")
87
load("//build_tools/bazel:iree_lit_test.bzl", "iree_lit_test_suite")
98

109
package(

compiler/plugins/target/ROCM/test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ iree_lit_test_suite(
4242
SRCS
4343
"external_function_validation.mlir"
4444
"smoketest.mlir"
45+
"smoketest_hsaco.mlir"
4546
"target_device_features.mlir"
4647
TOOLS
4748
FileCheck
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
// RUN: iree-opt --split-input-file --iree-hal-serialize-all-executables --iree-rocm-container-type=hsaco %s | FileCheck %s
2+
3+
// This smoketest verifies that serializing with the --iree-rocm-container-type=
4+
// flag set to raw HSACO ELF files produces an embedded ELF file. This cannot be
5+
// used with the IREE runtime but can be useful when using
6+
// --compile-mode=hal-executable and wanting an HSACO to pass to other tooling
7+
// without needing to unwrap the Flatbuffer. To avoid test churn we just check
8+
// that the `.ELF` magic bytes are present at the start and ignore the contents.
9+
10+
// CHECK: hal.executable public @executable
11+
// CHECK: hal.executable.binary public @rocm_hsaco_fb attributes {
12+
// CHECK-SAME: data = dense<"0x7F454C46
13+
// CHECK-SAME: format = "rocm-hsaco-fb"
14+
15+
#pipeline_layout = #hal.pipeline.layout<bindings = [
16+
#hal.pipeline.binding<storage_buffer>
17+
]>
18+
#executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {
19+
abi = "hip",
20+
iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>
21+
}>
22+
hal.executable public @executable {
23+
hal.executable.variant public @rocm_hsaco_fb target(#executable_target) {
24+
hal.executable.export public @export ordinal(0) layout(#pipeline_layout)
25+
builtin.module {
26+
llvm.func @export() {
27+
llvm.return
28+
}
29+
}
30+
}
31+
}

0 commit comments

Comments
 (0)