Skip to content

Commit 861f963

Browse files
authored
[Gluon] Add C++ -> gluon layout translation, use to implement permute (#7120)
1 parent 9ef21f2 commit 861f963

File tree

4 files changed

+95
-0
lines changed

4 files changed

+95
-0
lines changed

python/src/gluon_ir.cc

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,75 @@ struct GluonOpBuilder : public TritonOpBuilder {
8282
}
8383
};
8484

85+
struct GluonLayouts {
86+
py::handle BlockedLayout;
87+
py::handle SliceLayout;
88+
py::handle DistributedLinearLayout;
89+
py::handle NVMMASharedLayout;
90+
py::handle SwizzledSharedLayout;
91+
92+
GluonLayouts() {
93+
auto layouts =
94+
py::module::import("triton.experimental.gluon.language._layouts");
95+
BlockedLayout = py::object(layouts.attr("BlockedLayout")).release();
96+
SliceLayout = py::object(layouts.attr("SliceLayout")).release();
97+
DistributedLinearLayout =
98+
py::object(layouts.attr("DistributedLinearLayout")).release();
99+
NVMMASharedLayout = py::object(layouts.attr("NVMMASharedLayout")).release();
100+
SwizzledSharedLayout =
101+
py::object(layouts.attr("SwizzledSharedLayout")).release();
102+
}
103+
};
104+
105+
template <typename T> std::vector<T> toStdVector(llvm::ArrayRef<T> array) {
106+
return std::vector<T>(array.begin(), array.end());
107+
}
108+
109+
py::object layoutToGluon(Attribute layout) {
110+
static GluonLayouts layouts;
111+
if (auto blocked = dyn_cast<ttg::BlockedEncodingAttr>(layout)) {
112+
auto ctaLayout = blocked.getCTALayout();
113+
return layouts.BlockedLayout(toStdVector(blocked.getSizePerThread()),
114+
toStdVector(blocked.getThreadsPerWarp()),
115+
toStdVector(blocked.getWarpsPerCTA()),
116+
toStdVector(blocked.getOrder()),
117+
toStdVector(ctaLayout.getCTAsPerCGA()),
118+
toStdVector(ctaLayout.getCTASplitNum()),
119+
toStdVector(ctaLayout.getCTAOrder()));
120+
} else if (auto sliced = dyn_cast<ttg::SliceEncodingAttr>(layout)) {
121+
return layouts.SliceLayout(sliced.getDim(),
122+
layoutToGluon(sliced.getParent()));
123+
} else if (auto linear = dyn_cast<ttg::LinearEncodingAttr>(layout)) {
124+
auto ll = linear.getLinearLayout();
125+
auto ctx = layout.getContext();
126+
auto kReg = mlir::StringAttr::get(ctx, "register");
127+
auto kLane = mlir::StringAttr::get(ctx, "lane");
128+
auto kWarp = mlir::StringAttr::get(ctx, "warp");
129+
auto kBlock = mlir::StringAttr::get(ctx, "block");
130+
return layouts.DistributedLinearLayout(
131+
ll.getBases().lookup(kReg), ll.getBases().lookup(kLane),
132+
ll.getBases().lookup(kWarp), ll.getBases().lookup(kBlock),
133+
ll.getOutDimSizes());
134+
} else if (auto nvmma = dyn_cast<ttg::NVMMASharedEncodingAttr>(layout)) {
135+
auto ctaLayout = nvmma.getCTALayout();
136+
return layouts.NVMMASharedLayout(
137+
nvmma.getSwizzlingByteWidth(), nvmma.getElementBitWidth(),
138+
ctaLayout.getRank(), nvmma.getTransposed(), nvmma.getFp4Padded(),
139+
toStdVector(ctaLayout.getCTAsPerCGA()),
140+
toStdVector(ctaLayout.getCTASplitNum()),
141+
toStdVector(ctaLayout.getCTAOrder()));
142+
} else if (auto swizzled =
143+
dyn_cast<ttg::SwizzledSharedEncodingAttr>(layout)) {
144+
auto ctaLayout = nvmma.getCTALayout();
145+
return layouts.SwizzledSharedLayout(
146+
swizzled.getVec(), swizzled.getPerPhase(), swizzled.getMaxPhase(),
147+
swizzled.getOrder(), toStdVector(ctaLayout.getCTAsPerCGA()),
148+
toStdVector(ctaLayout.getCTASplitNum()),
149+
toStdVector(ctaLayout.getCTAOrder()));
150+
}
151+
throw py::value_error("Unhandled encoding encountered");
152+
}
153+
85154
void init_gluon_ir(py::module &&m) {
86155
using ret = py::return_value_policy;
87156

@@ -189,6 +258,12 @@ void init_gluon_ir(py::module &&m) {
189258
ctx, block[0], block[1], unpacked, ctaSplitNum[0],
190259
ctaSplitNum[1]);
191260
})
261+
.def("get_gluon_layout_from_tensor",
262+
[](GluonOpBuilder &self, Value tensor) -> py::object {
263+
auto ty = dyn_cast<RankedTensorType>(tensor.getType());
264+
assert(ty.getEncoding());
265+
return layoutToGluon(ty.getEncoding());
266+
})
192267
.def("create_convert_layout",
193268
[](GluonOpBuilder &self, Type resultTy, Value value) -> Value {
194269
return self.create<ttg::ConvertLayoutOp>(resultTy, value);

python/test/gluon/test_frontend.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -829,3 +829,16 @@ def test_linear_layout(fresh_knobs):
829829
} loc(#loc)
830830
#loc = loc(unknown)
831831
""")
832+
833+
834+
@filecheck_test
835+
@gluon.jit
836+
def test_tensor_permute():
837+
# CHECK-DAG: [[BLOCKED:#.*]] = #ttg.blocked<{sizePerThread = [1, 2], threadsPerWarp = [4, 8], warpsPerCTA = [4, 1], order = [1, 0]}>
838+
# CHECK-DAG: [[BLOCKED1:#.*]] = #ttg.blocked<{sizePerThread = [2, 1], threadsPerWarp = [8, 4], warpsPerCTA = [1, 4], order = [0, 1]}>
839+
layout: ttgl.constexpr = ttgl.BlockedLayout([1, 2], [4, 8], [4, 1], [1, 0])
840+
a = ttgl.full([32, 16], 0, ttgl.int32, layout=layout)
841+
# CHECK: tt.trans{{.*}} : tensor<32x16xi32, [[BLOCKED]]> -> tensor<16x32xi32, [[BLOCKED1]]>
842+
res = ttgl.permute(a, [1, 0])
843+
permuted_layout: ttgl.constexpr = ttgl.BlockedLayout([2, 1], [8, 4], [1, 4], [0, 1], [1, 1], [1, 1], [1, 0])
844+
ttgl.static_assert(permuted_layout == res.type.layout)

python/triton/experimental/gluon/language/_core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
"where", # NOQA: F822
5353
"maximum", # NOQA: F822
5454
"minimum", # NOQA: F822
55+
"permute",
5556
]
5657

5758
__all__ = [

python/triton/experimental/gluon/language/_semantic.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,12 @@ def expand_dims(self, input: TensorTy, axis: int) -> TensorTy:
5757
handle = self.builder.create_expand_dims(input.handle, axis, ret_ty.to_ir(self.builder))
5858
return self.tensor(handle, ret_ty)
5959

60+
def permute(self, input: TensorTy, dims: Tuple[int]) -> TensorTy:
61+
value = super().permute(input, dims)
62+
layout = self.builder.get_gluon_layout_from_tensor(value.handle)
63+
res_ty = ttgl.distributed_type(value.type.scalar, value.shape, layout)
64+
return self.tensor(value.handle, res_ty)
65+
6066
def broadcast_impl_shape(self, input: TensorTy, shape: Tuple[int]) -> TensorTy:
6167
_check(isinstance(input.type, ttgl.distributed_type),
6268
lambda: f"expected expand_dims input to be a distributed_type but got: {input.type!r}")

0 commit comments

Comments
 (0)