Skip to content

Commit a1aa27f

Browse files
authored
Merge branch 'main' into tkuczynski/enable_test_small_batch_matmul
2 parents 76c7f03 + e14f5b9 commit a1aa27f

File tree

57 files changed

+2007
-445
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

57 files changed

+2007
-445
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,7 @@ See [`python/triton/knobs.py`](python/triton/knobs.py) for the full list of conf
232232
- `TRITON_F32_DEFAULT` sets the default input precision of `tl.dot` when using 32-bit floats, which can be either `ieee`, `tf32`, or `tf32x3`.
233233
- `TRITON_FRONT_END_DEBUGGING=1` disables exception wrapping when an error occurs in the compiler frontend, allowing the full stack trace to be seen.
234234
- `TRITON_DISABLE_LINE_INFO=1` removes all line information from the module.
235+
- `PTXAS_OPTIONS` passes additional command-line options to the PTX assembler `ptxas` (only on NVIDIA).
235236

236237
> [!NOTE]
237238
> Some of these environment variables don't have a knob in `knobs.py`-- those are only relevant to the C++ layer(s), hence they don't exist in the python layer.

include/triton/Dialect/Triton/IR/TritonOpInterfaces.td

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,12 @@ def DotOpInterface : OpInterface<"DotOpInterface"> {
4949
/*retType=*/"::mlir::Value",
5050
/*methodName=*/"getB",
5151
/*args=*/(ins)>,
52-
InterfaceMethod<
52+
InterfaceMethod<
53+
/*desc=*/"Get the output tensor",
54+
/*retType=*/"::mlir::Value",
55+
/*methodName=*/"getD",
56+
/*args=*/(ins)>,
57+
InterfaceMethod<
5358
/*desc=*/"Verify the dimensions of the A and B DotOp operands.",
5459
/*retType=*/"bool",
5560
/*methodName=*/"verifyDims",
@@ -64,6 +69,7 @@ def DotOpInterface : OpInterface<"DotOpInterface"> {
6469
auto aTy = cast<ShapedType>($_op.getA().getType());
6570
auto bTy = cast<ShapedType>($_op.getB().getType());
6671
auto cTy = cast<ShapedType>($_op->getOperand(2).getType());
72+
auto dTy = cast<ShapedType>($_op.getD().getType());
6773
auto aShape = aTy.getShape();
6874
auto bShape = bTy.getShape();
6975
auto cShape = cTy.getShape();

include/triton/Dialect/TritonGPU/IR/CMakeLists.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,8 @@ set(LLVM_TARGET_DEFINITIONS TritonGPUTypeInterfaces.td)
2424
mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls)
2525
mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs)
2626
add_public_tablegen_target(TritonGPUTypeInterfacesIncGen)
27+
28+
set(LLVM_TARGET_DEFINITIONS TritonGPUOpInterfaces.td)
29+
mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls)
30+
mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs)
31+
add_public_tablegen_target(TritonGPUOpInterfacesIncGen)

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,13 @@ LinearLayout chooseScaledMfmaScaleLayout(MLIRContext *ctx, int dotOperandIdx,
135135
ArrayRef<unsigned> tilesPerWarp,
136136
ArrayRef<unsigned> warpsPerCTA);
137137

138+
LinearLayout getSM120DotScaledScaleLayout(MLIRContext *ctx, int dotOperandIdx,
139+
ArrayRef<int64_t> dotOperandShape,
140+
ArrayRef<unsigned> tilesPerWarp,
141+
ArrayRef<unsigned> warpsPerCTA,
142+
unsigned instrM, unsigned instrN,
143+
CTALayoutAttr ctaLayoutAttr);
144+
138145
// Create LinearLayout for nvidia mma tile.
139146
LinearLayout nvidiaMmaTile(MLIRContext *ctx, ArrayRef<unsigned> tileShape,
140147
unsigned kWidth, ArrayRef<unsigned> order,

include/triton/Dialect/TritonGPU/IR/TritonGPUInterfaces.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
#ifndef TRITON_GPU_DIALECT_INTERFACES_H
22
#define TRITON_GPU_DIALECT_INTERFACES_H
33

4+
#include "mlir/IR/OpDefinition.h"
5+
46
// clang-format off
57
#include "triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
8+
#include "triton/Dialect/TritonGPU/IR/OpInterfaces.h.inc"
69
#include "triton/Dialect/TritonGPU/IR/AttrInterfaces.h.inc"
710
// clang-format on
811

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#ifndef TRITONGPU_OP_INTERFACES
2+
#define TRITONGPU_OP_INTERFACES
3+
4+
include "mlir/IR/OpBase.td"
5+
6+
def UpcastFpOpInterface : OpInterface<"UpcastFpOpInterface"> {
7+
let description = [{
8+
This interface is for operations that upcast floating-point numbers.
9+
}];
10+
11+
let cppNamespace = "::mlir::triton::gpu";
12+
13+
let methods = [
14+
InterfaceMethod<
15+
/*desc=*/"Infer destination encoding",
16+
/*retType=*/"mlir::Attribute",
17+
/*methodName=*/"inferDstEncoding",
18+
/*args=*/(ins "unsigned":$opIdx, "mlir::Attribute":$srcEnc)
19+
>,
20+
InterfaceMethod<
21+
/*desc=*/"Infer operand encoding from dst encoding",
22+
/*retType=*/"mlir::Attribute",
23+
/*methodName=*/"inferSrcEncoding",
24+
/*args=*/(ins "unsigned":$opIdx, "mlir::Attribute":$dstEnc)
25+
>
26+
];
27+
}
28+
29+
#endif // TRITONGPU_OP_INTERFACES

include/triton/Dialect/TritonGPU/Transforms/DecomposeScaledBlocked.h

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,22 @@ class DecomposeScaledBlocked : public OpRewritePattern<DotScaledOp> {
2222
ModuleOp mod, TypedValue<RankedTensorType> scale,
2323
int dim) const;
2424
TypedValue<RankedTensorType> maskNan(PatternRewriter &rewriter,
25-
DotScaledOp scaledDotOp, ModuleOp mod,
25+
DotScaledOp scaledDotOp,
2626
TypedValue<RankedTensorType> mxfp,
2727
TypedValue<RankedTensorType> scale,
2828
int dim) const;
29-
TypedValue<RankedTensorType> scaleArg(PatternRewriter &rewriter,
30-
DotScaledOp scaledDotOp, int opIdx,
31-
FloatType computeType) const;
29+
virtual TypedValue<RankedTensorType> scaleArg(PatternRewriter &rewriter,
30+
DotScaledOp scaledDotOp,
31+
int opIdx,
32+
FloatType computeType) const;
3233
TypedValue<RankedTensorType>
3334
cvtDotOperand(PatternRewriter &rewriter, DotScaledOp scaledDotOp, int opIdx,
3435
TypedValue<RankedTensorType> v) const;
36+
TypedValue<RankedTensorType>
37+
extendAndBroadcastScale(PatternRewriter &rewriter, DotScaledOp scaledDotOp,
38+
TypedValue<RankedTensorType> &scale,
39+
FloatType computeType, RankedTensorType dstType,
40+
int opIdx) const;
3541
static SmallVector<int, 2> getTransposeOrder(int rank);
3642
};
3743

