Skip to content

Commit 6d5fb9f

Browse files
authored
[Gluon][Blackwell] Add _reinterpret to tmem descriptor and fix its lowering for TMEM (#7160)
1 parent 0e868f8 commit 6d5fb9f

File tree

8 files changed

+71
-9
lines changed

8 files changed

+71
-9
lines changed

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,6 +298,7 @@ def TTG_MemDescReinterpretOp : TTG_Op<"memdesc_reinterpret", [Pure, MemDescViewT
298298
}];
299299

300300
let hasVerifier = 1;
301+
let hasFolder = 1;
301302
}
302303

303304
def TTG_LocalLoadOp : TTG_Op<"local_load"> {

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,12 @@ LogicalResult MemDescReinterpretOp::verify() {
495495
return success();
496496
}
497497

498+
OpFoldResult MemDescReinterpretOp::fold(FoldAdaptor adaptor) {
499+
if (getType() == getSrc().getType())
500+
return getSrc();
501+
return {};
502+
}
503+
498504
// LocalAllocOp
499505
void LocalAllocOp::getEffects(
500506
SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>

python/triton/experimental/gluon/language/nvidia/blackwell/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,16 @@ def index(self, index, _semantic: GluonSemantic = None) -> tensor_memory_descrip
157157
ret.handle = builder.create_memdesc_subview(ret.type.to_ir(builder), self.handle, offsets)
158158
return ret
159159

160+
@builtin
161+
def _reinterpret(self, dtype, shape, layout, _semantic: GluonSemantic = None) -> tensor_memory_descriptor:
162+
dtype = _unwrap_if_constexpr(dtype)
163+
shape = [_unwrap_if_constexpr(s) for s in shape]
164+
layout = _unwrap_if_constexpr(layout)
165+
166+
ty = tensor_memory_descriptor_type(dtype, shape, layout, shape)
167+
handle = _semantic.builder.create_memdesc_reinterpret(ty.to_ir(_semantic.builder), self.handle)
168+
return tensor_memory_descriptor(handle, **ty.__dict__)
169+
160170

161171
@builtin
162172
def allocate_tensor_memory(element_ty, shape, layout, value=None, _semantic=None):

test/Conversion/tritongpu_to_llvm_blackwell.mlir

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -591,3 +591,18 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32} {
591591
tt.return
592592
}
593593
}
594+
595+
// -----
596+
597+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
598+
599+
module attributes {"ttg.num-warps" = 4 : i32} {
600+
601+
// CHECK-LABEL: @reinterpret
602+
tt.func private @reinterpret(%arg0: !ttg.memdesc<128xf32, #tmem, #ttng.tensor_memory>) -> !ttg.memdesc<16x16xf16, #tmem, #ttng.tensor_memory> {
603+
%0 = ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<128xf32, #tmem, #ttng.tensor_memory> -> !ttg.memdesc<16x16xf16, #tmem, #ttng.tensor_memory>
604+
// CHECK-NEXT: return %arg0
605+
tt.return %0 : !ttg.memdesc<16x16xf16, #tmem, #ttng.tensor_memory>
606+
}
607+
608+
}
Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,23 @@
11
// RUN: triton-opt %s -canonicalize | FileCheck %s
22

