Skip to content

Commit e48f743

Browse files
committed
Move util functions to a separate file
Signed-off-by: dchigarev <[email protected]>
1 parent f63e8f6 commit e48f743

File tree

3 files changed

+75
-76
lines changed

3 files changed

+75
-76
lines changed

include/gc/Transforms/Utils/ValueUtils.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,23 @@ FailureOr<SmallVector<int64_t>> getStaticStrides(Value val);
3333
// is not a memref.
3434
std::pair<Value, Value> getPtrAndOffset(OpBuilder &builder, Value operand);
3535

36+
template <typename T>
37+
Value createTypedVector(PatternRewriter &rewriter, Location loc,
38+
ArrayRef<T> values, Type elementType) {
39+
mlir::VectorType vectorType =
40+
mlir::VectorType::get({static_cast<int64_t>(values.size())}, elementType);
41+
mlir::DenseElementsAttr denseAttr =
42+
mlir::DenseElementsAttr::get(vectorType, values);
43+
auto vector =
44+
rewriter.create<mlir::arith::ConstantOp>(loc, vectorType, denseAttr)
45+
.getResult();
46+
return vector;
47+
}
48+
49+
Value flattenMemref(PatternRewriter &rewriter, Location loc, Value srcMemref);
50+
51+
bool hasSharedMemSpace(mlir::Value memref);
52+
3653
} // namespace utils
3754
} // namespace mlir
3855

lib/gc/Transforms/GPU/LinalgToXeGPU.cpp

