Skip to content

Commit 6a9a7d8

Browse files
Copybarachsigg
authored andcommitted
OpenXLA-specific changes.
PiperOrigin-RevId: 741163570
1 parent 4e364a7 commit 6a9a7d8

File tree

36 files changed

+3581
-928
lines changed

36 files changed

+3581
-928
lines changed

BUILD

Lines changed: 934 additions & 0 deletions
Large diffs are not rendered by default.

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -587,15 +587,17 @@ We call each individual tile "rep".
587587
"unsigned",
588588
"getTotalElemsPerThread",
589589
(ins "ArrayRef<int64_t>":$shape),
590+
/*methodBody=*/[{}],
590591
/*defaultImplementation=*/[{
591-
return toLinearEncoding($_self, shape).getTotalElemsPerThread(shape);
592+
return toLinearEncoding($_attr, shape).getTotalElemsPerThread(shape);
592593
}]>,
593594
InterfaceMethod<"Return element size per thread in each dimension.",
594595
"SmallVector<unsigned>",
595596
"getElemsPerThread",
596597
(ins "ArrayRef<int64_t>":$shape),
598+
/*methodBody=*/[{}],
597599
/*defaultImplementation=*/[{
598-
return toLinearEncoding($_self, shape).getElemsPerThread(shape);
600+
return toLinearEncoding($_attr, shape).getElemsPerThread(shape);
599601
}]>,
600602
InterfaceMethod<"Convert to LinearLayout.",
601603
"LinearLayout",

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,12 +57,14 @@ LinearEncodingAttr toLinearEncoding(Attribute layout, ArrayRef<int64_t> shape) {
5757
}
5858

5959
unsigned getTotalElemsPerThread(Attribute layout, ArrayRef<int64_t> shape) {
60-
return toLinearEncoding(layout, shape).getTotalElemsPerThread(shape);
60+
auto distributedEncoding = mlir::cast<DistributedEncodingTrait>(layout);
61+
return distributedEncoding.getTotalElemsPerThread(shape);
6162
}
6263

6364
SmallVector<unsigned> getElemsPerThread(Attribute layout,
6465
ArrayRef<int64_t> shape) {
65-
return toLinearEncoding(layout, shape).getElemsPerThread(shape);
66+
auto distributedEncoding = mlir::cast<DistributedEncodingTrait>(layout);
67+
return distributedEncoding.getElemsPerThread(shape);
6668
}
6769

6870
SmallVector<unsigned> getElemsPerThread(Type type) {

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,11 @@ struct CanonicalizeConvertFromAlloc
160160
auto convert = op.getSrc().getDefiningOp<ConvertLayoutOp>();
161161
if (!convert)
162162
return failure();
163+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
164+
// to SharedEncoding, so we want to keep this layout conversion.
165+
if (mlir::isa<triton::gpu::DotOperandEncodingAttr>(
166+
convert.getSrc().getType().getEncoding()))
167+
return failure();
163168
rewriter.replaceOpWithNewOp<triton::gpu::LocalAllocOp>(
164169
op, op->getResult(0).getType(), convert.getSrc());
165170
return mlir::success();

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,21 @@ getSharedMemoryMMAOperand(Value v, mlir::PatternRewriter &rewriter, int opIdx,
185185
auto newType = MemDescType::get(argType.getShape(), argType.getElementType(),
186186
newLayout, SharedMemorySpace);
187187
rewriter.setInsertionPointAfterValue(arg);
188+
189+
// LocalAllocOp lowering doesn't support going from DotOperandEncoding
190+
// to SharedEncoding.
191+
if (auto dotOpEnc = mlir::dyn_cast<DotOperandEncodingAttr>(
192+
argType.getEncoding())) {
193+
// Create a layout conversion from DotOperandEncoding to BlockedEncoding
194+
// then pass it to the LocalAllocOp.
195+
auto newArgType = RankedTensorType::get(
196+
argType.getShape(), argType.getElementType(), dotOpEnc.getParent());
197+
auto dotOperandToBlockedCvt =
198+
rewriter.create<ConvertLayoutOp>(arg.getLoc(), newArgType, arg);
199+
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType,
200+
dotOperandToBlockedCvt);
201+
}
202+
188203
return rewriter.create<LocalAllocOp>(arg.getLoc(), newType, arg);
189204
}
190205

