Skip to content

Commit 0a3dccc

Browse files
Merge commit '7088c64d7cd87b654c73ad90480e67bdbb1510f7'
2 parents baede21 + 7088c64 commit 0a3dccc

File tree

35 files changed

+344
-202
lines changed

35 files changed

+344
-202
lines changed

include/triton/Dialect/Triton/IR/CMakeLists.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,11 @@ mlir_tablegen(AttrInterfaces.h.inc -gen-attr-interface-decls)
2121
mlir_tablegen(AttrInterfaces.cpp.inc -gen-attr-interface-defs)
2222

2323
set(LLVM_TARGET_DEFINITIONS TritonTypeInterfaces.td)
24-
mlir_tablegen(TritonTypeInterfaces.h.inc -gen-type-interface-decls)
25-
mlir_tablegen(TritonTypeInterfaces.cpp.inc -gen-type-interface-defs)
24+
mlir_tablegen(TypeInterfaces.h.inc -gen-type-interface-decls)
25+
mlir_tablegen(TypeInterfaces.cpp.inc -gen-type-interface-defs)
26+
27+
set(LLVM_TARGET_DEFINITIONS TritonOpInterfaces.td)
28+
mlir_tablegen(OpInterfaces.h.inc -gen-op-interface-decls)
29+
mlir_tablegen(OpInterfaces.cpp.inc -gen-op-interface-defs)
2630

2731
add_public_tablegen_target(TritonTableGen)

include/triton/Dialect/Triton/IR/Dialect.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
#include "mlir/Interfaces/FunctionInterfaces.h"
1414
#include "mlir/Interfaces/SideEffectInterfaces.h"
1515
#include "triton/Dialect/Triton/IR/Dialect.h.inc"
16+
#include "triton/Dialect/Triton/IR/OpInterfaces.h"
1617
#include "triton/Dialect/Triton/IR/OpsEnums.h.inc"
1718
#include "triton/Dialect/Triton/IR/Traits.h"
1819
#include "triton/Dialect/Triton/IR/Types.h"
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
#ifndef TRITON_IR_OP_INTERFACES_H_
2+
#define TRITON_IR_OP_INTERFACES_H_
3+
4+
#include "mlir/IR/OpDefinition.h"
5+
6+
namespace mlir {
7+
8+
namespace triton {
9+
10+
namespace impl {
11+
12+
LogicalResult verifyTransposeOpInterface(Operation *op);
13+
14+
} // namespace impl
15+
16+
} // namespace triton
17+
} // namespace mlir
18+
19+
#include "triton/Dialect/Triton/IR/OpInterfaces.h.inc"
20+
21+
#endif // TRITON_IR_OP_INTERFACES_H_
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
#ifndef TRITON_OP_INTERFACES
2+
#define TRITON_OP_INTERFACES
3+
4+
include "mlir/IR/OpBase.td"
5+
6+
7+
def TransposeOpInterface : OpInterface<"TransposeOpInterface"> {
8+
let description = [{
9+
This interface is implemented by operations that perform a transpose.
10+
It provides methods to access common properties such as the order attribute and the source operand.
11+
}];
12+
13+
let cppNamespace = "::mlir::triton";
14+
15+
let methods = [
16+
InterfaceMethod<
17+
/*desc=*/[{
18+
Get the source operand of the transposition.
19+
}],
20+
/*retType=*/"::mlir::Value",
21+
/*methodName=*/"getSrc",
22+
/*args=*/(ins)>,
23+
InterfaceMethod<
24+
/*desc=*/[{
25+
Get the order of the transposition.
26+
}],
27+
/*retType=*/"::mlir::ArrayRef<int32_t>",
28+
/*methodName=*/"getOrder",
29+
/*args=*/(ins)>
30+
];
31+
32+
let verify = [{ return ::mlir::triton::impl::verifyTransposeOpInterface($_op); }];
33+
}
34+
35+
36+
#endif // TRITON_OP_INTERFACES