Lines changed: 15 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -49,78 +49,15 @@ namespace gc {
4949

5050
namespace {
5151

52-
// TODO: move these to utils
53-
template <typename T>
54-
static Value createTypedVector(PatternRewriter &rewriter, Location loc,
55-
ArrayRef<T> values, Type elementType) {
56-
mlir::VectorType vectorType =
57-
mlir::VectorType::get({static_cast<int64_t>(values.size())}, elementType);
58-
mlir::DenseElementsAttr denseAttr =
59-
mlir::DenseElementsAttr::get(vectorType, values);
60-
auto vector =
61-
rewriter.create<mlir::arith::ConstantOp>(loc, vectorType, denseAttr)
62-
.getResult();
63-
return vector;
64-
}
65-
66-
static Value createIndexVector(PatternRewriter &rewriter, Location loc,
67-
ArrayRef<int64_t> values) {
68-
return createTypedVector(rewriter, loc, values, rewriter.getIndexType());
69-
}
70-
71-
static Value createIndexConstant(PatternRewriter &rewriter, Location loc,
72-
int64_t value) {
73-
return rewriter.create<arith::ConstantIndexOp>(loc, value);
74-
}
75-
76-
static Value flattenMemref(PatternRewriter &rewriter, Location loc,
77-
Value srcMemref) {
78-
auto srcType = cast<MemRefType>(srcMemref.getType());
79-
80-
assert(srcType && "Expected a memref type");
81-
assert(srcType.getRank() == 2 && "Expected a 2D memref");
82-
83-
int64_t flatSize = srcType.getShape()[0] * srcType.getShape()[1];
84-
85-
Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
86-
Value size = rewriter.create<arith::ConstantIndexOp>(loc, flatSize);
87-
Value stride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
88-
89-
// Use memref.reinterpret_cast to flatten the memref
90-
auto flatMemRefType = MemRefType::get({flatSize}, srcType.getElementType(),
91-
nullptr, srcType.getMemorySpace());
92-
auto flatMemref =
93-
rewriter
94-
.create<memref::ReinterpretCastOp>(loc, flatMemRefType, srcMemref,
95-
offset, size, stride)
96-
.getResult();
97-
return flatMemref;
98-
}
99-
100-
static bool hasSharedMemSpace(mlir::Value memref) {
101-
auto type = mlir::dyn_cast<mlir::MemRefType>(memref.getType());
102-
if (!type)
103-
return false;
104-
105-
auto memSpace = type.getMemorySpace();
106-
if (!memSpace)
107-
return false;
108-
109-
if (auto gpuAttr = mlir::dyn_cast<mlir::gpu::AddressSpaceAttr>(memSpace))
110-
return gpuAttr.getValue() == mlir::gpu::AddressSpace::Private;
111-
112-
if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(memSpace))
113-
return intAttr.getValue() ==
114-
static_cast<int64_t>(mlir::gpu::AddressSpace::Private);
115-
116-
return false;
117-
}
118-
11952
static Value createFullMask(PatternRewriter &rewriter, Location loc,
12053
int64_t size) {
121-
auto maskVal = createIndexConstant(rewriter, loc, 32);
54+
auto maskVal = rewriter.create<arith::ConstantIndexOp>(loc, 32);
12255
mlir::VectorType maskVectorType =
12356
mlir::VectorType::get({size}, rewriter.getI1Type());
57+
// HACK: creating mask vector through this strange op instead of
58+
// simple 'arith.constant dense<true>' to avoid the mask being
59+
// moved out of the GPU kernel (it causes strange behaviour
60+
// when a bit-mask is passed as a kernel parameter).
12461
auto res = rewriter.create<vector::CreateMaskOp>(
12562
loc, maskVectorType, SmallVector<Value>({maskVal}));
12663
return res.getResult();
@@ -826,8 +763,9 @@ static SmallVector<Value> createScatterDescriptorTiles(
826763
}
827764

828765
int64_t skipPerLoad = memrefStrides[0] * rowsPerLoad;
829-
auto offsetPerLoad =
830-
createIndexVector(rewriter, loc, SmallVector<int64_t>(32, skipPerLoad));
766+
auto offsetPerLoad = utils::createTypedVector<int64_t>(
767+
rewriter, loc, SmallVector<int64_t>(32, skipPerLoad),
768+
rewriter.getIndexType());
831769

832770
auto offsetVecType = VectorType::get({maxLoadSize}, rewriter.getIndexType());
833771
auto descType = getTensorDescType(
@@ -843,7 +781,8 @@ static SmallVector<Value> createScatterDescriptorTiles(
843781

844782
SmallVector<Value> tiles;
845783
for (int i = 0; i < numColTiles; i++) {
846-
auto offsetsShift = createIndexVector(rewriter, loc, offsetShiftValues[i]);
784+
auto offsetsShift = utils::createTypedVector<int64_t>(
785+
rewriter, loc, offsetShiftValues[i], rewriter.getIndexType());
847786
auto offsets0 =
848787
rewriter.create<arith::AddIOp>(loc, blockOffsetV, offsetsShift);
849788

@@ -924,7 +863,7 @@ static SmallVector<Value> createSLMDescTiles(PatternRewriter &rewriter,
924863
}
925864

926865
// Scatter descriptors only work with 1D memrefs
927-
src = flattenMemref(rewriter, loc, src);
866+
src = utils::flattenMemref(rewriter, loc, src);
928867

929868
return createScatterDescriptorTiles(
930869
rewriter, loc, /*flatMemref=*/src, /*loadShape2D=*/loadShape,
@@ -938,7 +877,7 @@ static SmallVector<Value> createDescriptorTiles(
938877
std::optional<ArrayRef<int64_t>> loadOffsets = std::nullopt,
939878
int arrayLength = 1, bool transpose = false) {
940879

941-
if (hasSharedMemSpace(src)) {
880+
if (utils::hasSharedMemSpace(src)) {
942881
assert(!transpose && "Transpose is not supported for shared memory");
943882
assert(arrayLength == 1 &&
944883
"Array descriptors are not supported for shared memory");
@@ -1092,8 +1031,8 @@ loadScatterDescTiles(PatternRewriter &rewriter, Location loc,
10921031
// Accumulator vector for the current tile (its number of elements equals to
10931032
// tileShape) HACK: we first create a flat vector of zeros and then cast it
10941033
// to the 2D shape. Otherwise 'imex::ConvertGPUXToSPIRVPass' fails.
1095-
auto accumVector =
1096-
createTypedVector<Attribute>(rewriter, loc, accumValues, elementType);
1034+
auto accumVector = utils::createTypedVector<Attribute>(
1035+
rewriter, loc, accumValues, elementType);
10971036
accumVector =
10981037
rewriter.create<vector::ShapeCastOp>(loc, accumVectorType, accumVector);
10991038

@@ -1961,7 +1900,7 @@ LogicalResult createMemoryFillKernel(linalg::LinalgOp linalgOp,
19611900
}
19621901

19631902
// Extract SIMD sized sub-tiles
1964-
int64_t maxSizeSIMD = hasSharedMemSpace(output) ? 32 : 256;
1903+
int64_t maxSizeSIMD = utils::hasSharedMemSpace(output) ? 32 : 256;
19651904
int64_t subTileCols = std::min(outputShape[1], maxSizeSIMD);
19661905
int64_t subTileRows =
19671906
std::min(outputShape[0], std::max(maxSizeSIMD / subTileCols, 1L));

lib/gc/Transforms/Utils/ValueUtils.cpp

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
//
77
//===----------------------------------------------------------------------===//
88

9+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
910
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1011
#include "mlir/Dialect/Linalg/IR/Linalg.h"
1112
#include "mlir/Dialect/MemRef/IR/MemRef.h"
@@ -150,5 +151,47 @@ std::pair<Value, Value> getPtrAndOffset(OpBuilder &builder, Value operand) {
150151
return std::make_pair(alignedPointer, offset);
151152
}
152153

154+
Value flattenMemref(PatternRewriter &rewriter, Location loc, Value srcMemref) {
155+
auto srcType = cast<MemRefType>(srcMemref.getType());
156+
157+
assert(srcType && "Expected a memref type");
158+
assert(srcType.getRank() == 2 && "Expected a 2D memref");
159+
160+
int64_t flatSize = srcType.getShape()[0] * srcType.getShape()[1];
161+
162+
Value offset = rewriter.create<arith::ConstantIndexOp>(loc, 0);
163+
Value size = rewriter.create<arith::ConstantIndexOp>(loc, flatSize);
164+
Value stride = rewriter.create<arith::ConstantIndexOp>(loc, 1);
165+
166+
// Use memref.reinterpret_cast to flatten the memref
167+
auto flatMemRefType = MemRefType::get({flatSize}, srcType.getElementType(),
168+
nullptr, srcType.getMemorySpace());
169+
auto flatMemref =
170+
rewriter
171+
.create<memref::ReinterpretCastOp>(loc, flatMemRefType, srcMemref,
172+
offset, size, stride)
173+
.getResult();
174+
return flatMemref;
175+
}
176+
177+
bool hasSharedMemSpace(mlir::Value memref) {
178+
auto type = mlir::dyn_cast<mlir::MemRefType>(memref.getType());
179+
if (!type)
180+
return false;
181+
182+
auto memSpace = type.getMemorySpace();
183+
if (!memSpace)
184+
return false;
185+
186+
if (auto gpuAttr = mlir::dyn_cast<mlir::gpu::AddressSpaceAttr>(memSpace))
187+
return gpuAttr.getValue() == mlir::gpu::AddressSpace::Private;
188+
189+
if (auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>(memSpace))
190+
return intAttr.getValue() ==
191+
static_cast<int64_t>(mlir::gpu::AddressSpace::Private);
192+
193+
return false;
194+
}
195+
153196
} // namespace utils
154197
} // namespace mlir

0 commit comments

Comments
 (0)