Skip to content

Commit 650c4be

Browse files
Cleanup the way config values are retrieved. (#21610)
The main motivation here is to remove `getConfigStringAttr`, `getConfigIntegerAttr` and `getConfigBoolAttr` methods from `Utils.h` to avoid the leaking of names used for attributes stored in the `DictionaryAttr:configuration` field of `IREE::HAL::ExecutableTargetAttr`. Instead - Have specific methods to retrieve known properties from the `configuration` and also have these methods return the value directly instead of the attribute. - Most of these methods know just use `DictionaryAttr`. They could use `IREE::HAL::ExecutableAttr` directly. This was initially tried but some cases (related to MaterializeEncoding) use the `DictionaryAttr` without using the `IREE::HAL::ExecutableAttr`. So it was more uniform to just use the `DictionaryAttr` everywhere (only exception being `getGPUTargetAttr` due to the complication of target being specified from command line instead of `IREE::HAL::ExecutableTargetAttr`) - The names used to access the various target information encoding in the `configuration` field of the `hal.executable.target` are now made private and these fields can be queried through new utility methods added. - Rename `iree.gpu.target` -> `iree_codegen.target_info`. Most of the checking of null values needed were mostly to make lit-tests happy. The end-to-end compilation flow always ensure the attributes are never empty. --------- Signed-off-by: MaheshRavishankar <[email protected]>
1 parent 7749b93 commit 650c4be

File tree

72 files changed

+570
-438
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

72 files changed

+570
-438
lines changed

compiler/plugins/target/CUDA/CUDATarget.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ class CUDATargetBackend final : public TargetBackend {
424424

425425
if (auto target = GPU::getCUDATargetDetails(
426426
options.clTarget, options.clTargetFeatures, context)) {
427-
configItems.emplace_back(kGPUTargetAttrName, target);
427+
addConfigGPUTarget(context, target, configItems);
428428
}
429429

430430
return b.getAttr<IREE::HAL::ExecutableTargetAttr>(
@@ -467,7 +467,8 @@ class CUDATargetBackend final : public TargetBackend {
467467
auto targetAttr = variantOp.getTargetAttr();
468468
StringRef targetArch = options.clTarget;
469469
StringRef targetFeatures = options.clTargetFeatures;
470-
if (auto attr = getGPUTargetAttr(targetAttr)) {
470+
if (auto attr =
471+
getGPUTargetAttr(executableBuilder.getContext(), targetAttr)) {
471472
targetArch = attr.getArch();
472473
targetFeatures = attr.getFeatures();
473474
}

compiler/plugins/target/LLVMCPU/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ iree_compiler_cc_library(
128128
],
129129
deps = [
130130
":ResolveCPUAndCPUFeatures",
131+
"//compiler/src/iree/compiler/Codegen/LLVMCPU",
131132
"//compiler/src/iree/compiler/Utils",
132133
"@llvm-project//llvm:Analysis",
133134
"@llvm-project//llvm:Core",

compiler/plugins/target/LLVMCPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ iree_cc_library(
123123
LLVMTarget
124124
LLVMTargetParser
125125
MLIRIR
126+
iree::compiler::Codegen::LLVMCPU
126127
iree::compiler::Utils
127128
PUBLIC
128129
)

compiler/plugins/target/LLVMCPU/LLVMCPUTarget.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,8 @@ class LLVMCPUTargetBackend final : public TargetBackend {
245245

246246
void buildTranslationPassPipeline(IREE::HAL::ExecutableTargetAttr targetAttr,
247247
OpPassManager &passManager) override {
248-
bool enableAArch64SME = isAArch64(targetAttr) && hasSMEFeature(targetAttr);
248+
bool enableAArch64SME = isAArch64(targetAttr.getConfiguration()) &&
249+
hasSMEFeature(targetAttr.getConfiguration());
249250
buildLLVMCPUCodegenPassPipeline(passManager, enableAArch64SME);
250251
}
251252

@@ -556,7 +557,7 @@ class LLVMCPUTargetBackend final : public TargetBackend {
556557

557558
if (target.linkUkernelBitcode) {
558559
// Link in ukernel bitcode.
559-
if (hasUkernel(variantOp.getTarget())) {
560+
if (hasUkernel(variantOp.getTarget().getConfiguration())) {
560561
llvm::Expected<std::unique_ptr<llvm::Module>> bitcode =
561562
loadUKernelBitcode(targetMachine.get(), context);
562563
if (!bitcode) {

compiler/plugins/target/LLVMCPU/LLVMTargetOptions.cpp

Lines changed: 10 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include "compiler/plugins/target/LLVMCPU/LLVMTargetOptions.h"
88

99
#include "compiler/plugins/target/LLVMCPU/ResolveCPUAndCPUFeatures.h"
10+
#include "iree/compiler/Codegen/LLVMCPU/Utils.h"
1011
#include "llvm/ADT/APFloat.h"
1112
#include "llvm/ADT/StringRef.h"
1213
#include "llvm/Analysis/TargetTransformInfo.h"
@@ -126,20 +127,17 @@ void LLVMTarget::storeToConfigAttrs(MLIRContext *context,
126127
auto addBool = [&](StringRef name, bool value) {
127128
config.emplace_back(b.getStringAttr(name), b.getBoolAttr(value));
128129
};
129-
auto addInt64 = [&](StringRef name, int64_t value) {
130-
config.emplace_back(b.getStringAttr(name), b.getI64IntegerAttr(value));
131-
};
132130

133-
addString("target_triple", triple);
131+
addConfigTargetTriple(context, triple, config);
134132
addString("cpu", cpu);
135-
addString("cpu_features", cpuFeatures);
133+
addConfigCpuFeatures(context, cpuFeatures, config);
136134
if (!dataLayout.empty()) {
137-
addString("data_layout", dataLayout);
135+
addConfigDataLayout(context, dataLayout, config);
138136
}
139137
if (vectorWidthInBytes != DEFAULT_VECTOR_WIDTH_IN_BYTES) {
140-
addInt64("native_vector_size", vectorWidthInBytes);
138+
addConfigNativeVectorSize(context, vectorWidthInBytes, config);
141139
}
142-
addInt64("max_stack_allocation_size", maxStackAllocSizeInBytes);
140+
addConfigMaxStackAllocationSize(context, maxStackAllocSizeInBytes, config);
143141
if (linkEmbedded != DEFAULT_LINK_EMBEDDED) {
144142
addBool("link_embedded", linkEmbedded);
145143
}
@@ -234,24 +232,13 @@ LLVMTarget::loadFromConfigAttr(Location loc, DictionaryAttr config,
234232
}
235233
return fallback;
236234
};
237-
auto getInt64 = [&](StringRef name, int64_t fallback) -> int64_t {
238-
Attribute attr = config.get(name);
239-
if (auto iattr = llvm::dyn_cast_if_present<IntegerAttr>(attr)) {
240-
return iattr.getValue().getSExtValue();
241-
} else if (attr) {
242-
hasFailures = true;
243-
emitError(loc) << "executable config '" << name
244-
<< "' requires i64 but got " << attr;
245-
}
246-
return fallback;
247-
};
248235

249236
LLVMTarget target;
250237

251238
// Constructor arguments.
252-
auto triple = getOptionalString("target_triple");
239+
auto triple = getConfigTargetTriple(config);
253240
auto cpu = getOptionalString("cpu");
254-
auto cpuFeatures = getOptionalString("cpu_features");
241+
auto cpuFeatures = getConfigCpuFeatures(config);
255242
bool linkEmbedded = getBool("link_embedded", DEFAULT_LINK_EMBEDDED);
256243
if (triple || cpu || cpuFeatures) {
257244
if (!triple) {
@@ -279,9 +266,9 @@ LLVMTarget::loadFromConfigAttr(Location loc, DictionaryAttr config,
279266
target.copy(defaultTarget);
280267
}
281268

282-
target.dataLayout = getString("data_layout", DEFAULT_DATA_LAYOUT, false);
269+
target.dataLayout = getConfigDataLayout(config).value_or(DEFAULT_DATA_LAYOUT);
283270
target.vectorWidthInBytes =
284-
getInt64("native_vector_size", DEFAULT_VECTOR_WIDTH_IN_BYTES);
271+
getConfigNativeVectorSize(config).value_or(DEFAULT_VECTOR_WIDTH_IN_BYTES);
285272

286273
target.debugSymbols = getBool("debug_symbols", DEFAULT_DEBUG_SYMBOLS);
287274
target.linkStatic = getBool("link_static", DEFAULT_LINK_STATIC);

compiler/plugins/target/MetalSPIRV/MetalSPIRVTarget.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ class MetalSPIRVTargetBackend : public TargetBackend {
9898
Builder b(context);
9999
SmallVector<NamedAttribute, 1> configItems;
100100
if (auto target = GPU::getMetalTargetDetails(context)) {
101-
configItems.emplace_back(kGPUTargetAttrName, target);
101+
addConfigGPUTarget(context, target, configItems);
102102
}
103103

104104
return b.getAttr<IREE::HAL::ExecutableTargetAttr>(

compiler/plugins/target/MetalSPIRV/test/smoketest.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ module attributes {
44
hal.device.targets = [
55
#hal.device.target<"metal", [
66
#hal.executable.target<"metal-spirv", "metal-msl-fb", {
7-
iree.gpu.target = #iree_gpu.target<arch = "", features = "spirv:v1.3,cap:Shader", wgp = <
7+
iree_codegen.target_info = #iree_gpu.target<arch = "", features = "spirv:v1.3,cap:Shader", wgp = <
88
compute = fp32|int32, storage = b32, subgroup = none, dot = none, mma = [], scaled_mma = [], subgroup_size_choices = [32],
99
max_workgroup_sizes = [128, 128, 64], max_thread_count_per_workgroup = 128, max_workgroup_memory_bytes = 16384,
1010
max_workgroup_counts = [65535, 65535, 65535]>>

compiler/plugins/target/ROCM/ROCMTarget.cpp

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ class ROCMTargetBackend final : public TargetBackend {
285285

286286
if (auto target = GPU::getHIPTargetDetails(
287287
options.target, options.targetFeatures, context)) {
288-
addConfig(kGPUTargetAttrName, target);
288+
addConfigGPUTarget(context, target, configItems);
289289
if (options.encodingLayoutResolver !=
290290
GPU::kNoEncodingLayoutResolverName) {
291291
if (Attribute encoding = GPU::getHIPTargetEncodingLayoutAttr(
@@ -326,7 +326,7 @@ class ROCMTargetBackend final : public TargetBackend {
326326

327327
addConfig("ukernels", b.getStringAttr(options.enableROCMUkernels));
328328
if (options.wavesPerEu > 0) {
329-
addConfig("waves_per_eu", b.getI64IntegerAttr(options.wavesPerEu));
329+
addConfigWavesPerEu(b.getContext(), options.wavesPerEu, configItems);
330330
}
331331

332332
return b.getAttr<IREE::HAL::ExecutableTargetAttr>(
@@ -350,7 +350,7 @@ class ROCMTargetBackend final : public TargetBackend {
350350
buildConfigurationPassPipeline(IREE::HAL::ExecutableTargetAttr targetAttr,
351351
OpPassManager &passManager) override {
352352
if (options.specializeDispatches) {
353-
if (auto attr = getGPUTargetAttr(targetAttr)) {
353+
if (auto attr = getGPUTargetAttr(targetAttr.getContext(), targetAttr)) {
354354
ROCM::ApplyBuiltinPDLPatternsPassOptions options;
355355
options.enableSpecialization = true;
356356
if (IREE::GPU::TargetChipAttr chip = attr.getChip()) {
@@ -443,7 +443,7 @@ class ROCMTargetBackend final : public TargetBackend {
443443
auto targetAttr = variantOp.getTargetAttr();
444444
StringRef targetArch = options.target;
445445
StringRef targetFeatures = options.targetFeatures;
446-
if (auto attr = getGPUTargetAttr(targetAttr)) {
446+
if (auto attr = getGPUTargetAttr(variantOp.getContext(), targetAttr)) {
447447
targetArch = attr.getArch();
448448
targetFeatures = attr.getFeatures();
449449
}

compiler/plugins/target/ROCM/test/external_function_validation.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
// remain in the final bitcode (post device bitcode linking).
66

77
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
8-
{iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "",
8+
{iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "",
99
wgp = <compute = fp16, storage = b16,
1010
subgroup = none, dot = none, mma = [], scaled_mma = [],
1111
subgroup_size_choices = [64],

compiler/plugins/target/ROCM/test/smoketest_hsaco.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
]>
1818
#executable_target = #hal.executable.target<"rocm", "rocm-hsaco-fb", {
1919
abi = "hip",
20-
iree.gpu.target = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>], scaled_mma = [], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>
20+
iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "", wgp = <compute = fp64|fp32|fp16|int64|int32|int16|int8, storage = b64|b32|b16|b8, subgroup = shuffle|arithmetic, dot = dp4xi8toi32, mma = [<MFMA_F32_16x16x16_BF16>, <MFMA_F32_32x32x8_BF16>, <MFMA_F32_16x16x32_F8E5M2FNUZ>, <MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ>, <MFMA_F32_16x16x32_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ>, <MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ>, <MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ>, <MFMA_I32_16x16x32_I8>, <MFMA_I32_32x32x16_I8>, <MFMA_F64_16x16x4_F64>, <MFMA_F32_16x16x4_F32>, <MFMA_F32_16x16x16_F16>, <MFMA_F32_32x32x8_F16>], scaled_mma = [], subgroup_size_choices = [64], max_workgroup_sizes = [1024, 1024, 1024], max_thread_count_per_workgroup = 1024, max_workgroup_memory_bytes = 65536, max_workgroup_counts = [2147483647, 2147483647, 2147483647], max_load_instruction_bits = 128, simds_per_wgp = 4, vgpr_space_bits = 16384>>
2121
}>
2222
hal.executable public @executable {
2323
hal.executable.variant public @rocm_hsaco_fb target(#executable_target) {

0 commit comments

Comments
 (0)