Skip to content

Commit 71a23b2

Browse files
Merge commit '9e626543e6c2c9702c81ef6a4500b7b3e835e1b6'
2 parents dd36563 + 9e62654 commit 71a23b2

File tree

11 files changed

+50
-57
lines changed

11 files changed

+50
-57
lines changed

include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@ class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
1414
public:
1515
using TypeConverter::convertType;
1616

17-
TritonGPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option,
17+
TritonGPUToLLVMTypeConverter(MLIRContext *ctx,
18+
const LowerToLLVMOptions &option,
19+
const TargetInfoBase &targetInfo,
20+
const DataLayoutAnalysis *analysis = nullptr);
21+
TritonGPUToLLVMTypeConverter(MLIRContext *ctx,
1822
const TargetInfoBase &targetInfo,
1923
const DataLayoutAnalysis *analysis = nullptr);
2024

lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@ using ::mlir::triton::gpu::getTotalElemsPerThread;
1212
using ::mlir::triton::gpu::MemDescType;
1313

1414
TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
15-
MLIRContext *ctx, LowerToLLVMOptions &options,
15+
MLIRContext *ctx, const TargetInfoBase &targetInfo,
16+
const DataLayoutAnalysis *analysis)
17+
: TritonGPUToLLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), targetInfo,
18+
analysis) {}
19+
20+
TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
21+
MLIRContext *ctx, const LowerToLLVMOptions &options,
1622
const TargetInfoBase &targetInfo, const DataLayoutAnalysis *analysis)
1723
: LLVMTypeConverter(ctx, options, analysis) {
1824
addConversion([ctx](triton::PointerType type) -> std::optional<Type> {

test/Conversion/amd/buffer_load_store.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
187187
#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
188188
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
189189
// CHECK-LABEL: buffer_atomic
190-
tt.func @buffer_atomic_rmw_fadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0>{tt.divisibility=16:i32}, %N: i32, %values : tensor<128xf32, #blocked0>) {
190+
tt.func @buffer_atomic_rmw_fadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0>{tt.divisibility=16:i32}, %N: i32, %values : tensor<128xf32, #blocked0>, %stride: i32 {tt.divisibility=16:i32}) {
191191
%c128_i32 = arith.constant 128 : i32
192192
%0 = tt.get_program_id x : i32
193193
%1 = arith.muli %0, %c128_i32 : i32
@@ -203,7 +203,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
203203
// CHECK: %[[offset:.*]] = llvm.select %[[mask1]]
204204

205205
// We will have 4 calls to fadd, since the sizePerThread is 4. We should have a vmcnt between each call.
206-
%ret = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %values, %arg0[%offset], %mask : tensor<128xf32, #blocked0>
206+
%ret = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %values, %arg0[%offset], %mask stride = %stride : tensor<128xf32, #blocked0>
207207

208208
// CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32
209209
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "s_waitcnt vmcnt(0) ", "" : () -> !llvm.void

test/TritonGPU/amd/amd-convert-buffer-ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
566566
%5 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32
567567
%6 = tt.splat %5 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
568568
%7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
569-
// CHECK: %[[loaded:.*]] = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %arg1, %[[scalar_ptr]][%[[offset]]]
569+
// CHECK: %[[loaded:.*]] = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %arg1, %[[scalar_ptr]][%[[offset]]] stride = %c0_i32
570570
%8 = tt.atomic_rmw fadd, acq_rel, gpu, %7, %arg1 : (tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked>
571571
tt.return %8 : tensor<1024xf32, #blocked>
572572
}

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,11 @@ def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [
228228
TypesMatchWith<"result element type matches the pointed type of ptr", "result", "ptr", "getPointerTypeToElement($_self)">,
229229
TypesMatchWith<"result and offsets have the same shape", "result", "offsets", "getI32SameShape($_self)">,
230230
TypesMatchWith<"result and mask have the same shape", "result", "mask", "getI1SameShape($_self)",
231-
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
231+
"($_op.getOperands().size() <= 4) || std::equal_to<>()">,
232232
TypesMatchWith<"value element type matches the pointed type of ptr", "value", "ptr", "getPointerTypeToElement($_self)">,
233233
TypesMatchWith<"value and offsets have the same shape", "value", "offsets", "getI32SameShape($_self)">,
234234
TypesMatchWith<"value and mask have the same shape", "value", "mask", "getI1SameShape($_self)",
235-
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
235+
"($_op.getOperands().size() <= 4) || std::equal_to<>()">,
236236
]>{
237237
let summary = "Atomic RMW op which reads, modifies, and writes to a scalar base pointer and a tensor offset";
238238
let description = [{
@@ -242,13 +242,17 @@ def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [
242242
the atomic RMW op. Elements with `mask[i] == 0` are dropped (i.e., the atomic is not executed).
243243
Similar to TT_AtomicRMWOp: Buffer atomic RMW ops load data at $ptr, do $rmw_op with $val, and store result to $ptr with
244244
the specified memory semantics and scope. Atomic RMW ops return the pre-op value if used, otherwise the value is implicitly dropped.
245+
Stride is the distance between the beginning of contiguous memory chunks. When performing a RMW, the `stride` is
246+
the address difference between the first elements of each row in bytes. Compiler tries to obtain the `stride`
247+
when it converts to the buffer ops because it is important for optimizing the cache memory access.
245248
}];
246249
let arguments = (
247250
ins
248251
TT_AtomicRMWAttr:$atomic_rmw_op,
249252
TT_Ptr:$ptr,
250253
I32Tensor:$offsets,
251254
TT_Tensor:$value,
255+
I32:$stride,
252256
TT_MemSemanticAttr:$sem,
253257
TT_MemSyncScopeAttr:$scope,
254258
Optional<TT_BoolTensor>:$mask
@@ -257,6 +261,7 @@ def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [
257261

258262
let assemblyFormat = [{
259263
$atomic_rmw_op `,` $sem `,` $scope `,` $value `,` $ptr `[` $offsets `]` (`,` $mask^)?
264+
`stride` `=` $stride
260265
attr-dict `:` type($result)
261266
}];
262267
}

third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ Value BufferEmitter::createResourceDescriptor(Value basePtr,
6464
Value stride = b.int_val(16, 0);
6565
if (llvm::is_contained({ISAFamily::CDNA3, ISAFamily::CDNA4},
6666
targetInfo.getISAFamily())) {
67-
if (blockStride) { // TODO: BufferAtomicRMWOp is unsupported
67+
if (blockStride) {
6868
Value enableSwizzle = b.int_val(16, 16384);
6969
Value mask14b = b.int_val(16, 16383);
7070
// Cache swizzle supports only upto 8k stride. Also simply swizzling the

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,7 @@ struct BufferAtomicRMWOpConversion
691691
Value llOffset = adaptor.getOffsets();
692692
Value llMask = adaptor.getMask();
693693
Value llData = adaptor.getValue();
694+
Value llStride = adaptor.getStride();
694695

695696
// Determine the vectorization size
696697
Type valueTy = data.getType();
@@ -751,7 +752,7 @@ struct BufferAtomicRMWOpConversion
751752
emitReleaseFence = true;
752753
}
753754

754-
Value rsrcDesc = bufferEmitter.createResourceDescriptor(llPtr);
755+
Value rsrcDesc = bufferEmitter.createResourceDescriptor(llPtr, llStride);
755756
Value rDataMask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
756757
SmallVector<Value> loadedVals;
757758

third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,10 +366,10 @@ struct ConvertTritonAtomicRMWOpToBufferAtomicRMW
366366
Value maybeMask{};
367367
if (op.getMask() && !isZeroConst(op.getMask()))
368368
maybeMask = op.getMask();
369-
369+
Value blockStride = getBlockStride(op->getLoc(), tensorOffset, rewriter);
370370
rewriter.replaceOpWithNewOp<triton::amdgpu::BufferAtomicRMWOp>(
371371
op, op.getVal().getType(), atomicRmwOp, basePtr, tensorOffset,
372-
op.getVal(), sem, scope, maybeMask);
372+
op.getVal(), blockStride, sem, scope, maybeMask);
373373

374374
return success();
375375
}

third_party/nvidia/backend/compiler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,15 +307,13 @@ def make_llir(self, src, metadata, options, capability):
307307
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
308308
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
309309
passes.convert.add_scf_to_cf(pm)
310-
passes.convert.add_index_to_llvmir(pm)
311310
passes.ttgpuir.add_allocate_shared_memory(pm)
312311
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
313312
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
314313
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
315314
passes.common.add_canonicalizer(pm)
316315
passes.common.add_cse(pm)
317316
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
318-
passes.convert.add_arith_to_llvmir(pm)
319317
passes.common.add_canonicalizer(pm)
320318
passes.common.add_cse(pm)
321319
passes.common.add_symbol_dce(pm)
@@ -348,6 +346,10 @@ def make_llir(self, src, metadata, options, capability):
348346
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
349347

350348
# Get some metadata
349+
# warp-specialization mutates num_warps
350+
num_warp_groups = src.get_int_attr("ttg.num-warp-groups-per-cta")
351+
if num_warp_groups is not None:
352+
metadata["num_warps"] *= num_warp_groups
351353
metadata["shared"] = src.get_int_attr("ttg.shared")
352354
metadata["tmem_size"] = src.get_int_attr("ttg.tensor_memory_size")
353355
metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size")

third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
1818
"mlir::triton::TritonDialect",
1919
"mlir::triton::gpu::TritonGPUDialect",
2020
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
21+
"mlir::triton::nvgpu::NVGPUDialect",
2122
"mlir::NVVM::NVVMDialect"];
2223

2324
let options = [

0 commit comments

Comments
 (0)