Skip to content

Commit 88f4435

Browse files
committed
Merge commit '60a1996d12840039381f012f6d1c13c32cbabe20'
2 parents 8609010 + 60a1996 commit 88f4435

File tree

21 files changed

+284
-90
lines changed

21 files changed

+284
-90
lines changed
Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
set(MLIR_BINARY_DIR ${CMAKE_BINARY_DIR})
22

3+
set(LLVM_TARGET_DEFINITIONS GluonOps.td)
4+
mlir_tablegen(Ops.h.inc -gen-op-decls)
5+
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
6+
add_mlir_doc(GluonOps GluonOps dialects/ -gen-op-doc)
7+
38
set(LLVM_TARGET_DEFINITIONS GluonDialect.td)
49
mlir_tablegen(Dialect.h.inc -gen-dialect-decls -dialect=gluon)
510
mlir_tablegen(Dialect.cpp.inc -gen-dialect-defs -dialect=gluon)
6-
mlir_tablegen(Ops.h.inc -gen-op-decls)
7-
mlir_tablegen(Ops.cpp.inc -gen-op-defs)
811
add_mlir_doc(GluonDialect GluonDialect dialects/ -gen-dialect-doc)
9-
add_public_tablegen_target(GluonTableGen)
1012

1113
set(LLVM_TARGET_DEFINITIONS GluonAttrDefs.td)
1214
mlir_tablegen(GluonAttrDefs.h.inc -gen-attrdef-decls)
1315
mlir_tablegen(GluonAttrDefs.cpp.inc -gen-attrdef-defs)
14-
mlir_tablegen(OpsEnums.h.inc -gen-enum-decls)
15-
mlir_tablegen(OpsEnums.cpp.inc -gen-enum-defs)
16-
add_public_tablegen_target(GluonAttrDefsIncGen)
16+
17+
add_public_tablegen_target(GluonTableGen)

include/triton/Dialect/Gluon/IR/Dialect.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,6 @@
66

77
#define GET_ATTRDEF_CLASSES
88
#include "triton/Dialect/Gluon/IR/GluonAttrDefs.h.inc"
9+
10+
#define GET_OP_CLASSES
11+
#include "triton/Dialect/Gluon/IR/Ops.h.inc"
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#ifndef GLUON_OPS
2+
#define GLUON_OPS
3+
4+
include "triton/Dialect/Gluon/IR/GluonDialect.td"
5+
include "triton/Dialect/Gluon/IR/GluonAttrDefs.td"
6+
include "triton/Dialect/Triton/IR/TritonInterfaces.td"
7+
include "triton/Dialect/Triton/IR/TritonTypes.td"
8+
9+
class Gluon_Op<string mnemonic, list<Trait> traits = []> :
10+
Op<Gluon_Dialect, mnemonic,
11+
!listconcat(traits, [VerifyTensorLayoutsTrait])> {
12+
}
13+
14+
def Gluon_SetAutoLayoutOp : Gluon_Op<"set_auto_layout",
15+
[SameOperandsAndResultShape,
16+
SameOperandsAndResultElementType]> {
17+
let summary = "set auto encoding to a concrete encoding type";
18+
19+
let arguments = (ins TT_Tensor:$src);
20+
21+
let results = (outs TT_Tensor:$result);
22+
23+
let builders = [
24+
OpBuilder<(ins "Attribute":$encoding, "Value":$value)>
25+
];
26+
27+
let hasVerifier = 1;
28+
29+
let assemblyFormat = "$src attr-dict `:` type($src) `->` type($result)";
30+
}
31+
32+
#endif // GLUON_OPS

