@@ -37,44 +37,38 @@ std::pair<Value, Value> getTileSizes(ConversionPatternRewriter &rewriter,
3737 rewriter.create <LLVM::ConstantOp>(loc, llvmInt16Type, nattr));
3838}
3939
40- // / Verifies if the stride matches proper tile access.
41- LogicalResult verifyStride (MemRefType mType ) {
42- if (mType .getRank () < 2 )
43- return failure ();
44- int64_t last = mType .getRank () - 1 ;
45- int64_t offset;
46- SmallVector<int64_t , 4 > strides;
47- if (failed (getStridesAndOffset (mType , strides, offset)) || strides[last] != 1 )
48- return failure ();
49- return success ();
50- }
51-
5240// / Maps the 2-dim memref shape to the 64-bit stride. Note that the buffer
5341// / shape may "envelop" the actual tile shape, and may be dynamically sized.
54- Value getStride (ConversionPatternRewriter &rewriter,
55- const LLVMTypeConverter &typeConverter, MemRefType mType ,
56- Value base, Location loc) {
57- assert (mType .getRank () >= 2 );
42+ // / Returns failure if proper stride couldn't be found.
43+ FailureOr<Value> getStride (ConversionPatternRewriter &rewriter,
44+ const LLVMTypeConverter &typeConverter,
45+ MemRefType mType , Value base, Location loc) {
46+ if (mType .getRank () < 2 )
47+ return failure ();
5848 int64_t preLast = mType .getRank () - 2 ;
5949 Type llvmInt64Type = IntegerType::get (&typeConverter.getContext (), 64 );
6050 unsigned width = mType .getElementType ().getIntOrFloatBitWidth ();
6151 assert (llvm::isPowerOf2_64 (width) && width >= 8 );
6252 unsigned bytes = width >> 3 ;
6353 int64_t offset;
6454 SmallVector<int64_t , 4 > strides;
65- getStridesAndOffset (mType , strides, offset);
55+ if (failed (getStridesAndOffset (mType , strides, offset)) ||
56+ strides.back () != 1 )
57+ return failure ();
6658 if (strides[preLast] == ShapedType::kDynamic ) {
6759 // Dynamic stride needs code to compute the stride at runtime.
6860 MemRefDescriptor memrefDescriptor (base);
6961 auto attr = rewriter.getI64IntegerAttr (bytes);
7062 Value scale = rewriter.create <LLVM::ConstantOp>(loc, llvmInt64Type, attr);
71- return rewriter.create <LLVM::MulOp>(
72- loc, llvmInt64Type, scale,
73- memrefDescriptor.stride (rewriter, loc, preLast));
63+ return rewriter
64+ .create <LLVM::MulOp>(loc, llvmInt64Type, scale,
65+ memrefDescriptor.stride (rewriter, loc, preLast))
66+ .getResult ();
7467 }
7568 // Use direct constant for static stride.
7669 auto attr = rewriter.getI64IntegerAttr (strides[preLast] * bytes);
77- return rewriter.create <LLVM::ConstantOp>(loc, llvmInt64Type, attr);
70+ return rewriter.create <LLVM::ConstantOp>(loc, llvmInt64Type, attr)
71+ .getResult ();
7872}
7973
8074struct TileZeroConversion : public ConvertOpToLLVMPattern <TileZeroOp> {
@@ -106,16 +100,16 @@ struct TileLoadConversion : public ConvertOpToLLVMPattern<TileLoadOp> {
106100 std::pair<Value, Value> tsz =
107101 getTileSizes (rewriter, *getTypeConverter (), vType, op.getLoc ());
108102 // Determine stride.
109- if (failed (verifyStride (mType )))
103+ auto stride = getStride (rewriter, *getTypeConverter (), mType ,
104+ adaptor.getBase (), op.getLoc ());
105+ if (failed (stride))
110106 return failure ();
111- Value stride = getStride (rewriter, *getTypeConverter (), mType ,
112- adaptor.getBase (), op.getLoc ());
113107 // Replace operation with intrinsic.
114108 Value ptr = getStridedElementPtr (op.getLoc (), mType , adaptor.getBase (),
115109 adaptor.getIndices (), rewriter);
116110 Type resType = typeConverter->convertType (vType);
117111 rewriter.replaceOpWithNewOp <amx::x86_amx_tileloadd64>(
118- op, resType, tsz.first , tsz.second , ptr, stride);
112+ op, resType, tsz.first , tsz.second , ptr, stride. value () );
119113 return success ();
120114 }
121115};
@@ -132,15 +126,15 @@ struct TileStoreConversion : public ConvertOpToLLVMPattern<TileStoreOp> {
132126 std::pair<Value, Value> tsz =
133127 getTileSizes (rewriter, *getTypeConverter (), vType, op.getLoc ());
134128 // Determine stride.
135- if (failed (verifyStride (mType )))
129+ auto stride = getStride (rewriter, *getTypeConverter (), mType ,
130+ adaptor.getBase (), op.getLoc ());
131+ if (failed (stride))
136132 return failure ();
137- Value stride = getStride (rewriter, *getTypeConverter (), mType ,
138- adaptor.getBase (), op.getLoc ());
139133 // Replace operation with intrinsic.
140134 Value ptr = getStridedElementPtr (op.getLoc (), mType , adaptor.getBase (),
141135 adaptor.getIndices (), rewriter);
142136 rewriter.replaceOpWithNewOp <amx::x86_amx_tilestored64>(
143- op, tsz.first , tsz.second , ptr, stride, adaptor.getVal ());
137+ op, tsz.first , tsz.second , ptr, stride. value () , adaptor.getVal ());
144138 return success ();
145139 }
146140};
0 commit comments