Skip to content

Commit 93a910f

Browse files
committed
Check layout within tensor_descriptor -> block_pointer conversion pass
1 parent 983aa34 commit 93a910f

File tree

5 files changed

+26
-4
lines changed

5 files changed

+26
-4
lines changed

python/triton/experimental/gluon/language/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
device_assert,
5151
device_print,
5252
dot_fma,
53+
xpu_dot_fma,
5354
expand_dims,
5455
full,
5556
fp4_to_fp,

python/triton/experimental/gluon/language/_core.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -590,3 +590,9 @@ def dot_fma(a, b, acc, _semantic=None):
590590

591591
handle = _semantic.dot(a, b, acc, input_precision=None, max_num_imprecise_acc=None, out_dtype=acc.dtype).handle
592592
return tensor(handle, acc.type)
593+
594+
595+
@builtin
596+
def xpu_dot_fma(a, b, acc, _semantic=None):
597+
handle = _semantic.dot(a, b, acc, input_precision=None, max_num_imprecise_acc=None, out_dtype=acc.dtype).handle
598+
return tensor(handle, acc.type)
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from ._layouts import IntelDPASLayout
2+
from . import xpu
23

3-
__all__ = ["IntelDPASLayout"]
4+
__all__ = ["IntelDPASLayout", "xpu"]
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from . import xe
2+
3+
__all__ = ["xe"]
4+

third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,8 @@ struct TritonIntelTensorDescToBlockPointer
109109
tt::MakeTensorPtrOp
110110
findOrCreateMakeTensorPtr(Location loc, Value base, ValueRange shape,
111111
ValueRange strides, ValueRange offsets,
112-
ArrayRef<int32_t> sizes, OpBuilder &builder) {
112+
ArrayRef<int32_t> sizes, Attribute encoding,
113+
OpBuilder &builder) {
113114
Block *block = builder.getInsertionBlock();
114115
const Block::iterator insertPoint = builder.getInsertionPoint();
115116
auto it = std::find_if(block->begin(), insertPoint, [&](Operation &op) {
@@ -134,8 +135,15 @@ struct TritonIntelTensorDescToBlockPointer
134135
});
135136

136137
auto makeTensorPtrOp = [&]() {
138+
// Create the tensor type with encoding
139+
auto pointerType = cast<mlir::triton::PointerType>(base.getType());
140+
auto tensorType = RankedTensorType::get(
141+
SmallVector<int64_t>(sizes.begin(), sizes.end()),
142+
pointerType.getPointeeType(), encoding);
143+
auto resultType = mlir::triton::PointerType::get(tensorType, pointerType.getAddressSpace());
144+
137145
auto makeTensorPtr = builder.create<tt::MakeTensorPtrOp>(
138-
loc, base, shape, strides, offsets, sizes,
146+
loc, resultType, base, shape, strides, offsets,
139147
builder.getDenseI32ArrayAttr({1, 0}));
140148
return makeTensorPtr;
141149
};
@@ -190,6 +198,8 @@ struct TritonIntelTensorDescToBlockPointer
190198
Location loc = op.getLoc();
191199
tt::TensorDescType tDescType = op.getType();
192200

201+
// Extract encoding from the tensor descriptor's block type
202+
Attribute encoding = tDescType.getBlockType().getEncoding();
193203
// Create a new block pointer if a suitable one doesn't already exist.
194204
SmallVector<Value> shapes, strides, offsets;
195205
SmallVector<int32_t> sizes;
@@ -209,7 +219,7 @@ struct TritonIntelTensorDescToBlockPointer
209219
}
210220

211221
auto tensorPtr = findOrCreateMakeTensorPtr(
212-
loc, op.getBase(), shapes, strides, offsets, sizes, builder);
222+
loc, op.getBase(), shapes, strides, offsets, sizes, encoding, builder);
213223
LLVM_DEBUG({
214224
llvm::dbgs() << "With:\n";
215225
llvm::dbgs().indent(2) << tensorPtr << "\n";

0 commit comments

Comments
 (0)