lib/Dialect/Gluon/IR/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ add_triton_library(GluonIR
33

44
DEPENDS
55
GluonTableGen
6-
GluonAttrDefsIncGen
76

87
LINK_LIBS PUBLIC
98
TritonIR

lib/Dialect/Gluon/IR/Dialect.cpp

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@ namespace gluon = mlir::triton::gluon;
1212
#include "triton/Dialect/Gluon/IR/Dialect.cpp.inc"
1313
#include "triton/Dialect/Gluon/IR/GluonAttrDefs.cpp.inc"
1414

15+
#define GET_OP_CLASSES
16+
#include "triton/Dialect/Gluon/IR/Ops.cpp.inc"
17+
1518
namespace {
1619

1720
// Layout inference for AutoEncodingAttr -> always propagate AutoEncodingAttr to
@@ -111,4 +114,22 @@ void GluonDialect::initialize() {
111114
addInterfaces<GluonInferLayoutInterface>();
112115
}
113116

117+
void SetAutoLayoutOp::build(OpBuilder &builder, OperationState &state,
118+
Attribute enc, Value value) {
119+
auto resTy = cast<RankedTensorType>(value.getType()).cloneWithEncoding(enc);
120+
return build(builder, state, resTy, value);
121+
}
122+
123+
LogicalResult SetAutoLayoutOp::verify() {
124+
if (!isa<gluon::AutoEncodingAttr>(getSrc().getType().getEncoding())) {
125+
return emitOpError("input tensor must have an auto layout type");
126+
}
127+
auto dstEncoding = getType().getEncoding();
128+
if (!dstEncoding)
129+
return emitOpError("result tensor must have an encoding");
130+
if (isa<gluon::AutoEncodingAttr>(dstEncoding))
131+
return emitOpError("result type must not be auto layout");
132+
return success();
133+
}
134+
114135
} // namespace mlir::triton::gluon

lib/Dialect/Gluon/Transforms/ResolveAutoEncodings.cpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,11 @@ LogicalResult inferAutoLayouts(FuncOp func) {
6767
};
6868

6969
// 1. Set seed values from layout conversions
70-
auto res = func.walk([&](ttg::ConvertLayoutOp cvtOp) -> WalkResult {
71-
auto src = cvtOp.getSrc();
72-
auto res = cvtOp.getResult();
73-
auto srcEnc = src.getType().getEncoding();
74-
auto resEnc = res.getType().getEncoding();
75-
auto isAutoSrc = isa<gluon::AutoEncodingAttr>(srcEnc);
76-
auto isAutoRes = isa<gluon::AutoEncodingAttr>(resEnc);
77-
if (isAutoSrc && !isAutoRes) {
78-
return updateEncoding({src}, resEnc);
79-
}
80-
return WalkResult::advance();
70+
auto res = func.walk([&](gluon::SetAutoLayoutOp op) -> WalkResult {
71+
auto res = updateEncoding({op.getSrc()}, op.getType().getEncoding());
72+
op.getResult().replaceAllUsesWith(op.getSrc());
73+
op->erase();
74+
return res;
8175
});
8276

8377
if (res.wasInterrupted())
@@ -158,6 +152,7 @@ LogicalResult inferAutoLayouts(FuncOp func) {
158152
}
159153
}
160154
}
155+
161156
return success();
162157
}
163158

