Skip to content

Commit 6e2fff0

Browse files
Merge commit '1e0e51c4aeb3e1beea000da5d0e494f8b9ac40dd'
2 parents 79015d5 + 1e0e51c commit 6e2fff0

File tree

56 files changed

+826
-679
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

56 files changed

+826
-679
lines changed

Makefile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,14 @@ test-regression: all
4545

4646
.PHONY: test-interpret
4747
test-interpret: all
48-
cd python/test/unit && TRITON_INTERPRET=1 $(PYTEST) -s -n 16 -m interpreter language/test_core.py language/test_standard.py \
48+
cd python/test/unit && TRITON_INTERPRET=1 $(PYTEST) -s -n 16 -m interpreter cuda language/test_core.py language/test_standard.py \
4949
language/test_random.py language/test_block_pointer.py language/test_subprocess.py language/test_line_info.py \
5050
runtime/test_autotuner.py::test_kwargs[False] \
5151
../../tutorials/06-fused-attention.py::test_op --device=cpu
5252

5353
.PHONY: test-proton
5454
test-proton: all
55-
$(PYTEST) -s third_party/proton/test
55+
$(PYTEST) -s -n 8 third_party/proton/test
5656

5757
.PHONY: test-python
5858
test-python: test-unit test-regression test-interpret test-proton

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -581,6 +581,7 @@ def TT_TransOp : TT_Op<"trans", [Pure,
581581
let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
582582

583583
let hasFolder = 1;
584+
let hasVerifier = 1;
584585
}
585586

586587
//
@@ -830,7 +831,8 @@ def TT_MakeRangeOp : TT_Op<"make_range", [Pure]> {
830831
def TT_ElementwiseInlineAsmOp : TT_Op<"elementwise_inline_asm", [
831832
Elementwise,
832833
SameOperandsAndResultEncoding,
833-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
834+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
835+
DeclareOpInterfaceMethods<ConditionallySpeculatable>
834836
]> {
835837
let summary = "inline assembly applying an elementwise operation to a group of packed elements.";
836838
let description = [{

include/triton/Dialect/TritonGPU/IR/Dialect.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
#include <unordered_map>
1515

1616
// LinearLayoutCache Utils
17-
using CacheKey =
18-
std::tuple<std::vector<int64_t>, mlir::Attribute, std::optional<int32_t>>;
17+
using CacheKey = std::tuple<std::vector<int64_t>, mlir::Attribute>;
1918

2019
namespace llvm {
2120
template <typename T> size_t hash_value(const std::vector<T> &vec) {

include/triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ class NVMMASharedEncodingAttr;
4141
// shared layouts with nvmma_shared layout) but is otherwise unused.
4242
//
4343
// Returns std::nullopt if the given layout can't be converted to an LL.
44-
LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
45-
std::optional<int32_t> elemBitWidth = std::nullopt);
44+
LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout);
4645

4746
// Convert the shared encoding of a tensor with `nvmma_shared` layout to a
4847
// LinearLayout that maps from a linear shared memory offset to tensor index.
@@ -51,7 +50,6 @@ LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
5150
// swizzling.
5251
LinearLayout sharedToLinearLayoutLeadingOffset(ArrayRef<int64_t> shape,
5352
NVMMASharedEncodingAttr shared,
54-
int32_t elemBitWidth,
5553
bool disableSwizzle = false);
5654

5755
// Given a linear layout where the input dimensions contain a "block" dimension,

include/triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -423,6 +423,7 @@ def NVMMASharedEncodingAttr :
423423
ins
424424
"unsigned":$swizzlingByteWidth,
425425
"bool":$transposed,
426+
"unsigned":$elementBitWidth,
426427
"CTALayoutAttr":$CTALayout
427428
);
428429

@@ -433,7 +434,7 @@ def NVMMASharedEncodingAttr :
433434
"Type":$eltTy), [{
434435
auto shapePerCTA = getShapePerCTA(CTALayout.getCTASplitNum(), shape);
435436
int32_t swizzlingByteWidth = 0;
436-
int32_t eleBitWidth = eltTy.getIntOrFloatBitWidth();
437+
unsigned eleBitWidth = eltTy.getIntOrFloatBitWidth();
437438

438439
// get proper shared memory swizzling mode from the contiguous dimension
439440
// size of the origin blocked layout.
@@ -448,7 +449,7 @@ def NVMMASharedEncodingAttr :
448449
llvm_unreachable("unsupported shared memory layout for MMAv3");
449450
}
450451
bool transposed = order[0] == 0;
451-
return $_get(context, swizzlingByteWidth, transposed, CTALayout);
452+
return $_get(context, swizzlingByteWidth, transposed, eleBitWidth, CTALayout);
452453
}]>
453454
];
454455

