Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 63 additions & 46 deletions mlir/lib/Dialect/AMDGPU/Transforms/FoldMemRefsOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,62 +28,79 @@ struct AmdgpuFoldMemRefOpsPass final
}
};

static LogicalResult foldMemrefViewOp(PatternRewriter &rewriter, Location loc,
Value view, mlir::OperandRange indices,
SmallVectorImpl<Value> &resolvedIndices,
Value &memrefBase, StringRef role) {
Operation *defOp = view.getDefiningOp();
if (!defOp) {
return failure();
}
return llvm::TypeSwitch<Operation *, LogicalResult>(defOp)
.Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, loc, subviewOp.getMixedOffsets(),
subviewOp.getMixedStrides(), subviewOp.getDroppedDims(), indices,
resolvedIndices);
memrefBase = subviewOp.getSource();
return success();
})
.Case<memref::ExpandShapeOp>([&](memref::ExpandShapeOp expandShapeOp) {
if (failed(mlir::memref::resolveSourceIndicesExpandShape(
loc, rewriter, expandShapeOp, indices, resolvedIndices,
false))) {
return failure();
}
memrefBase = expandShapeOp.getViewSource();
return success();
})
.Case<memref::CollapseShapeOp>(
[&](memref::CollapseShapeOp collapseShapeOp) {
if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
loc, rewriter, collapseShapeOp, indices,
resolvedIndices))) {
return failure();
}
memrefBase = collapseShapeOp.getViewSource();
return success();
})
.Default([&](Operation *op) {
return rewriter.notifyMatchFailure(
op, (role + " producer is not one of SubViewOp, ExpandShapeOp, or "
"CollapseShapeOp")
.str());
});
}

