Skip to content

Commit baa03c6

Browse files
authored
[AMD] Support extract slice in canonicalize pointers pass (#7090)
There are scenarios where one can use `ExtractSliceOp` for slicing tensors of pointers. In this scenario, `CanonicalizePointers` pass fails. This PR adds `ConvertExtractSliceOp` rewrite pattern in `CanonicalizePointers` pass to handle the scenario. The lit test is provided.
1 parent 9c22937 commit baa03c6

File tree

2 files changed

+71
-1
lines changed

2 files changed

+71
-1
lines changed

test/TritonGPU/amd/amd-canonicalize-pointers.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1471,6 +1471,30 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.targ
14711471

14721472
// -----
14731473

1474+
#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [4, 16], warpsPerCTA = [2, 2], order = [1, 0]}>
1475+
module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 64 : i32} {
1476+
tt.func @conversion_extract_slice(%arg0: !tt.ptr<f32>, %arg1: tensor<256x256xi32, #blocked>) -> tensor<128x256xf32, #blocked> {
1477+
%3 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<256x256x!tt.ptr<f32>, #blocked>
1478+
%4 = tt.addptr %3, %arg1 : tensor<256x256x!tt.ptr<f32>, #blocked>, tensor<256x256xi32, #blocked>
1479+
%5 = amdgpu.extract_slice %4 [0, 0] : tensor<256x256x!tt.ptr<f32>, #blocked> to tensor<128x256x!tt.ptr<f32>, #blocked>
1480+
%6 = tt.load %5 : tensor<128x256x!tt.ptr<f32>, #blocked>
1481+
tt.return %6 : tensor<128x256xf32, #blocked>
1482+
}
1483+
}
1484+
1485+
// CHECK-LABEL: tt.func @conversion_extract_slice(
1486+
// CHECK-SAME: %[[ARG_0:.*]]: !tt.ptr<f32>, %[[ARG_1:.*]]: tensor<256x256xi32, #blocked>) -> tensor<128x256xf32, #blocked> {
1487+
// CHECK: %[[VAR_0:.*]] = arith.extsi %[[ARG_1]] : tensor<256x256xi32, #blocked> to tensor<256x256xi64, #blocked>
1488+
// CHECK: %[[VAR_1:.*]] = amdgpu.extract_slice %[[VAR_0]] [0, 0] : tensor<256x256xi64, #blocked> to tensor<128x256xi64, #blocked>
1489+
// CHECK: %[[VAR_2:.*]] = arith.trunci %[[VAR_1]] : tensor<128x256xi64, #blocked> to tensor<128x256xi32, #blocked>
1490+
// CHECK: %[[VAR_3:.*]] = tt.splat %[[ARG_0]] : !tt.ptr<f32> -> tensor<128x256x!tt.ptr<f32>, #blocked>
1491+
// CHECK: %[[VAR_4:.*]] = tt.addptr %[[VAR_3]], %[[VAR_2]] : tensor<128x256x!tt.ptr<f32>, #blocked>, tensor<128x256xi32, #blocked>
1492+
// CHECK: %[[VAR_5:.*]] = tt.load %[[VAR_4]] : tensor<128x256x!tt.ptr<f32>, #blocked>
1493+
// CHECK: tt.return %[[VAR_5]] : tensor<128x256xf32, #blocked>
1494+
// CHECK: }
1495+
1496+
// -----
1497+
14741498
module attributes {"ttg.num-warps" = 4 : i32} {
14751499
tt.func @ifOpPoison(%arg0: !tt.ptr<f32>, %arg1: tensor<1024xf32>, %arg2: i1) -> tensor<1024xf32> {
14761500
%c1024_i32 = arith.constant 1024 : i32

third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,50 @@ class ConvertAddPtrOp : public PointerCanonicalizationPattern<tt::AddPtrOp> {
635635
}
636636
};
637637

638+
/// Slice only offset and keep base - i.e.,
639+
/// slice(fatPtrBase, fatPtrOffset) -> (fatPtrBase, slice(fatPtrOffset))
640+
class ConvertExtractSliceOp
641+
: public PointerCanonicalizationPattern<tt::amdgpu::ExtractSliceOp> {
642+
public:
643+
using PointerCanonicalizationPattern::PointerCanonicalizationPattern;
644+
645+
LogicalResult
646+
matchAndRewrite_(tt::amdgpu::ExtractSliceOp extractSliceOp,
647+
OneToNOpAdaptor adaptor,
648+
ConversionPatternRewriter &rewriter) const override {
649+
ValueRange remappedOperands = adaptor.getSource();
650+
if (remappedOperands.size() != 2)
651+
return success();
652+
653+
Value fatPtrBase = remappedOperands[0];
654+
Value fatPtrOffset = remappedOperands[1];
655+
if (!llvm::isa<tt::PointerType>(fatPtrBase.getType()))
656+
return rewriter.notifyMatchFailure(extractSliceOp,
657+
"non tt.ptr base unimplemented");
658+
659+
auto fatPtrOffsetTy = dyn_cast<RankedTensorType>(fatPtrOffset.getType());
660+
if (!fatPtrOffsetTy)
661+
return rewriter.notifyMatchFailure(
662+
extractSliceOp, "non RankedTensorType offset unimplemented");
663+
664+
Location loc = extractSliceOp->getLoc();
665+
RankedTensorType resultType = extractSliceOp.getResult().getType();
666+
auto slicedOffsetsTy = RankedTensorType::get(
667+
resultType.getShape(), fatPtrOffsetTy.getElementType(),
668+
resultType.getEncoding());
669+
Value slicedOffsets = rewriter.create<tt::amdgpu::ExtractSliceOp>(
670+
loc, Type{slicedOffsetsTy}, Value{fatPtrOffset},
671+
extractSliceOp.getStaticOffsetsAttr());
672+
673+
rewriter.replaceOpWithMultiple(extractSliceOp,
674+
{{fatPtrBase, slicedOffsets}});
675+
fatPtrs[{fatPtrBase, slicedOffsets}] =
676+
fatPtrs.at({fatPtrBase, fatPtrOffset});
677+
678+
return success();
679+
}
680+
};
681+
638682
/// Rewrite init args and result type and bb args.
639683
class ConvertSCFForOp : public PointerCanonicalizationPattern<scf::ForOp> {
640684
using PointerCanonicalizationPattern::PointerCanonicalizationPattern;
@@ -1510,6 +1554,8 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() {
15101554
target.addDynamicallyLegalDialect<scf::SCFDialect>(isLegal);
15111555
target.addDynamicallyLegalDialect<cf::ControlFlowDialect>(isLegal);
15121556
target.addDynamicallyLegalDialect<arith::ArithDialect>(isLegal);
1557+
target.addDynamicallyLegalDialect<triton::amdgpu::TritonAMDGPUDialect>(
1558+
isLegal);
15131559

15141560
// Rewrite the rest of the ops.
15151561
// Note we *do not* declare unrealized_cast an illegal op here in order that
@@ -1521,7 +1567,7 @@ void TritonAMDGPUCanonicalizePointersPass::runOnOperation() {
15211567
RewritePatternSet patterns(&getContext());
15221568
patterns.add<
15231569
ConvertFuncOpArgsUnrealizedCasts, ConvertBroadcastOp, ConvertSplatOp,
1524-
ConvertConvertLayoutOp, ConvertAddPtrOp,
1570+
ConvertConvertLayoutOp, ConvertAddPtrOp, ConvertExtractSliceOp,
15251571
MaterializeFatPointer<tt::AtomicCASOp>,
15261572
MaterializeFatPointer<tt::AtomicRMWOp>,
15271573
MaterializeFatPointer<tt::BitcastOp>, MaterializeFatPointer<tt::LoadOp>,

0 commit comments

Comments
 (0)