lib/Dialect/TritonGPU/Transforms/Pipeliner/AssignLatencies.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,7 +319,7 @@ loadOpsToIndirectionLevel(scf::ForOp forOp, bool pipelineWithoutDot,
319319

320320
// If the loop has numStages attribute, also consider pipelining other loads
321321
// that are not directly used by dot ops.
322-
if (pipelineWithoutDot && !seenDot) {
322+
if (pipelineWithoutDot) {
323323
for (Operation &op : forOp.getBody()->without_terminator()) {
324324
if (!isa<tt::LoadOp, tt::DescriptorLoadOp, tt::DescriptorGatherOp>(op))
325325
dfs(&op, &op, 0);

python/src/gluon_ir.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,10 +114,9 @@ static bool isConvertLayoutTrivial(RankedTensorType dstTy, Value value) {
114114
auto srcTy = cast<RankedTensorType>(value.getType());
115115
if (srcTy.getEncoding() == dstTy.getEncoding())
116116
return true;
117-
// Handle unresolved layouts. auto -> T is trivial but T -> auto is not
118-
// necessarily.
117+
// Fail safe on unresolved layouts.
119118
if (isa<gluon::AutoEncodingAttr>(srcTy.getEncoding()))
120-
return true;
119+
return false;
121120
if (isa<gluon::AutoEncodingAttr>(dstTy.getEncoding()))
122121
return false;
123122

@@ -404,6 +403,10 @@ void init_gluon_ir(py::module &&m) {
404403
[](GluonOpBuilder &self, Type resultType, Value src) -> Value {
405404
return self.create<ttg::MemDescReinterpretOp>(resultType, src);
406405
})
406+
.def("create_set_auto_layout",
407+
[](GluonOpBuilder &self, Attribute layout, Value value) -> Value {
408+
return self.create<gluon::SetAutoLayoutOp>(layout, value);
409+
})
407410
.def("create_split",
408411
[](GluonOpBuilder &self, Value &a) -> py::tuple {
409412
auto argTy = cast<RankedTensorType>(a.getType());

python/test/gluon/test_frontend.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -73,41 +73,39 @@ def test_convert_layout_assert_trivial():
7373
# CHECK: ttg.convert_layout
7474
ttgl.convert_layout(value, equiv_layout, assert_trivial=True)
7575

76-
value = ttgl.arange(0, 128, layout=ttgl.AutoLayout())
77-
# CHECK: ttg.convert_layout
78-
ttgl.convert_layout(value, equiv_layout, assert_trivial=True)
79-
8076

8177
def test_convert_layout_not_trivial():
8278

8379
@gluon.jit
84-
def kernel():
85-
src_layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0])
86-
dst_layout: ttgl.constexpr = ttgl.BlockedLayout([1], [32], [4], [0])
87-
80+
def kernel(src_layout: ttgl.constexpr, dst_layout: ttgl.constexpr):
8881
value = ttgl.arange(0, 128, layout=src_layout)
8982
ttgl.convert_layout(value, dst_layout, assert_trivial=True)
9083

9184
with pytest.raises(CompilationError) as e:
92-
run_parser(kernel)
85+
src_layout = ttgl.BlockedLayout([2], [32], [4], [0])
86+
dst_layout = ttgl.BlockedLayout([1], [32], [4], [0])
87+
kernel.warmup(src_layout, dst_layout, grid=(1, ))
9388

94-
assert "layout conversion from BlockedLayout(size_per_thread=(2)" in str(e.value.__cause__)
95-
assert "to BlockedLayout(size_per_thread=(1)" in str(e.value.__cause__)
89+
assert "layout conversion from BlockedLayout(size_per_thread=[2]" in str(e.value.__cause__)
90+
assert "to BlockedLayout(size_per_thread=[1]" in str(e.value.__cause__)
9691
assert "is not trivial" in str(e.value.__cause__)
9792

98-
@gluon.jit
99-
def kernel():
100-
src_layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0])
101-
dst_layout: ttgl.constexpr = ttgl.AutoLayout()
93+
with pytest.raises(CompilationError) as e:
94+
src_layout = ttgl.BlockedLayout([2], [32], [4], [0])
95+
dst_layout = ttgl.AutoLayout()
96+
kernel.warmup(src_layout, dst_layout, grid=(1, ))
10297

103-
value = ttgl.arange(0, 128, layout=src_layout)
104-
ttgl.convert_layout(value, dst_layout, assert_trivial=True)
98+
assert "layout conversion from BlockedLayout(size_per_thread=[2]" in str(e.value.__cause__)
99+
assert "to AutoLayout() is not trivial" in str(e.value.__cause__)
105100

106101
with pytest.raises(CompilationError) as e:
107-
run_parser(kernel)
102+
src_layout: ttgl.constexpr = ttgl.AutoLayout()
103+
dst_layout: ttgl.constexpr = ttgl.BlockedLayout([2], [32], [4], [0])
104+
kernel.warmup(src_layout, dst_layout, grid=(1, ))
108105

109-
assert "layout conversion from BlockedLayout(size_per_thread=(2)" in str(e.value.__cause__)
110-
assert "to AutoLayout() is not trivial" in str(e.value.__cause__)
106+
assert "layout conversion from AutoLayout()" in str(e.value.__cause__)
107+
assert "to BlockedLayout(size_per_thread=[2]" in str(e.value.__cause__)
108+
assert "is not trivial" in str(e.value.__cause__)
111109

112110

113111
@gluon.jit
@@ -1223,6 +1221,7 @@ def kernel():
12231221
@filecheck_test
12241222
@gluon.jit
12251223
def test_auto_layout():
1224+
# CHECK-DAG: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
12261225
# CHECK: [[X_1D:%.*]] = arith.constant dense<7> : tensor<16xi32, #gluon.auto_encoding>
12271226
# CHECK: [[Y_1D:%.*]] = arith.constant dense<2> : tensor<8xi32, #gluon.auto_encoding>
12281227
x = ttgl.full([16], 7, ttgl.int32, layout=ttgl.AutoLayout())[:, None]
@@ -1232,8 +1231,11 @@ def test_auto_layout():
12321231
# CHECK: (tensor<16x8xi32, #gluon.auto_encoding>) -> tensor<16xi32, #gluon.auto_encoding
12331232
ttgl.sum(z, axis=1)
12341233

1235-
# CHECK: tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #gluon.auto_encoding>
1236-
ttgl.arange(0, 32)
1234+
# CHECK: [[I:%.*]] = tt.make_range {end = 32 : i32, start = 0 : i32} : tensor<32xi32, #gluon.auto_encoding>
1235+
i = ttgl.arange(0, 32)
1236+
1237+
# CHECK: gluon.set_auto_layout [[I]] : tensor<32xi32, #gluon.auto_encoding> -> tensor<32xi32, [[BLOCKED]]
1238+
ttgl.set_auto_layout(i, ttgl.BlockedLayout([1], [32], [4], [0]))
12371239

12381240

12391241
@filecheck_test
@@ -1245,13 +1247,13 @@ def test_auto_layout_broadcast():
12451247
x = ttgl.full([16, 1], 1, ttgl.int32, layout=ttgl.AutoLayout())
12461248
y = ttgl.full([1, 16], 2, ttgl.int32, layout=ttgl.BlockedLayout([1, 1], [1, 32], [4, 1], [1, 0]))
12471249

1248-
# CHECK: [[XCVT:%.*]] = ttg.convert_layout [[X]] : tensor<16x1xi32, #gluon.auto_encoding> -> tensor<16x1xi32, [[BLOCKED]]>
1250+
# CHECK: [[XCVT:%.*]] = gluon.set_auto_layout [[X]] : tensor<16x1xi32, #gluon.auto_encoding> -> tensor<16x1xi32, [[BLOCKED]]>
12491251
# CHECK: [[XBCAST:%.*]] = tt.broadcast [[XCVT]]
12501252
# CHECK: [[YBCAST:%.*]] = tt.broadcast [[Y]]
12511253
# CHECK: arith.addi [[XBCAST]], [[YBCAST]] : tensor<16x16xi32, [[BLOCKED]]>
12521254
_ = x + y
12531255

1254-
# CHECK: [[XCVT2:%.*]] = ttg.convert_layout [[X]] : tensor<16x1xi32, #gluon.auto_encoding> -> tensor<16x1xi32, [[BLOCKED]]>
1256+
# CHECK: [[XCVT2:%.*]] = gluon.set_auto_layout [[X]] : tensor<16x1xi32, #gluon.auto_encoding> -> tensor<16x1xi32, [[BLOCKED]]>
12551257
# CHECK: [[YBCAST2:%.*]] = tt.broadcast [[Y]]
12561258
# CHECK: [[XBCAST2:%.*]] = tt.broadcast [[XCVT2]]
12571259
# CHECK: arith.muli [[YBCAST2]], [[XBCAST2]] : tensor<16x16xi32, [[BLOCKED]]>

python/test/unit/language/test_core.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6096,6 +6096,36 @@ def kernel(Semaphore, Out, total: tl.constexpr):
60966096
assert out.item() >= 0
60976097

60986098

6099+
def test_constexpr_flattens():
6100+
assert tl.constexpr(tl.constexpr(5)) == tl.constexpr(5)
6101+
assert tl.constexpr(tl.constexpr(tl.constexpr(5))) == tl.constexpr(5)
6102+
6103+
6104+
@pytest.mark.parametrize("literal, tensor_ty", [(10, tl.int32), (32.1, tl.float32),
6105+
((5, 6, 7), None), # tuples can't be lifted to tensors
6106+
])
6107+
def test_constexpr_assignment(literal, tensor_ty):
6108+
from triton.language.core import constexpr_type
6109+
6110+
@triton.jit
6111+
def kernel(input_literal: tl.constexpr, tensor_type: tl.constexpr):
6112+
patched_literal: tl.constexpr = PATCHED
6113+
# Sanity checks
6114+
tl.static_assert(patched_literal.type == constexpr_type(PATCHED))
6115+
tl.static_assert(input_literal.type == constexpr_type(PATCHED))
6116+
6117+
assigned_literal: tl.constexpr = input_literal
6118+
tl.static_assert(assigned_literal.type == constexpr_type(PATCHED))
6119+
tl.static_assert(assigned_literal == patched_literal)
6120+
6121+
if tensor_type is not None:
6122+
assigned_variable = input_literal
6123+
tl.static_assert(assigned_variable.type == tensor_type)
6124+
6125+
kernel_patched = patch_kernel(kernel, {'PATCHED': f"{literal}"})
6126+
kernel_patched[(1, )](literal, tensor_ty)
6127+
6128+
60996129
@triton.jit
61006130
def return_poison(x):
61016131
a = False

0 commit comments

Comments
 (0)