@@ -222,9 +237,22 @@ getWarpsPerTile(DotOp dotOp, const ArrayRef<int64_t> shape, int version,
222237
}
223238

224239
static bool bwdFilter(Operation *op) {
240+
// Dot operand layout assignment to Predicates are not currently supported
241+
// during lowering from TritonGPU to LLVM in Triton for MMA cases. This
242+
// condition limits visibility of the original bit-width so that predicate
243+
// are not considered, hence, kwidth can never be = 32.
244+
if (isa<arith::UIToFPOp>(op)) {
245+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
246+
if (srcType.isInteger(1))
247+
return false;
248+
}
249+
250+
// b/405045790: We don't want to propagate through the BroadcastOp because we
251+
// probably don't care about the load before a broadcast as it would likely be
252+
// small. This is just a heuristic to avoid a regression.
225253
return (op->hasTrait<OpTrait::Elementwise>() && isMemoryEffectFree(op)) ||
226254
isView(op) ||
227-
isa<Fp4ToFpOp, LoadOp, DescriptorLoadOp, BroadcastOp, ConvertLayoutOp>(
255+
isa<Fp4ToFpOp, LoadOp, DescriptorLoadOp, /*BroadcastOp,*/ ConvertLayoutOp>(
228256
op);
229257
}
230258

lib/Dialect/TritonGPU/Transforms/Prefetch.cpp

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,14 @@ Value Prefetcher::generatePrefetch(Value v, unsigned opIdx, bool isPrologue,
147147
type.getMutableMemory(), type.getAllocShape()),
148148
v, offsetsVal);
149149