3-
// CHECK-LABEL: @test_dce_tmem_alloc
4-
// CHECK-NOT: ttng.tmem_alloc
5-
// CHECK: tt.return
63
#linear = #ttg.linear<{register = [[0, 1], [0, 2], [32, 0], [64, 0]], lane = [[1, 0], [2, 0], [4, 0], [8, 0], [16, 0]], warp = [[0, 0], [0, 0]], block = []}>
74
#tmem_scales = #ttng.tensor_memory_scales_encoding<>
5+
#tmem = #ttng.tensor_memory_encoding<blockM = 128, blockN = 128, unpacked = true>
86
module attributes {"ttg.num-warps" = 8 : i32, "ttg.num-ctas" = 1 : i32, "ttg.target" = "cuda:80"} {
7+
8+
// CHECK-LABEL: @test_dce_tmem_alloc
99
tt.func @test_dce_tmem_alloc(%arg: tensor<128x4xi8, #linear>) {
10-
%a = ttng.tmem_alloc %arg : (tensor<128x4xi8, #linear>) -> !ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory>
11-
tt.return
10+
// CHECK-NOT: ttng.tmem_alloc
11+
%a = ttng.tmem_alloc %arg : (tensor<128x4xi8, #linear>) -> !ttg.memdesc<128x4xi8, #tmem_scales, #ttng.tensor_memory>
12+
// CHECK-NEXT: tt.return
13+
tt.return
1214
}
15+
16+
// CHECK-LABEL: @reinterpret_fold
17+
tt.func @reinterpret_fold(%arg0: !ttg.memdesc<128xf32, #tmem, #ttng.tensor_memory>) -> !ttg.memdesc<128xf32, #tmem, #ttng.tensor_memory> {
18+
%0 = ttg.memdesc_reinterpret %arg0 : !ttg.memdesc<128xf32, #tmem, #ttng.tensor_memory> -> !ttg.memdesc<128xf32, #tmem, #ttng.tensor_memory>
19+
// CHECK-NEXT: return %arg0
20+
tt.return %0 : !ttg.memdesc<128xf32, #tmem, #ttng.tensor_memory>
21+
}
22+
1323
} // end module

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,5 @@ add_triton_library(TritonNVIDIAGPUToLLVM
2727
LINK_LIBS PUBLIC
2828
TritonGPUToLLVM
2929
TritonProtonToLLVM
30+
MLIRReconcileUnrealizedCasts
3031
)

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/ConvertWarpSpecializeToLLVM.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
#include "TargetInfo.h"
22
#include "TritonNVIDIAGPUToLLVM/PTXAsmFormat.h"
33
#include "mlir/Analysis/TopologicalSortUtils.h"
4+
#include "mlir/Conversion/Passes.h"
45
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
56
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
67
#include "mlir/IR/BuiltinOps.h"
78
#include "mlir/IR/ImplicitLocOpBuilder.h"
9+
#include "mlir/Pass/PassManager.h"
810
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
911
#include "triton/Conversion/TritonGPUToLLVM/Passes.h"
1012
#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"
@@ -561,10 +563,9 @@ struct ConvertWarpSpecializeToLLVM
561563
if (isa<WarpSpecializeOp, WarpSpecializePartitionsOp, WarpYieldOp>(op))
562564
convertOpTypes(op, typeConverter);
563565
});
564-
RewritePatternSet patterns(&getContext());
565-
UnrealizedConversionCastOp::getCanonicalizationPatterns(patterns,
566-
&getContext());
567-
if (failed(applyPatternsGreedily(mod, std::move(patterns))))
566+
OpPassManager pm;
567+
pm.addPass(createReconcileUnrealizedCastsPass());
568+
if (failed(runPipeline(pm, mod)))
568569
return signalPassFailure();
569570

570571
SmallVector<LLVM::LLVMFuncOp> kernels;

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TensorMemoryToLLVM.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -884,6 +884,23 @@ struct MemDescSubviewOpConversion
884884
}
885885
};
886886

887+
class MemDescReinterpretOpConversion
888+
: public ConvertOpToLLVMPattern<MemDescReinterpretOp> {
889+
public:
890+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
891+
892+
LogicalResult
893+
matchAndRewrite(MemDescReinterpretOp op, OpAdaptor adaptor,
894+
ConversionPatternRewriter &rewriter) const override {
895+
if (!isa<triton::nvidia_gpu::TensorMemoryEncodingAttr>(
896+
op.getSrc().getType().getEncoding())) {
897+
return failure();
898+
}
899+
rewriter.replaceOp(op, adaptor.getSrc());
900+
return success();
901+
}
902+
};
903+
887904
struct TMEMSubSliceOpConversion
888905
: public ConvertOpToLLVMPattern<triton::nvidia_gpu::TMEMSubSliceOp> {
889906
using ConvertOpToLLVMPattern<
@@ -937,5 +954,6 @@ void mlir::triton::NVIDIA::populateTensorMemorySubviewOpToLLVMPattern(
937954
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
938955
PatternBenefit benefit) {
939956
patterns.add<MemDescSubviewOpConversion>(typeConverter, benefit);
957+
patterns.add<MemDescReinterpretOpConversion>(typeConverter, benefit);
940958
return;
941959
}

0 commit comments

Comments
 (0)