Skip to content

Commit ddc8f87

Browse files
Merge commit '21119e31f2e5d517643749657403d4e43deb13d4'
2 parents 45aeede + 21119e3 commit ddc8f87

File tree

16 files changed

+489
-132
lines changed

16 files changed

+489
-132
lines changed

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,10 @@ def TT_ReduceOp: TT_Op<"reduce",
731731
llvm::SmallVector<RankedTensorType> getInputTypes();
732732
llvm::SmallVector<Type> getElementTypes();
733733
unsigned getNumOperands();
734+
735+
// Returns the CombineOp iff this ReduceOp's region contains only
736+
// one CombineOp other than the return, or nullptr if not applicable.
737+
::mlir::Operation *getSingleCombiner();
734738
}];
735739
}
736740

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -772,14 +772,6 @@ def MmaEncodingTrait : AttrInterface<"MmaEncodingTrait"> {
772772
"int":$kWidth,
773773
"int":$opIdx)>,
774774

775-
InterfaceMethod<"Return total element size per thread for dot operands.",
776-
"unsigned",
777-
"getTotalElemsPerThreadForOperand",
778-
(ins "ArrayRef<int64_t>":$tensorShape,
779-
"Type":$eltTy,
780-
"int":$kWidth,
781-
"int":$opIdx)>,
782-
783775
InterfaceMethod<"Return size per thread for dot operands.",
784776
"SmallVector<unsigned>",
785777
"getSizePerThreadForOperand",
@@ -1156,7 +1148,6 @@ For example, the matrix L corresponding to blockTileSize=[32,16] is:
11561148
};
11571149
SmallVector<unsigned> getSizePerThreadForOperand(int kWidth, int opIdx) const;
11581150
SmallVector<unsigned> getShapePerCTATileForOperand(ArrayRef<int64_t> shape, int kWidth, int opIdx) const;
1159-
unsigned getTotalElemsPerThreadForOperand(ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const;
11601151

11611152
SmallVector<unsigned> getContigPerThread() {
11621153
assert(isAmpere() || isHopper());

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -503,6 +503,22 @@ llvm::SmallVector<Type> ReduceOp::getElementTypes() {
503503
return getElementTypesImpl(this->getOperands());
504504
}
505505

506+
::mlir::Operation *ReduceOp::getSingleCombiner() {
507+
if (getNumOperands() != 1 || getNumResults() != 1)
508+
return nullptr;
509+
Block *block = &(*getCombineOp().begin());
510+
Operation *yield = block->getTerminator();
511+
Operation *reduceOp = yield->getOperand(0).getDefiningOp();
512+
if (!reduceOp || reduceOp->getNumOperands() != 2 ||
513+
reduceOp->getNumResults() != 1)
514+
return nullptr;
515+
if (reduceOp->getOperand(0) != block->getArgument(0) ||
516+
reduceOp->getOperand(1) != block->getArgument(1))
517+
return nullptr;
518+
519+
return reduceOp;
520+
}
521+
506522
unsigned ReduceOp::getNumOperands() { return this->getOperands().size(); }
507523

508524
//-- ScanOp --

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -951,11 +951,11 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
951951
elemsPerThread[rank - 1] = (idx == 0) ? rep[2] * kWidth : rep[2];
952952
return elemsPerThread;
953953
} else if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(parent)) {
954-
if (mma.isAmpere()) {
954+
if (mma.isAmpere() || mma.isHopper()) {
955955
auto bitwidth = getPointeeType(eltTy).getIntOrFloatBitWidth();
956956
auto rep = mma.getRepForOperand(shape, bitwidth, idx);
957957
auto sizePerThread = getSizePerThread();
958-
auto elemsPerKRep = 32 / bitwidth * 2;
958+
auto elemsPerKRep = mma.isHopper() ? (kWidth * 2) : (32 / bitwidth * 2);
959959
if (rank == 3)
960960
elemsPerThread[0] = rep[0];
961961
elemsPerThread[rank - 2] =
@@ -980,12 +980,18 @@ DotOperandEncodingAttr::getElemsPerThread(ArrayRef<int64_t> shape,
980980
unsigned DotOperandEncodingAttr::getTotalElemsPerThread(ArrayRef<int64_t> shape,
981981
Type eltTy) const {
982982
if (auto mmaParent = mlir::dyn_cast<MmaEncodingTrait>(getParent())) {
983-
if (auto nvidiaMmaParent = mlir::dyn_cast<NvidiaMmaEncodingAttr>(mmaParent);
984-
nvidiaMmaParent && nvidiaMmaParent.isAmpere()) {
983+
if (auto nvidiaMmaParent =
984+
mlir::dyn_cast<NvidiaMmaEncodingAttr>(mmaParent)) {
985985
return product<unsigned>(getElemsPerThread(shape, eltTy));
986986
}
987-
return mmaParent.getTotalElemsPerThreadForOperand(shape, eltTy, getKWidth(),
988-
getOpIdx());
987+
if (auto amdMfmaParent = mlir::dyn_cast<AMDMfmaEncodingAttr>(getParent())) {
988+
return amdMfmaParent.getTotalElemsPerThreadForOperand(
989+
shape, eltTy, getKWidth(), getOpIdx());
990+
}
991+
if (auto amdWmmaParent = mlir::dyn_cast<AMDWmmaEncodingAttr>(getParent())) {
992+
return amdWmmaParent.getTotalElemsPerThreadForOperand(
993+
shape, eltTy, getKWidth(), getOpIdx());
994+
}
989995
}
990996
if (auto blockedLayout = mlir::dyn_cast<BlockedEncodingAttr>(getParent())) {
991997
auto shapePerCTA = getShapePerCTA(*this, shape);
@@ -2021,26 +2027,9 @@ NvidiaMmaEncodingAttr::getRepForOperand(ArrayRef<int64_t> shape, int bitwidth,
20212027
}
20222028
}
20232029

2024-
unsigned NvidiaMmaEncodingAttr::getTotalElemsPerThreadForOperand(
2025-
ArrayRef<int64_t> shape, Type eltTy, int kWidth, int opIdx) const {
2026-
auto shapePerCTA = getShapePerCTA(*this, shape);
2027-
int warpsPerCTAM = getWarpsPerCTA()[0];
2028-
int warpsPerCTAN = getWarpsPerCTA()[1];
2029-
// H100
2030-
if (isHopper()) {
2031-
assert(opIdx == 0);
2032-
auto instrMNK = getInstrShape();
2033-
int repM = ceil<unsigned>(shapePerCTA[0], instrMNK[0] * warpsPerCTAM);
2034-
int repK = ceil<unsigned>(shapePerCTA[1], instrMNK[2]);
2035-
// For each WGMMA instr, a 2x2 matrix fragment is loaded. Each thread holds
2036-
// kWidth elements for each quadrant. WGMMA is repeated repM * repK times.
2037-
return 4 * kWidth * repM * repK;
2038-
}
2039-
llvm_unreachable("unknown mma layout");
2040-
}
20412030
SmallVector<unsigned> NvidiaMmaEncodingAttr::getShapePerCTATileForOperand(
20422031
ArrayRef<int64_t> shape, int kWidth, int opIdx) const {
2043-
assert(isAmpere() && "mmaLayout version = 1 is not implemented yet");
2032+
assert(isAmpere() && "mmaLayout Hopper is not implemented yet");
20442033
auto shapePerCTATile = getShapePerCTATile(shape);
20452034
auto rank = shapePerCTATile.size();
20462035
auto kDim = opIdx == 0 ? rank - 1 : rank - 2;
@@ -2050,7 +2039,6 @@ SmallVector<unsigned> NvidiaMmaEncodingAttr::getShapePerCTATileForOperand(
20502039
}
20512040
SmallVector<unsigned>
20522041
NvidiaMmaEncodingAttr::getSizePerThreadForOperand(int kWidth, int opIdx) const {
2053-
assert(isAmpere() && "mmaLayout version = 1 is not implemented yet");
20542042
auto rank = getWarpsPerCTA().size();
20552043
auto sizePerThread = SmallVector<unsigned>(rank, 1);
20562044
if (opIdx == 0) {

python/test/unit/language/test_core.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4136,13 +4136,12 @@ def _kernel(dst, src, CACHE: tl.constexpr):
41364136
cv_cache_modifier_str = 'sc0 sc1'
41374137
buffer_load_line = [line for line in amdgcn.splitlines() if "buffer_load" in line]
41384138
global_load_line = [line for line in amdgcn.splitlines() if "global_load" in line]
4139-
flat_load_line = [line for line in amdgcn.splitlines() if "flat_load" in line]
41404139
if cache == '' or cache == '.ca':
41414140
assert cg_cache_modifier_str not in (global_load_line[0] if global_load_line else buffer_load_line[0])
41424141
if cache == '.cg':
41434142
assert cg_cache_modifier_str in global_load_line[0]
41444143
if cache == '.cv':
4145-
assert cv_cache_modifier_str in flat_load_line[0]
4144+
assert cv_cache_modifier_str in global_load_line[0]
41464145

41474146
if is_cuda():
41484147
ptx = pgm.asm['ptx']

test/Conversion/amd/load_store.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 :
1515
%7 = tt.splat %arg1 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
1616
%8 = tt.addptr %7, %4 : tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xi32, #blocked0>
1717
// Load 8 elements from A with two vectorized load instruction
18-
// CHECK-COUNT-2: llvm.intr.masked.load {{.*}} : (!llvm.ptr, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
18+
// CHECK-COUNT-2: llvm.intr.masked.load {{.*}} : (!llvm.ptr<1>, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
1919
%9 = tt.load %6 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #blocked0>
2020
// Load 8 elements from B with two vectorized load instruction
21-
// CHECK-COUNT-2: llvm.intr.masked.load {{.*}} : (!llvm.ptr, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
21+
// CHECK-COUNT-2: llvm.intr.masked.load {{.*}} : (!llvm.ptr<1>, vector<4xi1>, vector<4xf32>) -> vector<4xf32>
2222
%10 = tt.load %8 {cache = 1 : i32, evict = 1 : i32, isVolatile = false} : tensor<256x!tt.ptr<f32>, #blocked0>
2323
%11 = arith.addf %9, %10 : tensor<256xf32, #blocked0>
2424
%12 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<256x!tt.ptr<f32>, #blocked0>
@@ -51,7 +51,7 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
5151
%105 = tt.splat %arg0 : !tt.ptr<f16> -> tensor<32x32x!tt.ptr<f16>, #mma>
5252
%106 = tt.addptr %105, %104 : tensor<32x32x!tt.ptr<f16>, #mma>, tensor<32x32xi32, #mma>
5353
// Store 16 elements with four vectorized store instruction
54-
// CHECK-COUNT-4: llvm.intr.masked.store {{.*}}, {{.*}}, {{.*}} {alignment = 16 : i32} : vector<4xf16>, vector<4xi1> into !llvm.ptr
54+
// CHECK-COUNT-4: llvm.intr.masked.store {{.*}}, {{.*}}, {{.*}} {alignment = 16 : i32} : vector<4xf16>, vector<4xi1> into !llvm.ptr<1>
5555
tt.store %106, %2 : tensor<32x32x!tt.ptr<f16>, #mma>
5656
tt.return
5757
}

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 78 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,8 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
2525
// CHECK: llvm.cond_br
2626
// CHECK: llvm.atomicrmw
2727
// CHECK: llvm.atomicrmw
28-
// CHECK: %[[ADDR1:.*]] = llvm.addrspacecast
29-
// CHECK: llvm.intr.masked.store %{{.*}}, %[[ADDR1]]
30-
// CHECK: %[[ADDR2:.*]] = llvm.addrspacecast
31-
// CHECK: llvm.intr.masked.store %{{.*}}, %[[ADDR2]]
28+
// CHECK: llvm.intr.masked.store
29+
// CHECK: llvm.intr.masked.store
3230
%0 = tt.atomic_rmw fadd, relaxed, gpu, %arg0, %arg2, %arg1 : (tensor<256x!tt.ptr<f32>, #blocked0>, tensor<256xf32, #blocked0>, tensor<256xi1, #blocked0>) -> tensor<256xf32, #blocked0>
3331
tt.store %arg0, %0 : tensor<256x!tt.ptr<f32>, #blocked0>
3432
tt.return
@@ -134,3 +132,79 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
134132
tt.return
135133
}
136134
}
135+
136+
// -----
137+
138+
#blocked3 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
139+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
140+
// CHECK-LABEL: reduce_dpp_max
141+
tt.func @reduce_dpp_max(%arg0: tensor<64xf32, #blocked3>) {
142+
// CHECK: rocdl.update.dpp
143+
// CHECK-SAME: with 280, 15, 15, true : f32
144+
// CHECK-NEXT: llvm.intr.maxnum
145+
146+
// CHECK-NEXT: rocdl.update.dpp
147+
// CHECK-SAME: with 276, 15, 15, true : f32
148+
// CHECK-NEXT: llvm.intr.maxnum
149+
150+
// CHECK-NEXT: rocdl.update.dpp
151+
// CHECK-SAME: with 274, 15, 15, true : f32
152+
// CHECK-NEXT: llvm.intr.maxnum
153+
154+
// CHECK-NEXT: rocdl.update.dpp
155+
// CHECK-SAME: with 273, 15, 15, true : f32
156+
// CHECK-NEXT: llvm.intr.maxnum
157+
158+
// CHECK-NEXT: rocdl.update.dpp
159+
// CHECK-SAME: with 322, 10, 15, true : f32
160+
// CHECK-NEXT: llvm.intr.maxnum
161+
162+
// CHECK-NEXT: rocdl.update.dpp
163+
// CHECK-SAME: with 323, 15, 15, true : f32
164+
// CHECK-NEXT: llvm.intr.maxnum
165+
166+
// CHECK: llvm.amdgcn.readlane
167+
%0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
168+
^bb0(%arg1: f32, %arg2: f32):
169+
%1 = arith.maxnumf %arg1, %arg2 : f32
170+
tt.reduce.return %1 : f32
171+
}) : (tensor<64xf32, #blocked3>) -> f32
172+
tt.return
173+
}
174+
}
175+
176+
// -----
177+
178+
#blocked4 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [64], warpsPerCTA = [1], order = [0]}>
179+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-warp" = 64 : i32} {
180+
// CHECK-LABEL: reduce_xor_max
181+
tt.func @reduce_xor_max(%arg0: tensor<32xf32, #blocked4>) {
182+
// CHECK: rocdl.ds_swizzle
183+
// CHECK: llvm.intr.maxnum
184+
185+
// CHECK: rocdl.update.dpp
186+
// CHECK-SAME: with 280, 15, 12, false : i32
187+
// CHECK: rocdl.update.dpp
188+
// CHECK-SAME: with 264, 15, 3, false : i32
189+
// CHECK: llvm.intr.maxnum
190+
191+
// CHECK: rocdl.update.dpp
192+
// CHECK-SAME: with 276, 15, 10, false : i32
193+
// CHECK: rocdl.update.dpp
194+
// CHECK-SAME: with 260, 15, 5, false : i32
195+
// CHECK: llvm.intr.maxnum
196+
197+
// CHECK: rocdl.update.dpp
198+
// CHECK-SAME: with 78, 15, 15, false : i32
199+
// CHECK: llvm.intr.maxnum
200+
201+
// CHECK: rocdl.update.dpp
202+
// CHECK-SAME: with 177, 15, 15, false : i32
203+
%0 = "tt.reduce"(%arg0) <{axis = 0 : i32}> ({
204+
^bb0(%arg1: f32, %arg2: f32):
205+
%1 = arith.maxnumf %arg1, %arg2 : f32
206+
tt.reduce.return %1 : f32
207+
}) : (tensor<32xf32, #blocked4>) -> f32
208+
tt.return
209+
}
210+
}

test/TritonGPU/amd/amd-sched-2nd-load.mlir

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,16 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
4949
}
5050
}
5151

52+
// -----
53+
54+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
55+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}>
56+
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}>
57+
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>
58+
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}>
59+
#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>
60+
#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>
61+
5262
// Should apply: tile size 256x256x64 with single dot
5363
// CHECK-LABEL: sink_2nd_load_256x256x64
5464
// CHECK: %[[tileA:.*]] = tt.load
@@ -78,6 +88,16 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
7888
}
7989
}
8090

