@@ -39,12 +39,13 @@ using namespace mlir;
3939
4040namespace {
4141
42- enum class NdDescI32Layout : uint32_t {
43- BasePtr = 0 ,
44- BaseShapeW = 2 ,
45- BaseShapeH = 3 ,
46- TensorOffsetW = 4 ,
47- TensorOffsetH = 5
42+ // Offsets to individual fields of the 8xi32 layout nd tensor descriptor.
43+ enum class NdTdescOffset : uint32_t {
44+ BasePtr = 0 , // Base pointer (i64)
45+ BaseShapeW = 2 , // Base shape width (i32)
46+ BaseShapeH = 3 , // Base shape height (i32)
47+ TensorOffsetW = 4 , // Tensor offset W (i32)
48+ TensorOffsetH = 5 // Tensor offset H (i32)
4849};
4950
5051static int32_t getNumericXeVMAddrSpace (xegpu::MemorySpace xeGpuMemspace) {
@@ -57,6 +58,7 @@ static int32_t getNumericXeVMAddrSpace(xegpu::MemorySpace xeGpuMemspace) {
5758 llvm_unreachable (" Unknown XeGPU memory space." );
5859}
5960
61+ // Get same bitwidth flat vector type of new element type.
6062static VectorType encodeVectorTypeTo (VectorType currentVecType,
6163 Type toElemType) {
6264 auto elemType = currentVecType.getElementType ();
@@ -221,20 +223,20 @@ class CreateNdDescToXeVMPattern
221223 vector::BitCastOp::create (rewriter, loc, payloadI64Ty, payload);
222224 payLoadAsI64 =
223225 vector::InsertOp::create (rewriter, loc, baseAddr, payLoadAsI64,
224- static_cast <int >(NdDescI32Layout ::BasePtr));
226+ static_cast <int >(NdTdescOffset ::BasePtr));
225227 payload = vector::BitCastOp::create (rewriter, loc, payloadTy, payLoadAsI64);
226228 payload =
227229 vector::InsertOp::create (rewriter, loc, baseShapeW, payload,
228- static_cast <int >(NdDescI32Layout ::BaseShapeW));
230+ static_cast <int >(NdTdescOffset ::BaseShapeW));
229231 payload =
230232 vector::InsertOp::create (rewriter, loc, baseShapeH, payload,
231- static_cast <int >(NdDescI32Layout ::BaseShapeH));
233+ static_cast <int >(NdTdescOffset ::BaseShapeH));
232234 payload = vector::InsertOp::create (
233235 rewriter, loc, offsetW, payload,
234- static_cast <int >(NdDescI32Layout ::TensorOffsetW));
236+ static_cast <int >(NdTdescOffset ::TensorOffsetW));
235237 payload = vector::InsertOp::create (
236238 rewriter, loc, offsetH, payload,
237- static_cast <int >(NdDescI32Layout ::TensorOffsetH));
239+ static_cast <int >(NdTdescOffset ::TensorOffsetH));
238240 rewriter.replaceOp (op, payload);
239241 return success ();
240242 }
@@ -249,6 +251,7 @@ class UpdateNdOffsetToXeVMPattern
249251 ConversionPatternRewriter &rewriter) const override {
250252 auto loc = op.getLoc ();
251253 auto mixedOffsets = op.getMixedOffsets ();
254+ // Only 2D offsets are supported for now.
252255 if (mixedOffsets.size () != 2 )
253256 return rewriter.notifyMatchFailure (op, " Expected 2D offsets." );
254257 auto tdesc = adaptor.getTensorDesc ();
@@ -264,9 +267,9 @@ class UpdateNdOffsetToXeVMPattern
264267 return vector::InsertOp::create (rewriter, loc, newOffset, tdesc,
265268 payloadPos);
266269 };
267- auto val =
268- updateOffset (0 , static_cast <int >(NdDescI32Layout ::TensorOffsetH));
269- val = updateOffset (1 , static_cast <int >(NdDescI32Layout ::TensorOffsetW));
270+ // Update offsets in the payload.
271+ auto val = updateOffset (0 , static_cast <int >(NdTdescOffset ::TensorOffsetH));
272+ val = updateOffset (1 , static_cast <int >(NdTdescOffset ::TensorOffsetW));
270273 rewriter.replaceOp (op, val);
271274 return success ();
272275 }
@@ -293,86 +296,74 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
293296 VectorType payloadI64Ty = VectorType::get (4 , rewriter.getI64Type ());
294297 Value payLoadAsI64 =
295298 vector::BitCastOp::create (rewriter, loc, payloadI64Ty, tdesc);
296- Value basePtr =
297- vector::ExtractOp::create (rewriter, loc, payLoadAsI64,
298- static_cast <int >(NdDescI32Layout::BasePtr));
299+ Value basePtr = vector::ExtractOp::create (
300+ rewriter, loc, payLoadAsI64, static_cast <int >(NdTdescOffset::BasePtr));
299301 Value baseShapeW = vector::ExtractOp::create (
300- rewriter, loc, tdesc, static_cast <int >(NdDescI32Layout ::BaseShapeW));
302+ rewriter, loc, tdesc, static_cast <int >(NdTdescOffset ::BaseShapeW));
301303 Value baseShapeH = vector::ExtractOp::create (
302- rewriter, loc, tdesc, static_cast <int >(NdDescI32Layout::BaseShapeH));
303- // Offsets can come from three sources:
304- // 1. Constant offsets, which are provided by the op.
305- // 2. Offsets as operands, which are provided by the op.
306- // 3. Offsets extracted from the tensor descriptor.
304+ rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::BaseShapeH));
305+ // Offsets provided in two ways:
306+ // 1. Offsets are extracted from the tensor descriptor.
307+ // 2. (Mixed) offsets which are provided by the op.
307308 Value offsetW;
308309 Value offsetH;
309- auto cOffsets = op.getConstOffsets ();
310- auto offsets = op.getOffsets ();
311- if (cOffsets) {
312- offsetW = arith::ConstantIntOp::create (
313- rewriter, loc, rewriter.getI32Type (), (*cOffsets)[0 ]);
314- offsetH = arith::ConstantIntOp::create (
315- rewriter, loc, rewriter.getI32Type (), (*cOffsets)[1 ]);
316- } else if (offsets.size () != 0 ) {
317- // offsets are provided as operands
318- if (offsets[0 ].getType () != rewriter.getI32Type ()) {
319- if (offsets[0 ].getType () != rewriter.getIndexType ()) {
320- return rewriter.notifyMatchFailure (
321- op, " Expected offsets to be of type i32 or index." );
322- }
323- offsetW = arith::IndexCastUIOp::create (
324- rewriter, loc, rewriter.getI32Type (), offsets[0 ]);
325- } else {
326- offsetW = offsets[0 ];
327- }
328- if (offsets[1 ].getType () != rewriter.getI32Type ()) {
329- if (offsets[1 ].getType () != rewriter.getIndexType ()) {
330- return rewriter.notifyMatchFailure (
331- op, " Expected offsets to be of type i32 or index." );
332- }
333- offsetH = arith::IndexCastUIOp::create (
334- rewriter, loc, rewriter.getI32Type (), offsets[1 ]);
335- } else {
336- offsetH = offsets[1 ];
337- }
310+ auto mixedOffsets = op.getMixedOffsets ();
311+ int64_t opOffsetsSize = mixedOffsets.size ();
312+ if (opOffsetsSize != 0 && opOffsetsSize != 2 ) {
313+ return rewriter.notifyMatchFailure (op,
314+ " Expected 2D offsets or no offsets." );
315+ }
316+ if (opOffsetsSize) {
317+ // If mixed offsets are provided by the op convert them to i32.
318+ offsetW = getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[1 ]);
319+ offsetW = getValueOrCreateCastToIndexLike (rewriter, loc,
320+ rewriter.getI32Type (), offsetW);
321+ offsetH = getValueOrCreateConstantIntOp (rewriter, loc, mixedOffsets[0 ]);
322+ offsetH = getValueOrCreateCastToIndexLike (rewriter, loc,
323+ rewriter.getI32Type (), offsetH);
338324 } else {
339325 // If offsets are not available, we need to extract them from the tensor
340326 // descriptor.
341327 offsetW = vector::ExtractOp::create (
342- rewriter, loc, tdesc,
343- static_cast <int >(NdDescI32Layout::TensorOffsetW));
328+ rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::TensorOffsetW));
344329 offsetH = vector::ExtractOp::create (
345- rewriter, loc, tdesc,
346- static_cast <int >(NdDescI32Layout::TensorOffsetH));
330+ rewriter, loc, tdesc, static_cast <int >(NdTdescOffset::TensorOffsetH));
347331 }
332+ // Get address space from tensor descriptor memory space.
348333 auto ptrTypeLLVM = LLVM::LLVMPointerType::get (
349334 ctxt, getNumericXeVMAddrSpace (tdescTy.getMemorySpace ()));
335+ // Convert base pointer (i64) to LLVM pointer type.
350336 Value basePtrLLVM =
351337 LLVM::IntToPtrOp::create (rewriter, loc, ptrTypeLLVM, basePtr);
338+ // Compute element byte size and surface width in bytes.
352339 auto elemType = tdescTy.getElementType ();
353340 auto elemBitSize = elemType.getIntOrFloatBitWidth ();
354341 Value elemByteSize = arith::ConstantIntOp::create (
355342 rewriter, loc, rewriter.getI32Type (), elemBitSize / 8 );
356343 Value surfaceW =
357344 arith::MulIOp::create (rewriter, loc, baseShapeW, elemByteSize);
358345
346+ // Get tile sizes and vblocks from the tensor descriptor type.
359347 auto tileW = tdescTy.getDimSize (1 );
360348 auto tileH = tdescTy.getDimSize (0 );
361349 int32_t vblocks = tdescTy.getArrayLength ();
362350 if constexpr (std::is_same_v<OpType, xegpu::StoreNdOp>) {
363- VectorType srcVecTy = cast<VectorType>(op.getValue ().getType ());
351+ VectorType srcVecTy = dyn_cast<VectorType>(adaptor.getValue ().getType ());
352+ if (!srcVecTy) {
353+ return rewriter.notifyMatchFailure (
354+ op, " Expected store value to be a vector type." );
355+ }
364356 auto storeCacheControl =
365357 translateStoreXeGPUCacheHint (op.getL1Hint (), op.getL3Hint ());
366- VectorType srcFlatVecTy =
367- VectorType::get (srcVecTy.getNumElements (), srcVecTy.getElementType ());
368- Value srcFlatVec = op.getValue ();
369- srcFlatVecTy = encodeVectorTypeTo (srcFlatVecTy,
370- rewriter.getIntegerType (elemBitSize));
371- srcFlatVec =
372- vector::BitCastOp::create (rewriter, loc, srcFlatVecTy, srcFlatVec);
358+ Value src = adaptor.getValue ();
359+ // Get flat vector type of integer type with matching element bit size.
360+ VectorType newSrcVecTy =
361+ encodeVectorTypeTo (srcVecTy, rewriter.getIntegerType (elemBitSize));
362+ if (srcVecTy != newSrcVecTy)
363+ src = vector::BitCastOp::create (rewriter, loc, newSrcVecTy, src);
373364 xevm::BlockStore2dOp::create (
374365 rewriter, loc, basePtrLLVM, surfaceW, baseShapeH, surfaceW, offsetW,
375- offsetH, elemBitSize, tileW, tileH, srcFlatVec ,
366+ offsetH, elemBitSize, tileW, tileH, src ,
376367 xevm::StoreCacheControlAttr::get (ctxt, storeCacheControl));
377368 rewriter.eraseOp (op);
378369 } else {
@@ -412,15 +403,14 @@ class LoadStorePrefetchNdToXeVMPattern : public OpConversionPattern<OpType> {
412403
413404// Add a builder that creates
414405// offset * elemByteSize + baseAddr
415- static auto addOffset = [](ConversionPatternRewriter &rewriter, Location loc,
416- Value baseAddr, Value offset,
417- int64_t elemByteSize) -> Value {
406+ static Value addOffset (ConversionPatternRewriter &rewriter, Location loc,
407+ Value baseAddr, Value offset, int64_t elemByteSize) {
418408 Value byteSize = arith::ConstantIntOp::create (
419409 rewriter, loc, rewriter.getI64Type (), elemByteSize);
420410 Value byteOffset = arith::MulIOp::create (rewriter, loc, offset, byteSize);
421411 Value newAddr = arith::AddIOp::create (rewriter, loc, baseAddr, byteOffset);
422412 return newAddr;
423- };
413+ }
424414
425415class CreateDescToXeVMPattern
426416 : public OpConversionPattern<xegpu::CreateDescOp> {
@@ -908,6 +898,10 @@ struct ConvertXeGPUToXeVMPass
908898 return IntegerType::get (&getContext (), 64 );
909899 });
910900
901+ // LLVM type converter puts unrealized casts for the following cases:
902+ // add materialization casts to handle them.
903+
904+ // Materialization to convert memref to i64
911905 auto memrefMaterializationCast = [](OpBuilder &builder, Type type,
912906 ValueRange inputs,
913907 Location loc) -> Value {
@@ -924,6 +918,7 @@ struct ConvertXeGPUToXeVMPass
924918 return {};
925919 };
926920
921+ // Materialization to convert ui64 to i64
927922 auto ui64MaterializationCast = [](OpBuilder &builder, Type type,
928923 ValueRange inputs,
929924 Location loc) -> Value {
@@ -940,6 +935,7 @@ struct ConvertXeGPUToXeVMPass
940935 return {};
941936 };
942937
938+ // Materialization to convert ui32 to i32
943939 auto ui32MaterializationCast = [](OpBuilder &builder, Type type,
944940 ValueRange inputs,
945941 Location loc) -> Value {
@@ -956,9 +952,13 @@ struct ConvertXeGPUToXeVMPass
956952 return {};
957953 };
958954
959- auto vector1DMaterializationCast = [](OpBuilder &builder, Type type,
960- ValueRange inputs,
961- Location loc) -> Value {
955+ // Materialization to convert
956+ // - single element 1D vector to scalar
957+ // - bitcast vector of same rank
958+ // - shape vector of different rank but same element type
959+ auto vectorMaterializationCast = [](OpBuilder &builder, Type type,
960+ ValueRange inputs,
961+ Location loc) -> Value {
962962 if (inputs.size () != 1 )
963963 return {};
964964 auto input = inputs.front ();
@@ -971,18 +971,30 @@ struct ConvertXeGPUToXeVMPass
971971 cast = arith::IndexCastUIOp::create (builder, loc, type, cast)
972972 .getResult ();
973973 return cast;
974+ } else if (auto targetVecTy = dyn_cast<VectorType>(type)) {
975+ // If the target type is a vector of same rank,
976+ // bitcast to the target type.
977+ if (targetVecTy.getRank () == vecTy.getRank ())
978+ return vector::BitCastOp::create (builder, loc, targetVecTy, input)
979+ .getResult ();
980+ else if (targetVecTy.getElementType () == vecTy.getElementType ()) {
981+ // If the target type is a vector of different rank but same element
982+ // type, reshape to the target type.
983+ return vector::ShapeCastOp::create (builder, loc, targetVecTy, input)
984+ .getResult ();
985+ }
974986 }
975987 }
976988 return {};
977989 };
978990 typeConverter.addSourceMaterialization (memrefMaterializationCast);
979991 typeConverter.addSourceMaterialization (ui64MaterializationCast);
980992 typeConverter.addSourceMaterialization (ui32MaterializationCast);
981- typeConverter.addSourceMaterialization (vector1DMaterializationCast );
993+ typeConverter.addSourceMaterialization (vectorMaterializationCast );
982994 typeConverter.addTargetMaterialization (memrefMaterializationCast);
983995 typeConverter.addTargetMaterialization (ui32MaterializationCast);
984996 typeConverter.addTargetMaterialization (ui64MaterializationCast);
985- typeConverter.addTargetMaterialization (vector1DMaterializationCast );
997+ typeConverter.addTargetMaterialization (vectorMaterializationCast );
986998 ConversionTarget target (getContext ());
987999 target.addLegalDialect <xevm::XeVMDialect, LLVM::LLVMDialect,
9881000 vector::VectorDialect, arith::ArithDialect,
0 commit comments