include/triton/Dialect/TritonGPU/Transforms/PipeliningUtility.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,10 @@ bool isPureScalarOp(Operation *op);
4040
bool getDominatingValueSetOpsToHoist(
4141
DominanceInfo &domInfo, Operation *refOp, ArrayRef<Value> valueSet,
4242
llvm::SetVector<Operation *> &toHoist,
43-
function_ref<bool(Operation *)> canHoist = isPureScalarOp);
43+
function_ref<bool(Operation *)> canHoist = isPureScalarOp,
44+
function_ref<bool(BlockArgument)> canUseArg = [](BlockArgument) {
45+
return false;
46+
});
4447

4548
// Hoist the given set of operations above the reference operation.
4649
void hoistOpsBefore(Operation *refOp,

lib/Analysis/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_triton_library(TritonAnalysis
1010
TritonGPUTableGen
1111
TritonGPUAttrDefsIncGen
1212
TritonGPUTypeInterfacesIncGen
13+
TritonGPUOpInterfacesIncGen
1314

1415
LINK_LIBS PUBLIC
1516
MLIRAnalysis

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1158,6 +1158,14 @@ SharedMemoryObject::getMaskSpanOffsets(triton::gpu::MemDescType srcTy) {
11581158
if (allocShape == shape) {
11591159
return 0;
11601160
}
1161+
if (auto paddedEncoding = dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(
1162+
srcTy.getEncoding())) {
1163+
// Mask is used in fusion of constant part of memory operation address as
1164+
// immediate operand. Padded layout has additional address computations
1165+
// between main offset computation and actual memory access, which breaks
1166+
// constand fusing. Full mask disables this optimization.
1167+
return ~uint64_t(0);
1168+
}
11611169
auto totalLl = triton::gpu::toLinearLayout(allocShape, srcTy.getEncoding());
11621170
auto dimNames = standardOutDimNames(ctx, shape.size());
11631171
// Remove the kBlock dimension
@@ -1194,14 +1202,15 @@ Value SharedMemoryObject::getShmemOffset(Location loc, RewriterBase &rewriter,
11941202
return b.i32_val(0);
11951203
}
11961204

1205+
LinearLayout ll;
11971206
// We return the offset without the padding. The padding will be added in the
11981207
// lowering
11991208
if (auto paddedSharedEncoding =
12001209
dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(
12011210
srcTy.getEncoding())) {
1202-
auto allocShape64 = srcTy.getAllocShape();
1203-
SmallVector<unsigned> allocShape(allocShape64.begin(), allocShape64.end());
1204-
return LLVM::linearize(rewriter, loc, offsets, allocShape);
1211+
ll = paddedSharedEncoding.getLinearComponent();
1212+
} else {
1213+
ll = triton::gpu::toLinearLayout(srcTy);
12051214
}
12061215

12071216
auto dimNames = standardOutDimNames(ctx, offsets.size());
@@ -1210,7 +1219,6 @@ Value SharedMemoryObject::getShmemOffset(Location loc, RewriterBase &rewriter,
12101219
logicalOffsets.push_back({dim, offset});
12111220
}
12121221

1213-
LinearLayout ll = triton::gpu::toLinearLayout(srcTy);
12141222
ll = ll.sublayout({str_attr("offset")}, dimNames);
12151223
auto offset =
12161224
applyLinearLayout(loc, rewriter, ll.invert(), logicalOffsets)[0].second;

0 commit comments

Comments
 (0)