include/triton/Dialect/Triton/IR/TritonOps.td

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,13 @@ include "triton/Dialect/Triton/IR/TritonInterfaces.td"
88
include "mlir/IR/OpBase.td"
99
include "mlir/IR/SymbolInterfaces.td" // SymbolUserOpInterface
1010
include "mlir/IR/OpAsmInterface.td" // OpAsmOpInterface
11-
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
12-
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
1311
include "mlir/Interfaces/FunctionInterfaces.td" // FunctionOpInterface
1412
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
1513
include "mlir/Interfaces/ControlFlowInterfaces.td" // BranchOpInterface
1614
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
17-
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
18-
include "mlir/Interfaces/CastInterfaces.td" // CastOpInterface
1915
include "mlir/Interfaces/CallInterfaces.td" // CallOpInterface
2016
include "triton/Dialect/Triton/IR/TritonTypeInterfaces.td"
17+
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
2118

2219

2320
//
@@ -44,8 +41,7 @@ class TT_Op<string mnemonic, list<Trait> traits = []> :
4441
def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise,
4542
SameOperandsAndResultShape,
4643
SameOperandsAndResultEncoding,
47-
Pure,
48-
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
44+
Pure]> {
4945
let summary = "Cast int64 to pointer";
5046

5147
let arguments = (ins TT_I64Like:$src);
@@ -58,8 +54,7 @@ def TT_IntToPtrOp : TT_Op<"int_to_ptr", [Elementwise,
5854
def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise,
5955
SameOperandsAndResultShape,
6056
SameOperandsAndResultEncoding,
61-
Pure,
62-
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
57+
Pure]> {
6358
let summary = "Cast pointer to int64";
6459

6560
let arguments = (ins TT_PtrLike:$src);
@@ -73,8 +68,7 @@ def TT_PtrToIntOp : TT_Op<"ptr_to_int", [Elementwise,
7368
def TT_BitcastOp : TT_Op<"bitcast", [Elementwise,
7469
SameOperandsAndResultShape,
7570
SameOperandsAndResultEncoding,
76-
Pure,
77-
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
71+
Pure]> {
7872
let summary = "Cast between types of the same bitwidth";
7973

8074
let arguments = (ins TT_Type:$src);
@@ -89,8 +83,7 @@ def TT_BitcastOp : TT_Op<"bitcast", [Elementwise,
8983
def TT_FpToFpOp : TT_Op<"fp_to_fp", [Elementwise,
9084
SameOperandsAndResultShape,
9185
SameOperandsAndResultEncoding,
92-
Pure,
93-
/*DeclareOpInterfaceMethods<CastOpInterface>*/]> {
86+
Pure]> {
9487
let summary = "Floating point casting for custom types";
9588

9689
let description = [{
@@ -118,8 +111,8 @@ def TT_FpToFpOp : TT_Op<"fp_to_fp", [Elementwise,
118111
//
119112

120113
def TT_ClampFOp : TT_Op<"clampf", [Elementwise,
121-
SameOperandsAndResultType,
122-
Pure]> {
114+
SameOperandsAndResultType,
115+
Pure]> {
123116
let summary = "Clamp operation for floating point types";
124117

125118
let description = [{
@@ -149,8 +142,8 @@ def TT_ClampFOp : TT_Op<"clampf", [Elementwise,
149142
//
150143

151144
def TT_PreciseSqrtOp : TT_Op<"precise_sqrt", [Elementwise,
152-
SameOperandsAndResultType,
153-
Pure]> {
145+
SameOperandsAndResultType,
146+
Pure]> {
154147
let summary = "Precise sqrt for floating point types";
155148

156149
let description = [{
@@ -165,8 +158,8 @@ def TT_PreciseSqrtOp : TT_Op<"precise_sqrt", [Elementwise,
165158
}
166159

167160
def TT_PreciseDivFOp : TT_Op<"precise_divf", [Elementwise,
168-
SameOperandsAndResultType,
169-
Pure]> {
161+
SameOperandsAndResultType,
162+
Pure]> {
170163
let summary = "Precise div for floating point types";
171164

172165
let description = [{
@@ -181,8 +174,8 @@ def TT_PreciseDivFOp : TT_Op<"precise_divf", [Elementwise,
181174
}
182175

183176
def TT_MulhiUIOp : TT_Op<"mulhiui", [Elementwise,
184-
SameOperandsAndResultType,
185-
Pure]> {
177+
SameOperandsAndResultType,
178+
Pure]> {
186179
let summary = "Most significant N bits of the 2N-bit product of two integers";
187180

188181
let description = [{
@@ -200,12 +193,12 @@ def TT_MulhiUIOp : TT_Op<"mulhiui", [Elementwise,
200193
// Pointer Arith Ops
201194
//
202195
def TT_AddPtrOp : TT_Op<"addptr",
203-
[Pure,
204-
Elementwise,
205-
SameOperandsAndResultShape,
206-
SameOperandsAndResultEncoding,
207-
TypesMatchWith<"result type matches ptr type",
208-
"result", "ptr", "$_self">]> {
196+
[Pure,
197+
Elementwise,
198+
SameOperandsAndResultShape,
199+
SameOperandsAndResultEncoding,
200+
TypesMatchWith<"result type matches ptr type",
201+
"result", "ptr", "$_self">]> {
209202
let arguments = (ins TT_PtrLike:$ptr, TT_IntLike:$offset);
210203

211204
let results = (outs TT_PtrLike:$result);
@@ -546,6 +539,7 @@ def TT_SplitOp : TT_Op<"split", [
546539
}
547540

548541
def TT_TransOp : TT_Op<"trans", [Pure,
542+
TransposeOpInterface,
549543
DeclareOpInterfaceMethods<InferTypeOpInterface>,
550544
SameOperandsAndResultElementType]> {
551545

@@ -579,16 +573,15 @@ def TT_TransOp : TT_Op<"trans", [Pure,
579573
}];
580574

581575
let arguments = (
582-
ins TT_TensorOrMemDesc:$src,
576+
ins TT_Tensor:$src,
583577
DenseI32ArrayAttr:$order
584578
);
585579

586-
let results = (outs TT_TensorOrMemDesc:$result);
580+
let results = (outs TT_Tensor:$result);
587581

588582
let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
589583

590584
let hasFolder = 1;
591-
let hasVerifier = 1;
592585
}
593586

594587
//
@@ -677,10 +670,10 @@ def TT_DotOp : TT_Op<"dot", [Pure,
677670
// DotScaled Op
678671
//
679672
def TT_DotScaledOp : TT_Op<"dot_scaled", [Pure,
680-
AttrSizedOperandSegments,
681-
DotLike,
682-
TypesMatchWith<"result's type matches accumulator's type",
683-
"d", "c", "$_self">]> {
673+
AttrSizedOperandSegments,
674+
DotLike,
675+
TypesMatchWith<"result's type matches accumulator's type",
676+
"d", "c", "$_self">]> {
684677
let summary = "dot_scaled";
685678

686679
let description = [{
@@ -783,10 +776,10 @@ def TT_ScanReturnOp: TT_Op<"scan.return",
783776
// External Elementwise op
784777
//
785778
def TT_ExternElementwiseOp : TT_Op<"extern_elementwise", [Elementwise,
786-
SameOperandsAndResultEncoding,
787-
SameVariadicOperandSize,
788-
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
789-
ConditionallySpeculatable]> {
779+
SameOperandsAndResultEncoding,
780+
SameVariadicOperandSize,
781+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
782+
ConditionallySpeculatable]> {
790783

791784
let description = [{
792785
call an external function $symbol implemented in $libpath/$libname with $args

include/triton/Dialect/Triton/IR/Types.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
#define GET_TYPEDEF_CLASSES
99
#include "triton/Dialect/Triton/IR/Types.h.inc"
1010

11-
#include "triton/Dialect/Triton/IR/TritonTypeInterfaces.h.inc"
11+
#include "triton/Dialect/Triton/IR/TypeInterfaces.h.inc"
1212

1313
namespace mlir {
1414

include/triton/Dialect/TritonGPU/IR/TritonGPUOps.td

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ include "triton/Dialect/TritonGPU/IR/TritonGPUAttrDefs.td"
77
include "mlir/Dialect/Arith/IR/ArithBase.td"
88
include "triton/Dialect/Triton/IR/TritonTypes.td"
99
include "triton/Dialect/Triton/IR/TritonAttrDefs.td"
10+
include "triton/Dialect/Triton/IR/TritonOpInterfaces.td"
1011
include "mlir/IR/OpBase.td"
1112
include "mlir/Interfaces/SideEffectInterfaces.td" // Pure
1213
include "mlir/Interfaces/InferTypeOpInterface.td" // SameOperandsAndResultType
@@ -221,6 +222,31 @@ def TTG_MemDescSubviewOp : TTG_Op<"memdesc_subview", [Pure]> {
221222
let hasVerifier = 1;
222223
}
223224

225+
def TTG_MemDescTransOp : TTG_Op<"memdesc_trans", [Pure,
226+
TransposeOpInterface,
227+
DeclareOpInterfaceMethods<InferTypeOpInterface>,
228+
SameOperandsAndResultElementType]> {
229+
let summary = "transpose the descriptor";
230+
231+
let description = [{
232+
This operation returns a new descriptor
233+
representing a transposed view of the buffer.
234+
}];
235+
236+
let arguments = (ins TT_MemDescType:$src, Variadic<I32>:$order);
237+
238+
let arguments = (
239+
ins TT_MemDescType:$src,
240+
DenseI32ArrayAttr:$order
241+
);
242+
243+
let results = (outs TT_MemDescType:$result);
244+
245+
let assemblyFormat = "$src attr-dict `:` qualified(type($src)) `->` qualified(type($result))";
246+
247+
let hasFolder = 1;
248+
}
249+
224250
def TTG_LocalLoadOp : TTG_Op<"local_load", [DeclareOpInterfaceMethods<MemoryEffectsOpInterface>]> {
225251
let summary = "Load a buffer from local memory into a distributed tensor";
226252

lib/Analysis/Alias.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,8 @@ LogicalResult SharedMemoryAliasAnalysis::visitOperation(
3838
if (isa<triton::gpu::LocalAllocOp>(op)) {
3939
aliasInfo.insert(result);
4040
pessimistic = false;
41-
} else if (isa<triton::gpu::MemDescSubviewOp, triton::TransOp>(op)) {
42-
// extract_slice %src
43-
// trans %src
41+
} else if (isa<triton::gpu::MemDescSubviewOp, triton::gpu::MemDescTransOp>(
42+
op)) {
4443
aliasInfo = AliasInfo(operands[0]->getValue());
4544
pessimistic = false;
4645
} else {

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -752,26 +752,26 @@ SmallVector<Value> convertMxfp4x2ToBf16x2(RewriterBase &rewriter, Location loc,
752752
ArrayRef<Value> values) {
753753
SmallVector<Value> results;
754754
for (auto v : values) {
755-
auto em0 = and_(v, i8_val(0x70));
756-
auto em1 = and_(v, i8_val(0x7));
757-
Value v0 = or_(shl(zext(i16_ty, em0), i16_val(2)),
758-
shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8)));
759-
Value v1 = or_(shl(zext(i16_ty, em1), i16_val(6)),
755+
auto em0 = and_(v, i8_val(0x7));
756+
auto em1 = and_(v, i8_val(0x70));
757+
Value v0 = or_(shl(zext(i16_ty, em0), i16_val(6)),
760758
shl(zext(i16_ty, and_(v, i8_val(0x8))), i16_val(12)));
759+
Value v1 = or_(shl(zext(i16_ty, em1), i16_val(2)),
760+
shl(zext(i16_ty, and_(v, i8_val(0x80))), i16_val(8)));
761761

762762
// Three cases:
763763
// 1) x is normal and non-zero: Correct bias
764-
v0 = select(icmp_ne(and_(em0, i8_val(0x60)), i8_val(0)),
764+
v0 = select(icmp_ne(and_(em0, i8_val(0x6)), i8_val(0)),
765765
add(v0, i16_val((127 - 1) << 7)), v0);
766-
v1 = select(icmp_ne(and_(em1, i8_val(0x6)), i8_val(0)),
766+
v1 = select(icmp_ne(and_(em1, i8_val(0x60)), i8_val(0)),
767767
add(v1, i16_val((127 - 1) << 7)), v1);
768768

769769
// 2) x is subnormal (x == 0bs001 where s is the sign): Map to +-0.5 in
770770
// bf16
771-
v0 = bitcast(select(icmp_eq(em0, i8_val(0x10)),
771+
v0 = bitcast(select(icmp_eq(em0, i8_val(0x1)),
772772
or_(i16_val(16128), and_(v0, i16_val(0x8000))), v0),
773773
bf16_ty);
774-
v1 = bitcast(select(icmp_eq(em1, i8_val(0x1)),
774+
v1 = bitcast(select(icmp_eq(em1, i8_val(0x10)),
775775
or_(i16_val(16128), and_(v1, i16_val(0x8000))), v1),
776776
bf16_ty);
777777
// 3) x is zero, nothing to do

0 commit comments

Comments
 (0)