Skip to content

Commit 024f495

Browse files
committed
Linter fix + add annotate_module for gluon
1 parent acb526d commit 024f495

File tree

6 files changed

+49
-49
lines changed

6 files changed

+49
-49
lines changed

python/src/gluon_ir.cc

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -900,12 +900,14 @@ void init_gluon_ir(py::module &&m) {
900900
return self.create<ttag::ArriveBarrierOp>(memDesc, count);
901901
})
902902
.def("create_prefetch",
903-
[](GluonOpBuilder &self, Value tensorDesc, std::vector<Value> &offsets,
904-
bool isVolatile) {
903+
[](GluonOpBuilder &self, Value tensorDesc,
904+
std::vector<Value> &offsets, bool isVolatile) {
905905
// Get the base pointer from tensor descriptor
906-
auto makeTensorDescOp = tensorDesc.getDefiningOp<triton::MakeTensorDescOp>();
906+
auto makeTensorDescOp =
907+
tensorDesc.getDefiningOp<triton::MakeTensorDescOp>();
907908
if (!makeTensorDescOp) {
908-
throw std::runtime_error("Expected tensor descriptor from MakeTensorDescOp");
909+
throw std::runtime_error(
910+
"Expected tensor descriptor from MakeTensorDescOp");
909911
}
910912

911913
Value base = makeTensorDescOp.getBase();
@@ -918,7 +920,8 @@ void init_gluon_ir(py::module &&m) {
918920
// variadic of 64-bit signless integer, but got 'i32'
919921
SmallVector<Value> i64Shape;
920922
for (auto shapeVal : shape) {
921-
auto i64Val = self.create<arith::ExtSIOp>(self.getBuilder().getI64Type(), shapeVal);
923+
auto i64Val = self.create<arith::ExtSIOp>(
924+
self.getBuilder().getI64Type(), shapeVal);
922925
i64Shape.push_back(i64Val);
923926
}
924927

@@ -944,11 +947,14 @@ void init_gluon_ir(py::module &&m) {
944947
// Empty mask
945948
Value maskVal = Value();
946949

947-
auto tensorPtrOp = self.create<mlir::triton::MakeTensorPtrOp>(base, /*shape*/i64Shape, strides, offsets,
948-
/*tensor_shape*/blockShape, order);
950+
auto tensorPtrOp = self.create<mlir::triton::MakeTensorPtrOp>(
951+
base, /*shape*/ i64Shape, strides, offsets,
952+
/*tensor_shape*/ blockShape, order);
949953

950954
auto op = self.create<ttgi::PrefetchOp>(
951-
/*base*/tensorPtrOp.getResult(), maskVal, tt::CacheModifier::NONE, tt::EvictionPolicy::NORMAL, isVolatile);
955+
/*base*/ tensorPtrOp.getResult(), maskVal,
956+
tt::CacheModifier::NONE, tt::EvictionPolicy::NORMAL,
957+
isVolatile);
952958
return op.getOperation();
953959
})
954960
// Example for passing block_ptr
@@ -961,7 +967,8 @@ void init_gluon_ir(py::module &&m) {
961967
// Value maskVal = Value();
962968
//
963969
// self.create<ttgi::PrefetchOp>(
964-
// ptr, maskVal, tt::CacheModifier::NONE, tt::EvictionPolicy::NORMAL, isVolatile);
970+
// ptr, maskVal, tt::CacheModifier::NONE,
971+
// tt::EvictionPolicy::NORMAL, isVolatile);
965972
// })
966973
;
967974

python/src/ir.cc

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -649,10 +649,8 @@ void init_triton_ir(py::module &&m) {
649649
return py::none();
650650
return py::str(ret.getValue().str());
651651
})
652-
.def("set_attr",
653-
[](Operation &self, const std::string &name, Attribute &attr) {
654-
self.setAttr(name, attr);
655-
});
652+
.def("set_attr", [](Operation &self, const std::string &name,
653+
Attribute &attr) { self.setAttr(name, attr); });
656654

657655
// dynamic_attr is used to transfer ownership of the MLIR context to the
658656
// module
@@ -1534,7 +1532,7 @@ void init_triton_ir(py::module &&m) {
15341532
})
15351533
.def("create_descriptor_store",
15361534
[](TritonOpBuilder &self, Value desc, Value value,
1537-
std::vector<Value> &indices) -> Operation* {//void {
1535+
std::vector<Value> &indices) -> Operation * { // void {
15381536
auto op = self.create<DescriptorStoreOp>(desc, value, indices);
15391537
return op.getOperation();
15401538
})
Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
from . import xe
22

33
__all__ = ["xe"]
4-

python/triton/experimental/gluon/language/intel/xpu/xe.py

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -2,23 +2,19 @@
22

33
from typing import List, Tuple, Sequence
44
from dataclasses import dataclass
5-
import triton.language.core as tl_core
65

76
import triton.experimental.gluon.language._core as ttgl
87
from triton.experimental.gluon.language._layouts import DotOperandLayout
98
from triton.experimental.gluon.language.intel._layouts import IntelDPASLayout
109
from triton.experimental.gluon.language._core import builtin, _unwrap_if_constexpr
1110
from triton.language.core import ir, constexpr, tensor_descriptor_base, block_type, tensor, tuple
1211

13-
1412
# load_tensor_descriptor = builtin(tl_core.load_tensor_descriptor)
1513
# store_tensor_descriptor = builtin(tl_core.store_tensor_descriptor)
1614

17-
1815
__all__ = ["make_tensor_descriptor", "dot_fma"]
1916

2017

21-
2218
class tensor_descriptor(tensor_descriptor_base):
2319
"""A descriptor representing a tensor in global memory."""
2420

@@ -31,12 +27,9 @@ def __init__(self, handle, shape: List[tensor], strides: List[tensor], block_typ
3127
self.strides = tuple(strides)
3228
self.layout = layout
3329

34-
self.type = tensor_descriptor_type(
35-
block_type,
36-
shape_type=self.shape.type,
37-
strides_type=self.strides.type,
38-
layout=self.layout, # comment
39-
)
30+
self.type = tensor_descriptor_type(block_type, shape_type=self.shape.type, strides_type=self.strides.type,
31+
layout=self.layout, # comment
32+
)
4033

4134
def _flatten_ir(self, handles: List[ir.value]) -> None:
4235
handles.append(self.handle)
@@ -72,14 +65,14 @@ def store(self, offsets: Sequence[constexpr | tensor], value: tensor, is_2d_bloc
7265
return op
7366

7467
@builtin
75-
def prefetch(self, offsets: Sequence[constexpr | tensor], mask=None, cache=None, evict=None, is_volatile=False, is_2d_block=False, _semantic=None):
68+
def prefetch(self, offsets: Sequence[constexpr | tensor], mask=None, cache=None, evict=None, is_volatile=False,
69+
is_2d_block=False, _semantic=None):
7670
# TODO: handle other ttig.prefetch params
7771
# ptr is just temporary, support for tensor descriptor is needed
7872
# calculate offsets like tt.advance
7973
# maybe add support for mask, seems optional
8074
# also 2d block attr and others
8175
#return _semantic.builder.create_prefetch(ptr.handle, False)
82-
8376
"""
8477
pyton/triton/language/semantic.py @ load:1077 (TritonSemantic)
8578
cache_modifier: str, eviction_policy: str
@@ -98,7 +91,6 @@ def prefetch(self, offsets: Sequence[constexpr | tensor], mask=None, cache=None,
9891
return op
9992

10093

101-
10294
@dataclass(eq=True)
10395
class tensor_descriptor_type(ttgl.base_type):
10496
"""The type for a tensor descriptor."""
@@ -137,9 +129,8 @@ def mangle(self) -> str:
137129

138130

139131
@builtin
140-
def make_tensor_descriptor(ptr: ttgl.tensor, shape: List[int], strides: List[int],
141-
block_shape: List[int], layout: IntelDPASLayout,
142-
_semantic=None) -> tensor_descriptor:
132+
def make_tensor_descriptor(ptr: ttgl.tensor, shape: List[int], strides: List[int], block_shape: List[int],
133+
layout: IntelDPASLayout, _semantic=None) -> tensor_descriptor:
143134
# Unwrap constexpr if needed
144135
layout = _unwrap_if_constexpr(layout)
145136

@@ -168,20 +159,16 @@ def make_tensor_descriptor(ptr: ttgl.tensor, shape: List[int], strides: List[int
168159
shape_tuple = ttgl.tuple(shape_tensors, shape_type)
169160
strides_tuple = ttgl.tuple(stride_tensors, strides_type)
170161

171-
desc_type = tensor_descriptor_type(block_type, shape_type, strides_type, layout) #, shape_handles)
162+
desc_type = tensor_descriptor_type(block_type, shape_type, strides_type, layout) #, shape_handles)
172163

173164
# Create the descriptor
174165
padding = _semantic._str_to_padding_option("zero")
175-
desc_handle = _semantic.builder.create_make_tensor_descriptor(
176-
desc_type._to_ir(_semantic.builder),
177-
ptr_handle,
178-
shape_handles,
179-
stride_handles,
180-
padding
181-
)
166+
desc_handle = _semantic.builder.create_make_tensor_descriptor(desc_type._to_ir(_semantic.builder), ptr_handle,
167+
shape_handles, stride_handles, padding)
182168

183169
return tensor_descriptor(desc_handle, shape_tuple, strides_tuple, block_type, layout)
184170

171+
185172
@builtin
186173
def dot_fma(a, b, acc, _semantic=None):
187174
assert isinstance(a, tensor), "a must be a tensor"
@@ -199,4 +186,3 @@ def dot_fma(a, b, acc, _semantic=None):
199186

200187
handle = _semantic.dot(a, b, acc, input_precision=None, max_num_imprecise_acc=None, out_dtype=acc.dtype).handle
201188
return tensor(handle, acc.type)
202-

third_party/intel/backend/compiler.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ class XPUOptions:
2929
num_ctas: int = 1
3030
num_stages: int = 2
3131
cluster_dims: tuple = (1, 1, 1)
32-
warp_size: int = 16 #32 # TODO:[mdziado]
32+
warp_size: int = 16 #32 # TODO:[mdziado]
3333
optimize_epilogue: bool = False
3434
enable_fp_fusion: bool = True
3535
launch_cooperative_grid: bool = False
@@ -306,8 +306,15 @@ def make_ttgir(cls, mod, metadata, opt, properties):
306306
metadata["cluster_dims"] = (cluster_info.clusterDimX, cluster_info.clusterDimY, cluster_info.clusterDimZ)
307307
return mod
308308

309-
def gluon_to_ttgir(self, src, metadata, options):
310-
mod = src
309+
def gluon_to_ttgir(self, mod, metadata, options):
310+
pm = ir.pass_manager(mod.context)
311+
pm.enable_debug()
312+
313+
module_opts = intel.passes.ttgpuir.AnnotateModuleOptions()
314+
self.annotate_module(module_opts, self.properties, options)
315+
intel.passes.ttgpuir.add_triton_annotate_module(pm, module_opts)
316+
pm.run(mod, 'annotate_module')
317+
311318
pm = ir.pass_manager(mod.context)
312319
pm.enable_debug()
313320

third_party/intel/lib/Dialect/Triton/Transforms/TensorDescToBlockPointer.cpp

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
#include "intel/include/Dialect/Triton/Transforms/Passes.h"
22
#include "intel/include/Dialect/TritonGEN/IR/TritonGENDialect.h"
3+
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
34
#include "intel/include/Utils/Utility.h"
45
#include "mlir/IR/BuiltinTypes.h"
56
#include "mlir/IR/Verifier.h"
67
#include "mlir/Interfaces/LoopLikeInterface.h"
78
#include "triton/Dialect/Triton/IR/Dialect.h"
8-
#include "intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
9-
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
109
#include "triton/Dialect/Triton/IR/Types.h"
10+
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
1111
#include "llvm/ADT/STLExtras.h"
1212
#include "llvm/ADT/TypeSwitch.h"
1313
#include "llvm/Support/Debug.h"
@@ -143,7 +143,8 @@ struct TritonIntelTensorDescToBlockPointer
143143
auto tensorType = RankedTensorType::get(
144144
SmallVector<int64_t>(sizes.begin(), sizes.end()),
145145
pointerType.getPointeeType(), encoding);
146-
auto resultType = mlir::triton::PointerType::get(tensorType, pointerType.getAddressSpace());
146+
auto resultType = mlir::triton::PointerType::get(
147+
tensorType, pointerType.getAddressSpace());
147148

148149
auto makeTensorPtr = builder.create<tt::MakeTensorPtrOp>(
149150
loc, resultType, base, shape, strides, offsets,
@@ -278,8 +279,9 @@ struct TritonIntelTensorDescToBlockPointer
278279
/*padding*/ std::nullopt, op.getCache(), op.getEvict(),
279280
/*volatile*/ false);
280281
if (blockIOAttr) {
281-
auto* loadOpInst = loadOp.getDefiningOp();
282-
loadOpInst->setAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName(), blockIOAttr);
282+
auto *loadOpInst = loadOp.getDefiningOp();
283+
loadOpInst->setAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName(),
284+
blockIOAttr);
283285
}
284286
LLVM_DEBUG(llvm::dbgs().indent(2) << loadOp << "\n");
285287
op.replaceAllUsesWith(loadOp);
@@ -288,7 +290,8 @@ struct TritonIntelTensorDescToBlockPointer
288290
loc, ptr, op.getSrc(), boundaryCheck, tt::CacheModifier::NONE,
289291
tt::EvictionPolicy::NORMAL);
290292
if (blockIOAttr) {
291-
storeOp->setAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName(), blockIOAttr);
293+
storeOp->setAttr(ttgi::TritonIntelGPUDialect::getBlockIOAttrName(),
294+
blockIOAttr);
292295
}
293296
LLVM_DEBUG(llvm::dbgs().indent(2) << storeOp << "\n");
294297
}

0 commit comments

Comments
 (0)