@@ -61,10 +61,22 @@ namespace mlir::iree_compiler::IREE::HAL {
6161
6262namespace {
6363
64+ enum class ContainerType {
65+ // Automatically detect the container type from the target ABI attribute.
66+ Auto,
67+ // HIP ExecutableDef flatbuffer.
68+ HIP,
69+ // AMDGPU ExecutableDef flatbuffer.
70+ AMDGPU,
71+ // Raw HSACO image (ELF).
72+ HSACO,
73+ };
74+
6475// TODO(#18792): rename flags back to iree-rocm- as they are not HIP-specific.
6576struct ROCMOptions {
6677 std::string target = " " ;
6778 std::string targetFeatures = " " ;
79+ ContainerType containerType = ContainerType::Auto;
6880 std::string bitcodeDirectory = getDefaultBitcodeDirectory();
6981 int wavesPerEu = 0 ;
7082 std::string enableROCMUkernels = " none" ;
@@ -80,6 +92,7 @@ struct ROCMOptions {
8092 void bindOptions (OptionsBinder &binder) {
8193 using namespace llvm ;
8294 static cl::OptionCategory category (" HIP HAL Target" );
95+
8396 binder.opt <std::string>(
8497 " iree-hip-target" , target, cl::cat (category),
8598 cl::desc (
@@ -93,16 +106,34 @@ struct ROCMOptions {
93106 " for more details."
94107 // clang-format on
95108 ));
109+
96110 binder.opt <std::string>(
97111 " iree-hip-target-features" , targetFeatures, cl::cat (category),
98112 cl::desc (" HIP target features as expected by LLVM AMDGPU backend; "
99113 " e.g., '+sramecc,+xnack'." ));
114+
115+ binder.opt <ContainerType>(
116+ " iree-rocm-container-type" , containerType,
117+ llvm::cl::desc (" Serialized executable container type." ),
118+ llvm::cl::cat (category),
119+ llvm::cl::values (clEnumValN (ContainerType::Auto, " auto" ,
120+ " Automatically detect the container type "
121+ " from the target ABI attribute." ),
122+ clEnumValN (ContainerType::HIP, " hip" ,
123+ " HIP ExecutableDef flatbuffer." ),
124+ clEnumValN (ContainerType::AMDGPU, " amdgpu" ,
125+ " AMDGPU ExecutableDef flatbuffer." ),
126+ clEnumValN (ContainerType::HSACO, " hsaco" ,
127+ " Raw HSACO image (ELF)." )));
128+
100129 binder.opt <std::string>(" iree-hip-bc-dir" , bitcodeDirectory,
101130 cl::cat (category),
102131 cl::desc (" Directory of HIP Bitcode." ));
132+
103133 binder.opt <int >(" iree-hip-waves-per-eu" , wavesPerEu, cl::cat (category),
104134 cl::desc (" Optimization hint specifying minimum "
105135 " number of waves per execution unit." ));
136+
106137 binder.opt <std::string>(
107138 " iree-hip-enable-ukernels" , enableROCMUkernels, cl::cat (category),
108139 cl::desc (" Enables microkernels in the HIP compiler backend. May be "
@@ -124,6 +155,7 @@ struct ROCMOptions {
124155 " to be passed to the target backend compiler during HIP "
125156 " executable serialization" ),
126157 cl::ZeroOrMore, cl::cat (category));
158+
127159 binder.opt <bool >(" iree-hip-llvm-slp-vec" , slpVectorization,
128160 cl::cat (category),
129161 cl::desc (" Enable slp vectorization in llvm opt." ));
@@ -673,14 +705,44 @@ class ROCMTargetBackend final : public TargetBackend {
673705 " .hsaco" , targetHSACO);
674706 }
675707
676- // Wrap the HSACO ELF binary in a Flatbuffers container.
708+ // Determine container type from the target ABI attribute.
709+ ContainerType containerType = options.containerType ;
710+ if (containerType == ContainerType::Auto) {
711+ if (getABI (targetAttr) == " amdgpu" ) {
712+ containerType = ContainerType::AMDGPU;
713+ } else {
714+ containerType = ContainerType::HIP;
715+ }
716+ }
717+
718+ // Wrap the HSACO ELF binary in the requested container type (if any).
677719 FailureOr<DenseIntElementsAttr> binaryContainer;
678- if (getABI (targetAttr) == " amdgpu" ) {
720+ switch (containerType) {
721+ case ContainerType::Auto: {
722+ // Resolved above; unreachable. Fall-through to error case.
723+ assert (false && " auto container type must have resolved earlier" );
724+ break ;
725+ }
726+ case ContainerType::AMDGPU: {
679727 binaryContainer = serializeAMDGPUBinaryContainer (
680728 serializationOptions, variantOp, exportOps, targetHSACO);
681- } else {
729+ break ;
730+ }
731+ case ContainerType::HIP: {
682732 binaryContainer = serializeHIPBinaryContainer (
683733 serializationOptions, variantOp, exportOps, targetHSACO);
734+ break ;
735+ }
736+ case ContainerType::HSACO: {
737+ SmallVector<uint8_t > image;
738+ image.resize (targetHSACO.size ());
739+ std::memcpy (image.data (), targetHSACO.data (), image.size ());
740+ binaryContainer = DenseIntElementsAttr::get (
741+ VectorType::get ({static_cast <int64_t >(targetHSACO.size ())},
742+ executableBuilder.getI8Type ()),
743+ image);
744+ break ;
745+ }
684746 }
685747 if (failed (binaryContainer) || !binaryContainer.value ()) {
686748 return failure ();
0 commit comments