Skip to content

Commit 3315f73

Browse files
committed
add xevm operations and conversion tests
1 parent b7129d8 commit 3315f73

File tree

11 files changed

+1041
-10
lines changed

11 files changed

+1041
-10
lines changed

include/gc/Conversion/Passes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#ifndef GC_CONVERSION_PASSES_H
1010
#define GC_CONVERSION_PASSES_H
1111

12-
#include "gc/Conversion/XeVMToLLVM.h"
12+
#include "gc/Conversion/XeVMToLLVM/XeVMToLLVM.h"
1313

1414
namespace mlir {
1515

include/gc/Conversion/XeVMToLLVM/XeVMToLLVM.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class RewritePatternSet;
1717
class Pass;
1818

1919
#define GEN_PASS_DECL_CONVERTXEVMTOLLVMPASS
20-
#include "mlir/Conversion/Passes.h.inc"
20+
#include "gc/Conversion/Passes.h.inc"
2121

2222
void populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns);
2323

include/gc/Dialect/LLVMIR/XeVMOps.td

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,15 @@ def XeVM_Dialect : Dialect {
1919
let name = "xevm";
2020
let cppNamespace = "::mlir::xevm";
2121
let dependentDialects = ["LLVM::LLVMDialect"];
22+
23+
let extraClassDeclaration = [{
24+
/// Get the name for the attribute used to specify cache control
25+
/// decorations.
26+
static constexpr ::llvm::StringRef getCacheControlsAttrName() {
27+
return ::llvm::StringLiteral("xevm.DecorationCacheControl");
28+
}
29+
}];
30+
2231
let useDefaultAttributePrinterParser = 1;
2332
}
2433

@@ -161,6 +170,52 @@ def XeVM_BlockStore2dOp : XeVM_Op<"blockstore2d">,
161170
let hasVerifier = 1;
162171
}
163172