91+
// -----
92+
93+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
94+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}>
95+
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}>
96+
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>
97+
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}>
98+
#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>
99+
#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>
100+
81101
// Should NOT apply: tile size 256x64x128 with single dot
82102
// CHECK-LABEL: sink_2nd_load_256x64x128
83103
// CHECK: %[[tileA:.*]] = tt.load
@@ -107,6 +127,16 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
107127
}
108128
}
109129

130+
// -----
131+
132+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
133+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}>
134+
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}>
135+
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>
136+
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}>
137+
#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>
138+
#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>
139+
110140
// Should NOT apply: tile size 256x256x32 with single dot
111141
// CHECK-LABEL: sink_2nd_load_256x256x32
112142
// CHECK: %[[tileA:.*]] = tt.load
@@ -136,6 +166,15 @@ module attributes {"triton_gpu.num-warps" = 1 : i32, "triton_gpu.threads-per-war
136166
}
137167
}
138168

169+
// -----
170+
171+
#blocked = #triton_gpu.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [1, 0]}>
172+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [8, 1], threadsPerWarp = [8, 8], warpsPerCTA = [1, 1], order = [0, 1]}>
173+
#mma = #triton_gpu.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [1, 1], instrShape = [16, 16], isTransposed = true}>
174+
#shared = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [1, 0], hasLeadingOffset = false}>
175+
#shared1 = #triton_gpu.shared<{vec = 8, perPhase = 1, maxPhase = 8, order = [0, 1], hasLeadingOffset = false}>
176+
#dotOp0 = #triton_gpu.dot_op<{opIdx = 0, parent = #mma, kWidth = 8}>
177+
#dotOp1 = #triton_gpu.dot_op<{opIdx = 1, parent = #mma, kWidth = 8}>
139178

