Skip to content

Commit 2973181

Browse files
committed
Folding
1 parent c0cd803 commit 2973181

File tree

4 files changed

+100
-0
lines changed

4 files changed

+100
-0
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1395,6 +1395,7 @@ def AMDGPU_MakeDmaDescriptorOp :
13951395
}];
13961396

13971397
let hasVerifier = 1;
1398+
let hasFolder = 1;
13981399
}
13991400

14001401
def AMDGPU_TensorLoadToLDSOp :

mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -760,6 +760,9 @@ LogicalResult MakeDmaDescriptorOp::verify() {
760760
if (rank < 2) {
761761
return emitOpError("tensor and tile must be at least of rank 2.");
762762
}
763+
if (rank > 5) {
764+
return emitOpError("tensor and tile must be at most of rank 5.");
765+
}
763766
if (rank != globalStaticStrides.size()) {
764767
return emitOpError("strides and sizes must have same rank.");
765768
}
@@ -779,6 +782,82 @@ LogicalResult MakeDmaDescriptorOp::verify() {
779782
return success();
780783
}
781784

785+
static bool maybeUpdateDynamicIndexList(
786+
ArrayRef<int64_t> staticElements, ArrayRef<Attribute> foldedElements,
787+
SmallVector<Value> dynamicElements, SmallVector<int64_t> &newStaticElements,
788+
SmallVector<Value> &newDynamicElements) {
789+
bool changed = false;
790+
int index = 0;
791+
792+
for (int64_t static_element : staticElements) {
793+
if (!ShapedType::isDynamic(static_element)) {
794+
newStaticElements.push_back(static_element);
795+
continue;
796+
}
797+
798+
Attribute folded_element = foldedElements[index++];
799+
if (auto attr = dyn_cast<IntegerAttr>(folded_element)) {
800+
newStaticElements.push_back(attr.getInt());
801+
changed = true;
802+
continue;
803+
}
804+
805+
newStaticElements.push_back(ShapedType::kDynamic);
806+
newDynamicElements.push_back(dynamicElements[index]);
807+
}
808+
return changed;
809+
}
810+
811+
OpFoldResult MakeDmaDescriptorOp::fold(FoldAdaptor adaptor) {
812+
ArrayRef<int64_t> oldGlobalStaticStrides = adaptor.getGlobalStaticStrides();
813+
ArrayRef<Attribute> foldedGlobalDynamicStrides =
814+
adaptor.getGlobalDynamicStrides();
815+
SmallVector<Value> oldGlobalDynamicStrides = getGlobalDynamicStrides();
816+
817+
SmallVector<int64_t> newGlobalStaticStrides;
818+
SmallVector<Value> newGlobalDynamicStrides;
819+
820+
bool change = maybeUpdateDynamicIndexList(
821+
oldGlobalStaticStrides, foldedGlobalDynamicStrides,
822+
oldGlobalDynamicStrides, newGlobalStaticStrides, newGlobalDynamicStrides);
823+
824+
ArrayRef<int64_t> oldGlobalStaticSizes = adaptor.getGlobalStaticSizes();
825+
ArrayRef<Attribute> foldedGlobalDynamicSizes =
826+
adaptor.getGlobalDynamicSizes();
827+
SmallVector<Value> oldGlobalDynamicSizes = getGlobalDynamicSizes();
828+
829+
SmallVector<int64_t> newGlobalStaticSizes;
830+
SmallVector<Value> newGlobalDynamicSizes;
831+
832+
change |= maybeUpdateDynamicIndexList(
833+
oldGlobalStaticSizes, foldedGlobalDynamicSizes, oldGlobalDynamicSizes,
834+
newGlobalStaticSizes, newGlobalDynamicSizes);
835+
836+
ArrayRef<int64_t> oldSharedStaticSizes = adaptor.getSharedStaticSizes();
837+
ArrayRef<Attribute> foldedSharedDynamicSizes =
838+
adaptor.getSharedDynamicSizes();
839+
SmallVector<Value> oldSharedDynamicSizes = getSharedDynamicSizes();
840+
841+
SmallVector<int64_t> newSharedStaticSizes;
842+
SmallVector<Value> newSharedDynamicSizes;
843+
844+
change |= maybeUpdateDynamicIndexList(
845+
oldSharedStaticSizes, foldedSharedDynamicSizes, oldSharedDynamicSizes,
846+
newSharedStaticSizes, newSharedDynamicSizes);
847+
848+
if (change) {
849+
setGlobalStaticStrides(newGlobalStaticStrides);
850+
getGlobalDynamicStridesMutable().assign(newGlobalDynamicStrides);
851+
setGlobalStaticSizes(newGlobalStaticSizes);
852+
getGlobalDynamicSizesMutable().assign(newGlobalDynamicSizes);
853+
setSharedStaticSizes(newSharedStaticSizes);
854+
getSharedDynamicSizesMutable().assign(newSharedDynamicSizes);
855+
return getResult();
856+
}
857+
858+
return nullptr;
859+
}
860+
782861
//===----------------------------------------------------------------------===//
783862
// ScaledMFMAOp
784863
//===----------------------------------------------------------------------===//
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
// RUN: mlir-opt --canonicalize %s | FileCheck %s
2+
3+
// CHECK-LABEL: @make_dma_descriptor_fold
4+
// CHECK-SAME: (%[[BASE:.+]]: !amdgpu.tdm_base<i32>, %[[IDX:.+]]: index)
5+
func.func @make_dma_descriptor_fold(%base: !amdgpu.tdm_base<i32>, %idx: index) -> !amdgpu.tdm_descriptor {
6+
%c64 = arith.constant 64 : index
7+
8+
// CHECK: amdgpu.make_dma_descriptor %[[BASE]]
9+
%0 = amdgpu.make_dma_descriptor %base
10+
// CHECK-SAME: globalSize [64, 64]
11+
globalSize [%c64, %c64]
12+
// CHECK-SAME: globalStride [64, 1]
13+
globalStride [%c64, 1]
14+
// CHECK-SAME: sharedSize [64, 64]
15+
sharedSize [%c64, %c64]
16+
iterate %idx, %idx, %idx
17+
: !amdgpu.tdm_base<i32> -> !amdgpu.tdm_descriptor
18+
func.return %0 : !amdgpu.tdm_descriptor
19+
}

mlir/test/Dialect/AMDGPU/ops.mlir

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -758,3 +758,4 @@ func.func @tensor_load_store(%desc: !amdgpu.tdm_descriptor) {
758758
amdgpu.tensor_store_from_lds %desc : !amdgpu.tdm_descriptor
759759
return
760760
}
761+

0 commit comments

Comments
 (0)