Skip to content

Commit fe234f8

Browse files
authored
[VectorDistribute] Flush denormals for attention reduction config (#22041)
We are already flushing denormals for intrinsic based attention configuration, this patch does the same for subgroup reduction based attention configuration
1 parent f60ba58 commit fe234f8

File tree

2 files changed

+6
-1
lines changed

2 files changed

+6
-1
lines changed

compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1933,15 +1933,19 @@ static LogicalResult setAttentionReductionConfig(
19331933
decompositionConfig.emplace_back(IREE::LinalgExt::AttentionOp::getPVAttrStr(),
19341934
pvAttrDict);
19351935

1936+
SmallVector<NamedAttribute, 1> pipelineAttrs;
1937+
setAttentionPipelineAttributes(target, pipelineAttrs);
1938+
19361939
// Set attention decomposition control config.
19371940
op.setDecompositionConfigAttr(b.getDictionaryAttr(decompositionConfig));
19381941

19391942
auto configDict = b.getDictionaryAttr(attrs);
19401943
auto loweringConfig = IREE::GPU::LoweringConfigAttr::get(context, configDict);
1944+
auto pipelineConfig = DictionaryAttr::get(context, pipelineAttrs);
19411945

19421946
return setOpConfigAndEntryPointFnTranslation(
19431947
entryPoint, op, loweringConfig, CodeGenPipeline::LLVMGPUVectorDistribute,
1944-
workgroupSize, targetSubgroupSize);
1948+
workgroupSize, targetSubgroupSize, pipelineConfig);
19451949

19461950
return success();
19471951
}

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
// RUN: --pass-pipeline="builtin.module(iree-llvmgpu-select-lowering-strategy)" %s | FileCheck %s
44

55
// CHECK: #iree_codegen.translation_info<pipeline = LLVMGPUVectorDistribute
6+
// CHECK-SAME: iree_codegen.denormal_fp_math_f32 = #iree_codegen.denormal_fp_math<"preserve-sign">
67

78
#pipeline_layout = #hal.pipeline.layout<bindings = [
89
#hal.pipeline.binding<storage_buffer>,

0 commit comments

Comments
 (0)