include/triton/Dialect/TritonGPU/IR/TritonGPUDialect.td

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,7 @@ def TritonGPU_Dialect : Dialect {
4444
return cast<IntegerAttr>(threadsPerWarp).getInt();
4545
}
4646

47-
LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout,
48-
std::optional<int32_t> elemBitWidth);
47+
LinearLayout toLinearLayout(ArrayRef<int64_t> shape, Attribute layout);
4948

5049
private:
5150
LinearLayoutCache llCache;

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -225,7 +225,7 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> {
225225

226226
def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure,
227227
TransposeOpInterface,
228-
DeclareOpInterfaceMethods<InferTypeOpInterface>,
228+
InferTypeOpWithLayoutEquivalence,
229229
SameOperandsAndResultElementType]> {
230230
let summary = "transpose the descriptor";
231231

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,8 @@ bool emitTransferBetweenRegistersAndShared(
317317
StringAttr kWarp = str_attr("warp");
318318

319319
auto shape = sharedTy.getShape();
320-
LinearLayout sharedLayout = triton::gpu::toLinearLayout(
321-
shape, sharedTy.getEncoding(), elemLlvmTy.getIntOrFloatBitWidth());
320+
LinearLayout sharedLayout =
321+
triton::gpu::toLinearLayout(shape, sharedTy.getEncoding());
322322
LinearLayout regToSharedLayout = regLayout.invertAndCompose(sharedLayout);
323323

324324
// TODO(jlebar): We don't currently support loading from shared memory in a
@@ -363,8 +363,7 @@ bool emitTransferBetweenRegistersAndShared(
363363
auto allocShape = sharedTy.getAllocShape();
364364
LinearLayout invertAllocSharedLayout =
365365
triton::gpu::toLinearLayout(allocShape.take_back(sharedTy.getRank()),
366-
sharedTy.getEncoding(),
367-
elemLlvmTy.getIntOrFloatBitWidth())
366+
sharedTy.getEncoding())
368367
.pseudoinvert();
369368

370369
int numElems = regToSharedLayout.getInDimSize(kRegister);
@@ -386,9 +385,8 @@ bool emitTransferBetweenRegistersAndShared(
386385
const SharedMemoryObject &smemObj, Location loc, RewriterBase &rewriter,
387386
const TargetInfoBase &target,
388387
std::function<void(VectorType, Value /*shmemAddr*/)> perVectorCallback) {
389-
auto regLayout = triton::gpu::toLinearLayout(
390-
registerTy.getShape(), registerTy.getEncoding(),
391-
elemLlvmTy.getIntOrFloatBitWidth());
388+
auto regLayout = triton::gpu::toLinearLayout(registerTy.getShape(),
389+
registerTy.getEncoding());
392390
return emitTransferBetweenRegistersAndShared(
393391
regLayout, sharedTy, elemLlvmTy, maxVecElems, smemObj, loc, rewriter,
394392
target, perVectorCallback);

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,23 @@ OpFoldResult TransOp::fold(FoldAdaptor adaptor) {
209209
return {};
210210
}
211211

212+
LogicalResult TransOp::verify() {
213+
auto order = getOrder();
214+
auto srcTy = cast<RankedTensorType>(getSrc().getType());
215+
if (order.size() != srcTy.getShape().size()) {
216+
return emitError("order must have the same size as the source tensor");
217+
}
218+
if (!isPermutationOfIota(order)) {
219+
return emitError("order must be a permutation of 0..n-1");
220+
}
221+
SmallVector<int64_t> retShape = applyPermutation(srcTy.getShape(), order);
222+
if (retShape != getType().getShape()) {
223+
return emitError(
224+
"result shape must match the permutation of the source shape");
225+
}
226+
return success();
227+
}
228+
212229
LogicalResult TransOp::inferReturnTypes(
213230
MLIRContext *context, std::optional<Location> location,
214231
TransOp::Adaptor adaptor, SmallVectorImpl<Type> &inferredReturnTypes) {
@@ -1037,6 +1054,12 @@ void ElementwiseInlineAsmOp::getEffects(
10371054
SideEffects::DefaultResource::get());
10381055
}
10391056

1057+
Speculation::Speculatability ElementwiseInlineAsmOp::getSpeculatability() {
1058+
if (getPure())
1059+
return Speculation::Speculatable;
1060+
return Speculation::NotSpeculatable;
1061+
}
1062+
10401063
LogicalResult ElementwiseInlineAsmOp::verify() {
10411064
if (getNumOperands() >= 1) {
10421065
auto tensorType = dyn_cast<RankedTensorType>(getOperand(0).getType());

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1927,6 +1927,7 @@ Attribute NVMMASharedEncodingAttr::parse(AsmParser &parser, Type type) {
19271927

19281928
unsigned swizzlingByteWidth;
19291929
bool transposed;
1930+
unsigned elementBitWidth;
19301931
std::optional<SmallVector<unsigned>> CTAsPerCGA;
19311932
std::optional<SmallVector<unsigned>> CTASplitNum;
19321933
std::optional<SmallVector<unsigned>> CTAOrder;
@@ -1938,6 +1939,9 @@ Attribute NVMMASharedEncodingAttr::parse(AsmParser &parser, Type type) {
19381939
} else if (attr.getName() == "transposed") {
19391940
if (parseBool(parser, attr, transposed, "transposed").failed())
19401941
return {};
1942+
} else if (attr.getName() == "elementBitWidth") {
1943+
if (parseUInt(parser, attr, elementBitWidth, "elementBitWidth").failed())
1944+
return {};
19411945
} else if (attr.getName() == "CTAsPerCGA") {
19421946
if (parseIntArrayAttr(parser, attr, CTAsPerCGA.emplace(), "CTAsPerCGA")
19431947
.failed())
@@ -1963,13 +1967,15 @@ Attribute NVMMASharedEncodingAttr::parse(AsmParser &parser, Type type) {
19631967
return {};
19641968

19651969
return parser.getChecked<NVMMASharedEncodingAttr>(
1966-
parser.getContext(), swizzlingByteWidth, transposed, *CTALayout);
1970+
parser.getContext(), swizzlingByteWidth, transposed, elementBitWidth,
1971+
*CTALayout);
19671972
}
19681973

19691974
void NVMMASharedEncodingAttr::print(AsmPrinter &printer) const {
19701975
printer << "<{"
19711976
<< "swizzlingByteWidth = " << getSwizzlingByteWidth() //
1972-
<< ", transposed = " << getTransposed();
1977+
<< ", transposed = " << getTransposed() //
1978+
<< ", elementBitWidth = " << getElementBitWidth();
19731979
maybePrintCTALayout(getContext(), printer, getCTALayout(),
19741980
/*rank=*/2);
19751981
printer << "}>";
@@ -2611,7 +2617,8 @@ struct TritonGPUInferLayoutInterface
26112617
return failure();
26122618
}
26132619
resultEncoding = NVMMASharedEncodingAttr::get(
2614-
ctx, enc.getSwizzlingByteWidth(), !enc.getTransposed(), *ctaLayout);
2620+
ctx, enc.getSwizzlingByteWidth(), !enc.getTransposed(),
2621+
enc.getElementBitWidth(), *ctaLayout);
26152622
return success();
26162623
}
26172624

0 commit comments

Comments
 (0)