@@ -91,6 +91,13 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
9191 Type llvmI32 = this ->typeConverter ->convertType (i32 );
9292 Type llvmI16 = this ->typeConverter ->convertType (rewriter.getI16Type ());
9393
94+ auto toI32 = [&](Value val) -> Value {
95+ if (val.getType () == llvmI32)
96+ return val;
97+
98+ return rewriter.create <LLVM::TruncOp>(loc, llvmI32, val);
99+ };
100+
94101 int64_t elementByteWidth = memrefType.getElementTypeBitWidth () / 8 ;
95102 Value byteWidthConst = createI32Constant (rewriter, loc, elementByteWidth);
96103
@@ -166,22 +173,22 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
166173 Value stride = rewriter.create <LLVM::ConstantOp>(
167174 loc, llvmI16, rewriter.getI16IntegerAttr (0 ));
168175 Value numRecords;
169- if (memrefType.hasStaticShape ()) {
176+ if (memrefType.hasStaticShape () && memrefType. getLayout (). isIdentity () ) {
170177 numRecords = createI32Constant (
171178 rewriter, loc,
172179 static_cast <int32_t >(memrefType.getNumElements () * elementByteWidth));
173180 } else {
174181 Value maxIndex;
175182 for (uint32_t i = 0 , e = memrefType.getRank (); i < e; ++i) {
176- Value size = memrefDescriptor.size (rewriter, loc, i);
177- Value stride = memrefDescriptor.stride (rewriter, loc, i);
183+ Value size = toI32 ( memrefDescriptor.size (rewriter, loc, i) );
184+ Value stride = toI32 ( memrefDescriptor.stride (rewriter, loc, i) );
178185 stride = rewriter.create <LLVM::MulOp>(loc, stride, byteWidthConst);
179186 Value maxThisDim = rewriter.create <LLVM::MulOp>(loc, size, stride);
180187 maxIndex = maxIndex ? rewriter.create <LLVM::MaximumOp>(loc, maxIndex,
181188 maxThisDim)
182189 : maxThisDim;
183190 }
184- numRecords = rewriter. create <LLVM::TruncOp>(loc, llvmI32, maxIndex) ;
191+ numRecords = maxIndex;
185192 }
186193
187194 // Flag word:
@@ -218,7 +225,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
218225 Value strideOp;
219226 if (ShapedType::isDynamic (strides[i])) {
220227 strideOp = rewriter.create <LLVM::MulOp>(
221- loc, memrefDescriptor.stride (rewriter, loc, i), byteWidthConst);
228+ loc, toI32 (memrefDescriptor.stride (rewriter, loc, i)),
229+ byteWidthConst);
222230 } else {
223231 strideOp =
224232 createI32Constant (rewriter, loc, strides[i] * elementByteWidth);
@@ -240,7 +248,7 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
240248 sgprOffset = createI32Constant (rewriter, loc, 0 );
241249 if (ShapedType::isDynamic (offset))
242250 sgprOffset = rewriter.create <LLVM::AddOp>(
243- loc, memrefDescriptor.offset (rewriter, loc), sgprOffset);
251+ loc, toI32 ( memrefDescriptor.offset (rewriter, loc) ), sgprOffset);
244252 else if (offset > 0 )
245253 sgprOffset = rewriter.create <LLVM::AddOp>(
246254 loc, sgprOffset, createI32Constant (rewriter, loc, offset));
0 commit comments