Skip to content

Commit f906b9b

Browse files
authored
[AMD] Make buffer op stride optional (#5908)
We may not able to deduce it; then we cannot set the swizzling factor. Using zero is confusing for such cases as it can mean broadcasting; we want to explicitly use nullptr. Also cleaned up some style nits along the way.
1 parent 56a9adf commit f906b9b

File tree

4 files changed

+84
-70
lines changed

4 files changed

+84
-70
lines changed

test/Conversion/amd/buffer_load_store.mlir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
99
// CHECK: %[[offset:.*]] = llvm.select %[[c_mask]]
1010
// CHECK: %[[aux:.*]] = llvm.mlir.constant(3 : i32) : i32
1111
// CHECK: rocdl.raw.ptr.buffer.load {{.*}}, %[[offset]], {{.*}}, %[[aux]]
12-
%c0 = arith.constant 0 : i32
13-
%ret = amdgpu.buffer_load %arg0[%offset] cacheModifier = cs stride = %c0 : tensor<128xf32, #blocked0>
12+
%ret = amdgpu.buffer_load %arg0[%offset] cacheModifier = cs : tensor<128xf32, #blocked0>
1413
tt.return
1514
}
1615
}

test/TritonGPU/amd/amd-convert-buffer-ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
566566
%5 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32
567567
%6 = tt.splat %5 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
568568
%7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
569-
// CHECK: %[[loaded:.*]] = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %arg1, %[[scalar_ptr]][%[[offset]]] stride = %c0_i32
569+
// CHECK: %[[loaded:.*]] = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %arg1, %[[scalar_ptr]][%[[offset]]]
570570
%8 = tt.atomic_rmw fadd, acq_rel, gpu, %7, %arg1 : (tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked>
571571
tt.return %8 : tensor<1024xf32, #blocked>
572572
}

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 76 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ include "TritonAMDGPUAttrDefs.td"
4141

4242

4343
class TT_AMDGPU_Op<string mnemonic, list<Trait> traits = []> :
44-
Op<TritonAMDGPU_Dialect, mnemonic, !listconcat(traits, [])> {
45-
}
44+
Op<TritonAMDGPU_Dialect, mnemonic, !listconcat(traits, [])>;
4645

4746
//
4847
// Interfaces
@@ -53,8 +52,7 @@ def GlobalMemory : Resource<"::mlir::triton::GlobalMemory">;
5352
// ExtractSliceOp
5453
//===----------------------------------------------------------------------===//
5554

