Skip to content

Commit cb55700

Browse files
committed
Update
1 parent 0917afe commit cb55700

File tree

3 files changed

+46
-3
lines changed

3 files changed

+46
-3
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_promote_matmul_operands.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,3 +392,35 @@ func.func @swizzle_operand_no_promote_fill(%b: tensor<128x128xf32>) -> tensor<4x
392392
// CHECK-NOT: tensor.expand_shape
393393
// CHECK: linalg.matmul
394394
// CHECK: return
395+
396+
// -----
397+
398+
// Verify that when use_global_load_dma is requested but input comes from
399+
// tensor.pad, it falls back to derived_thread_config since the padded data
400+
// is not from global memory. Non-padded inputs should still use DMA.
401+
402+
#lowering_config_dma_with_pad = #iree_gpu.lowering_config<{
403+
promote_operands = [0, 1],
404+
promotion_types = [#iree_gpu.use_global_load_dma, #iree_gpu.use_global_load_dma]}>
405+
406+
func.func @no_dma_for_padded_input(%a : tensor<4x127xf32>, %b: tensor<128x128xf32>) -> tensor<4x128xf32> {
407+
%cst = arith.constant 0.000000e+00 : f32
408+
%empty = tensor.empty() : tensor<4x128xf32>
409+
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<4x128xf32>) -> tensor<4x128xf32>
410+
%padded = tensor.pad %a low[0, 0] high[0, 1] {
411+
^bb0(%arg0: index, %arg1: index):
412+
tensor.yield %cst : f32
413+
} : tensor<4x127xf32> to tensor<4x128xf32>
414+
%mm = linalg.matmul {lowering_config = #lowering_config_dma_with_pad}
415+
ins(%padded, %b : tensor<4x128xf32>, tensor<128x128xf32>) outs(%fill : tensor<4x128xf32>) -> tensor<4x128xf32>
416+
return %mm : tensor<4x128xf32>
417+
}
418+
419+
// Padded input falls back to derived_thread_config, non-padded uses DMA.
420+
// CHECK-LABEL: func.func @no_dma_for_padded_input
421+
// CHECK: tensor.pad
422+
// CHECK: linalg.copy
423+
// CHECK-SAME: lowering_config = #iree_gpu.derived_thread_config
424+
// CHECK: linalg.copy
425+
// CHECK-SAME: lowering_config = #iree_gpu.use_global_load_dma
426+
// CHECK: linalg.matmul

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/PromotionImpls.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
1111
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
12+
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
1213
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUOps.h"
1314
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
1415
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
@@ -124,8 +125,18 @@ Value defaultPromotionImpl(OpBuilder &builder, OpOperand &operand,
124125
if (promotedValue.has_value()) {
125126
return promotedValue.value();
126127
}
128+
129+
// Global load DMA requires the source to come from global memory. If the
130+
// source comes from tensor.pad, the data is not in global memory, so fall
131+
// back to derived thread config.
132+
Attribute effectiveAttr = attr;
133+
if (isa<UseGlobalLoadDMAAttr>(attr) &&
134+
operand.get().getDefiningOp<tensor::PadOp>()) {
135+
effectiveAttr = DerivedThreadConfigAttr::get(builder.getContext());
136+
}
137+
127138
return promoteValue(builder, operand.getOwner()->getLoc(), operand.get(),
128-
attr);
139+
effectiveAttr);
129140
}
130141

131142
/// Inserts a `linalg.copy` directly before the given operation on the

compiler/src/iree/compiler/Codegen/Dialect/GPU/TargetUtils/ConfigUtils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -924,8 +924,8 @@ getMatmulOrIGEMMLoweringConfigAndWorkgroupSize(
924924
}
925925
// Do not use direct load DMA when padding is needed, as the source will
926926
// go through tensor.pad and won't be directly from global memory.
927-
ArrayRef<Attribute> promotionTypes =
928-
(useDirectLoad && !couldNeedPadding) ? ArrayRef<Attribute>(promotionArray)
927+
ArrayRef<Attribute> promotionTypes = (useDirectLoad && !couldNeedPadding)
928+
? ArrayRef<Attribute>(promotionArray)
929929
: ArrayRef<Attribute>{};
930930
GPU::appendPromotedOperandsList(context, attrs, promotionList,
931931
promotionTypes);

0 commit comments

Comments
 (0)