150+
// We need to assign kwidth to zero in the case where the parent layout is
151+
// Blocked, otherwise the verifier emits a failure. The parent layout is
152+
// Blocked only when Tensor Cores are disabled.
153+
int kwidth = dyn_cast<triton::gpu::BlockedEncodingAttr>(dotEncoding)
154+
? 0
155+
: prefetchWidth / 8;
150156
auto dotOperandEnc = triton::gpu::DotOperandEncodingAttr::get(
151-
builder.getContext(), opIdx, dotEncoding, prefetchWidth / 8);
157+
builder.getContext(), opIdx, dotEncoding, kwidth);
152158
Value prefetchSlice = builder.create<triton::gpu::LocalLoadOp>(
153159
v.getLoc(), RankedTensorType::get(shape, elementType, dotOperandEnc),
154160
newSmem);
@@ -198,6 +204,22 @@ LogicalResult Prefetcher::initialize() {
198204
break;
199205
if (!op->getResult(0).hasOneUse())
200206
break;
207+
// Similar to issues faced in HoistLayoutConversion pattern in
208+
// OptimizeDotOperands.cpp, we can't propagate through type casts from
209+
// predicates as they aren't supported in Triton when encoded with dot_op
210+
// layout.
211+
if (isa<arith::UIToFPOp>(op)) {
212+
Type srcType = getElementTypeOrSelf(op->getOperand(0));
213+
if (srcType.isInteger(1))
214+
break;
215+
}
216+
// Propagation through ExpandDims is currently not supported. This blindly
217+
// replaces the encoding with dot encoding & but ExpandDims requires a
218+
// SliceEncoding. This could be rewritten to support it somehow, but I
219+
// don't think it's trivial & it's currently crashing.
220+
if (isa<ExpandDimsOp>(op)) {
221+
break;
222+
}
201223
rets.push_back(op->getOperand(0));
202224
if (auto cvt = dyn_cast<triton::gpu::LocalLoadOp>(op)) {
203225
// NYI for other encodings, for example if we have transpose

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1121,8 +1121,7 @@ void LayoutRematerialization::hoistConvertDotOperand(
11211121
// The pass is targeted to Nvidia mma/wgmma dot operands
11221122

11231123
auto canBePipelined = [&](ConvertLayoutOp convertOp) {
1124-
// FIXME: Check that the parent is a for loop
1125-
auto parent = convertOp->getParentOp();
1124+
auto parent = dyn_cast<scf::ForOp>(convertOp->getParentOp());
11261125
if (!parent)
11271126
return false;
11281127

lib/Dialect/TritonNvidiaGPU/Transforms/FenceInsertion.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ struct FenceInsertionPass
4141
if (::triton::tools::getBoolEnv("DISABLE_MMA_V3"))
4242
return;
4343
ModuleOp mod = getOperation();
44+
DenseSet<std::pair<Operation *, unsigned>> trace;
4445
mod.walk([&](Operation *op) {
4546
if (!isa<ttng::WarpGroupDotOp>(op))
4647
return WalkResult::advance();
@@ -51,8 +52,8 @@ struct FenceInsertionPass
5152
cast<RankedTensorType>(op->getResult(0).getType()).getEncoding());
5253
if (!mmaEncoding || !mmaEncoding.isHopper())
5354
return WalkResult::advance();
54-
bool aDependsOnShared = dependOnSharedEncOperand(a);
55-
bool bDependsOnShared = dependOnSharedEncOperand(b);
55+
bool aDependsOnShared = dependOnSharedEncOperand(a, trace);
56+
bool bDependsOnShared = dependOnSharedEncOperand(b, trace);
5657
if (!aDependsOnShared && !bDependsOnShared)
5758
return WalkResult::advance();
5859
Operation *fence = builder.create<ttng::FenceAsyncSharedOp>(
@@ -73,8 +74,7 @@ struct FenceInsertionPass
7374
}
7475

7576
private:
76-
bool dependOnSharedEncOperand(Value operand) {
77-
static DenseSet<std::pair<Operation *, unsigned>> trace;
77+
bool dependOnSharedEncOperand(Value operand, DenseSet<std::pair<Operation *, unsigned>> &trace) {
7878
auto op = operand.getDefiningOp();
7979
// avoid redundant insertion
8080
if (op && isa<mlir::triton::DotOpInterface>(op))
@@ -89,7 +89,7 @@ struct FenceInsertionPass
8989
// op and not BlockArgument
9090
if (op && !isa<BlockArgument>(operand)) {
9191
for (auto v : op->getOperands()) {
92-
if (dependOnSharedEncOperand(v))
92+
if (dependOnSharedEncOperand(v, trace))
9393
return true;
9494
}
9595
}
@@ -104,7 +104,7 @@ struct FenceInsertionPass
104104
auto iterOperands = forOp.getInitArgs();
105105
if (argNum == 0)
106106
return false;
107-
if (dependOnSharedEncOperand(iterOperands[argNum - 1]))
107+
if (dependOnSharedEncOperand(iterOperands[argNum - 1], trace))
108108
return true;
109109
// yield
110110
auto yieldOp = forOp.getBody()->getTerminator();
@@ -117,7 +117,7 @@ struct FenceInsertionPass
117117
else
118118
trace.insert(entry);
119119

120-
if (dependOnSharedEncOperand(v))
120+
if (dependOnSharedEncOperand(v, trace))
121121
return true;
122122
} else if (auto whileOp = dyn_cast<scf::WhileOp>(argOwner)) {
123123
assert(false && "FenceInsertionPass does not supported WhileOp");

lib/Dialect/TritonNvidiaGPU/Transforms/OptimizeDescriptorEncoding.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,6 +220,7 @@ EncodingInfo combineEncodings(const EncodingInfo &lhs, const EncodingInfo &rhs,
220220
break;
221221
case 1:
222222
result.ctaLayout = ctaLayouts[0];
223+
break;
223224
default:
224225
break;
225226
}
@@ -237,6 +238,7 @@ EncodingInfo combineEncodings(const EncodingInfo &lhs, const EncodingInfo &rhs,
237238
break;
238239
case 1:
239240
result.desiredEncoding = desiredEncodings[0];
241+
break;
240242
default:
241243
break;
242244
}

python/BUILD

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
# NOTE: Do not depend on any targets from this directory,
2+
# but use //third_party/py/triton instead.
3+
4+
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
5+
6+
package(
7+
default_applicable_licenses = ["@triton//:license"],
8+
default_visibility = [
9+
"//third_party/py/triton:__pkg__",
10+
"@triton//python:__subpackages__",
11+
],
12+
)
13+
14+
cc_library(
15+
name = "passes",
16+
hdrs = ["src/passes.h"],
17+
includes = ["src"],
18+
visibility = ["@triton//third_party:__subpackages__"],
19+
)
20+
21+
pybind_extension(
22+
name = "libtriton",
23+
srcs = [
24+
"src/interpreter.cc",
25+
"src/ir.cc",
26+
"src/llvm.cc",
27+
"src/main.cc",
28+
"src/passes.cc",
29+
],
30+
copts = ["-DTRITON_BACKENDS_TUPLE=(nvidia)"],
31+
deps = [
32+
":passes",
33+
"@llvm-project//llvm:Core",
34+
"@llvm-project//llvm:IPO",
35+
"@llvm-project//llvm:IRReader",
36+
"@llvm-project//llvm:InstCombine",
37+
"@llvm-project//llvm:Instrumentation",
38+
"@llvm-project//llvm:Linker",
39+
"@llvm-project//llvm:MC",
40+
"@llvm-project//llvm:Passes",
41+
"@llvm-project//llvm:Support",
42+
"@llvm-project//llvm:Target",
43+
"@llvm-project//mlir:BuiltinToLLVMIRTranslation",
44+
"@llvm-project//mlir:BytecodeWriter",
45+
"@llvm-project//mlir:ControlFlowDialect",
46+
"@llvm-project//mlir:ConversionPasses",
47+
"@llvm-project//mlir:IR",
48+
"@llvm-project//mlir:IndexDialect",
49+
"@llvm-project//mlir:LLVMDialect",
50+
"@llvm-project//mlir:LLVMIRTransforms",
51+
"@llvm-project//mlir:LLVMToLLVMIRTranslation",
52+
"@llvm-project//mlir:NVVMToLLVMIRTranslation",
53+
"@llvm-project//mlir:Parser",
54+
"@llvm-project//mlir:Pass",
55+
"@llvm-project//mlir:Support",
56+
"@llvm-project//mlir:ToLLVMIRTranslation",
57+
"@llvm-project//mlir:Transforms",
58+
"@llvm-project//mlir:UBDialect",
59+
"@triton//:TritonDialects",
60+
"@triton//:TritonGPUToLLVM",
61+
"@triton//:TritonGPUTransforms",
62+
"@triton//:TritonHSACO",
63+
"@triton//:TritonLLVMIR",
64+
"@triton//:TritonNvidiaGPUTransforms",
65+
"@triton//:TritonPTX",
66+
"@triton//:TritonToTritonGPU",
67+
"@triton//:TritonTools",
68+
"@triton//:TritonTransforms",
69+
"@triton//third_party/nvidia:triton_nvidia",
70+
"@triton//third_party/proton:ProtonIRDialect",
71+
],
72+
)
73+
74+
filegroup(
75+
name = "files",
76+
srcs = glob(
77+
include = ["triton/**/*.py"],
78+
),
79+
)

0 commit comments

Comments
 (0)