Skip to content

Commit e3dda29

Browse files
committed
Update fat_raw_buffer_cast for i64 buffer lengths
1 parent 5c1026d commit e3dda29

File tree

4 files changed

+53
-40
lines changed

4 files changed

+53
-40
lines changed

mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def AMDGPU_FatRawBufferCastOp :
235235
DeclareOpInterfaceMethods<InferTypeOpInterface>,
236236
ViewLikeOpInterface, AttrSizedOperandSegments]>,
237237
Arguments<(ins AnyMemRef:$source,
238-
Optional<I32>:$validBytes,
238+
Optional<I64>:$validBytes,
239239
Optional<I<14>>:$cacheSwizzleStride,
240240
DefaultValuedAttr<BoolAttr, "true">:$boundsCheck,
241241
UnitAttr:$resetOffset)>,
@@ -680,8 +680,8 @@ def AMDGPU_PermlaneSwapOp : AMDGPU_Op<"permlane_swap", [Pure, AllTypesMatch<["re
680680
* `$fetch_inactive`: Optional. Used to dertermine behavior of a fetch from a disabled lane.
681681
`fetch_inactive = false`: If the source lane is disabled, use `bound_ctrl` to determine the source value.
682682
`fetch_inactive = true`: If the source lane is disabled, fetch the source value anyway (ignoring `bound_ctrl`).
683-
* `$bound_ctrl`: Optional. Used to determine what a thread should do if its source operand is from
684-
a disabled lane: use the value zero, or disable the write.
683+
* `$bound_ctrl`: Optional. Used to determine what a thread should do if its source operand is from
684+
a disabled lane: use the value zero, or disable the write.
685685
`bound_ctrl = false`: Do not write when source is from a disabled lane
686686
`bound_ctrl = true`: Use zero as input if source is from a disabled lane
687687

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,25 @@ static Value createI32Constant(ConversionPatternRewriter &rewriter,
6161
return LLVM::ConstantOp::create(rewriter, loc, i32, value);
6262
}
6363

64+
/// Convert an unsigned number `val` to i64.
65+
static Value convertUnsignedToI64(ConversionPatternRewriter &rewriter,
66+
Location loc, Value val) {
67+
IntegerType i64 = rewriter.getI64Type();
68+
// Force check that `val` is of int type.
69+
auto valTy = cast<IntegerType>(val.getType());
70+
if (i64 == valTy)
71+
return val;
72+
return valTy.getWidth() > 64
73+
? Value(LLVM::TruncOp::create(rewriter, loc, i64, val))
74+
: Value(LLVM::ZExtOp::create(rewriter, loc, i64, val));
75+
}
76+
77+
static Value createI64Constant(ConversionPatternRewriter &rewriter,
78+
Location loc, int64_t value) {
79+
Type i64 = rewriter.getI64Type();
80+
return LLVM::ConstantOp::create(rewriter, loc, i64, value);
81+
}
82+
6483
static Value createI1Constant(ConversionPatternRewriter &rewriter, Location loc,
6584
bool value) {
6685
Type llvmI1 = rewriter.getI1Type();
@@ -95,17 +114,15 @@ static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
95114
MemRefType memrefType,
96115
MemRefDescriptor &memrefDescriptor,
97116
ArrayRef<int64_t> strides,
98-
uint32_t elementByteWidth) {
117+
int64_t elementByteWidth) {
99118
if (memrefType.hasStaticShape() &&
100119
!llvm::any_of(strides, ShapedType::isDynamic)) {
101120
int64_t size = memrefType.getRank() == 0 ? 1 : 0;
102121
ArrayRef<int64_t> shape = memrefType.getShape();
103122
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i)
104123
size = std::max(shape[i] * strides[i], size);
105124
size = size * elementByteWidth;
106-
assert(size < std::numeric_limits<uint32_t>::max() &&
107-
"the memref buffer is too large");
108-
return createI32Constant(rewriter, loc, static_cast<int32_t>(size));
125+
return createI64Constant(rewriter, loc, static_cast<int32_t>(size));
109126
}
110127
Value maxIndex;
111128
for (uint32_t i = 0, e = memrefType.getRank(); i < e; ++i) {
@@ -116,9 +133,9 @@ static Value getNumRecords(ConversionPatternRewriter &rewriter, Location loc,
116133
? LLVM::UMaxOp::create(rewriter, loc, maxIndex, maxThisDim)
117134
: maxThisDim;
118135
}
119-
Value maxIndexI32 = convertUnsignedToI32(rewriter, loc, maxIndex);
120-
Value byteWidthConst = createI32Constant(rewriter, loc, elementByteWidth);
121-
return LLVM::MulOp::create(rewriter, loc, maxIndexI32, byteWidthConst);
136+
Value maxIndexI64 = convertUnsignedToI64(rewriter, loc, maxIndex);
137+
Value byteWidthConst = createI64Constant(rewriter, loc, elementByteWidth);
138+
return LLVM::MulOp::create(rewriter, loc, maxIndexI64, byteWidthConst);
122139
}
123140

124141
static Value makeBufferRsrc(ConversionPatternRewriter &rewriter, Location loc,

mlir/test/Conversion/AMDGPUToROCDL/amdgpu-to-rocdl.mlir

Lines changed: 25 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ func.func @fat_raw_buffer_cast(%buf: memref<8xi32, #gpu_global_addrspace>) -> me
1717
// CHECK-DAG: %[[offset:.*]] = llvm.extractvalue %[[desc]][2]
1818
// CHECK-DAG: %[[sizes:.*]] = llvm.extractvalue %[[desc]][3]
1919
// CHECK-DAG: %[[strides:.*]] = llvm.extractvalue %[[desc]][4]
20-
// CHECK-DAG: %[[numRecords:.*]] = llvm.mlir.constant(32 : i32) : i32
20+
// CHECK-DAG: %[[numRecords:.*]] = llvm.mlir.constant(32 : i64) : i64
2121
// CHECK-DAG: %[[strideArg:.*]] = llvm.mlir.constant(0 : i16) : i16
2222
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
2323
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
@@ -38,7 +38,7 @@ func.func @fat_raw_buffer_cast_0d(%buf: memref<i32, #gpu_global_addrspace>) -> m
3838
// CHECK: %[[desc:.*]] = builtin.unrealized_conversion_cast %{{.*}} : memref<i32, 1> to !llvm.struct<(ptr<1>, ptr<1>, i64)>
3939
// CHECK-DAG: %[[base:.*]] = llvm.extractvalue %[[desc]][1]
4040
// CHECK-DAG: %[[offset:.*]] = llvm.extractvalue %[[desc]][2]
41-
// CHECK-DAG: %[[numRecords:.*]] = llvm.mlir.constant(4 : i32) : i32
41+
// CHECK-DAG: %[[numRecords:.*]] = llvm.mlir.constant(4 : i64) : i64
4242
// CHECK-DAG: %[[strideArg:.*]] = llvm.mlir.constant(0 : i16) : i16
4343
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
4444
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
@@ -57,9 +57,8 @@ func.func @fat_raw_buffer_cast_dyn_size_offset(%buf: memref<?xi32, strided<[1],
5757
// CHECK: %[[size0:.*]] = llvm.extractvalue %{{.*}}[3, 0]
5858
// CHECK: %[[stride0:.*]] = llvm.extractvalue %{{.*}}[4, 0]
5959
// CHECK: %[[maxVals:.*]] = llvm.mul %[[size0]], %[[stride0]]
60-
// CHECK: %[[maxValsI32:.*]] = llvm.trunc %[[maxVals]] : i64 to i32
61-
// CHECK: %[[byteSize:.*]] = llvm.mlir.constant(4 : i32) : i32
62-
// CHECK: %[[numRecords:.*]] = llvm.mul %[[maxValsI32]], %[[byteSize]]
60+
// CHECK: %[[byteSize:.*]] = llvm.mlir.constant(4 : i64) : i64
61+
// CHECK: %[[numRecords:.*]] = llvm.mul %[[maxVals]], %[[byteSize]]
6362
// CHECK: %[[offset:.*]] = llvm.extractvalue %{{.*}}[2]
6463
// CHECK: rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %{{.*}}
6564
// CHECK: llvm.insertvalue %[[offset]], %{{.*}}[2]
@@ -83,10 +82,10 @@ func.func @fat_raw_buffer_cast_reset_offset(%buf: memref<?xi32, strided<[1], off
8382

8483
// CHECK-LABEL: func @fat_raw_buffer_cast_valid_bytes
8584
func.func @fat_raw_buffer_cast_valid_bytes(%buf: memref<8xi32, #gpu_global_addrspace>) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
86-
// CHECK: %[[numRecords:.*]] = arith.constant -1 : i32
85+
// CHECK: %[[numRecords:.*]] = arith.constant -1 : i64
8786
// CHECK: rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %{{.*}}
88-
%cu32_max = arith.constant 0xffffffff : i32
89-
%ret = amdgpu.fat_raw_buffer_cast %buf validBytes(%cu32_max) : memref<8xi32, #gpu_global_addrspace> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
87+
%cu64_max = arith.constant -1 : i64
88+
%ret = amdgpu.fat_raw_buffer_cast %buf validBytes(%cu64_max) : memref<8xi32, #gpu_global_addrspace> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
9089
return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
9190
}
9291

@@ -115,9 +114,7 @@ func.func @fat_raw_buffer_cast_cache_swizzle(%buf: memref<64x64xi32, #gpu_global
115114

116115
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_scalar_i32
117116
func.func @gpu_gcn_raw_buffer_load_scalar_i32(%buf: memref<i32>) -> i32 {
118-
// Extra constant for byte width
119-
// CHECK: llvm.mlir.constant(4 : i32)
120-
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(4 : i32)
117+
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(4 : i64)
121118
// CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16)
122119
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
123120
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
@@ -130,7 +127,7 @@ func.func @gpu_gcn_raw_buffer_load_scalar_i32(%buf: memref<i32>) -> i32 {
130127

131128
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i32
132129
func.func @gpu_gcn_raw_buffer_load_i32(%buf: memref<64xi32>, %idx: i32) -> i32 {
133-
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)
130+
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i64)
134131
// CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16)
135132
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
136133
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
@@ -155,11 +152,10 @@ func.func @gpu_gcn_raw_buffer_load_i32_strided(%buf: memref<16x16xi32, strided<[
155152
// CHECK: %[[stride_j:.*]] = llvm.extractvalue %[[descriptor]][4, 1] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
156153
// CHECK: %[[ext_j:.*]] = llvm.mul %[[sz_j]], %[[stride_j]] : i64
157154
// CHECK: %[[num_records:.*]] = llvm.intr.umax(%[[ext_i]], %[[ext_j]]) : (i64, i64) -> i64
158-
// CHECK: %[[num_rec_i32:.*]] = llvm.trunc %[[num_records]] : i64 to i32
159-
// CHECK: %[[elem_size_2:.*]] = llvm.mlir.constant(4 : i32) : i32
160-
// CHECK: %[[num_rec_bytes_i32:.*]] = llvm.mul %[[num_rec_i32]], %[[elem_size_2]] : i32
155+
// CHECK: %[[elem_size_2:.*]] = llvm.mlir.constant(4 : i64) : i64
156+
// CHECK: %[[num_rec_bytes:.*]] = llvm.mul %[[num_records]], %[[elem_size_2]] : i64
161157
// CHECK: %[[stride:.*]] = llvm.mlir.constant(0 : i16) : i16
162-
// CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[ptr]], %[[stride]], %[[num_rec_bytes_i32]], %{{.*}} : !llvm.ptr to <8>
158+
// CHECK: %[[rsrc:.*]] = rocdl.make.buffer.rsrc %[[ptr]], %[[stride]], %[[num_rec_bytes]], %{{.*}} : !llvm.ptr to <8>
163159
// CHECK: %[[stride_i_1:.*]] = llvm.extractvalue %[[descriptor]][4, 0] : !llvm.struct<(ptr, ptr, i64, array<2 x i64>, array<2 x i64>)>
164160
// CHECK: %[[stride_i_i32:.*]] = llvm.trunc %[[stride_i_1]] : i64 to i32
165161
// CHECK: %[[t_0:.*]] = llvm.mul %{{.*}}, %[[stride_i_i32]] : i32
@@ -207,7 +203,7 @@ func.func @gpu_gcn_raw_buffer_load_2xi32(%buf: memref<64xi32>, %idx: i32) -> vec
207203

208204
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_i8
209205
func.func @gpu_gcn_raw_buffer_load_i8(%buf: memref<64xi8>, %idx: i32) -> i8 {
210-
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i32)
206+
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i64)
211207
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %{{.*}}
212208
// CHECK: %[[ret:.*]] = rocdl.raw.ptr.buffer.load %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i8
213209
// CHECK: return %[[ret]]
@@ -217,7 +213,7 @@ func.func @gpu_gcn_raw_buffer_load_i8(%buf: memref<64xi8>, %idx: i32) -> i8 {
217213

218214
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_2xi8
219215
func.func @gpu_gcn_raw_buffer_load_2xi8(%buf: memref<64xi8>, %idx: i32) -> vector<2xi8> {
220-
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i32)
216+
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i64)
221217
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %{{.*}}
222218
// CHECK: %[[loaded:.*]] = rocdl.raw.ptr.buffer.load %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i16
223219
// CHECK: %[[ret:.*]] = llvm.bitcast %[[loaded]] : i16 to vector<2xi8>
@@ -237,7 +233,7 @@ func.func @gpu_gcn_raw_buffer_load_16xi8(%buf: memref<64xi8>, %idx: i32) -> vect
237233

238234
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_f8E5M2FNUZ
239235
func.func @gpu_gcn_raw_buffer_load_f8E5M2FNUZ(%buf: memref<64xf8E5M2FNUZ>, %idx: i32) -> f8E5M2FNUZ {
240-
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i32)
236+
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i64)
241237
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %{{.*}}
242238
// CHECK: %[[loaded:.*]] = rocdl.raw.ptr.buffer.load %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i8
243239
// CHECK: %[[ret:.*]] = builtin.unrealized_conversion_cast %[[loaded]] : i8 to f8E5M2FNUZ
@@ -248,7 +244,7 @@ func.func @gpu_gcn_raw_buffer_load_f8E5M2FNUZ(%buf: memref<64xf8E5M2FNUZ>, %idx:
248244

249245
// CHECK-LABEL: func @gpu_gcn_raw_buffer_load_4xf8E4M3FNUZ
250246
func.func @gpu_gcn_raw_buffer_load_4xf8E4M3FNUZ(%buf: memref<64xf8E4M3FNUZ>, %idx: i32) -> vector<4xf8E4M3FNUZ> {
251-
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i32)
247+
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(64 : i64)
252248
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %{{.*}}
253249
// CHECK: %[[loaded:.*]] = rocdl.raw.ptr.buffer.load %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i32
254250
// CHECK: %[[cast:.*]] = llvm.bitcast %[[loaded]] : i32 to vector<4xi8>
@@ -271,7 +267,7 @@ func.func @gpu_gcn_raw_buffer_store_scalar_i32(%value: i32, %buf: memref<i32>) {
271267

272268
// CHECK-LABEL: func @gpu_gcn_raw_buffer_store_i32
273269
func.func @gpu_gcn_raw_buffer_store_i32(%value: i32, %buf: memref<64xi32>, %idx: i32) {
274-
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)
270+
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i64)
275271
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
276272
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
277273
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]
@@ -307,7 +303,7 @@ func.func @gpu_gcn_raw_buffer_store_16xi8(%value: vector<16xi8>, %buf: memref<64
307303
// And more so for atomic add
308304
// CHECK-LABEL: func @gpu_gcn_raw_buffer_atomic_fadd_f32
309305
func.func @gpu_gcn_raw_buffer_atomic_fadd_f32(%value: f32, %buf: memref<64xf32>, %idx: i32) {
310-
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)
306+
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i64)
311307
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
312308
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
313309
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]
@@ -318,7 +314,7 @@ func.func @gpu_gcn_raw_buffer_atomic_fadd_f32(%value: f32, %buf: memref<64xf32>,
318314

319315
// CHECK-LABEL: func @gpu_gcn_raw_buffer_atomic_fadd_v2f16
320316
func.func @gpu_gcn_raw_buffer_atomic_fadd_v2f16(%value: vector<2xf16>, %buf: memref<64xf16>, %idx: i32) {
321-
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(128 : i32)
317+
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(128 : i64)
322318
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
323319
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
324320
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]
@@ -329,7 +325,7 @@ func.func @gpu_gcn_raw_buffer_atomic_fadd_v2f16(%value: vector<2xf16>, %buf: mem
329325

330326
// CHECK-LABEL: func @gpu_gcn_raw_buffer_atomic_fadd_v2bf16
331327
func.func @gpu_gcn_raw_buffer_atomic_fadd_v2bf16(%value: vector<2xbf16>, %buf: memref<64xbf16>, %idx: i32) {
332-
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(128 : i32)
328+
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(128 : i64)
333329
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
334330
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
335331
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]
@@ -340,7 +336,7 @@ func.func @gpu_gcn_raw_buffer_atomic_fadd_v2bf16(%value: vector<2xbf16>, %buf: m
340336

341337
// CHECK-LABEL: func @gpu_gcn_raw_buffer_atomic_fmax_f32
342338
func.func @gpu_gcn_raw_buffer_atomic_fmax_f32(%value: f32, %buf: memref<64xf32>, %idx: i32) {
343-
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)
339+
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i64)
344340
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
345341
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
346342
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]
@@ -351,7 +347,7 @@ func.func @gpu_gcn_raw_buffer_atomic_fmax_f32(%value: f32, %buf: memref<64xf32>,
351347

352348
// CHECK-LABEL: func @gpu_gcn_raw_buffer_atomic_smax_i32
353349
func.func @gpu_gcn_raw_buffer_atomic_smax_i32(%value: i32, %buf: memref<64xi32>, %idx: i32) {
354-
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)
350+
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i64)
355351
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
356352
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
357353
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]
@@ -362,7 +358,7 @@ func.func @gpu_gcn_raw_buffer_atomic_smax_i32(%value: i32, %buf: memref<64xi32>,
362358

363359
// CHECK-LABEL: func @gpu_gcn_raw_buffer_atomic_umin_i32
364360
func.func @gpu_gcn_raw_buffer_atomic_umin_i32(%value: i32, %buf: memref<64xi32>, %idx: i32) {
365-
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)
361+
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i64)
366362
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
367363
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
368364
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]
@@ -376,7 +372,7 @@ func.func @gpu_gcn_raw_buffer_atomic_umin_i32(%value: i32, %buf: memref<64xi32>,
376372
func.func @amdgpu_raw_buffer_atomic_cmpswap_f32(%src : f32, %cmp : f32, %buf : memref<64xf32>, %idx: i32) -> f32 {
377373
// CHECK: %[[srcCast:.*]] = llvm.bitcast %[[src]] : f32 to i32
378374
// CHECK: %[[cmpCast:.*]] = llvm.bitcast %[[cmp]] : f32 to i32
379-
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i32)
375+
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(256 : i64)
380376
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
381377
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
382378
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]
@@ -390,7 +386,7 @@ func.func @amdgpu_raw_buffer_atomic_cmpswap_f32(%src : f32, %cmp : f32, %buf : m
390386
// CHECK-LABEL: func @amdgpu_raw_buffer_atomic_cmpswap_i64
391387
// CHECK-SAME: (%[[src:.*]]: i64, %[[cmp:.*]]: i64, {{.*}})
392388
func.func @amdgpu_raw_buffer_atomic_cmpswap_i64(%src : i64, %cmp : i64, %buf : memref<64xi64>, %idx: i32) -> i64 {
393-
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(512 : i32)
389+
// CHECK: %[[numRecords:.*]] = llvm.mlir.constant(512 : i64)
394390
// GFX9: %[[flags:.*]] = llvm.mlir.constant(159744 : i32)
395391
// RDNA: %[[flags:.*]] = llvm.mlir.constant(822243328 : i32)
396392
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %[[numRecords]], %[[flags]]

mlir/test/Dialect/AMDGPU/ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -360,7 +360,7 @@ func.func @fat_raw_buffer_cast_easy(%m: memref<8xi32>) -> memref<8xi32, #amdgpu.
360360
// CHECK-SAME: cacheSwizzleStride(%{{[^)]*}})
361361
// CHECK-SAME: boundsCheck(false)
362362
// CHECK-SAME: resetOffset
363-
func.func @fat_raw_buffer_cast(%m: memref<8xi32, strided<[1], offset: ?>>, %validBytes: i32, %cacheSwizzle: i14) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
363+
func.func @fat_raw_buffer_cast(%m: memref<8xi32, strided<[1], offset: ?>>, %validBytes: i64, %cacheSwizzle: i14) -> memref<8xi32, #amdgpu.address_space<fat_raw_buffer>> {
364364
%ret = amdgpu.fat_raw_buffer_cast %m validBytes(%validBytes) cacheSwizzleStride(%cacheSwizzle) boundsCheck(false) resetOffset
365365
: memref<8xi32, strided<[1], offset: ?>> to memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>
366366
func.return %ret : memref<8xi32, #amdgpu.address_space<fat_raw_buffer>>

0 commit comments

Comments
 (0)