struct FoldMemRefOpsIntoGatherToLDSOp final : OpRewritePattern<GatherToLDSOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(GatherToLDSOp op,
PatternRewriter &rewriter) const override {
Location loc = op.getLoc();

Value memrefSource;
SmallVector<Value> sourceIndices;
auto foldResult =
llvm::TypeSwitch<Operation *, LogicalResult>(
op.getSrc().getDefiningOp())
.Case<memref::SubViewOp>([&](memref::SubViewOp subviewOp) {
// If the source is a SubViewOp, we can directly rewrite the
// GatherToLDSOp.
mlir::affine::resolveIndicesIntoOpWithOffsetsAndStrides(
rewriter, loc, subviewOp.getMixedOffsets(),
subviewOp.getMixedStrides(), subviewOp.getDroppedDims(),
op.getSrcIndices(), sourceIndices);
memrefSource = subviewOp.getSource();
return success();
})
.Case<memref::ExpandShapeOp>(
[&](memref::ExpandShapeOp expandShapeOp) {
if (failed(mlir::memref::resolveSourceIndicesExpandShape(
loc, rewriter, expandShapeOp, op.getSrcIndices(),
sourceIndices, false))) {
return failure();
}
memrefSource = expandShapeOp.getViewSource();
return success();
})
.Case<memref::CollapseShapeOp>(
[&](memref::CollapseShapeOp collapseShapeOp) {
if (failed(mlir::memref::resolveSourceIndicesCollapseShape(
loc, rewriter, collapseShapeOp, op.getSrcIndices(),
sourceIndices))) {
return failure();
}
memrefSource = collapseShapeOp.getViewSource();
return success();
})
.Default([&](Operation *op) {
// If the source is not a SubViewOp, ExpandShapeOp, or
// CollapseShapeOp, we cannot fold the GatherToLDSOp.
return rewriter.notifyMatchFailure(
op,
"source producer is not one of SubViewOp, ExpandShapeOp, or "
"CollapseShapeOp");
});
SmallVector<Value> sourceIndices, destIndices;
Value memrefSource, memrefDest;

auto foldSrcResult =
foldMemrefViewOp(rewriter, loc, op.getSrc(), op.getSrcIndices(),
sourceIndices, memrefSource, "source");

if (failed(foldSrcResult)) {
memrefSource = op.getSrc();
sourceIndices = op.getSrcIndices();
}

auto foldDstResult =
foldMemrefViewOp(rewriter, loc, op.getDst(), op.getDstIndices(),
destIndices, memrefDest, "destination");

if (failed(foldResult)) {
return failure();
if (failed(foldDstResult)) {
memrefDest = op.getDst();
destIndices = op.getDstIndices();
}

rewriter.replaceOpWithNewOp<GatherToLDSOp>(op, memrefSource, sourceIndices,
op.getDst(), op.getDstIndices(),
memrefDest, destIndices,
op.getTransferType());

return success();
Expand Down
91 changes: 80 additions & 11 deletions mlir/test/Dialect/AMDGPU/amdgpu-fold-memrefs.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -54,18 +54,20 @@ func.func @subview_folding_offset(%offset_i: index, %offset_j: index) {
// CHECK: func @test_expand_shape
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
func.func @test_expand_shape(%offset_i: index, %offset_j: index) {
// CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
// CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3>
// CHECK: %[[MEM:.*]] = memref.alloc() : memref<8192xf16>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[IDX:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
// CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDX]]], %[[LOCAL]][%[[C0]], %[[C0]]]
// CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<64x64xf16, 3>
// CHECK: %[[IDXM:.*]] = affine.linearize_index [%[[ARG0]], %[[ARG1]]] by (64, 128) : index
// CHECK: %[[IDXL:.*]] = affine.linearize_index [%[[C0]], %[[C0]]] by (64, 64) : index
// CHECK: amdgpu.gather_to_lds %[[MEM]][%[[IDXM]]], %[[LOCAL]][%[[IDXL]]]
// CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<4096xf16, 3>

%alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
%alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace>
%mem = memref.alloc() : memref<8192xf16>
%expand = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16> into memref<64x128xf16>
%expand_mem = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16> into memref<64x128xf16>
%expand_alloc = memref.expand_shape %alloc [[0, 1]] output_shape [64, 64] : memref<4096xf16, #gpu_lds_addrspace> into memref<64x64xf16, #gpu_lds_addrspace>
%c0 = arith.constant 0 : index
amdgpu.gather_to_lds %expand[%offset_i, %offset_j], %alloc[%c0, %c0]
amdgpu.gather_to_lds %expand_mem[%offset_i, %offset_j], %expand_alloc[%c0, %c0]
: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, #gpu_lds_addrspace>
func.return
}
Expand All @@ -80,15 +82,82 @@ func.func @test_collapse_shape(%offset_i: index, %offset_j: index) {
// CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<64x64xf16, 3>
// CHECK: %[[MEM:.*]] = memref.alloc() : memref<64x128xf16>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[INDICES:.*]]:2 = affine.delinearize_index %[[ARG0]] into (64, 128) : index, index
// CHECK: amdgpu.gather_to_lds %[[MEM]][%[[INDICES]]#0, %[[INDICES]]#1], %[[LOCAL]][%[[C0]], %[[C0]]]
// CHECK: %[[INDICES_MEM:.*]]:2 = affine.delinearize_index %[[ARG0]] into (64, 128) : index, index
// CHECK: %[[INDICES_LDS:.*]]:2 = affine.delinearize_index %[[ARG1]] into (64, 64) : index, index
// CHECK: amdgpu.gather_to_lds %[[MEM]][%[[INDICES_MEM]]#0, %[[INDICES_MEM]]#1], %[[LOCAL]][%[[INDICES_LDS]]#0, %[[INDICES_LDS]]#1]
// CHECK-SAME: vector<8xf16>, memref<64x128xf16>, memref<64x64xf16, 3>

%alloc = memref.alloc() : memref<64x64xf16, #gpu_lds_addrspace>
%collapse_alloc = memref.collapse_shape %alloc [[0, 1]] : memref<64x64xf16, #gpu_lds_addrspace> into memref<4096xf16, #gpu_lds_addrspace>
%mem = memref.alloc() : memref<64x128xf16>
%collapse = memref.collapse_shape %mem [[0, 1]] : memref<64x128xf16> into memref<8192xf16>
%collapse_mem = memref.collapse_shape %mem [[0, 1]] : memref<64x128xf16> into memref<8192xf16>
%c0 = arith.constant 0 : index
amdgpu.gather_to_lds %collapse[%offset_i], %alloc[%c0, %c0]
amdgpu.gather_to_lds %collapse_mem[%offset_i], %collapse_alloc[%offset_j]
: vector<8xf16>, memref<8192xf16>, memref<4096xf16, #gpu_lds_addrspace>
func.return
}


// -----

#gpu_lds_addrspace = 3


// CHECK: func @test_expand_shape_src_raw_buffer
// CHECK-SAME: %[[ARG0:.*]]: memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
func.func @test_expand_shape_src_raw_buffer(%mem : memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, %offset_i: index, %offset_j: index) {
// CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[IDXM:.*]] = affine.linearize_index [%[[ARG1]], %[[ARG2]]] by (64, 128) : index
// CHECK: amdgpu.gather_to_lds %[[ARG0]][%[[IDXM]]], %[[LOCAL]][%[[C0]]]
// CHECK-SAME: vector<8xf16>, memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, memref<4096xf16, 3>

%alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace>
%expand_mem = memref.expand_shape %mem [[0, 1]] output_shape [64, 128] : memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>> into memref<64x128xf16, #amdgpu.address_space<fat_raw_buffer>>

%c0 = arith.constant 0 : index
amdgpu.gather_to_lds %expand_mem[%offset_i, %offset_j], %alloc[%c0]
: vector<8xf16>, memref<64x128xf16, #amdgpu.address_space<fat_raw_buffer>>, memref<4096xf16, #gpu_lds_addrspace>
func.return
}

// -----

#gpu_lds_addrspace = 3

// CHECK: func @test_expand_shape_dst_only
// CHECK-SAME: %[[ARG0:.*]]: index, %[[ARG1:.*]]: index
func.func @test_expand_shape_dst_only(%offset_i: index, %offset_j: index) {
// CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3>
// CHECK: %[[MEM:.*]] = memref.alloc() : memref<8192xf16>
// CHECK: %[[C0:.*]] = arith.constant 0 : index
// CHECK: %[[IDX_LDS:.*]] = affine.linearize_index [%[[ARG1]], %[[C0]]] by (64, 64) : index
// CHECK: amdgpu.gather_to_lds %[[MEM]][%[[ARG0]]], %[[LOCAL]][%[[IDX_LDS]]]
// CHECK-SAME: vector<8xf16>, memref<8192xf16>, memref<4096xf16, 3>

%alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace>
%mem = memref.alloc() : memref<8192xf16>
%expand_alloc = memref.expand_shape %alloc [[0, 1]] output_shape [64, 64] : memref<4096xf16, #gpu_lds_addrspace> into memref<64x64xf16, #gpu_lds_addrspace>

%c0 = arith.constant 0 : index
amdgpu.gather_to_lds %mem[%offset_i], %expand_alloc[%offset_j, %c0]
: vector<8xf16>, memref<8192xf16>, memref<64x64xf16, #gpu_lds_addrspace>
func.return
}

// -----

#gpu_lds_addrspace = 3

// CHECK: func @test_nop
// CHECK-SAME: %[[ARG0:.*]]: memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index
func.func @test_nop(%mem : memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, %offset_i: index, %offset_j: index) {
// CHECK: %[[LOCAL:.*]] = memref.alloc() : memref<4096xf16, 3>
// CHECK: amdgpu.gather_to_lds %[[ARG0]][%[[ARG1]]], %[[LOCAL]][%[[ARG2]]]
// CHECK-SAME: vector<8xf16>, memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, memref<4096xf16, 3>

%alloc = memref.alloc() : memref<4096xf16, #gpu_lds_addrspace>
amdgpu.gather_to_lds %mem[%offset_i], %alloc[%offset_j]
: vector<8xf16>, memref<8192xf16, #amdgpu.address_space<fat_raw_buffer>>, memref<4096xf16, #gpu_lds_addrspace>
func.return
}