Skip to content

Commit d1c128a

Browse files
authored
[iree][codegen] Set #iree_codegen.denormal_fp_math in attention dispatches (#21940)
This commit adds support for attaching `#iree_codegen.denormal_fp_math` at the function level. This attr is added to the `translation_info` dictionary. Which will be propagated until it reaches the *AnnotateKernelForTranslation passes. A `#iree_codegen.denormal_fp_math` attribute specifies how denormal floating-point values are handled. --------- Signed-off-by: Fabian Mora <[email protected]>
1 parent 88d8316 commit d1c128a

File tree

5 files changed

+130
-5
lines changed

5 files changed

+130
-5
lines changed

compiler/src/iree/compiler/Codegen/Common/ReconcileTranslationInfo.cpp

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -700,13 +700,15 @@ static FailureOr<int64_t> reconcileSubgroupSize(
700700

701701
/// Helper function to retrieve the target-func-attrs value from translation
702702
/// info.
703-
static DictionaryAttr
704-
getTargetFuncAttrs(IREE::Codegen::TranslationInfoAttr translationInfo) {
703+
template <typename ConcreteTy>
704+
static ConcreteTy
705+
getTranslationInfoAttrs(IREE::Codegen::TranslationInfoAttr translationInfo,
706+
StringRef key) {
705707
auto translationConfig = translationInfo.getConfiguration();
706708
if (!translationConfig) {
707709
return nullptr;
708710
}
709-
auto attr = translationConfig.getAs<DictionaryAttr>("llvm_func_attrs");
711+
auto attr = translationConfig.getAs<ConcreteTy>(key);
710712
if (!attr) {
711713
return nullptr;
712714
}
@@ -801,10 +803,18 @@ void ReconcileTranslationInfoPass::runOnOperation() {
801803
// translation info into the func-like op. This is not the best
802804
// place to do this, but the intent is after this pass all the
803805
// lowering configs and translation infos will be deleted.
804-
DictionaryAttr targetFuncAttrs = getTargetFuncAttrs(translationInfo);
806+
auto targetFuncAttrs = getTranslationInfoAttrs<DictionaryAttr>(
807+
translationInfo, "llvm_func_attrs");
805808
if (targetFuncAttrs) {
806809
funcOp->setAttr("llvm_func_attrs", targetFuncAttrs);
807810
}
811+
if (auto denormalAttr =
812+
getTranslationInfoAttrs<IREE::Codegen::DenormalFpMathAttr>(
813+
translationInfo,
814+
IREE::Codegen::DenormalFpMathAttr::getFP32DictKeyName())) {
815+
funcOp->setAttr(IREE::Codegen::DenormalFpMathAttr::getFP32DictKeyName(),
816+
denormalAttr);
817+
}
808818
}
809819

810820
// Reconcile workgroup sizes.

compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1326,6 +1326,16 @@ setMatmulVectorDistributionConfig(IREE::GPU::TargetAttr target,
13261326
targetSubgroupSize, pipelineConfig);
13271327
}
13281328

1329+
/// Sets attention specific pipeline attributes.
1330+
static void
1331+
setAttentionPipelineAttributes(IREE::GPU::TargetAttr target,
1332+
SmallVectorImpl<NamedAttribute> &pipelineAttrs) {
1333+
pipelineAttrs.emplace_back(
1334+
IREE::Codegen::DenormalFpMathAttr::getFP32DictKeyName(),
1335+
IREE::Codegen::DenormalFpMathAttr::get(
1336+
target.getContext(), IREE::Codegen::DenormalFpMath::PreserveSign));
1337+
}
1338+
13291339
static LogicalResult setAttentionIntrinsicBasedVectorDistributionConfig(
13301340
IREE::GPU::TargetAttr target, mlir::FunctionOpInterface entryPoint,
13311341
IREE::LinalgExt::AttentionOp op) {
@@ -1577,6 +1587,8 @@ static LogicalResult setAttentionIntrinsicBasedVectorDistributionConfig(
15771587

15781588
SmallVector<NamedAttribute, 1> pipelineAttrs;
15791589

1590+
setAttentionPipelineAttributes(target, pipelineAttrs);
1591+
15801592
// TODO: We do not turn prefetching on even when requested by the prefetching
15811593
// flag because there is a shared memory allocation the two matmuls, which
15821594
// the prefetching pass cannot understand.

compiler/src/iree/compiler/Codegen/LLVMGPU/ROCDLAnnotateKernelForTranslation.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include <cassert>
88
#include "iree/compiler/Codegen/Common/PassUtils.h"
9+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
910
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
1011
#include "iree/compiler/Codegen/LLVMGPU/ROCDLPasses.h"
1112
#include "iree/compiler/Codegen/Utils/GPUUtils.h"
@@ -81,6 +82,21 @@ annotateKernelForTranslation(LLVM::LLVMFuncOp funcOp,
8182
IREE::Codegen::stringifyDenormalFpMath(attr.getValue()));
8283
}
8384

85+
// Check if the `denormal_fp_math_f32` dictionary is set and proccess it.
86+
auto denormalFp32 = cast_or_null<IREE::Codegen::DenormalFpMathAttr>(
87+
funcOp->getDiscardableAttr(
88+
IREE::Codegen::DenormalFpMathAttr::getFP32DictKeyName()));
89+
if (denormalFp32) {
90+
if (denormalFp32.getValue() != IREE::Codegen::DenormalFpMath::None) {
91+
funcOp.setDenormalFpMathF32(
92+
IREE::Codegen::stringifyDenormalFpMath(denormalFp32.getValue()));
93+
}
94+
95+
// Discard the attribute.
96+
funcOp->removeDiscardableAttr(
97+
IREE::Codegen::DenormalFpMathAttr::getFP32DictKeyName());
98+
}
99+
84100
// Kernel argument preloading is only supported on gfx942 and newer targets
85101
// from the CDNA family. This is enabled using the `inreg` function argument
86102
// attribute.

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/annotate_kernel_for_translation.mlir

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,88 @@ builtin.module {
186186
}
187187
}
188188
}
189+
190+
// -----
191+
192+
// Check that we handle the `denormal_fp_math_f32` appropriately
193+
194+
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
195+
{iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "",
196+
wgp = <compute = int32, storage = b32,
197+
subgroup = none,
198+
subgroup_size_choices = [64],
199+
max_workgroup_sizes = [1024, 1024, 1024],
200+
max_thread_count_per_workgroup = 1024,
201+
max_workgroup_memory_bytes = 65536,
202+
max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>,
203+
ukernels = "none"}>
204+
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer, Indirect>],
205+
flags = Indirect>
206+
builtin.module {
207+
hal.executable public @test_kern_arg {
208+
hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) {
209+
hal.executable.export public @test_kern_arg ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) {
210+
%c128 = arith.constant 128 : index
211+
%c2 = arith.constant 2 : index
212+
%c1 = arith.constant 1 : index
213+
hal.return %c128, %c2, %c1 : index, index, index
214+
} attributes {subgroup_size = 64 : index, workgroup_size = [128 : index, 2 : index, 1 : index]}
215+
builtin.module {
216+
llvm.func @test_kern_arg(%arg0: i32) attributes {
217+
iree_codegen.denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<"preserve-sign">,
218+
llvm_func_attrs = {check_attr}
219+
} {
220+
llvm.return
221+
}
222+
}
223+
}
224+
}
225+
}
226+
227+
// CHECK-LABEL: llvm.func @test_kern_arg
228+
// CHECK: denormal_fp_math_f32 = "preserve-sign"
229+
// CHECK-NOT: iree_codegen.denormal_fp_math_f32
230+
// CHECK: llvm_func_attrs = {check_attr
231+
232+
233+
// -----
234+
235+
// Check that we handle the `denormal_fp_math_f32` appropriately
236+
237+
#executable_target_rocm_hsaco_fb = #hal.executable.target<"rocm", "rocm-hsaco-fb",
238+
{iree_codegen.target_info = #iree_gpu.target<arch = "gfx942", features = "",
239+
wgp = <compute = int32, storage = b32,
240+
subgroup = none,
241+
subgroup_size_choices = [64],
242+
max_workgroup_sizes = [1024, 1024, 1024],
243+
max_thread_count_per_workgroup = 1024,
244+
max_workgroup_memory_bytes = 65536,
245+
max_workgroup_counts = [2147483647, 2147483647, 2147483647]>>,
246+
ukernels = "none"}>
247+
#pipeline_layout = #hal.pipeline.layout<bindings = [#hal.pipeline.binding<storage_buffer, Indirect>],
248+
flags = Indirect>
249+
builtin.module {
250+
hal.executable public @test_kern_arg {
251+
hal.executable.variant public @rocm_hsaco_fb target(#executable_target_rocm_hsaco_fb) {
252+
hal.executable.export public @test_kern_arg ordinal(0) layout(#pipeline_layout) count(%arg0: !hal.device) -> (index, index, index) {
253+
%c128 = arith.constant 128 : index
254+
%c2 = arith.constant 2 : index
255+
%c1 = arith.constant 1 : index
256+
hal.return %c128, %c2, %c1 : index, index, index
257+
} attributes {subgroup_size = 64 : index, workgroup_size = [128 : index, 2 : index, 1 : index]}
258+
builtin.module {
259+
llvm.func @test_kern_arg(%arg0: i32) attributes {
260+
iree_codegen.denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<none>,
261+
llvm_func_attrs = {check_attr}
262+
} {
263+
llvm.return
264+
}
265+
}
266+
}
267+
}
268+
}
269+
270+
// CHECK-LABEL: llvm.func @test_kern_arg
271+
// CHECK-NOT: denormal_fp_math_f32
272+
// CHECK-NOT: iree_codegen.denormal_fp_math_f32
273+
// CHECK: llvm_func_attrs = {check_attr

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_gfx950.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,7 @@ func.func @attention_20x4096x64x4096x64() {
240240

241241
// CHECK: #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
242242
// CHECK-NOT: prefetch_shared_memory = true
243+
// CHECK: iree_codegen.denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<"preserve-sign">
243244

244245
// CHECK-LABEL: func.func @attention_large_head_dim_shared_mem()
245246

@@ -291,7 +292,8 @@ func.func @attention_large_head_dim_shared_mem() {
291292
// and the QK matmul used MFMA_F32_32x32x64_F8E4M3FN. Vector distribution failed
292293
// to distribute these layouts to threads.
293294

294-
// CHECK: #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64, {}>
295+
// CHECK: #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute workgroup_size = [256, 1, 1] subgroup_size = 64
296+
// CHECK: iree_codegen.denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<"preserve-sign">
295297
// CHECK-LABEL: func.func @attention_check_mma_accs_compatable
296298

297299
#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d3)>

0 commit comments

Comments
 (0)