@@ -37,40 +37,38 @@ std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
3737 rewriter.create <LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
3838}
3939
40- // / Verifies if the stride matches proper tile access.
41- LogicalResult verifyStride (MemRefType mType ) {
42- if (mType .getRank () < 2 )
43- return failure ();
44- int64_t last = mType .getRank () - 1 ;
45- int64_t offset;
46- SmallVector<int64_t , 4 > strides;
47- if (failed (getStridesAndOffset (mType , strides, offset)) || strides[last] != 1 )
48- return failure ();
49- return success ();
50- }
51-
5240// / Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
5341// / shape may "envelop" the actual tile shape, and may be dynamically sized.
54- Value getStride (ConversionPatternRewriter &rewriter,
55- const LLVMTypeConverter &typeConverter, MemRefType mType ,
56- Value base, Location loc) {
57- assert (mType .getRank () >= 2 );
58- int64_t last = mType .getRank () - 1 ;
42+ // / Returns failure if proper stride couldn't be found.
43+ FailureOr<Value> getStride (ConversionPatternRewriter &rewriter,
44+ const LLVMTypeConverter &typeConverter,
45+ MemRefType mType , Value base, Location loc) {
46+ if (mType .getRank () < 2 )
47+ return failure ();
48+ int64_t preLast = mType .getRank () - 2 ;
5949 Type llvmInt64Type = IntegerType::get (&typeConverter.getContext (), 64 );
6050 unsigned width = mType .getElementType ().getIntOrFloatBitWidth ();
6151 assert (llvm::isPowerOf2_64 (width) && width >= 8 );
6252 unsigned bytes = width >> 3 ;
63- if (mType .isDynamicDim (last)) {
64- // Dynamic size needs code to compute the stride at runtime.
53+ int64_t offset;
54+ SmallVector<int64_t , 4 > strides;
55+ if (failed (getStridesAndOffset (mType , strides, offset)) ||
56+ strides.back () != 1 )
57+ return failure ();
58+ if (strides[preLast] == ShapedType::kDynamic ) {
59+ // Dynamic stride needs code to compute the stride at runtime.
6560 MemRefDescriptor memrefDescriptor (base);
6661 auto attr = rewriter.getI64IntegerAttr (bytes);
6762 Value scale = rewriter.create <LLVM::ConstantOp>(loc, llvmInt64Type, attr);
68- return rewriter.create <LLVM::MulOp>(
69- loc, llvmInt64Type, scale, memrefDescriptor.size (rewriter, loc, last));
63+ return rewriter
64+ .create <LLVM::MulOp>(loc, llvmInt64Type, scale,
65+ memrefDescriptor.stride (rewriter, loc, preLast))
66+ .getResult ();
7067 }
71- // Use direct constant for static size.
72- auto attr = rewriter.getI64IntegerAttr (mType .getDimSize (last) * bytes);
73- return rewriter.create <LLVM::ConstantOp>(loc, llvmInt64Type, attr);
68+ // Use direct constant for static stride.
69+ auto attr = rewriter.getI64IntegerAttr (strides[preLast] * bytes);
70+ return rewriter.create <LLVM::ConstantOp>(loc, llvmInt64Type, attr)
71+ .getResult ();
7472}
7573
7674struct TileZeroConversion : public ConvertOpToLLVMPattern <TileZeroOp> {
@@ -102,16 +100,16 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
102100 std::pair<Value, Value> tsz =
103101 getTileSizes (rewriter, *getTypeConverter (), vType, op.getLoc ());
104102 // Determine stride.
105- if (failed (verifyStride (mType )))
103+ auto stride = getStride (rewriter, *getTypeConverter (), mType ,
104+ adaptor.getBase (), op.getLoc ());
105+ if (failed (stride))
106106 return failure ();
107- Value stride = getStride (rewriter, *getTypeConverter (), mType ,
108- adaptor.getBase (), op.getLoc ());
109107 // Replace operation with intrinsic.
110108 Value ptr = getStridedElementPtr (op.getLoc (), mType , adaptor.getBase (),
111109 adaptor.getIndices (), rewriter);
112110 Type resType = typeConverter->convertType (vType);
113111 rewriter.replaceOpWithNewOp <amx::x86_amx_tileloadd64>(
114- op, resType, tsz.first , tsz.second , ptr, stride);
112+ op, resType, tsz.first , tsz.second , ptr, stride. value () );
115113 return success ();
116114 }
117115};
@@ -128,15 +126,15 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
128126 std::pair<Value, Value> tsz =
129127 getTileSizes (rewriter, *getTypeConverter (), vType, op.getLoc ());
130128 // Determine stride.
131- if (failed (verifyStride (mType )))
129+ auto stride = getStride (rewriter, *getTypeConverter (), mType ,
130+ adaptor.getBase (), op.getLoc ());
131+ if (failed (stride))
132132 return failure ();
133- Value stride = getStride (rewriter, *getTypeConverter (), mType ,
134- adaptor.getBase (), op.getLoc ());
135133 // Replace operation with intrinsic.
136134 Value ptr = getStridedElementPtr (op.getLoc (), mType , adaptor.getBase (),
137135 adaptor.getIndices (), rewriter);
138136 rewriter.replaceOpWithNewOp <amx::x86_amx_tilestored64>(
139- op, tsz.first , tsz.second , ptr, stride, adaptor.getVal ());
137+ op, tsz.first , tsz.second , ptr, stride. value () , adaptor.getVal ());
140138 return success ();
141139 }
142140};
0 commit comments