@@ -30,10 +30,23 @@ namespace mlir {
3030using namespace mlir ;
3131using namespace mlir ::amdgpu;
3232
33+ // / Convert an unsigned number `val` to i32.
34+ static Value convertUnsignedToI32 (ConversionPatternRewriter &rewriter,
35+ Location loc, Value val) {
36+ IntegerType i32 = rewriter.getI32Type ();
37+ // Force check that `val` is of int type.
38+ auto valTy = cast<IntegerType>(val.getType ());
39+ if (i32 == valTy)
40+ return val;
41+ return valTy.getWidth () > 32
42+ ? Value (rewriter.create <LLVM::TruncOp>(loc, i32 , val))
43+ : Value (rewriter.create <LLVM::ZExtOp>(loc, i32 , val));
44+ }
45+
3346static Value createI32Constant (ConversionPatternRewriter &rewriter,
3447 Location loc, int32_t value) {
35- Type llvmI32 = rewriter.getI32Type ();
36- return rewriter.create <LLVM::ConstantOp>(loc, llvmI32 , value);
48+ Type i32 = rewriter.getI32Type ();
49+ return rewriter.create <LLVM::ConstantOp>(loc, i32 , value);
3750}
3851
3952static Value createI1Constant (ConversionPatternRewriter &rewriter, Location loc,
@@ -42,6 +55,27 @@ static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
4255 return rewriter.create <LLVM::ConstantOp>(loc, llvmI1, value);
4356}
4457
58+ // / Returns the linear index used to access an element in the memref.
59+ static Value getLinearIndexI32 (ConversionPatternRewriter &rewriter,
60+ Location loc, MemRefDescriptor &memRefDescriptor,
61+ ValueRange indices, ArrayRef<int64_t > strides) {
62+ IntegerType i32 = rewriter.getI32Type ();
63+ Value index;
64+ for (auto [i, increment, stride] : llvm::enumerate (indices, strides)) {
65+ if (stride != 1 ) { // Skip if stride is 1.
66+ Value strideValue =
67+ ShapedType::isDynamic (stride)
68+ ? convertUnsignedToI32 (rewriter, loc,
69+ memRefDescriptor.stride (rewriter, loc, i))
70+ : rewriter.create <LLVM::ConstantOp>(loc, i32 , stride);
71+ increment = rewriter.create <LLVM::MulOp>(loc, increment, strideValue);
72+ }
73+ index =
74+ index ? rewriter.create <LLVM::AddOp>(loc, index, increment) : increment;
75+ }
76+ return index ? index : createI32Constant (rewriter, loc, 0 );
77+ }
78+
4579namespace {
4680// Define commonly used chipsets versions for convenience.
4781constexpr Chipset kGfx908 = Chipset(9 , 0 , 8 );
@@ -88,17 +122,12 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
88122 Type llvmWantedDataType = this ->typeConverter ->convertType (wantedDataType);
89123
90124 Type i32 = rewriter.getI32Type ();
91- Type llvmI32 = this ->typeConverter ->convertType (i32 );
92- Type llvmI16 = this ->typeConverter ->convertType (rewriter.getI16Type ());
125+ Type i16 = rewriter.getI16Type ();
93126
94- auto toI32 = [&](Value val) -> Value {
95- if (val.getType () == llvmI32)
96- return val;
97-
98- return rewriter.create <LLVM::TruncOp>(loc, llvmI32, val);
99- };
100-
101- int64_t elementByteWidth = memrefType.getElementTypeBitWidth () / 8 ;
127+ // Get the type size in bytes.
128+ DataLayout dataLayout = DataLayout::closest (gpuOp);
129+ int64_t elementByteWidth =
130+ dataLayout.getTypeSizeInBits (memrefType.getElementType ()) / 8 ;
102131 Value byteWidthConst = createI32Constant (rewriter, loc, elementByteWidth);
103132
104133 // If we want to load a vector<NxT> with total size <= 32
@@ -114,7 +143,8 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
114143 }
115144 if (auto dataVector = dyn_cast<VectorType>(wantedDataType)) {
116145 uint32_t vecLen = dataVector.getNumElements ();
117- uint32_t elemBits = dataVector.getElementTypeBitWidth ();
146+ uint32_t elemBits =
147+ dataLayout.getTypeSizeInBits (dataVector.getElementType ());
118148 uint32_t totalBits = elemBits * vecLen;
119149 bool usePackedFp16 =
120150 isa_and_present<RawBufferAtomicFaddOp>(*gpuOp) && vecLen == 2 ;
@@ -167,28 +197,36 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
167197
168198 MemRefDescriptor memrefDescriptor (memref);
169199
170- Value ptr = memrefDescriptor.alignedPtr (rewriter, loc);
200+ Value ptr = memrefDescriptor.bufferPtr (
201+ rewriter, loc, *this ->getTypeConverter (), memrefType);
171202 // The stride value is always 0 for raw buffers. This also disables
172203 // swizling.
173204 Value stride = rewriter.create <LLVM::ConstantOp>(
174- loc, llvmI16, rewriter.getI16IntegerAttr (0 ));
205+ loc, i16 , rewriter.getI16IntegerAttr (0 ));
206+ // Get the number of elements.
175207 Value numRecords;
176- if (memrefType.hasStaticShape () && memrefType.getLayout ().isIdentity ()) {
177- numRecords = createI32Constant (
178- rewriter, loc,
179- static_cast <int32_t >(memrefType.getNumElements () * elementByteWidth));
208+ if (memrefType.hasStaticShape () &&
209+ !llvm::any_of (strides, ShapedType::isDynamic)) {
210+ int64_t size = memrefType.getRank () == 0 ? 1 : 0 ;
211+ ArrayRef<int64_t > shape = memrefType.getShape ();
212+ for (uint32_t i = 0 , e = memrefType.getRank (); i < e; ++i)
213+ size = std::max (shape[i] * strides[i], size);
214+ size = size * elementByteWidth;
215+ assert (size < std::numeric_limits<uint32_t >::max () &&
216+ " the memref buffer is too large" );
217+ numRecords = createI32Constant (rewriter, loc, static_cast <int32_t >(size));
180218 } else {
181219 Value maxIndex;
182220 for (uint32_t i = 0 , e = memrefType.getRank (); i < e; ++i) {
183- Value size = toI32 (memrefDescriptor.size (rewriter, loc, i));
184- Value stride = toI32 (memrefDescriptor.stride (rewriter, loc, i));
185- stride = rewriter.create <LLVM::MulOp>(loc, stride, byteWidthConst);
221+ Value size = memrefDescriptor.size (rewriter, loc, i);
222+ Value stride = memrefDescriptor.stride (rewriter, loc, i);
186223 Value maxThisDim = rewriter.create <LLVM::MulOp>(loc, size, stride);
187- maxIndex = maxIndex ? rewriter. create <LLVM::MaximumOp>(loc, maxIndex,
188- maxThisDim)
189- : maxThisDim;
224+ maxIndex =
225+ maxIndex ? rewriter. create <LLVM::UMaxOp>(loc, maxIndex, maxThisDim)
226+ : maxThisDim;
190227 }
191- numRecords = maxIndex;
228+ numRecords = rewriter.create <LLVM::MulOp>(
229+ loc, convertUnsignedToI32 (rewriter, loc, maxIndex), byteWidthConst);
192230 }
193231
194232 // Flag word:
@@ -218,40 +256,23 @@ struct RawBufferOpLowering : public ConvertOpToLLVMPattern<GpuOp> {
218256 args.push_back (resource);
219257
220258 // Indexing (voffset)
221- Value voffset = createI32Constant (rewriter, loc, 0 );
222- for (auto pair : llvm::enumerate (adaptor.getIndices ())) {
223- size_t i = pair.index ();
224- Value index = pair.value ();
225- Value strideOp;
226- if (ShapedType::isDynamic (strides[i])) {
227- strideOp = rewriter.create <LLVM::MulOp>(
228- loc, toI32 (memrefDescriptor.stride (rewriter, loc, i)),
229- byteWidthConst);
230- } else {
231- strideOp =
232- createI32Constant (rewriter, loc, strides[i] * elementByteWidth);
233- }
234- index = rewriter.create <LLVM::MulOp>(loc, index, strideOp);
235- voffset = rewriter.create <LLVM::AddOp>(loc, voffset, index);
236- }
237- if (adaptor.getIndexOffset ()) {
238- int32_t indexOffset = *gpuOp.getIndexOffset () * elementByteWidth;
239- Value extraOffsetConst = createI32Constant (rewriter, loc, indexOffset);
259+ Value voffset = getLinearIndexI32 (rewriter, loc, memrefDescriptor,
260+ adaptor.getIndices (), strides);
261+ if (std::optional<int32_t > indexOffset = adaptor.getIndexOffset ();
262+ indexOffset && *indexOffset > 0 ) {
263+ Value extraOffsetConst = createI32Constant (rewriter, loc, *indexOffset);
240264 voffset =
241265 voffset ? rewriter.create <LLVM::AddOp>(loc, voffset, extraOffsetConst)
242266 : extraOffsetConst;
243267 }
268+ voffset = rewriter.create <LLVM::MulOp>(loc, voffset, byteWidthConst);
244269 args.push_back (voffset);
245270
271+ // SGPR offset.
246272 Value sgprOffset = adaptor.getSgprOffset ();
247273 if (!sgprOffset)
248274 sgprOffset = createI32Constant (rewriter, loc, 0 );
249- if (ShapedType::isDynamic (offset))
250- sgprOffset = rewriter.create <LLVM::AddOp>(
251- loc, toI32 (memrefDescriptor.offset (rewriter, loc)), sgprOffset);
252- else if (offset > 0 )
253- sgprOffset = rewriter.create <LLVM::AddOp>(
254- loc, sgprOffset, createI32Constant (rewriter, loc, offset));
275+ sgprOffset = rewriter.create <LLVM::MulOp>(loc, sgprOffset, byteWidthConst);
255276 args.push_back (sgprOffset);
256277
257278 // bit 0: GLC = 0 (atomics drop value, less coherency)
0 commit comments