Skip to content

Commit acb526d

Browse files
committed
Test support for tensor_descriptors in gluon, initial load/store/prefetch operations
1 parent 93a910f commit acb526d

File tree

7 files changed

+296
-12
lines changed

7 files changed

+296
-12
lines changed

python/src/gluon_ir.cc

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ namespace ttg = triton::gpu;
3131
namespace ttng = triton::nvidia_gpu;
3232
namespace gluon = mlir::triton::gluon;
3333
namespace ttag = mlir::triton::amdgpu;
34+
namespace ttgi = mlir::triton::gpu::intel;
3435

3536
// Helper to check if an MLIR type or attribute has a verifier method.
3637
template <typename AttrOrType>
@@ -897,7 +898,72 @@ void init_gluon_ir(py::module &&m) {
897898
.def("create_lds_barrier_arrive",
898899
[](GluonOpBuilder &self, Value memDesc, int count) -> Value {
899900
return self.create<ttag::ArriveBarrierOp>(memDesc, count);
900-
});
901+
})
902+
.def("create_prefetch",
903+
[](GluonOpBuilder &self, Value tensorDesc, std::vector<Value> &offsets,
904+
bool isVolatile) {
905+
// Get the base pointer from tensor descriptor
906+
auto makeTensorDescOp = tensorDesc.getDefiningOp<triton::MakeTensorDescOp>();
907+
if (!makeTensorDescOp) {
908+
throw std::runtime_error("Expected tensor descriptor from MakeTensorDescOp");
909+
}
910+
911+
Value base = makeTensorDescOp.getBase();
912+
auto shape = makeTensorDescOp.getShape();
913+
auto strides = makeTensorDescOp.getStrides();
914+
915+
// Convert shape from i32 to i64 for MakeTensorPtrOp
916+
// Needed because:
917+
// error: 'tt.make_tensor_ptr' op operand #1 must be
918+
// variadic of 64-bit signless integer, but got 'i32'
919+
SmallVector<Value> i64Shape;
920+
for (auto shapeVal : shape) {
921+
auto i64Val = self.create<arith::ExtSIOp>(self.getBuilder().getI64Type(), shapeVal);
922+
i64Shape.push_back(i64Val);
923+
}
924+
925+
// Get block shape from tensor descriptor type
926+
auto descType = cast<triton::TensorDescType>(tensorDesc.getType());
927+
auto blockType = cast<RankedTensorType>(descType.getBlockType());
928+
auto tensorShape = blockType.getShape();
929+
930+
// Convert to int32 vector for MakeTensorPtrOp
931+
std::vector<int32_t> blockShape;
932+
for (int64_t dim : tensorShape) {
933+
blockShape.push_back(static_cast<int32_t>(dim));
934+
}
935+
936+
// Default order for 2D tensors (row-major)
937+
std::vector<int32_t> order = {1, 0};
938+
if (blockShape.size() != 2) {
939+
// For non-2D tensors, use sequential order
940+
order.resize(blockShape.size());
941+
std::iota(order.begin(), order.end(), 0);
942+
}
943+
944+
// Empty mask
945+
Value maskVal = Value();
946+
947+
auto tensorPtrOp = self.create<mlir::triton::MakeTensorPtrOp>(base, /*shape*/i64Shape, strides, offsets,
948+
/*tensor_shape*/blockShape, order);
949+
950+
auto op = self.create<ttgi::PrefetchOp>(
951+
/*base*/tensorPtrOp.getResult(), maskVal, tt::CacheModifier::NONE, tt::EvictionPolicy::NORMAL, isVolatile);
952+
return op.getOperation();
953+
})
954+
// Example for passing block_ptr
955+
// .def("create_prefetch",
956+
// [](GluonOpBuilder &self, Value ptr, //, py::object mask,
957+
// //triton::CacheModifier cache, triton::EvictionPolicy evict,
958+
// bool isVolatile) {
959+
// //auto c = triton::CacheModifier();
960+
// //Value maskVal = mask.is_none() ? Value() : mask.cast<Value>();
961+
// Value maskVal = Value();
962+
//
963+
// self.create<ttgi::PrefetchOp>(
964+
// ptr, maskVal, tt::CacheModifier::NONE, tt::EvictionPolicy::NORMAL, isVolatile);
965+
// })
966+
;
901967

902968
m.def(
903969
"compute_tmem_reg_layout",

python/src/ir.cc

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -648,7 +648,11 @@ void init_triton_ir(py::module &&m) {
648648
if (!ret)
649649
return py::none();
650650
return py::str(ret.getValue().str());
651-
});
651+
})
652+
.def("set_attr",
653+
[](Operation &self, const std::string &name, Attribute &attr) {
654+
self.setAttr(name, attr);
655+
});
652656

