Skip to content

Commit baa4a36

Browse files
committed
Merge remote-tracking branch 'origin/main' into dchigarev/new-imex
2 parents 3f5c79b + 33c3e5f commit baa4a36

File tree

6 files changed

+563
-3
lines changed

6 files changed

+563
-3
lines changed

include/gc/Transforms/Passes.td

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,14 @@ def AddContextArg : Pass<"add-ctx-arg", "func::FuncOp"> {
101101
}];
102102
}
103103

104+
def AllocsToSLM : Pass<"allocs-to-slm", "func::FuncOp"> {
105+
let summary = "Add 'shared' memory space to memrefs allocated inside a gpu.block.";
106+
let description = [{Add 'shared' memory space to memrefs allocated inside a gpu.block.}];
107+
let dependentDialects = [
108+
"gpu::GPUDialect", "memref::MemRefDialect"
109+
];
110+
}
111+
104112
def GpuToGpuOcl : Pass<"gpu-to-gpuocl", "ModuleOp"> {
105113
let summary = "Convert the GPU operations to GpuOclRuntime calls.";
106114
let description = [{

lib/gc/Dialect/LLVMIR/IR/XeVMDialect.cpp

Lines changed: 260 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,18 +9,275 @@
99

1010
#include "mlir/Dialect/GPU/IR/CompilationInterfaces.h"
1111
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
12+
#include "mlir/Dialect/Utils/StaticValueUtils.h"
1213
#include "mlir/IR/DialectImplementation.h"
1314
#include "llvm/ADT/TypeSwitch.h"
15+
#include "llvm/Support/MathExtras.h"
1416

1517
using namespace mlir;
1618
using namespace xevm;
1719

1820
#include "gc/Dialect/LLVMIR/XeVMOpsDialect.cpp.inc"
1921
#include "gc/Dialect/LLVMIR/XeVMOpsEnums.cpp.inc"
2022

21-
// TODO
22-
LogicalResult BlockLoad2dOp::verify() { return success(); }
23-
LogicalResult BlockStore2dOp::verify() { return success(); }
23+
namespace {
24+
constexpr uint32_t subgroupSize = 16;
25+
26+
template <typename Op> LogicalResult verifyMatrixInput(Op op) {
27+
static_assert(llvm::is_one_of<Op, BlockLoad2dOp, BlockStore2dOp>::value,
28+
"Unexpected template parameter");
29+
30+
std::optional<int64_t> width = getConstantIntValue(op.getBaseWidth());
31+
std::optional<int64_t> pitch = getConstantIntValue(op.getBasePitch());
32+
if (pitch && width && *pitch < *width)
33+
return op->emitOpError(
34+
"4th operand (base pitch) should be >= 2nd operand (base width)");
35+
36+
uint32_t elemSize = op.getElemSizeInBits();
37+
if (elemSize < 8 || !llvm::isPowerOf2_32(elemSize) || elemSize > 32)
38+
return op->emitOpError("expecting 'elem_size_in_bits' to be 8, 16, or 32");
39+
40+
uint32_t tileHeight = op.getTileHeight();
41+
if (tileHeight > 32 || !llvm::isPowerOf2_32(tileHeight))
42+
return op->emitOpError("expecting tile_height to be 1, 2, 4, 8, 16, or 32");
43+
44+
uint32_t vBlocks = op.getVBlocks();
45+
if (vBlocks > 8 || !llvm::isPowerOf2_32(vBlocks))
46+
return op->emitOpError("expecting v_blocks to be 1, 2, 4, or 8");
47+
48+
return success();
49+
}
50+
51+
LogicalResult verify2DBlockLoadHWRestriction(BlockLoad2dOp op) {
52+
VectorType resTy = op.getRes().getType();
53+
if (!resTy.getElementType().isIntOrFloat())
54+
return op.emitOpError()
55+
<< "expecting result element type to be int or float";
56+
unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
57+
unsigned resSize = resTy.getNumElements() * resElemTySize;
58+
unsigned expectedSize = op.getElemSizeInBits() * op.getTileHeight() *
59+
op.getTileWidth() * op.getVBlocks() / subgroupSize;
60+
if (resSize != expectedSize)
61+
return op.emitOpError() << "result size of " << resSize
62+
<< " bits does not match the expected size of "
63+
<< expectedSize << " bits";
64+
65+
if (op.getTranspose() && op.getVnniTransform())
66+
return op.emitOpError(
67+
"transpose and vnni_transform are mutually exclusive");
68+
69+
if (!op.getTranspose() && !op.getVnniTransform()) {
70+
uint32_t tileHeight = op.getTileHeight();
71+
if (tileHeight < 1 || tileHeight > 32)
72+
return op.emitOpError("expecting tile_height to be between 1 and 32");
73+
74+
uint32_t tileWidth = op.getTileWidth();
75+
uint32_t vBlocks = op.getVBlocks();
76+
switch (op.getElemSizeInBits()) {
77+
case 8:
78+
if (tileWidth < 4 || tileWidth > 64)
79+
return op.emitOpError("expecting tile_width to be between 4 and 64");
80+
if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
81+
return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
82+
if (tileWidth * vBlocks > 64)
83+
return op.emitOpError(
84+
"tile_width * v_blocks should be less than or equal "
85+
"to 64 for 8 bit elements");
86+
break;
87+
case 16:
88+
if (tileWidth < 2 || tileWidth > 32)
89+
return op.emitOpError("expecting tile_width to be between 2 and 32");
90+
if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
91+
return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
92+
if (tileWidth * vBlocks > 32)
93+
return op.emitOpError(
94+
"tile_width * v_blocks should be less than or equal "
95+
"to 32 for 16 bit elements");
96+
break;
97+
case 32:
98+
if (tileWidth < 1 || tileWidth > 16)
99+
return op.emitOpError("expecting tile_width to be between 1 and 16");
100+
if (vBlocks != 1 && vBlocks != 2)
101+
return op.emitOpError("expecting v_blocks to be 1 or 2");
102+
if (tileWidth * vBlocks > 16)
103+
return op.emitOpError(
104+
"tile_width * v_blocks should be less than or equal "
105+
"to 16 for 32 bit elements");
106+
break;
107+
case 64:
108+
if (tileWidth < 1 || tileWidth > 8)
109+
return op.emitOpError("expecting tile_width to be between 1 and 8");
110+
if (vBlocks != 1)
111+
return op.emitOpError("expecting v_blocks to be 1");
112+
break;
113+
default:
114+
return op.emitOpError(
115+
"expecting elem_size_in_bits to be 8, 16, 32, or 64");
116+
}
117+
118+
return success();
119+
}
120+
121+
if (op.getTranspose()) {
122+
assert(!op.getVnniTransform() &&
123+
"Expecting vnni_transform should be false");
124+
125+
uint32_t vBlocks = op.getVBlocks();
126+
if (vBlocks != 1)
127+
return op.emitOpError("expecting v_blocks to be 1");
128+
129+
uint32_t tileHeight = op.getTileHeight();
130+
uint32_t tileWidth = op.getTileWidth();
131+
switch (op.getElemSizeInBits()) {
132+
case 32:
133+
if (tileHeight < 1 || tileHeight > 32)
134+
return op.emitOpError("expecting tile_height to be between 1 and 32");
135+
if (tileWidth < 1 || tileWidth > 8)
136+
return op.emitOpError("expecting tile_width to be between 1 and 8");
137+
break;
138+
case 64:
139+
if (tileHeight != 8)
140+
return op.emitOpError(
141+
"expecting tile_height to be 8 for 64 bit elements");
142+
if (tileWidth != 1 && tileWidth != 2 && tileWidth != 4)
143+
return op.emitOpError("expecting tile_width to be 1, 2, or 4");
144+
break;
145+
default:
146+
return op.emitOpError("transpose is only supported for 32 and 64 bit "
147+
"elements");
148+
}
149+
150+
return success();
151+
}
152+
153+
assert(op.getVnniTransform() && !op.getTranspose() &&
154+
"Expecting vnni_transform should be true and transpose should be "
155+
"false");
156+
157+
uint32_t vBlocks = op.getVBlocks();
158+
if (vBlocks != 1 && vBlocks != 2 && vBlocks != 4)
159+
return op.emitOpError("expecting v_blocks to be 1, 2, or 4");
160+
161+
uint32_t tileHeight = op.getTileHeight();
162+
uint32_t tileWidth = op.getTileWidth();
163+
switch (op.getElemSizeInBits()) {
164+
case 8:
165+
if (tileHeight < 4 || tileHeight > 32)
166+
return op.emitOpError("expecting tile_height to be between 4 and 32");
167+
if (tileWidth < 4 || tileWidth > 16)
168+
return op.emitOpError("expecting tile_width to be between 4 and 16");
169+
break;
170+
case 16:
171+
if (tileHeight < 2 || tileHeight > 32)
172+
return op.emitOpError("expecting tile_height to be between 2 and 32");
173+
if (tileWidth < 2 || tileWidth > 16)
174+
return op.emitOpError("expecting tile_width to be between 2 and 16");
175+
if (tileWidth * vBlocks > 32)
176+
return op.emitOpError(
177+
"tile_width * v_blocks should be less than or equal "
178+
"to 32 for 16 bit elements");
179+
break;
180+
default:
181+
return op.emitOpError("vnni_transform is only supported for 8 and 16 bit "
182+
"elements");
183+
}
184+
185+
return success();
186+
}
187+
188+
static LogicalResult verify2DBlockStoreHWRestriction(BlockStore2dOp op) {
189+
uint32_t tileHeight = op.getTileHeight();
190+
if (tileHeight < 1 || tileHeight > 8)
191+
return op.emitOpError("expecting tile_height to be between 1 and 8");
192+
193+
uint32_t tileWidth = op.getTileWidth();
194+
switch (op.getElemSizeInBits()) {
195+
case 8:
196+
if (tileWidth < 4 || tileWidth > 64)
197+
return op.emitOpError("expecting tile_width to be between 4 and 64");
198+
break;
199+
case 16:
200+
if (tileWidth < 2 || tileWidth > 32)
201+
return op.emitOpError("expecting tile_width to be between 2 and 32");
202+
break;
203+
case 32:
204+
if (tileWidth < 1 || tileWidth > 16)
205+
return op.emitOpError("expecting tile_width to be between 1 and 16");
206+
break;
207+
case 64:
208+
if (tileWidth < 1 || tileWidth > 8)
209+
return op.emitOpError("expecting tile_width to be between 1 and 8");
210+
break;
211+
default:
212+
return op.emitOpError("expecting elem_size_in_bits to be 8, 16, 32, or 64");
213+
}
214+
215+
uint32_t vBlocks = op.getVBlocks();
216+
if (vBlocks != 1)
217+
return op.emitOpError("expecting v_blocks to be 1");
218+
return success();
219+
}
220+
221+
} // namespace
222+
223+
LogicalResult BlockLoad2dOp::verify() {
224+
if (verify2DBlockLoadHWRestriction(*this).failed())
225+
return failure();
226+
227+
if (verifyMatrixInput(*this).failed())
228+
return failure();
229+
230+
VectorType resTy = getRes().getType();
231+
if (!resTy.getElementType().isIntOrFloat())
232+
return emitOpError() << "expecting result element type to be int of float";
233+
unsigned resElemTySize = resTy.getElementType().getIntOrFloatBitWidth();
234+
if (getElemSizeInBits() == 32 || getVnniTransform()) {
235+
if (resElemTySize != 32)
236+
return emitOpError() << "expecting result element type to be 32 bits";
237+
}
238+
239+
uint32_t tileWidth = getTileWidth();
240+
if (getVnniTransform()) {
241+
if (tileWidth != 16)
242+
return emitOpError(
243+
"tile_width when vnni_transform is true should be equal "
244+
"to subgroup size (16 elements)");
245+
return success();
246+
}
247+
248+
return success();
249+
}
250+
251+
LogicalResult BlockStore2dOp::verify() {
252+
if (verify2DBlockStoreHWRestriction(*this).failed())
253+
return failure();
254+
255+
if (verifyMatrixInput(*this).failed())
256+
return failure();
257+
258+
uint32_t tileWidth = getTileWidth();
259+
switch (getElemSizeInBits()) {
260+
case 8:
261+
if (tileWidth != 16 && tileWidth != 32)
262+
return emitOpError("tile_width for 8 bit elements should be equal to "
263+
"16 or 32");
264+
break;
265+
case 16:
266+
if (tileWidth != 16)
267+
return emitOpError("tile_width for 16 bit elements should be equal "
268+
"to 16");
269+
break;
270+
case 32:
271+
if (tileWidth != 16)
272+
return emitOpError("tile_width for 32 bit elements should be equal "
273+
"to 16");
274+
break;
275+
default:
276+
llvm_unreachable("unexpected element size");
277+
}
278+
279+
return success();
280+
}
24281

25282
void XeVMDialect::initialize() {
26283
// NOLINTBEGIN
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
//===- AllocsToSLM.cpp - A pass adding shared mem-space attr ----*- C++ -*-===//
2+
//
3+
// This file is licensed under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "gc/Transforms/Passes.h"
10+
11+
#include "mlir/Dialect/Func/IR/FuncOps.h"
12+
#include "mlir/Dialect/GPU/TransformOps/Utils.h"
13+
#include "mlir/Dialect/MemRef/IR/MemRef.h"
14+
#include "mlir/IR/Dialect.h"
15+
#include "mlir/Pass/Pass.h"
16+
#include "mlir/Pass/PassManager.h"
17+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
18+
19+
#include <numeric>
20+
#include <optional>
21+
22+
using namespace mlir;
23+
using namespace mlir::gc;
24+
25+
namespace mlir {
26+
namespace gc {
27+
#define GEN_PASS_DEF_ALLOCSTOSLM
28+
#include "gc/Transforms/Passes.h.inc"
29+
} // namespace gc
30+
} // namespace mlir
31+
32+
namespace {
33+
34+
bool isInGpuLaunch(Operation *op) {
35+
auto launchOp = op->getParentOfType<gpu::LaunchOp>();
36+
return launchOp != nullptr;
37+
}
38+
39+
bool hasAssignedMemSpace(Value value) {
40+
if (auto memrefType = dyn_cast<MemRefType>(value.getType())) {
41+
if (memrefType.getMemorySpace()) {
42+
return true;
43+
}
44+
}
45+
return false;
46+
}
47+
48+
struct ConvertAlloc : public OpRewritePattern<memref::AllocOp> {
49+
using OpRewritePattern<memref::AllocOp>::OpRewritePattern;
50+
51+
ConvertAlloc(MLIRContext *ctx) : OpRewritePattern<memref::AllocOp>(ctx) {}
52+
53+
LogicalResult matchAndRewrite(memref::AllocOp allocOp,
54+
PatternRewriter &rewriter) const override {
55+
if (hasAssignedMemSpace(allocOp->getResult(0))) {
56+
return rewriter.notifyMatchFailure(
57+
allocOp, "Memref already has some memory space attribute");
58+
}
59+
60+
if (!isInGpuLaunch(allocOp)) {
61+
return rewriter.notifyMatchFailure(allocOp,
62+
"Only support allocs in GPU regions");
63+
}
64+
65+
Value memref = allocOp->getResult(0);
66+
MemRefType originalMemRefType = cast<MemRefType>(memref.getType());
67+
68+
IntegerAttr sharedAddressSpace =
69+
IntegerAttr::get(rewriter.getIntegerType(64),
70+
static_cast<int64_t>(gpu::AddressSpace::Private));
71+
72+
// Create a new MemRefType with the desired address space
73+
MemRefType newMemRefType = MemRefType::get(
74+
originalMemRefType.getShape(), originalMemRefType.getElementType(),
75+
originalMemRefType.getLayout(), sharedAddressSpace);
76+
77+
Value newMemRef = rewriter.create<memref::AllocOp>(
78+
allocOp.getLoc(), newMemRefType, allocOp.getOperands());
79+
80+
memref.replaceAllUsesWith(newMemRef);
81+
82+
return success();
83+
}
84+
};
85+
86+
struct AllocsToSLM : public gc::impl::AllocsToSLMBase<AllocsToSLM> {
87+
void runOnOperation() override {
88+
const auto ctx = &getContext();
89+
90+
RewritePatternSet patterns(ctx);
91+
patterns.add<ConvertAlloc>(patterns.getContext());
92+
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
93+
}
94+
};
95+
96+
} // namespace

lib/gc/Transforms/GPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ set_property(GLOBAL APPEND PROPERTY IMEX_LIBS ${IMEX_LIBS})
1212

1313
gc_add_mlir_library(GcGpuPasses
1414
AddContextArg.cpp
15+
AllocsToSLM.cpp
1516
GpuToGpuOcl.cpp
1617
LinalgToXeGPU.cpp
1718
Pipeline.cpp

0 commit comments

Comments
 (0)