56-
def ExtractSliceOp
57-
: TT_AMDGPU_Op<"extract_slice", [Pure]> {
55+
def ExtractSliceOp : TT_AMDGPU_Op<"extract_slice", [Pure]> {
5856
let summary = "extract slice operation";
5957
let description = [{
6058
The "extract_slice" operation enables extracting a slice of a tensor in
@@ -92,8 +90,10 @@ def ExtractSliceOp
9290
size of the slice is determined by the result type.
9391
}];
9492

95-
let arguments = (ins AnyRankedTensor:$source,
96-
DenseI64ArrayAttr:$static_offsets);
93+
let arguments = (ins
94+
AnyRankedTensor:$source,
95+
DenseI64ArrayAttr:$static_offsets
96+
);
9797
let results = (outs AnyRankedTensor:$result);
9898

9999
let builders = [
@@ -117,6 +117,10 @@ def ExtractSliceOp
117117
let hasVerifier = 1;
118118
}
119119

120+
//===----------------------------------------------------------------------===//
121+
// InstructionSchedHint
122+
//===----------------------------------------------------------------------===//
123+
120124
def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
121125
let summary = "A placeholder op for instruction scheduling hints within a basic block";
122126
let description = [{
@@ -156,8 +160,11 @@ def InstructionSchedHint : TT_AMDGPU_Op<"instruction_sched_hint", []> {
156160
let assemblyFormat = [{ attr-dict }];
157161
}
158162

159-
def CondBarrierOp : TT_AMDGPU_Op<"cond_barrier">,
160-
Arguments<(ins I1:$pred)> {
163+
//===----------------------------------------------------------------------===//
164+
// CondBarrierOp
165+
//===----------------------------------------------------------------------===//
166+
167+
def CondBarrierOp : TT_AMDGPU_Op<"cond_barrier"> {
161168
let summary = "Conditionally set barriers to synchronize partial threads in a block";
162169

163170
let description = [{
@@ -170,22 +177,25 @@ def CondBarrierOp : TT_AMDGPU_Op<"cond_barrier">,
170177
NB. This doesn't set any memory fence.
171178
}];
172179

180+
let arguments = (ins I1:$pred);
181+
173182
let assemblyFormat = "$pred attr-dict";
174183
}
175184

176-
//
177-
// AMD Buffer operations.
178-
//
185+
//===----------------------------------------------------------------------===//
186+
// BufferLoadOp
187+
//===----------------------------------------------------------------------===//
188+
179189
def BufferLoadOp : TT_AMDGPU_Op<"buffer_load", [
180190
SameLoadStoreOperandsAndResultEncoding,
181191
AttrSizedOperandSegments,
182192
MemoryEffects<[MemRead<GlobalMemory>]>,
183193
TypesMatchWith<"result element type matches the pointed type of ptr", "result", "ptr", "getPointerTypeToElement($_self)">,
184194
TypesMatchWith<"result and offsets have the same shape", "result", "offsets", "getI32SameShape($_self)">,
185195
TypesMatchWith<"result and mask have the same shape", "result", "mask", "getI1SameShape($_self)",
186-
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
196+
"(cast<BufferLoadOp>($_op).getMask() == nullptr) || std::equal_to<>()">,
187197
TypesMatchWith<"result and other have the same type", "result", "other", "$_self",
188-
"($_op.getOperands().size() <= 4) || std::equal_to<>()">,
198+
"(cast<BufferLoadOp>($_op).getOther() == nullptr) || std::equal_to<>()">,
189199
]>{
190200
let summary = "Load from a scalar base pointer and a tensor offset";
191201
let description = [{
@@ -201,11 +211,10 @@ def BufferLoadOp : TT_AMDGPU_Op<"buffer_load", [
201211
when it converts to the buffer ops because it is important for optimizing
202212
the cache memory access.
203213
}];
204-
let arguments = (
205-
ins
214+
let arguments = (ins
206215
TT_Ptr:$ptr,
207216
I32Tensor:$offsets,
208-
I32:$stride,
217+
Optional<I32>:$stride,
209218
DefaultValuedAttr<TT_CacheModifierAttr, "::mlir::triton::CacheModifier::NONE">:$cache,
210219
Optional<TT_BoolTensor>:$mask,
211220
Optional<TT_Tensor>:$other
@@ -215,24 +224,29 @@ def BufferLoadOp : TT_AMDGPU_Op<"buffer_load", [
215224
let assemblyFormat = [{
216225
$ptr `[` $offsets `]` (`,` $mask^)? (`,` $other^)?
217226
oilist(`cacheModifier` `=` $cache)
218-
`stride` `=` $stride
227+
(`stride` `=` $stride^)?
219228
attr-dict `:` type($result)
220229
}];
221230
}
222231

232+
//===----------------------------------------------------------------------===//
233+
// BufferAtomicRMWOp
234+
//===----------------------------------------------------------------------===//
235+
223236
def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [
237+
AttrSizedOperandSegments,
224238
SameLoadStoreOperandsAndResultEncoding,
225239
MemoryEffects<[MemRead<GlobalMemory>]>,
226240
MemoryEffects<[MemWrite<GlobalMemory>]>,
227241
TypesMatchWith<"result element type matches the value type", "result", "value", "$_self">,
228242
TypesMatchWith<"result element type matches the pointed type of ptr", "result", "ptr", "getPointerTypeToElement($_self)">,
229243
TypesMatchWith<"result and offsets have the same shape", "result", "offsets", "getI32SameShape($_self)">,
230244
TypesMatchWith<"result and mask have the same shape", "result", "mask", "getI1SameShape($_self)",
231-
"($_op.getOperands().size() <= 4) || std::equal_to<>()">,
245+
"(cast<BufferAtomicRMWOp>($_op).getMask() == nullptr) || std::equal_to<>()">,
232246
TypesMatchWith<"value element type matches the pointed type of ptr", "value", "ptr", "getPointerTypeToElement($_self)">,
233247
TypesMatchWith<"value and offsets have the same shape", "value", "offsets", "getI32SameShape($_self)">,
234248
TypesMatchWith<"value and mask have the same shape", "value", "mask", "getI1SameShape($_self)",
235-
"($_op.getOperands().size() <= 4) || std::equal_to<>()">,
249+
"(cast<BufferAtomicRMWOp>($_op).getMask() == nullptr) || std::equal_to<>()">,
236250
]>{
237251
let summary = "Atomic RMW op which reads, modifies, and writes to a scalar base pointer and a tensor offset";
238252
let description = [{
@@ -246,13 +260,12 @@ def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [
246260
the address difference between the first elements of each row in bytes. Compiler tries to obtain the `stride`
247261
when it converts to the buffer ops because it is important for optimizing the cache memory access.
248262
}];
249-
let arguments = (
250-
ins
263+
let arguments = (ins
251264
TT_AtomicRMWAttr:$atomic_rmw_op,
252265
TT_Ptr:$ptr,
253266
I32Tensor:$offsets,
254267
TT_Tensor:$value,
255-
I32:$stride,
268+
Optional<I32>:$stride,
256269
TT_MemSemanticAttr:$sem,
257270
TT_MemSyncScopeAttr:$scope,
258271
Optional<TT_BoolTensor>:$mask
@@ -261,46 +274,23 @@ def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [
261274

262275
let assemblyFormat = [{
263276
$atomic_rmw_op `,` $sem `,` $scope `,` $value `,` $ptr `[` $offsets `]` (`,` $mask^)?
264-
`stride` `=` $stride
277+
(`stride` `=` $stride^)?
265278
attr-dict `:` type($result)
266279
}];
267280
}
268281

269-
def TTG_UpcastMXFPOp : TT_AMDGPU_Op<"upcast_mxfp", [Pure]> {
270-
let summary = "Convert an mxfp tensor to bf16/fp16";
271-
272-
let hasVerifier = 1;
273-
274-
let description = [{
275-
Compute the bf16 encoded in the given mxfp number as per
276-
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
277-
}];
278-
let arguments = (
279-
ins
280-
TT_Tensor:$src,
281-
TT_Tensor:$scale,
282-
TT_ScaleDotElemTypeAttr:$fp_type,
283-
BoolAttr:$fastMath
284-
);
285-
let results = (outs TT_Tensor:$result);
286-
287-
let assemblyFormat = [{
288-
$src `,` $scale `fp_type` `=` $fp_type attr-dict `:` type($src) `,` type($scale) `->` type($result)
289-
}];
290-
291-
let extraClassDeclaration = [{
292-
static RankedTensorType deduceOutputType(
293-
TypedValue<RankedTensorType> inputTensor, ScaleDotElemType inputElemType, Type outputElemType);
294-
}];
295-
}
282+
//===----------------------------------------------------------------------===//
283+
// BufferStoreOp
284+
//===----------------------------------------------------------------------===//
296285

297286
def BufferStoreOp : TT_AMDGPU_Op<"buffer_store", [
287+
AttrSizedOperandSegments,
298288
SameLoadStoreOperandsEncoding,
299289
MemoryEffects<[MemWrite<GlobalMemory>]>,
300290
TypesMatchWith<"value element type matches the pointed type of ptr", "value", "ptr", "getPointerTypeToElement($_self)">,
301291
TypesMatchWith<"value and offsets have the same shape", "value", "offsets", "getI32SameShape($_self)">,
302292
TypesMatchWith<"value and mask have the same shape", "value", "mask", "getI1SameShape($_self)",
303-
"($_op.getOperands().size() <= 4) || std::equal_to<>()">,
293+
"(cast<BufferStoreOp>($_op).getMask() == nullptr) || std::equal_to<>()">,
304294
]>{
305295
let summary = "Store into scalar base pointer and a tensor offset";
306296
let description = [{
@@ -316,22 +306,53 @@ def BufferStoreOp : TT_AMDGPU_Op<"buffer_store", [
316306
when it converts to the buffer ops because it is important for optimizing
317307
the cache memory access.
318308
}];
319-
let arguments = (
320-
ins
309+
let arguments = (ins
321310
TT_Tensor:$value,
322311
TT_Ptr:$ptr,
323312
I32Tensor:$offsets,
324-
I32:$stride,
313+
Optional<I32>:$stride,
325314
DefaultValuedAttr<TT_CacheModifierAttr, "mlir::triton::CacheModifier::NONE">:$cache,
326315
Optional<TT_BoolTensor>:$mask
327316
);
328317

329318
let assemblyFormat = [{
330319
$value `,` $ptr `[` $offsets `]` (`,` $mask^)?
331320
oilist(`cacheModifier` `=` $cache)
332-
`stride` `=` $stride
321+
(`stride` `=` $stride^)?
333322
attr-dict `:` type($value)
334323
}];
335324
}
336325

326+
//===----------------------------------------------------------------------===//
327+
// UpcastMXFPOp
328+
//===----------------------------------------------------------------------===//
329+
330+
def TTG_UpcastMXFPOp : TT_AMDGPU_Op<"upcast_mxfp", [Pure]> {
331+
let summary = "Convert an mxfp tensor to bf16/fp16";
332+
333+
let hasVerifier = 1;
334+
335+
let description = [{
336+
Compute the bf16 encoded in the given mxfp number as per
337+
https://www.opencompute.org/documents/ocp-microscaling-formats-mx-v1-0-spec-final-pdf
338+
}];
339+
let arguments = (
340+
ins
341+
TT_Tensor:$src,
342+
TT_Tensor:$scale,
343+
TT_ScaleDotElemTypeAttr:$fp_type,
344+
BoolAttr:$fastMath
345+
);
346+
let results = (outs TT_Tensor:$result);
347+
348+
let assemblyFormat = [{
349+
$src `,` $scale `fp_type` `=` $fp_type attr-dict `:` type($src) `,` type($scale) `->` type($result)
350+
}];
351+
352+
let extraClassDeclaration = [{
353+
static RankedTensorType deduceOutputType(
354+
TypedValue<RankedTensorType> inputTensor, ScaleDotElemType inputElemType, Type outputElemType);
355+
}];
356+
}
357+
337358
#endif

third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -237,22 +237,16 @@ Value getBlockStride(Location loc, Value offset, PatternRewriter &rewriter) {
237237
// canonicalize pointer pass sets block stride via
238238
// `offset:add-broadcast-muli-splat`, backtrace that pattern to reach the
239239
// stride.
240-
if (auto maybeAdd = offset.getDefiningOp<arith::AddIOp>()) {
241-
for (auto addOpr : maybeAdd.getOperands()) {
240+
if (auto maybeAdd = offset.getDefiningOp<arith::AddIOp>())
241+
for (auto addOpr : maybeAdd.getOperands())
242242
if (auto maybeBC = addOpr.getDefiningOp<tt::BroadcastOp>()) {
243243
auto bcSrc = maybeBC.getSrc();
244-
if (auto maybeMul = bcSrc.getDefiningOp<arith::MulIOp>()) {
245-
for (auto mulOpr : maybeMul.getOperands()) {
246-
if (auto maybeSplat = mulOpr.getDefiningOp<tt::SplatOp>()) {
244+
if (auto maybeMul = bcSrc.getDefiningOp<arith::MulIOp>())
245+
for (auto mulOpr : maybeMul.getOperands())
246+
if (auto maybeSplat = mulOpr.getDefiningOp<tt::SplatOp>())
247247
return maybeSplat.getSrc();
248-
}
249-
}
250-
}
251248
}
252-
}
253-
}
254-
return rewriter.create<arith::ConstantIntOp>(loc, 0, 32);
255-
;
249+
return nullptr;
256250
}
257251

258252
} // namespace

0 commit comments

Comments
 (0)