173+
def XeVM_BlockPrefetch2dOp : XeVM_Op<"blockprefetch2d">,
174+
Arguments<(ins
175+
Arg<LLVM_AnyPointer, "", [MemRead]>:$ptr,
176+
I32:$base_width,
177+
I32:$base_height,
178+
I32:$base_pitch,
179+
I32:$x,
180+
I32:$y,
181+
I32Attr:$elem_size_in_bits,
182+
I32Attr:$tile_width,
183+
I32Attr:$tile_height,
184+
I32Attr:$v_blocks,
185+
DefaultValuedAttr<XeVM_L1LoadCacheControl, "::mlir::xevm::L1LoadCacheControl::DEFAULT">:$l1_cache_control,
186+
DefaultValuedAttr<XeVM_L3LoadCacheControl, "::mlir::xevm::L3LoadCacheControl::DEFAULT">:$l3_cache_control
187+
)> {
188+
189+
let summary = "2D block prefetch";
190+
191+
let description = [{
192+
The `xevm.blockprefetch2d` operation prefetches a two dimensional tile
193+
from a larger matrix residing in memory. The parameters are:
194+
$ptr - the base address of the matrix containing the tile to prefetch
195+
$base_width, $base_height, $base_pitch - the shape of the matrix
196+
$x, $y, $tile_width, $tile_height - the starting offsets and shape of tile to prefetch
197+
$elem_size_in_bits - the size in bits of the matrix element
198+
- 32 for f32, bf32
199+
- 16 for f16, int16, bf16
200+
- 8 for int8, int4, int2
201+
$v_blocks - number of tiles to prefetch
202+
$cache_control - an enumerator that sets the L1 and L3 cache behaviour
203+
204+
Notes:
205+
- coordinate is provided in elements, while width and pitch are provided in bytes.
206+
}];
207+
208+
let assemblyFormat = [{
209+
operands ` ` `{` `elem_size_in_bits` `=` $elem_size_in_bits `,` `tile_width` `=` $tile_width `,`
210+
`tile_height` `=` $tile_height `,` `v_blocks` `=` $v_blocks `,` `l1_cache_control` `=` $l1_cache_control `,`
211+
`l3_cache_control` `=` $l3_cache_control `}`
212+
attr-dict `:` `(` type(operands) `)`
213+
}];
214+
215+
let hasVerifier = 1;
216+
}
217+
218+
164219
def XeVM_TargetAttr : XeVM_Attr<"XeVMTarget", "target"> {
165220
let description = [{
166221
GPU target attribute for controlling compilation of targets. All

lib/gc/Conversion/XeVMToLLVM/XeVMToLLVM.cpp

Lines changed: 265 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,16 @@
1111
#include "gc/Dialect/LLVMIR/XeVMDialect.h"
1212
#include "mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1313
#include "mlir/Conversion/LLVMCommon/Pattern.h"
14+
#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
1415
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1516
#include "mlir/Pass/Pass.h"
1617
#include "mlir/Support/LLVM.h"
18+
#include "llvm/Support/FormatVariadic.h"
19+
20+
#include "mlir/IR/BuiltinTypes.h"
21+
#include "mlir/IR/Types.h"
22+
23+
#include "llvm/ADT/TypeSwitch.h"
1724

1825
#define DEBUG_TYPE "xevm-to-llvm"
1926

@@ -26,6 +33,226 @@ using namespace mlir;
2633
using namespace xevm;
2734

2835
namespace {
36+
struct LLVMFuncAttributeOptions {
37+
bool isConvergent = false;
38+
bool isNoUnwind = false;
39+
bool isWillReturn = false;
40+
LLVM::MemoryEffectsAttr memEffectsAttr{};
41+
};
42+
static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
43+
false, true, false, {}};
44+
static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
45+
false, true, true, {}};
46+
47+
std::string getTypeMangling(Type ty, bool isUnsigned = false) {
48+
return TypeSwitch<Type, std::string>(ty)
49+
.Case([isUnsigned](VectorType ty) -> std::string {
50+
return "Dv" + std::to_string(ty.getNumElements()) + "_" +
51+
getTypeMangling(ty.getElementType(), isUnsigned);
52+
})
53+
.Case([](Float16Type) -> std::string { return "Dh"; })
54+
.Case([](Float32Type) -> std::string { return "f"; })
55+
.Case([](Float64Type) -> std::string { return "d"; })
56+
.Case([isUnsigned](IntegerType ty) -> std::string {
57+
switch (ty.getWidth()) {
58+
case 8:
59+
return isUnsigned ? "h" : "c";
60+
case 16:
61+
return isUnsigned ? "t" : "s";
62+
case 32:
63+
return isUnsigned ? "j" : "i";
64+
case 64:
65+
return isUnsigned ? "m" : "l";
66+
default:
67+
llvm_unreachable("unhandled integer type");
68+
}
69+
});
70+
}
71+
72+
template <typename OpType>
73+
static std::optional<ArrayAttr>
74+
getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op,
75+
const bool isLoad) {
76+
if ((op.getL1CacheControlAttr() ==
77+
xevm::L1StoreCacheControlAttr::get(
78+
rewriter.getContext(), xevm::L1StoreCacheControl::DEFAULT) &&
79+
op.getL3CacheControlAttr() ==
80+
xevm::L3StoreCacheControlAttr::get(
81+
rewriter.getContext(), xevm::L3StoreCacheControl::DEFAULT)) ||
82+
83+
(op.getL1CacheControlAttr() ==
84+
xevm::L1LoadCacheControlAttr::get(
85+
rewriter.getContext(), xevm::L1LoadCacheControl::DEFAULT) &&
86+
op.getL3CacheControlAttr() ==
87+
xevm::L3LoadCacheControlAttr::get(
88+
rewriter.getContext(), xevm::L3LoadCacheControl::DEFAULT))) {
89+
return {};
90+
}
91+
constexpr int32_t decorationCacheControlArity{4};
92+
constexpr int32_t loadCacheControlKey{6442};
93+
constexpr int32_t storeCacheControlKey{6443};
94+
constexpr int32_t l1Level{0};
95+
constexpr int32_t l3Level{1};
96+
const int32_t controlKey{isLoad ? loadCacheControlKey : storeCacheControlKey};
97+
SmallVector<int32_t, decorationCacheControlArity> decorationsL1{
98+
controlKey, l1Level, static_cast<int32_t>(op.getL1CacheControl()), 0};
99+
SmallVector<int32_t, decorationCacheControlArity> decorationsL3{
100+
controlKey, l3Level, static_cast<int32_t>(op.getL3CacheControl()), 0};
101+
auto arrayAttrL1 = rewriter.getI32ArrayAttr(decorationsL1);
102+
auto arrayAttrL3 = rewriter.getI32ArrayAttr(decorationsL3);
103+
104+
SmallVector<Attribute, 2> combinedAttrs = {arrayAttrL1, arrayAttrL3};
105+
return rewriter.getArrayAttr(combinedAttrs);
106+
}
107+
108+
static LLVM::CallOp createDeviceFunctionCall(
109+
ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
110+
ArrayRef<Type> argTypes, ArrayRef<Value> args,
111+
mlir::ArrayRef<std::pair<unsigned, mlir::StringRef>> paramAttrs,
112+
LLVMFuncAttributeOptions funcAttributeOptions) {
113+
auto moduleOp = rewriter.getBlock()->getParent()->getParentOfType<ModuleOp>();
114+
MLIRContext *ctx = rewriter.getContext();
115+
Location loc = UnknownLoc::get(ctx);
116+
117+
LLVM::LLVMFuncOp funcOp =
118+
LLVM::lookupOrCreateFn(moduleOp, funcName, argTypes, retType);
119+
funcOp.setCConv(LLVM::cconv::CConv::SPIR_FUNC);
120+
funcOp.setConvergent(funcAttributeOptions.isConvergent);
121+
funcOp.setNoUnwind(funcAttributeOptions.isNoUnwind);
122+
funcOp.setWillReturn(funcAttributeOptions.isWillReturn);
123+
124+
if (funcAttributeOptions.memEffectsAttr)
125+
funcOp.setMemoryEffectsAttr(funcAttributeOptions.memEffectsAttr);
126+
127+
for (auto [idx, attrName] : paramAttrs)
128+
funcOp.setArgAttr(idx, attrName, rewriter.getUnitAttr());
129+
130+
auto callOp = rewriter.create<LLVM::CallOp>(loc, funcOp, args);
131+
callOp->setAttrs(funcOp->getAttrs());
132+
133+
return callOp;
134+
}
135+
136+
template <typename OpType>
137+
class LoadStorePrefetchToOCLPattern : public OpConversionPattern<OpType> {
138+
using OpConversionPattern<OpType>::OpConversionPattern;
139+
LogicalResult
140+
matchAndRewrite(OpType op, typename OpType::Adaptor adaptor,
141+
ConversionPatternRewriter &rewriter) const override {
142+
constexpr bool isLoad = std::is_same_v<OpType, xevm::BlockLoad2dOp>;
143+
constexpr bool isPrefetch = std::is_same_v<OpType, xevm::BlockPrefetch2dOp>;
144+
145+
auto loc = op.getLoc();
146+
VectorType vecType;
147+
bool vnni = false;
148+
bool transpose = false;
149+
if constexpr (isLoad) {
150+
vecType = op.getRes().getType();
151+
vnni = op.getVnniTransform();
152+
transpose = op.getTranspose();
153+
} else if constexpr (!isPrefetch) {
154+
vecType = op.getStoredVal().getType();
155+
}
156+
157+
auto i32Type = rewriter.getI32Type();
158+
Value byteCoord =
159+
rewriter.create<LLVM::UndefOp>(loc, VectorType::get(2, i32Type));
160+
Value zero = rewriter.create<LLVM::ConstantOp>(
161+
loc, i32Type, rewriter.getI32IntegerAttr(0));
162+
Value one = rewriter.create<LLVM::ConstantOp>(
163+
loc, i32Type, rewriter.getI32IntegerAttr(1));
164+
byteCoord = rewriter.create<LLVM::InsertElementOp>(
165+
loc, VectorType::get(2, i32Type), byteCoord, op.getX(), zero);
166+
byteCoord = rewriter.create<LLVM::InsertElementOp>(
167+
loc, VectorType::get(2, i32Type), byteCoord, op.getY(), one);
168+
SmallVector<Value> args{op.getPtr(), op.getBaseWidth(), op.getBaseHeight(),
169+
op.getBasePitch(), byteCoord};
170+
SmallVector<Type> retTypes;
171+
Value spvLoadDstPtr;
172+
std::string funcName{"intel_sub_group_2d_block_"};
173+
std::string bitWidthId;
174+
LLVMFuncAttributeOptions funcAttr{noUnwindWillReturnAttrs};
175+
SmallVector<std::pair<unsigned, mlir::StringRef>, 4> paramAttrs;
176+
if constexpr (isPrefetch) { // Prefetch
177+
funcName += "prefetch";
178+
paramAttrs = {std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName())};
179+
auto memAttr = rewriter.getAttr<LLVM::MemoryEffectsAttr>(
180+
/*other=*/LLVM::ModRefInfo::NoModRef,
181+
/*argMem=*/LLVM::ModRefInfo::Ref,
182+
/*inaccessibleMem=*/LLVM::ModRefInfo::NoModRef);
183+
auto funcAttrs = noUnwindAttrs;
184+
funcAttrs.memEffectsAttr = memAttr;
185+
} else {
186+
auto vecElemType = vecType.getElementType();
187+
auto vecElemBitWidth = vecElemType.getIntOrFloatBitWidth();
188+
Value numElems = rewriter.create<LLVM::ConstantOp>(
189+
loc, i32Type, vecType.getNumElements());
190+
auto dstOrSrcPtr = rewriter.create<LLVM::AllocaOp>(
191+
loc, LLVM::LLVMPointerType::get(rewriter.getContext()), vecElemType,
192+
numElems);
193+
args.push_back(dstOrSrcPtr);
194+
if constexpr (isLoad) { // Load
195+
funcName += "read";
196+
bitWidthId = getTypeMangling(vecElemType, /*isUnsigned=*/true);
197+
if (vnni)
198+
funcName += "_transform";
199+
else if (transpose)
200+
funcName += "_transpose";
201+
spvLoadDstPtr = dstOrSrcPtr;
202+
retTypes.push_back(vecType);
203+
paramAttrs = {
204+
std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
205+
std::make_pair(0, LLVM::LLVMDialect::getReadonlyAttrName()),
206+
std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
207+
std::make_pair(5, LLVM::LLVMDialect::getWriteOnlyAttrName()),
208+
};
209+
} else { // Store
210+
funcName += "write";
211+
bitWidthId = (vecElemBitWidth == 32)
212+
? "j"
213+
: ((vecElemBitWidth == 16) ? "t" : "h");
214+
rewriter.create<LLVM::StoreOp>(loc, op.getStoredVal(), dstOrSrcPtr);
215+
paramAttrs = {
216+
std::make_pair(0, LLVM::LLVMDialect::getNonNullAttrName()),
217+
std::make_pair(0, LLVM::LLVMDialect::getWriteOnlyAttrName()),
218+
std::make_pair(5, LLVM::LLVMDialect::getNonNullAttrName()),
219+
std::make_pair(5, LLVM::LLVMDialect::getReadonlyAttrName()),
220+
};
221+
}
222+
}
223+
224+
funcName =
225+
llvm::formatv("{0}_{1}b_{2}r{3}x{4}c", funcName, op.getElemSizeInBits(),
226+
op.getTileHeight(), op.getTileWidth(), op.getVBlocks())
227+
.str();
228+
funcName = llvm::formatv("_Z{0}{1}PU3AS1viiiDv2_i{2}{3}", funcName.size(),
229+
funcName, isPrefetch ? "" : "P", bitWidthId)
230+
.str();
231+
SmallVector<Type> argTypes;
232+
for (auto arg : args) {
233+
argTypes.push_back(arg.getType());
234+
}
235+
LLVM::CallOp call = createDeviceFunctionCall(
236+
rewriter, funcName, LLVM::LLVMVoidType::get(rewriter.getContext()),
237+
argTypes, args, paramAttrs, funcAttr);
238+
if (std::optional<ArrayAttr> optCacheControls =
239+
getCacheControlMetadata(rewriter, op, isLoad || isPrefetch)) {
240+
call->setAttr(xevm::XeVMDialect::getCacheControlsAttrName(),
241+
*optCacheControls);
242+
}
243+
if constexpr (isLoad)
244+
rewriter.replaceOp(
245+
op, rewriter.create<LLVM::LoadOp>(loc, vecType, spvLoadDstPtr));
246+
else
247+
rewriter.eraseOp(op);
248+
return success();
249+
}
250+
};
251+
252+
//===----------------------------------------------------------------------===//
253+
// Pass Definition
254+
//===----------------------------------------------------------------------===//
255+
29256
struct ConvertXeVMToLLVMPass
30257
: public impl::ConvertXeVMToLLVMPassBase<ConvertXeVMToLLVMPass> {
31258
using Base::Base;
@@ -37,19 +264,51 @@ struct ConvertXeVMToLLVMPass
37264
void runOnOperation() override {
38265
ConversionTarget target(getContext());
39266
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
40-
RewritePatternSet pattern(&getContext());
41-
mlir::populateXeVMToLLVMConversionPatterns(pattern);
42-
if (failed(
43-
applyPartialConversion(getOperation(), target, std::move(pattern))))
267+
target.addIllegalDialect<xevm::XeVMDialect>();
268+
RewritePatternSet patterns(&getContext());
269+
mlir::populateXeVMToLLVMConversionPatterns(patterns);
270+
if (failed(applyPartialConversion(getOperation(), target,
271+
std::move(patterns))))
44272
signalPassFailure();
45273
}
46274
};
47275
} // namespace
48276

