1111#include " gc/Dialect/LLVMIR/XeVMDialect.h"
1212#include " mlir/Conversion/ConvertToLLVM/ToLLVMInterface.h"
1313#include " mlir/Conversion/LLVMCommon/Pattern.h"
14+ #include " mlir/Dialect/GPU/IR/GPUDialect.h"
1415#include " mlir/Dialect/LLVMIR/FunctionCallUtils.h"
1516#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
1617#include " mlir/Pass/Pass.h"
@@ -53,6 +54,8 @@ static constexpr LLVMFuncAttributeOptions noUnwindAttrs = {
5354 false , true , false , {}};
5455static constexpr LLVMFuncAttributeOptions noUnwindWillReturnAttrs = {
5556 false , true , true , {}};
57+ static constexpr LLVMFuncAttributeOptions convergentNoUnwindWillReturnAttrs = {
58+ true , true , true , {}};
5659
5760std::string getTypeMangling (Type ty, bool isUnsigned = false ) {
5861 return TypeSwitch<Type, std::string>(ty)
@@ -79,6 +82,31 @@ std::string getTypeMangling(Type ty, bool isUnsigned = false) {
7982 });
8083}
8184
85+ std::string mangle (StringRef baseName, ArrayRef<Type> types,
86+ ArrayRef<bool > isUnsigned = {}) {
87+ assert ((isUnsigned.empty () || isUnsigned.size () == types.size ()) &&
88+ " Signedness info doesn't match" );
89+ std::string s;
90+ llvm::raw_string_ostream os (s);
91+ llvm::SmallDenseMap<Type, unsigned > substitutions;
92+ os << " _Z" << baseName.size () << baseName;
93+ for (auto [idx, type] : llvm::enumerate (types)) {
94+ auto it = substitutions.find (type);
95+ if (it != substitutions.end ()) {
96+ os << " S" ;
97+ // First substitution is `S_`, second is `S0_`, and so on.
98+ if (unsigned firstIdx = it->getSecond (); firstIdx > 0 )
99+ os << firstIdx - 1 ;
100+ os << " _" ;
101+ } else {
102+ if (!type.isIntOrFloat ())
103+ substitutions[type] = substitutions.size ();
104+ os << getTypeMangling (type, isUnsigned.empty () ? false : isUnsigned[idx]);
105+ }
106+ }
107+ return os.str ();
108+ }
109+
82110template <typename OpType>
83111static std::optional<ArrayAttr>
84112getCacheControlMetadata (ConversionPatternRewriter &rewriter, OpType op,
@@ -115,13 +143,15 @@ getCacheControlMetadata(ConversionPatternRewriter &rewriter, OpType op,
115143 return rewriter.getArrayAttr (combinedAttrs);
116144}
117145
118- static LLVM::CallOp
119- createDeviceFunctionCall (ConversionPatternRewriter &rewriter,
120- StringRef funcName, Type retType,
121- ArrayRef<Type> argTypes, ArrayRef<Value> args,
122- ArrayRef<std::pair<unsigned , StringRef>> paramAttrs,
123- LLVMFuncAttributeOptions funcAttributeOptions) {
124- auto moduleOp = rewriter.getBlock ()->getParent ()->getParentOfType <ModuleOp>();
146+ static LLVM::CallOp createDeviceFunctionCall (
147+ ConversionPatternRewriter &rewriter, StringRef funcName, Type retType,
148+ ArrayRef<Type> argTypes, ArrayRef<Value> args,
149+ mlir::ArrayRef<std::pair<unsigned , mlir::StringRef>> paramAttrs,
150+ LLVMFuncAttributeOptions funcAttributeOptions) {
151+ auto moduleOp = rewriter.getBlock ()
152+ ->getParentOp ()
153+ ->getParentWithTrait <OpTrait::SymbolTable>();
154+ assert (moduleOp && " Expecting module" );
125155 MLIRContext *ctx = rewriter.getContext ();
126156 Location loc = UnknownLoc::get (ctx);
127157
@@ -144,6 +174,96 @@ createDeviceFunctionCall(ConversionPatternRewriter &rewriter,
144174 return callOp;
145175}
146176
177+ class DPASToOCLPattern : public OpConversionPattern <xevm::DPASOp> {
178+ using OpConversionPattern::OpConversionPattern;
179+ LogicalResult
180+ matchAndRewrite (xevm::DPASOp op, xevm::DPASOp::Adaptor adaptor,
181+ ConversionPatternRewriter &rewriter) const override {
182+ constexpr uint32_t bitWidthPackedA{16 };
183+ constexpr uint32_t bitWidthPackedB{32 };
184+ auto loc = op.getLoc ();
185+
186+ auto castIfNeeded = [&](Value val, Type packedType) -> Value {
187+ VectorType origTy = cast<VectorType>(val.getType ());
188+ const uint32_t vecBitSize =
189+ origTy.getNumElements () *
190+ origTy.getElementType ().getIntOrFloatBitWidth ();
191+ VectorType newTy = VectorType::get (
192+ vecBitSize / packedType.getIntOrFloatBitWidth (), packedType);
193+ if (origTy != newTy)
194+ val = rewriter.create <LLVM::BitcastOp>(loc, newTy, val);
195+ return val;
196+ };
197+
198+ Value a = op.getA ();
199+ Type packedAType = (op.getPa () == xevm::PrecisionType::TF32)
200+ ? cast<Type>(rewriter.getF32Type ())
201+ : rewriter.getIntegerType (bitWidthPackedA);
202+ a = castIfNeeded (a, packedAType);
203+
204+ Value b = op.getB ();
205+ Type packedBType = (op.getPb () == xevm::PrecisionType::TF32)
206+ ? cast<Type>(rewriter.getF32Type ())
207+ : rewriter.getIntegerType (bitWidthPackedB);
208+ b = castIfNeeded (b, packedBType);
209+
210+ Value c = op.getC ();
211+ VectorType cOrigTy = cast<VectorType>(c.getType ());
212+ assert (cOrigTy == op->getResultTypes ()[0 ] &&
213+ " Accumulator and result type mismatch" );
214+ // OCL builtins encode bfloat16 as int16
215+ VectorType cTy =
216+ cOrigTy.getElementType ().isBF16 ()
217+ ? VectorType::get (cOrigTy.getShape (), rewriter.getIntegerType (16 ))
218+ : cOrigTy;
219+ if (cOrigTy != cTy)
220+ c = rewriter.create <LLVM::BitcastOp>(loc, cTy, c);
221+
222+ constexpr int32_t systolicDepth{8 };
223+ std::string fnName =
224+ llvm::formatv (" intel_sub_group_{0}_{1}_matrix_mad_k{2}" ,
225+ stringifyPrecisionType (op.getPa ()).str (),
226+ stringifyPrecisionType (op.getPb ()).str (),
227+ systolicDepth * getNumOperandsPerDword (op.getPa ()))
228+ .str ();
229+ SmallVector<Type> argTypes{a.getType (), b.getType (), cTy};
230+ fnName = mangle (fnName, argTypes);
231+ SmallVector<Value> args{a, b, c};
232+
233+ auto memAttr = rewriter.getAttr <LLVM::MemoryEffectsAttr>(
234+ /* other=*/ LLVM::ModRefInfo::NoModRef,
235+ /* argMem=*/ LLVM::ModRefInfo::NoModRef,
236+ /* inaccessibleMem=*/ LLVM::ModRefInfo::NoModRef);
237+ auto funcAttrs = convergentNoUnwindWillReturnAttrs;
238+ funcAttrs.memEffectsAttr = memAttr;
239+ Value result = createDeviceFunctionCall (rewriter, fnName, cTy, argTypes,
240+ args, {}, funcAttrs)
241+ ->getResult (0 );
242+
243+ if (cOrigTy != cTy)
244+ result = rewriter.create <LLVM::BitcastOp>(loc, cOrigTy, result);
245+
246+ rewriter.replaceOp (op, result);
247+ return success ();
248+ }
249+
250+ private:
251+ static unsigned getNumOperandsPerDword (xevm::PrecisionType pTy) {
252+ switch (pTy) {
253+ case xevm::PrecisionType::TF32:
254+ return 1 ;
255+ case xevm::PrecisionType::BF16:
256+ case xevm::PrecisionType::FP16:
257+ return 2 ;
258+ case xevm::PrecisionType::U8:
259+ case xevm::PrecisionType::S8:
260+ return 4 ;
261+ default :
262+ llvm_unreachable (" unsupported xevm::PrecisionType" );
263+ }
264+ }
265+ };
266+
147267template <typename OpType>
148268class LoadStorePrefetchToOCLPattern : public OpConversionPattern <OpType> {
149269 using OpConversionPattern<OpType>::OpConversionPattern;
@@ -291,10 +411,11 @@ struct ConvertXeVMToLLVMPass
291411// ===----------------------------------------------------------------------===//
292412
293413void mlir::populateXeVMToLLVMConversionPatterns (RewritePatternSet &patterns) {
294- patterns.add <LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
295- LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
296- LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>>(
297- patterns.getContext ());
414+ patterns
415+ .add <LoadStorePrefetchToOCLPattern<BlockLoad2dOp>,
416+ LoadStorePrefetchToOCLPattern<BlockStore2dOp>,
417+ LoadStorePrefetchToOCLPattern<BlockPrefetch2dOp>, DPASToOCLPattern>(
418+ patterns.getContext ());
298419}
299420
300421// ===----------------------------------------------------------------------===//
0 commit comments