653657
// dynamic_attr is used to transfer ownership of the MLIR context to the
654658
// module
@@ -1530,8 +1534,9 @@ void init_triton_ir(py::module &&m) {
15301534
})
15311535
.def("create_descriptor_store",
15321536
[](TritonOpBuilder &self, Value desc, Value value,
1533-
std::vector<Value> &indices) -> void {
1534-
self.create<DescriptorStoreOp>(desc, value, indices);
1537+
std::vector<Value> &indices) -> Operation* {//void {
1538+
auto op = self.create<DescriptorStoreOp>(desc, value, indices);
1539+
return op.getOperation();
15351540
})
15361541
.def("create_descriptor_reduce",
15371542
[](TritonOpBuilder &self, DescriptorReduceKind kind, Value desc,

python/triton/experimental/gluon/language/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,6 @@
5050
device_assert,
5151
device_print,
5252
dot_fma,
53-
xpu_dot_fma,
5453
expand_dims,
5554
full,
5655
fp4_to_fp,

python/triton/experimental/gluon/language/_core.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -590,9 +590,3 @@ def dot_fma(a, b, acc, _semantic=None):
590590

591591
handle = _semantic.dot(a, b, acc, input_precision=None, max_num_imprecise_acc=None, out_dtype=acc.dtype).handle
592592
return tensor(handle, acc.type)
593-
594-
595-
@builtin
596-
def xpu_dot_fma(a, b, acc, _semantic=None):
597-
handle = _semantic.dot(a, b, acc, input_precision=None, max_num_imprecise_acc=None, out_dtype=acc.dtype).handle
598-
return tensor(handle, acc.type)
Lines changed: 202 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
from __future__ import annotations
2+
3+
from typing import List, Tuple, Sequence
4+
from dataclasses import dataclass
5+
import triton.language.core as tl_core
6+
7+
import triton.experimental.gluon.language._core as ttgl
8+
from triton.experimental.gluon.language._layouts import DotOperandLayout
9+
from triton.experimental.gluon.language.intel._layouts import IntelDPASLayout
10+
from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
11+
from triton.language.core import ir, constexpr, tensor_descriptor_base, block_type, tensor, tuple
12+
13+
14+
# load_tensor_descriptor = builtin(tl_core.load_tensor_descriptor)
15+
# store_tensor_descriptor = builtin(tl_core.store_tensor_descriptor)
16+
17+
18+
__all__ = ["make_tensor_descriptor", "dot_fma"]
19+
20+
21+
22+
class tensor_descriptor(tensor_descriptor_base):
23+
"""A descriptor representing a tensor in global memory."""
24+
25+
def __init__(self, handle, shape: List[tensor], strides: List[tensor], block_type: block_type, layout):
26+
"""Not called by user code."""
27+
# IR handle
28+
super().__init__(handle, block_type)
29+
# Global shape
30+
self.shape = tuple(shape)
31+
self.strides = tuple(strides)
32+
self.layout = layout
33+
34+
self.type = tensor_descriptor_type(
35+
block_type,
36+
shape_type=self.shape.type,
37+
strides_type=self.strides.type,
38+
layout=self.layout, # comment
39+
)
40+
41+
def _flatten_ir(self, handles: List[ir.value]) -> None:
42+
handles.append(self.handle)
43+
self.shape._flatten_ir(handles)
44+
self.strides._flatten_ir(handles)
45+
46+
# TODO: MaterializeBlockPointers.cpp
47+
# Add 2d_block_io parameter + validation to set proper attribute
48+
# Validation: (?)
49+
# > 2 dims
50+
# > stride 16 bytes aligned
51+
# and others
52+
@builtin
53+
def load(self, offsets: Sequence[constexpr | tensor], is_2d_block=False, _semantic=None) -> tensor:
54+
op = _semantic.descriptor_load(self, offsets, "", "")
55+
56+
if is_2d_block:
57+
# TODO: proper handling like below test example
58+
# Option to set row/column major and other params
59+
attr = _semantic.builder.get_string_attr("row_major")
60+
op.handle.set_attr("ttig.block_io", attr)
61+
62+
return op
63+
64+
@builtin
65+
def store(self, offsets: Sequence[constexpr | tensor], value: tensor, is_2d_block=False, _semantic=None) -> tensor:
66+
op = _semantic.descriptor_store(self, value, offsets)
67+
68+
if is_2d_block:
69+
attr = _semantic.builder.get_string_attr("row_major")
70+
op.handle.set_attr("ttig.block_io", attr)
71+
72+
return op
73+
74+
@builtin
75+
def prefetch(self, offsets: Sequence[constexpr | tensor], mask=None, cache=None, evict=None, is_volatile=False, is_2d_block=False, _semantic=None):
76+
# TODO: handle other ttig.prefetch params
77+
# ptr is just temporary, support for tensor descriptor is needed
78+
# calculate offsets like tt.advance
79+
# maybe add support for mask, seems optional
80+
# also 2d block attr and others
81+
#return _semantic.builder.create_prefetch(ptr.handle, False)
82+
83+
"""
84+
pyton/triton/language/semantic.py @ load:1077 (TritonSemantic)
85+
cache_modifier: str, eviction_policy: str
86+
cache = self._str_to_load_cache_modifier(cache_modifier)
87+
eviction = self._str_to_eviction_policy(eviction_policy)
88+
"""
89+
90+
ptr_handle = self.handle
91+
offsets_handles = [offset.handle if hasattr(offset, 'handle') else offset for offset in offsets]
92+
op = _semantic.builder.create_prefetch(ptr_handle, offsets_handles, False)
93+
94+
if is_2d_block:
95+
attr = _semantic.builder.get_string_attr("row_major")
96+
op.set_attr("ttig.block_io", attr)
97+
98+
return op
99+
100+
101+
102+
@dataclass(eq=True)
103+
class tensor_descriptor_type(ttgl.base_type):
104+
"""The type for a tensor descriptor."""
105+
106+
block_type: ttgl.block_type
107+
shape_type: ttgl.tuple_type
108+
strides_type: ttgl.tuple_type
109+
layout: IntelDPASLayout
110+
111+
def __str__(self) -> str:
112+
return f"tensor_descriptor<{self.block_type}, {self.layout}>"
113+
114+
def _unflatten_ir(self, handles: List[ir.value], cursor: int) -> Tuple[tensor_descriptor, int]:
115+
handle = handles[cursor]
116+
cursor += 1
117+
shape, cursor = self.shape_type._unflatten_ir(handles, cursor)
118+
strides, cursor = self.strides_type._unflatten_ir(handles, cursor)
119+
value = tensor_descriptor(handle, shape, strides, self.block_type, self.layout)
120+
return value, cursor
121+
122+
def _to_ir(self, builder: ir.builder) -> ir.type:
123+
is_signed = self.block_type.element_ty.is_int_signed()
124+
return builder.get_tensor_descriptor_layout_type(
125+
self.block_type.to_ir(builder),
126+
is_signed,
127+
self.layout._to_ir(builder),
128+
)
129+
130+
def _flatten_ir_types(self, builder: ir.builder, out: List[ir.type]) -> None:
131+
out.append(self._to_ir(builder))
132+
self.shape_type._flatten_ir_types(builder, out)
133+
self.strides_type._flatten_ir_types(builder, out)
134+
135+
def mangle(self) -> str:
136+
return f"TD{self.block_type.mangle()}_{self.shape_type.mangle()}_{self.strides_type.mangle()}_{self.layout.mangle()}TD"
137+
138+
139+
@builtin
140+
def make_tensor_descriptor(ptr: ttgl.tensor, shape: List[int], strides: List[int],
141+
block_shape: List[int], layout: IntelDPASLayout,
142+
_semantic=None) -> tensor_descriptor:
143+
# Unwrap constexpr if needed
144+
layout = _unwrap_if_constexpr(layout)
145+
146+
# Get the pointer handle directly
147+
ptr_handle = ptr.handle
148+
149+
# Convert shape and strides to IR values AND create tensor objects
150+
shape_handles = _semantic._convert_to_ir_values(shape, require_i64=False)
151+
stride_handles = _semantic._convert_to_ir_values(strides, require_i64=True)
152+
153+
# Create tensor objects from the handles
154+
shape_tensors = [ttgl.tensor(h, ttgl.int32) for h in shape_handles]
155+
stride_tensors = [ttgl.tensor(h, ttgl.int64) for h in stride_handles]
156+
157+
# Build type information
158+
block_type = ttgl.block_type(ptr.type.element_ty, block_shape)
159+
160+
# TODO: this is w/a for xpu_dot_fma assertion - layout for block_type is not implemented yet
161+
# See: gluon/language/_core.py:19
162+
block_type.layout = layout
163+
164+
shape_type = ttgl.tuple_type([ttgl.int32] * len(shape))
165+
strides_type = ttgl.tuple_type([ttgl.int64] * len(strides))
166+
167+
# Pass tensor objects, not constexpr values
168+
shape_tuple = ttgl.tuple(shape_tensors, shape_type)
169+
strides_tuple = ttgl.tuple(stride_tensors, strides_type)
170+
171+
desc_type = tensor_descriptor_type(block_type, shape_type, strides_type, layout) #, shape_handles)
172+
173+
# Create the descriptor
174+
padding = _semantic._str_to_padding_option("zero")
175+
desc_handle = _semantic.builder.create_make_tensor_descriptor(
176+
desc_type._to_ir(_semantic.builder),
177+
ptr_handle,
178+
shape_handles,
179+
stride_handles,
180+
padding
181+
)
182+
183+
return tensor_descriptor(desc_handle, shape_tuple, strides_tuple, block_type, layout)
184+
185+
@builtin
186+
def dot_fma(a, b, acc, _semantic=None):
187+
assert isinstance(a, tensor), "a must be a tensor"
188+
assert isinstance(b, tensor), "b must be a tensor"
189+
assert isinstance(acc, tensor), "acc must be a tensor"
190+
191+
mma_layout = acc.type.layout
192+
assert isinstance(mma_layout, IntelDPASLayout), "acc must have a BlockedLayout"
193+
assert isinstance(a.type.layout, DotOperandLayout), "a must have a DotOperandLayout"
194+
assert isinstance(b.type.layout, DotOperandLayout), "b must have a DotOperandLayout"
195+
assert a.type.layout.parent == mma_layout, "a's parent layout must be the same as acc's layout"
196+
assert b.type.layout.parent == mma_layout, "b's parent layout must be the same as acc's layout"
197+
assert a.type.layout.operand_index == 0, "a's operand index must be 0"
198+
assert b.type.layout.operand_index == 1, "b's operand index must be 1"
199+
200+
handle = _semantic.dot(a, b, acc, input_precision=None, max_num_imprecise_acc=None, out_dtype=acc.dtype).handle
201+
return tensor(handle, acc.type)
202+

third_party/intel/backend/compiler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class XPUOptions:
2929
num_ctas: int = 1
3030
num_stages: int = 2
3131
cluster_dims: tuple = (1, 1, 1)
32-
warp_size: int = 32
32+
warp_size: int = 16 #32 # TODO:[mdziado]
3333
optimize_epilogue: bool = False
3434
enable_fp_fusion: bool = True
3535
launch_cooperative_grid: bool = False
@@ -311,6 +311,11 @@ def gluon_to_ttgir(self, src, metadata, options):
311311
pm = ir.pass_manager(mod.context)
312312
pm.enable_debug()
313313

314+
# TODO: support tensor descriptors
315+
# This is W/A to convert them into block_pointers
316+
intel.passes.ttir.add_convert_tdesc_to_block_pointer(pm)
317+
passes.ttir.add_rewrite_tensor_descriptor_to_pointer(pm)
318+
314319
passes.gluon.add_inliner(pm)
315320
passes.gluon.add_resolve_auto_encodings(pm)
316321
passes.common.add_sccp(pm)

third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#include "mlir/IR/Verifier.h"
66
#include "mlir/Interfaces/LoopLikeInterface.h"
77
#include "triton/Dialect/Triton/IR/Dialect.h"
8+
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
9+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
810
#include "triton/Dialect/Triton/IR/Types.h"
911
#include "llvm/ADT/STLExtras.h"
1012
#include "llvm/ADT/TypeSwitch.h"
@@ -15,6 +17,7 @@
1517

1618
using namespace mlir;
1719
namespace tt = mlir::triton;
20+
namespace ttgi = mlir::triton::gpu::intel;
1821

1922
namespace mlir::triton::intel {
2023
#define GEN_PASS_DEF_TRITONINTELTENSORDESCTOBLOCKPOINTER
@@ -265,18 +268,28 @@ struct TritonIntelTensorDescToBlockPointer
265268
for (size_t i = 0; i < tensorType.getRank(); ++i)
266269
boundaryCheck.push_back(i);
267270

271+
Attribute blockIOAttr =
272+
op->getAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName());
273+
268274
constexpr bool isLoad = std::is_same_v<OpTy, tt::DescriptorLoadOp>;
269275
if constexpr (isLoad) {
270276
auto loadOp = builder.createOrFold<tt::LoadOp>(
271277
loc, ptr, boundaryCheck,
272278
/*padding*/ std::nullopt, op.getCache(), op.getEvict(),
273279
/*volatile*/ false);
280+
if (blockIOAttr) {
281+
auto* loadOpInst = loadOp.getDefiningOp();
282+
loadOpInst->setAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName(), blockIOAttr);
283+
}
274284
LLVM_DEBUG(llvm::dbgs().indent(2) << loadOp << "\n");
275285
op.replaceAllUsesWith(loadOp);
276286
} else {
277287
[[maybe_unused]] auto storeOp = builder.createOrFold<tt::StoreOp>(
278288
loc, ptr, op.getSrc(), boundaryCheck, tt::CacheModifier::NONE,
279289
tt::EvictionPolicy::NORMAL);
290+
if (blockIOAttr) {
291+
storeOp->setAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName(), blockIOAttr);
292+
}
280293
LLVM_DEBUG(llvm::dbgs().indent(2) << storeOp << "\n");
281294
}
282295

0 commit comments

Comments
 (0)