277+
//===----------------------------------------------------------------------===//
278+
// Pattern Population
279+
//===----------------------------------------------------------------------===//
280+
49281
void mlir::populateXeVMToLLVMConversionPatterns(RewritePatternSet &patterns) {
50-
/*TODO*/
282+
patterns.add<LoadStorePrefetchToOCLPattern<xevm::BlockLoad2dOp>,
283+
LoadStorePrefetchToOCLPattern<xevm::BlockStore2dOp>,
284+
LoadStorePrefetchToOCLPattern<xevm::BlockPrefetch2dOp>>(
285+
patterns.getContext());
51286
}
52287

288+
//===----------------------------------------------------------------------===//
289+
// ConvertToLLVMPatternInterface implementation
290+
//===----------------------------------------------------------------------===//
291+
292+
namespace {
293+
/// Implement the interface to convert XeVM to LLVM.
294+
struct XeVMToLLVMDialectInterface : public ConvertToLLVMPatternInterface {
295+
using ConvertToLLVMPatternInterface::ConvertToLLVMPatternInterface;
296+
void loadDependentDialects(MLIRContext *context) const final {
297+
context->loadDialect<LLVM::LLVMDialect>();
298+
}
299+
300+
/// Hook for derived dialect interface to provide conversion patterns
301+
/// and mark dialect legal for the conversion target.
302+
void populateConvertToLLVMConversionPatterns(
303+
ConversionTarget &target, LLVMTypeConverter &typeConverter,
304+
RewritePatternSet &patterns) const final {
305+
populateXeVMToLLVMConversionPatterns(patterns);
306+
}
307+
};
308+
} // namespace
309+
53310
void mlir::registerConvertXeVMToLLVMInterface(DialectRegistry &registry) {
54-
/*TODO*/
311+
registry.addExtension(+[](MLIRContext *ctx, xevm::XeVMDialect *dialect) {
312+
dialect->addInterfaces<XeVMToLLVMDialectInterface>();
313+
});
55314
}

0 commit comments

Comments
 (0)