140179
// Category 2: single dot with two loads and tile size is large enough (128x128x128).
141180
// We make sure the move is legal.

third_party/amd/include/TritonAMDGPUToLLVM/TargetUtils.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,17 @@ enum class ISAFamily {
1919
// Deduces the corresponding ISA family for the given target gfx |arch|.
2020
ISAFamily deduceISAFamily(llvm::StringRef arch);
2121

22+
// Here is a partial definition of DppCtrl enums. For the complete definition,
23+
// please check:
24+
// https://github.com/llvm/llvm-project/blob/8c75290/llvm/lib/Target/AMDGPU/SIDefines.h#L939
25+
enum class DppCtrl : uint32_t {
26+
QUAD_PERM_FIRST = 0,
27+
ROW_SHL0 = 0x100,
28+
ROW_SHR0 = 0x110,
29+
BCAST15 = 0x142,
30+
BCAST31 = 0x143
31+
};
32+
2233
} // namespace mlir::triton::AMD
2334

2435
#endif // TRITON_CONVERSION_TRITONGPU_TO_LLVM_TARGETUTILS_H

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
300300
assert(wordNElems * nWords * numVecs == numElems);
301301

302302
Value pred = mask ? maskElems[vecStart] : int_val(1, 1);
303-
Value ptr = addrspacecast(ptr_ty(getContext()), ptrElems[vecStart]);
303+
Value ptr = ptrElems[vecStart];
304304

305305
Value falseVal = createZeroVector(rewriter, loc, cast<VectorType>(vecTy));
306306
// If we need to mask the loaded value with other elements
@@ -477,7 +477,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
477477

478478
SmallVector<std::pair<Value, std::string>> asmArgs;
479479
Value elem = valueElems[vecStart];
480-
Value ptr = addrspacecast(ptr_ty(getContext()), ptrElems[vecStart]);
480+
Value ptr = ptrElems[vecStart];
481481

482482
// Create the store val
483483
Value storeVal = packElementRangeIntoVector(

0 commit comments

Comments
 (0)