diff --git a/bin/RegisterTritonDialects.h b/bin/RegisterTritonDialects.h index 3589bac675ce..aa92771c311a 100644 --- a/bin/RegisterTritonDialects.h +++ b/bin/RegisterTritonDialects.h @@ -63,6 +63,7 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { mlir::triton::registerOptimizeAMDLDSUsage(); // TritonAMDGPUTransforms passes + mlir::registerTritonAMDGPUMembarAnalysis(); mlir::registerTritonAMDGPUAccelerateMatmul(); mlir::registerTritonAMDGPUOptimizeEpilogue(); mlir::registerTritonAMDGPUHoistLayoutConversions(); @@ -80,6 +81,8 @@ inline void registerTritonDialects(mlir::DialectRegistry ®istry) { // NVWS passes mlir::registerNVWSTransformsPasses(); + mlir::registerTritonAMDGPURefineOps(); + mlir::registerTritonAMDGPURescheduleOps(); registry.insert< mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect, diff --git a/include/triton/Conversion/TritonGPUToLLVM/Utility.h b/include/triton/Conversion/TritonGPUToLLVM/Utility.h index 7410ddae7454..3f1bf19119c3 100644 --- a/include/triton/Conversion/TritonGPUToLLVM/Utility.h +++ b/include/triton/Conversion/TritonGPUToLLVM/Utility.h @@ -394,16 +394,6 @@ class SharedMemoryObject { return offsets[dim]; } - // TODO(Keren): deprecate the method once AMD backend has cleaned up - Value getBaseBeforeSlice(int dim, Location loc, - RewriterBase &rewriter) const { - auto b = TritonLLVMOpBuilder(loc, rewriter); - Value cSwizzleOffset = getCSwizzleOffset(dim); - Value offset = b.sub(b.i32_val(0), cSwizzleOffset); - Type type = base.getType(); - return b.gep(type, baseElemType, base, offset); - } - private: static SmallVector getOrderForShape(ArrayRef shape, ArrayRef layoutOrder) { diff --git a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp index edfc2c2f775c..da849d14cca1 100644 --- a/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp +++ b/lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp @@ -457,7 +457,22 @@ AMDMfmaEncodingAttr::toLinearLayout(ArrayRef shape) const { identityStandardND(S("warp"), getWarpsPerCTA(), order); LinearLayout ctaLayout = tileLayout * warpLayout; - return combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); + auto combinedLayout = + combineCtaCgaWithShape(ctaLayout, getCTALayout(), shape); + + auto bases = combinedLayout.getBases(); + std::vector> newRegBases; + for (const auto &basis : bases[S("register")]) { + if (llvm::any_of(basis, [](int b) { return b != 0; })) { + newRegBases.push_back(basis); + } + } + bases[S("register")] = newRegBases; + + auto result = LinearLayout(std::move(bases), + llvm::to_vector(combinedLayout.getOutDimNames())); + + return result; } LinearLayout chooseDotDsReadB64TrLayout(DotOperandEncodingAttr dotMfmaLayout, diff --git a/test/Conversion/amd/async_ops_to_llvm.mlir b/test/Conversion/amd/async_ops_to_llvm.mlir index 2caf672b1621..e7f41d3c3959 100644 --- a/test/Conversion/amd/async_ops_to_llvm.mlir +++ b/test/Conversion/amd/async_ops_to_llvm.mlir @@ -1,5 +1,5 @@ -// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s --check-prefix=GFX950 -// RUN: triton-opt %s -split-input-file --allocate-shared-memory --convert-triton-amdgpu-to-llvm=arch=gfx942 --verify-diagnostics | FileCheck %s +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --triton-amdgpu-membar-analysis --convert-triton-amdgpu-to-llvm=arch=gfx950 | FileCheck %s --check-prefix=GFX950 +// RUN: triton-opt %s -split-input-file --allocate-shared-memory --triton-amdgpu-membar-analysis --convert-triton-amdgpu-to-llvm=arch=gfx942 --verify-diagnostics | FileCheck %s #blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 64], warpsPerCTA = [4, 1], order = [1, 0]}> #shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [1, 0]}> diff --git a/test/TritonGPU/amd/amd-extractslice-op.mlir b/test/TritonGPU/amd/amd-extractslice-op.mlir index dfdb88231cf7..d9f45a61597a 100644 --- a/test/TritonGPU/amd/amd-extractslice-op.mlir +++ b/test/TritonGPU/amd/amd-extractslice-op.mlir @@ -12,3 +12,40 @@ module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, tt.return } } + +#blocked3 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [8, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @extract_slice_slice_1(%arg0: tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> {tt.divisibility = 16 : i32}) { + // CHECK: llvm.func @extract_slice_slice_1 + // CHECK-COUNT-8: %{{[0-9]*}} = llvm.extractvalue %arg0[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %8 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32)> + // CHECK-COUNT-4: %{{[0-9]*}} = llvm.insertvalue %{{[0-9]*}}, %{{[0-9]*}}[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32)> + %1 = amdgpu.extract_slice %arg0 [128] : tensor<256xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> to tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked3}>> + tt.return + } +} + +#blocked4 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @extract_slice_slice_0(%arg0: tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked4}>> {tt.divisibility = 16 : i32}) { + // CHECK: llvm.func @extract_slice_slice_0 + // CHECK-COUNT-8: %{{[0-9]*}} = llvm.extractvalue %arg0[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32, i32, i32, i32, i32)> + // CHECK: %8 = llvm.mlir.undef : !llvm.struct<(i32, i32, i32, i32)> + // CHECK-COUNT-4: %{{[0-9]*}} = llvm.insertvalue %{{[0-9]*}}, %{{[0-9]*}}[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32)> + %0 = amdgpu.extract_slice %arg0 [0] : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked4}>> to tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked4}>> + tt.return + } +} + +#blocked5 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [4, 16], warpsPerCTA = [4, 1], order = [1, 0], CTAsPerCGA = [1, 1], CTASplitNum = [1, 1], CTAOrder = [0, 1]}> +module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 8 : i32, "ttg.threads-per-warp" = 64 : i32} { + tt.func @extract_slice_slice_2() { + %0 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked5}>> + %1 = tt.make_range {end = 128 : i32, start = 0 : i32} : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked5}>> + // CHECK-COUNT-4: %{{[0-9]*}} = llvm.insertvalue %{{[0-9]*}}, %{{[0-9]*}}[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32)> + %2 = amdgpu.extract_slice %0 [0] : tensor<128xi32, #ttg.slice<{dim = 0, parent = #blocked5}>> to tensor<64xi32, #ttg.slice<{dim = 0, parent = #blocked5}>> + // CHECK-COUNT-4: %{{[0-9]*}} = llvm.insertvalue %{{[0-9]*}}, %{{[0-9]*}}[{{[0-9]*}}] : !llvm.struct<(i32, i32, i32, i32)> + %3 = amdgpu.extract_slice %1 [0] : tensor<128xi32, #ttg.slice<{dim = 1, parent = #blocked5}>> to tensor<64xi32, #ttg.slice<{dim = 1, parent = #blocked5}>> + tt.return + } +} diff --git a/test/TritonGPU/amd/ops-refinement/elementwise.mlir b/test/TritonGPU/amd/ops-refinement/elementwise.mlir new file mode 100644 index 000000000000..5c173482aaf8 --- /dev/null +++ b/test/TritonGPU/amd/ops-refinement/elementwise.mlir @@ -0,0 +1,166 @@ +// RUN: triton-opt %s -split-input-file -triton-amdgpu-refine-ops='arch=gfx942' | FileCheck %s + +// CHECK-LABEL: @exp_kernel +// CHECK-DAG: [[VALUE_1:%.*]] = amdgpu.extract_slice {{.*}} [0, 0] +// CHECK-DAG: [[VALUE_2:%.*]] = math.exp2 [[VALUE_1]] +// CHECK-DAG: [[VALUE_3:%.*]] = amdgpu.extract_slice {{.*}} [0, 16] +// CHECK-DAG: [[VALUE_4:%.*]] = math.exp2 [[VALUE_3]] +// CHECK-DAG: [[VALUE_5:%.*]] = amdgpu.extract_slice {{.*}} [64, 0] +// CHECK-DAG: [[VALUE_6:%.*]] = math.exp2 [[VALUE_5]] +// CHECK-DAG: [[VALUE_7:%.*]] = amdgpu.extract_slice {{.*}} [64, 16] +// CHECK-DAG: [[VALUE_8:%.*]] = math.exp2 [[VALUE_7]] +// CHECK-DAG: [[VALUE_9:%.*]] = amdgpu.concat [[VALUE_2]], [[VALUE_4]], [[VALUE_6]], [[VALUE_8]] +// CHECK-DAG: tt.return [[VALUE_9]] +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @exp_kernel(%arg0: tensor<128x32xf32, #blocked>) -> tensor<128x32xf32, #blocked> attributes {noinline = false} { + amdgpu.instruction_sched_hint {isBufferLoadsAEnabled = false, isBufferLoadsBEnabled = false, numDsReadsA = #amdgpu.InstCounter<0, none>, numDsReadsB = #amdgpu.InstCounter<0, none>, numDsWritesA = #amdgpu.InstCounter<0, none>, numDsWritesB = #amdgpu.InstCounter<0, none>, numGlobalLoadsA = #amdgpu.InstCounter<0, none>, numGlobalLoadsB = #amdgpu.InstCounter<0, none>, numMMAs = #amdgpu.InstCounter<0, none>, variant = #amdgpu.SchedHintVariant} + %0 = math.exp2 %arg0 : tensor<128x32xf32, #blocked> + tt.return %0 : tensor<128x32xf32, #blocked> + } +} + +// ----- + +// CHECK-LABEL: mul_kernel +// CHECK-DAG: [[VALUE_1:%.*]] = amdgpu.extract_slice {{.*}} [0, 0] +// CHECK-DAG: [[VALUE_2:%.*]] = amdgpu.extract_slice {{.*}} [0, 0] +// CHECK-DAG: [[VALUE_3:%.*]] = arith.mulf [[VALUE_1]], [[VALUE_2]] +// CHECK-DAG: [[VALUE_4:%.*]] = amdgpu.extract_slice {{.*}} [0, 16] +// CHECK-DAG: [[VALUE_5:%.*]] = amdgpu.extract_slice {{.*}} [0, 16] +// CHECK-DAG: [[VALUE_6:%.*]] = arith.mulf [[VALUE_4]], [[VALUE_5]] +// CHECK-DAG: [[VALUE_7:%.*]] = amdgpu.extract_slice {{.*}} [64, 0] +// CHECK-DAG: [[VALUE_8:%.*]] = amdgpu.extract_slice {{.*}} [64, 0] +// CHECK-DAG: [[VALUE_9:%.*]] = arith.mulf [[VALUE_7]], [[VALUE_8]] +// CHECK-DAG: [[VALUE_10:%.*]] = amdgpu.extract_slice {{.*}} [64, 16] +// CHECK-DAG: [[VALUE_11:%.*]] = amdgpu.extract_slice {{.*}} [64, 16] +// CHECK-DAG: [[VALUE_12:%.*]] = arith.mulf [[VALUE_10]], [[VALUE_11]] +// CHECK-DAG: [[VALUE_13:%.*]] = amdgpu.concat [[VALUE_3]], [[VALUE_6]], [[VALUE_9]], [[VALUE_12]] +// CHECK-DAG: tt.return [[VALUE_13]] +#blocked = #ttg.blocked<{sizePerThread = [2, 2], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @mul_kernel(%arg0: tensor<128x32xf32, #blocked>, %arg1: tensor<128x32xf32, #blocked>) -> tensor<128x32xf32, #blocked> attributes {noinline = false} { + amdgpu.instruction_sched_hint {isBufferLoadsAEnabled = false, isBufferLoadsBEnabled = false, numDsReadsA = #amdgpu.InstCounter<0, none>, numDsReadsB = #amdgpu.InstCounter<0, none>, numDsWritesA = #amdgpu.InstCounter<0, none>, numDsWritesB = #amdgpu.InstCounter<0, none>, numGlobalLoadsA = #amdgpu.InstCounter<0, none>, numGlobalLoadsB = #amdgpu.InstCounter<0, none>, numMMAs = #amdgpu.InstCounter<0, none>, variant = #amdgpu.SchedHintVariant} + %0 = arith.mulf %arg0, %arg1 : tensor<128x32xf32, #blocked> + tt.return %0 : tensor<128x32xf32, #blocked> + } +} + +// ----- + +// CHECK-LABEL: @multiple_operations_kernel + +// CHECK-COUNT-4: amdgpu.extract_slice {{.*}} +// CHECK: [[OP1:%.*]] = amdgpu.concat +// CHECK-COUNT-4: amdgpu.extract_slice [[OP1]] +// CHECK: [[OP2:%.*]] = amdgpu.concat +// CHECK-COUNT-4: amdgpu.extract_slice [[OP2]] +// CHECK: [[OP3:%.*]] = amdgpu.concat +// CHECK: tt.return [[OP3]] +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @multiple_operations_kernel(%arg0: tensor<128x32xf32, #mma>, %arg1: tensor<128x32xf32, #mma>) -> tensor<128x32xf32, #mma> attributes {noinline = false} { + amdgpu.instruction_sched_hint {isBufferLoadsAEnabled = false, isBufferLoadsBEnabled = false, numDsReadsA = #amdgpu.InstCounter<0, none>, numDsReadsB = #amdgpu.InstCounter<0, none>, numDsWritesA = #amdgpu.InstCounter<0, none>, numDsWritesB = #amdgpu.InstCounter<0, none>, numGlobalLoadsA = #amdgpu.InstCounter<0, none>, numGlobalLoadsB = #amdgpu.InstCounter<0, none>, numMMAs = #amdgpu.InstCounter<0, none>, variant = #amdgpu.SchedHintVariant} + %0 = math.exp2 %arg0 : tensor<128x32xf32, #mma> + %1 = math.exp2 %0 : tensor<128x32xf32, #mma> + %2 = math.exp2 %1 : tensor<128x32xf32, #mma> + tt.return %2 : tensor<128x32xf32, #mma> + } +} + +// ----- + +// CHECK-LABEL: @nested_operations_kernel +// CHECK-COUNT-8: amdgpu.extract_slice +// CHECK: mulf +// CHECK: amdgpu.concat +// CHECK: scf.for +// CHECK-COUNT-4: amdgpu.extract_slice +// CHECK: math.exp2 +// CHECK: amdgpu.concat +// CHECK: } +#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @nested_operations_kernel(%arg0: tensor<128x32xf32, #blocked>, %arg1: tensor<128x32xf32, #blocked>) -> tensor<128x32xf32, #blocked> attributes {noinline = false} { + amdgpu.instruction_sched_hint {isBufferLoadsAEnabled = false, isBufferLoadsBEnabled = false, numDsReadsA = #amdgpu.InstCounter<0, none>, numDsReadsB = #amdgpu.InstCounter<0, none>, numDsWritesA = #amdgpu.InstCounter<0, none>, numDsWritesB = #amdgpu.InstCounter<0, none>, numGlobalLoadsA = #amdgpu.InstCounter<0, none>, numGlobalLoadsB = #amdgpu.InstCounter<0, none>, numMMAs = #amdgpu.InstCounter<0, none>, variant = #amdgpu.SchedHintVariant} + %0 = arith.mulf %arg0, %arg1 : tensor<128x32xf32, #blocked> + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %c4 = arith.constant 4 : i32 + %1 = scf.for %arg2 = %c0 to %c4 step %c1 iter_args(%arg3 = %0) -> (tensor<128x32xf32, #blocked>) : i32 { + %2 = math.exp2 %0 : tensor<128x32xf32, #blocked> + scf.yield %2 : tensor<128x32xf32, #blocked> + } + tt.return %1 : tensor<128x32xf32, #blocked> + } +} + +// ----- + +// CHECK-LABEL: @peer_operations_kernel +// CHECK: scf.for +// CHECK-COUNT-4: amdgpu.extract_slice +// CHECK: math.exp2 +// CHECK: amdgpu.concat +// CHECK: scf.for +// CHECK-NOT: amdgpu.extract_slice +// CHECK: math.exp2 +// CHECK-NOT: amdgpu.concat +// CHECK: } +#blocked = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @peer_operations_kernel(%arg0: tensor<128x32xf32, #blocked>) -> tensor<128x32xf32, #blocked> attributes {noinline = false} { + %c0 = arith.constant 0 : i32 + %c1 = arith.constant 1 : i32 + %c4 = arith.constant 4 : i32 + %1 = scf.for %arg1 = %c0 to %c4 step %c1 iter_args(%arg2 = %arg0) -> (tensor<128x32xf32, #blocked>) : i32 { + amdgpu.instruction_sched_hint {isBufferLoadsAEnabled = false, isBufferLoadsBEnabled = false, numDsReadsA = #amdgpu.InstCounter<0, none>, numDsReadsB = #amdgpu.InstCounter<0, none>, numDsWritesA = #amdgpu.InstCounter<0, none>, numDsWritesB = #amdgpu.InstCounter<0, none>, numGlobalLoadsA = #amdgpu.InstCounter<0, none>, numGlobalLoadsB = #amdgpu.InstCounter<0, none>, numMMAs = #amdgpu.InstCounter<0, none>, variant = #amdgpu.SchedHintVariant} + %2 = math.exp2 %arg2 : tensor<128x32xf32, #blocked> + scf.yield %2 : tensor<128x32xf32, #blocked> + } + %3 = scf.for %arg3 = %c0 to %c4 step %c1 iter_args(%arg4 = %1) -> (tensor<128x32xf32, #blocked>) : i32 { + %4 = math.exp2 %arg4 : tensor<128x32xf32, #blocked> + scf.yield %4 : tensor<128x32xf32, #blocked> + } + tt.return %3 : tensor<128x32xf32, #blocked> + } +} + +// ----- + +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [32, 32], isTransposed = true}> +#mma1 = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [4, 1], instrShape = [16, 16], isTransposed = true}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 16384 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @convert_layout(%arg0: tensor<128x64xf16, #mma>) attributes {noinline = false} { + // CHECK-LABEL: convert_layout + + // CHECK: [[ES_0:%.*]] = amdgpu.extract_slice %arg0 [0, 0] : tensor<128x64xf16, #mma> to tensor<128x16xf16, #mma> + // CHECK: [[CL_0:%.*]] = ttg.convert_layout [[ES_0]] : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + // CHECK: [[ES_1:%.*]] = amdgpu.extract_slice %arg0 [0, 16] : tensor<128x64xf16, #mma> to tensor<128x16xf16, #mma> + // CHECK: [[CL_1:%.*]] = ttg.convert_layout [[ES_1]] : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + // CHECK: [[ES_2:%.*]] = amdgpu.extract_slice %arg0 [0, 32] : tensor<128x64xf16, #mma> to tensor<128x16xf16, #mma> + // CHECK: [[CL_2:%.*]] = ttg.convert_layout [[ES_2]] : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + // CHECK: [[ES_3:%.*]] = amdgpu.extract_slice %arg0 [0, 48] : tensor<128x64xf16, #mma> to tensor<128x16xf16, #mma> + // CHECK: [[CL_3:%.*]] = ttg.convert_layout [[ES_3]] : tensor<128x16xf16, #mma> -> tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + // CHECK: %8 = amdgpu.concat [[CL_0]], [[CL_1]], [[CL_2]], [[CL_3]] [1, 4] {loweringOrder = array} : tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>, tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>, tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>>, tensor<128x16xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + + %0 = ttg.convert_layout %arg0 : tensor<128x64xf16, #mma> -> tensor<128x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma1, kWidth = 4}>> + amdgpu.instruction_sched_hint {isBufferLoadsAEnabled = false, isBufferLoadsBEnabled = false, numDsReadsA = #amdgpu.InstCounter<0, none>, numDsReadsB = #amdgpu.InstCounter<0, none>, numDsWritesA = #amdgpu.InstCounter<0, none>, numDsWritesB = #amdgpu.InstCounter<0, none>, numGlobalLoadsA = #amdgpu.InstCounter<0, none>, numGlobalLoadsB = #amdgpu.InstCounter<0, none>, numMMAs = #amdgpu.InstCounter<0, none>, variant = #amdgpu.SchedHintVariant} + tt.return + } +} + +// ----- + +// blocked layout cta tile has size of whole tensor, no transformation should happen +// CHECK-LABEL: @convert_layout_kernel_neg +// CHECK-NOT: amdgpu.extract_slice +#blocked1 = #ttg.blocked<{sizePerThread = [4, 1], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +#blocked2 = #ttg.blocked<{sizePerThread = [1, 4], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @convert_layout_kernel_neg(%arg0: tensor<128x32xf32, #blocked1>) -> tensor<128x32xf32, #blocked2> attributes {noinline = false} { + amdgpu.instruction_sched_hint {isBufferLoadsAEnabled = false, isBufferLoadsBEnabled = false, numDsReadsA = #amdgpu.InstCounter<0, none>, numDsReadsB = #amdgpu.InstCounter<0, none>, numDsWritesA = #amdgpu.InstCounter<0, none>, numDsWritesB = #amdgpu.InstCounter<0, none>, numGlobalLoadsA = #amdgpu.InstCounter<0, none>, numGlobalLoadsB = #amdgpu.InstCounter<0, none>, numMMAs = #amdgpu.InstCounter<0, none>, variant = #amdgpu.SchedHintVariant} + %0 = ttg.convert_layout %arg0 : tensor<128x32xf32, #blocked1> -> tensor<128x32xf32, #blocked2> + tt.return %0 : tensor<128x32xf32, #blocked2> + } +} diff --git a/test/TritonGPU/amd/ops-refinement/local_alloc.mlir b/test/TritonGPU/amd/ops-refinement/local_alloc.mlir new file mode 100644 index 000000000000..4ece55d7606d --- /dev/null +++ b/test/TritonGPU/amd/ops-refinement/local_alloc.mlir @@ -0,0 +1,35 @@ +// RUN: triton-opt %s -split-input-file -triton-amdgpu-refine-ops='arch=gfx942' -canonicalize | FileCheck %s + +#shared = #ttg.swizzled_shared<{vec = 1, perPhase = 1, maxPhase = 1, order = [0, 1]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [64, 1], warpsPerCTA = [1, 4], order = [0, 1]}> +#smem = #ttg.shared_memory + + +// CHECK-LABEL: @local_alloc_refinement +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.shared = 16384 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + tt.func public @local_alloc_refinement(%arg0: tensor<64x16xf16, #blocked>) attributes {noinline = false} { + + // CHECK: [[OFFSET_12:%.*]] = arith.constant 12 : i32 + // CHECK: [[OFFSET_8:%.*]] = arith.constant 8 : i32 + // CHECK: [[OFFSET_4:%.*]] = arith.constant 4 : i32 + // CHECK: [[OFFSET_0:%.*]] = arith.constant 0 : i32 + // CHECK: [[ALLOC:%.*]] = ttg.local_alloc : () -> !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable> + // CHECK: [[SUBVIEW_0:%.*]] = ttg.memdesc_subview [[ALLOC]][[[OFFSET_0]], [[OFFSET_0]], [[OFFSET_0]]] : !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16> + // CHECK: [[SLICE_0:%.*]] = amdgpu.extract_slice %arg0 [0, 0] : tensor<64x16xf16, #blocked> to tensor<64x4xf16, #blocked> + // CHECK: ttg.local_store [[SLICE_0]], [[SUBVIEW_0]] : tensor<64x4xf16, #blocked> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16> + // CHECK: [[SUBVIEW_1:%.*]] = ttg.memdesc_subview [[ALLOC]][[[OFFSET_0]], [[OFFSET_0]], [[OFFSET_4]]] : !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16> + // CHECK: [[SLICE_1:%.*]] = amdgpu.extract_slice %arg0 [0, 4] : tensor<64x16xf16, #blocked> to tensor<64x4xf16, #blocked> + // CHECK: ttg.local_store [[SLICE_1]], [[SUBVIEW_1]] : tensor<64x4xf16, #blocked> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16> + // CHECK: [[SUBVIEW_2:%.*]] = ttg.memdesc_subview [[ALLOC]][[[OFFSET_0]], [[OFFSET_0]], [[OFFSET_8]]] : !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16> + // CHECK: [[SLICE_2:%.*]] = amdgpu.extract_slice %arg0 [0, 8] : tensor<64x16xf16, #blocked> to tensor<64x4xf16, #blocked> + // CHECK: ttg.local_store [[SLICE_2]], [[SUBVIEW_2]] : tensor<64x4xf16, #blocked> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16> + // CHECK: [[SUBVIEW_3:%.*]] = ttg.memdesc_subview [[ALLOC]][[[OFFSET_0]], [[OFFSET_0]], [[OFFSET_12]]] : !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16> + // CHECK: [[SLICE_3:%.*]] = amdgpu.extract_slice %arg0 [0, 12] : tensor<64x16xf16, #blocked> to tensor<64x4xf16, #blocked> + // CHECK: ttg.local_store [[SLICE_3]], [[SUBVIEW_3]] : tensor<64x4xf16, #blocked> -> !ttg.memdesc<64x4xf16, #shared, #smem, mutable, 1x64x16> + // CHECK: amdgpu.instruction_sched_hint {isBufferLoadsAEnabled = false, isBufferLoadsBEnabled = false, numDsReadsA = #amdgpu.InstCounter<0, none>, numDsReadsB = #amdgpu.InstCounter<0, none>, numDsWritesA = #amdgpu.InstCounter<0, none>, numDsWritesB = #amdgpu.InstCounter<0, none>, numGlobalLoadsA = #amdgpu.InstCounter<0, none>, numGlobalLoadsB = #amdgpu.InstCounter<0, none>, numMMAs = #amdgpu.InstCounter<0, none>, variant = #amdgpu.SchedHintVariant} + // CHECK: ttg.local_dealloc [[ALLOC]] : !ttg.memdesc<1x64x16xf16, #shared, #smem, mutable> + %0 = ttg.local_alloc %arg0 : (tensor<64x16xf16, #blocked>) -> !ttg.memdesc<64x16xf16, #shared, #smem> + amdgpu.instruction_sched_hint {isBufferLoadsAEnabled = false, isBufferLoadsBEnabled = false, numDsReadsA = #amdgpu.InstCounter<0, none>, numDsReadsB = #amdgpu.InstCounter<0, none>, numDsWritesA = #amdgpu.InstCounter<0, none>, numDsWritesB = #amdgpu.InstCounter<0, none>, numGlobalLoadsA = #amdgpu.InstCounter<0, none>, numGlobalLoadsB = #amdgpu.InstCounter<0, none>, numMMAs = #amdgpu.InstCounter<0, none>, variant = #amdgpu.SchedHintVariant} + tt.return + } +} diff --git a/test/TritonGPU/amd/ops-refinement/simple-dot.mlir b/test/TritonGPU/amd/ops-refinement/simple-dot.mlir new file mode 100644 index 000000000000..cf13440d2b24 --- /dev/null +++ b/test/TritonGPU/amd/ops-refinement/simple-dot.mlir @@ -0,0 +1,42 @@ +// RUN: triton-opt %s -split-input-file -tritonamdgpu-stream-pipeline='num_stages=2' -cse -canonicalize -triton-amdgpu-refine-ops='arch=gfx942' -canonicalize | FileCheck %s + +#mma = #ttg.amd_mfma<{versionMajor = 3, versionMinor = 0, warpsPerCTA = [2, 2], instrShape = [32, 32], isTransposed = true}> +#blocked = #ttg.blocked<{sizePerThread = [1, 8], threadsPerWarp = [8, 8], warpsPerCTA = [4, 1], order = [1, 0]}> + +module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} { + // CHECK: @matmul_kernel + tt.func public @matmul_kernel( + %arg0: tensor<256x64x!tt.ptr, #blocked> {tt.contiguity=16 : i32, tt.divisibility=16: i32, tt.constancy=16: i32}, + %arg1: tensor<64x128x!tt.ptr, #blocked> {tt.contiguity=16 : i32, tt.divisibility=16: i32, tt.constancy=16: i32}) -> tensor<256x128xf32, #mma> attributes {noinline = false} { + + %output = arith.constant dense<0.000000e+00> : tensor<256x128xf32, #mma> + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + %c64_i32 = arith.constant 64 : i32 + + %shift0 = arith.constant dense<64> : tensor<256x64xi32, #blocked> + %shift1 = arith.constant dense<64> : tensor<64x128xi32, #blocked> + + %0:3 = scf.for %arg2 = %c0_i32 to %c64_i32 step %c1_i32 iter_args( + %loop_arg0 = %output, + %loop_arg1 = %arg0, + %loop_arg2 = %arg1) -> ( + tensor<256x128xf32, #mma>, + tensor<256x64x!tt.ptr, #blocked>, + tensor<64x128x!tt.ptr, #blocked>) : i32 { + %1 = tt.load %loop_arg1 : tensor<256x64x!tt.ptr, #blocked> + %2 = tt.load %loop_arg2 : tensor<64x128x!tt.ptr, #blocked> + %3 = ttg.convert_layout %1 : tensor<256x64xf16, #blocked> -> tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> + %4 = ttg.convert_layout %2 : tensor<64x128xf16, #blocked> -> tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> + %5 = tt.dot %3, %4, %loop_arg0 : tensor<256x64xf16, #ttg.dot_op<{opIdx = 0, parent = #mma, kWidth = 4}>> * tensor<64x128xf16, #ttg.dot_op<{opIdx = 1, parent = #mma, kWidth = 4}>> -> tensor<256x128xf32, #mma> + %6 = tt.addptr %loop_arg1, %shift0 : tensor<256x64x!tt.ptr, #blocked>, tensor<256x64xi32, #blocked> + %7 = tt.addptr %loop_arg2, %shift1 : tensor<64x128x!tt.ptr, #blocked>, tensor<64x128xi32, #blocked> + scf.yield %5, %6, %7 : tensor<256x128xf32, #mma>, tensor<256x64x!tt.ptr, #blocked>, tensor<64x128x!tt.ptr, #blocked> + } + + tt.return %0#0 : tensor<256x128xf32, #mma> + } +} + + +// TODO: add TT GEMM case to the test diff --git a/third_party/amd/backend/compiler.py b/third_party/amd/backend/compiler.py index 3b763c3c84c4..d99735a816e6 100644 --- a/third_party/amd/backend/compiler.py +++ b/third_party/amd/backend/compiler.py @@ -299,6 +299,10 @@ def make_llir(src, metadata, options): passes.convert.add_index_to_llvmir(pm) passes.ttgpuir.add_allocate_shared_memory(pm) + amd.passes.ttgpuir.add_membar_analysis(pm) + amd.passes.ttgpuir.add_refine_amdgpu_ops(pm, options.arch) + passes.common.add_canonicalizer(pm) + amd.passes.ttgpuir.add_reschedule_amdgpu_ops(pm, options.arch) ## __HIP_FTZ is used to control the denorm flushing behavior of exp2 op as follows: ## 1. If __HIP_FTZ = 1, exp2 flushes denorms in input and output regardless ## of the value of kernel arg `allow_flush_denorm`. diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td index 101b7bd444be..c54ac326b021 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUAttrDefs.td @@ -75,15 +75,32 @@ class TritonAMDGPU_I32EnumAttr : def SchedHintCaseNone : I32EnumAttrCase<"none", 0>; def SchedHintCaseLocalPrefetch : I32EnumAttrCase<"local_prefetch", 1>; def SchedHintCaseAttention : I32EnumAttrCase<"attention", 2>; +def SchedHintCaseRefineOps : I32EnumAttrCase<"refine_ops", 4>; def TritonAMDGPU_SchedHintsEnum : TritonAMDGPU_I32Enum< "SchedHint", "Instruction Scheduling Hints for AMD GPUs", [ SchedHintCaseNone, SchedHintCaseLocalPrefetch, SchedHintCaseAttention, + SchedHintCaseRefineOps ]>; def TritonAMDGPU_SchedHintVariantAttr : TritonAMDGPU_I32EnumAttr<"SchedHintVariant", TritonAMDGPU_SchedHintsEnum>; +def TritonAMDGPU_DotTile : TritonAMDGPU_Attr<"DotTile"> { + let cppNamespace = "::mlir::triton::amdgpu"; + let mnemonic = "DotTile"; + let summary = "Information regarding the dot-tile this op belongs to."; + let description = [{ + The attribute is a way to describe which input argument of the target + operation (e.g., `tt.dot`) the result of a given operation belongs to. + The parameters tile{M,N,K,Serial} refer to the dot-tile's id within the dot, + while the element{M,N,K,Serial} refer to the element's id within the dot-tile. + + }]; + let parameters = (ins "int32_t":$tileM, "int32_t":$tileN, "int32_t":$tileK, "uint32_t":$tileSerial, "int32_t":$elementM, "int32_t":$elementN, "int32_t":$elementK, "uint32_t":$elementSerial ); + let assemblyFormat = "`<` $tileM `,` $tileN `,` $tileK `>` `[` $tileSerial `]` `,` `<` $elementM `,` $elementN `,` $elementK `>` `[` $elementSerial `]`"; +} + #endif diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td index 17d9409468d8..177c41f6a137 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td +++ b/third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td @@ -117,6 +117,38 @@ def ExtractSliceOp : TT_AMDGPU_Op<"extract_slice", [Pure]> { }]; let hasVerifier = 1; + let hasCanonicalizer = 1; +} + +def ConcatOp : TT_AMDGPU_Op<"concat", [Pure]> { + let summary = "concat operation"; + let description = [{ + The "concat" operation joins slices of a tensor together + in the row-major style. + + TODO: complete + }]; + + let arguments = (ins Variadic:$sources, DenseI64ArrayAttr:$coords, DenseI64ArrayAttr:$loweringOrder); + let results = (outs AnyRankedTensor:$result); + + let builders = [ + OpBuilder<(ins "::mlir::Type":$result, "::mlir::ValueRange":$sources, "::mlir::DenseI64ArrayAttr":$coords), [{ + auto ctx = $_state.getContext(); + std::vector order(coords.getSize()); + int64_t counter = 0; + std::for_each(order.rbegin(), order.rend(), [&counter](int64_t &val) { val = counter++; }); + auto loweringOrderAttr = DenseI64ArrayAttr::get(ctx, order); + build($_builder, $_state, result, sources, coords, loweringOrderAttr); + }]> + ]; + + let assemblyFormat = [{ + $sources $coords attr-dict `:` type($sources) `->` type($result) + }]; + + let hasVerifier = 1; + let hasCanonicalizer = 1; } //===----------------------------------------------------------------------===// diff --git a/third_party/amd/include/Dialect/TritonAMDGPU/Utility/CommonUtils.h b/third_party/amd/include/Dialect/TritonAMDGPU/Utility/CommonUtils.h index 5b599f5e757c..3f93944ca306 100644 --- a/third_party/amd/include/Dialect/TritonAMDGPU/Utility/CommonUtils.h +++ b/third_party/amd/include/Dialect/TritonAMDGPU/Utility/CommonUtils.h @@ -3,9 +3,74 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "triton/Dialect/Triton/IR/Dialect.h" +#include namespace mlir::triton::AMD { SmallVector getLeafForOps(triton::FuncOp funcOp); + +class CoordinateMapper { +public: + CoordinateMapper(llvm::ArrayRef layout) : layout(layout) { + bounds.resize(layout.size()); + std::exclusive_scan(layout.rbegin(), layout.rend(), bounds.begin(), 1, + std::multiplies<>()); + } + + SmallVector map(int64_t index) { + SmallVector coords(bounds.size(), 0); + for (size_t i = 1; i < bounds.size(); ++i) { + size_t d = bounds.size() - i; + coords[d] = index / bounds[d]; + index = index % bounds[d]; + } + coords[0] = index; + std::reverse(coords.begin(), coords.end()); + return coords; + } + + // TODO (Ravil): add C++ + template + static std::vector> cartesian(const std::vector &ranges, + const std::vector &order) { + assert(ranges.size() == order.size()); + auto imageSize = + std::accumulate(ranges.begin(), ranges.end(), 1, std::multiplies{}); + auto product = + std::vector>(imageSize, std::vector(ranges.size())); + + auto strides = CoordinateMapper::getDeviders(ranges, order); + for (size_t vec = 0; vec < imageSize; ++vec) { + for (size_t elem = 0; elem < ranges.size(); ++elem) { + product[vec][elem] = (vec / strides[elem]) % ranges[elem]; + } + } + + return product; + } + +private: + template + static std::vector getDeviders(const std::vector &dims, + const std::vector &order) { + std::vector orderedDims(dims.size()); + for (size_t i = 0; i < dims.size(); ++i) { + orderedDims[i] = dims[order[i]]; + } + std::vector strides(dims.size()); + std::exclusive_scan(orderedDims.begin(), orderedDims.end(), strides.begin(), + static_cast(1), std::multiplies<>()); + + std::vector orderedDeviders(dims.size()); + for (size_t d = 0; d < dims.size(); ++d) { + orderedDeviders[d] = strides[order[d]]; + } + return orderedDeviders; + } + + llvm::ArrayRef layout; + std::vector bounds; +}; + } // namespace mlir::triton::AMD #endif // TRITON_THIRD_PARTY_AMD_INCLUDE_DIALECT_TRITONAMDGPU_UTILITY_COMMONUTILS_H_ diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td index ecf2dccb1327..2094c884ee7c 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td +++ b/third_party/amd/include/TritonAMDGPUToLLVM/Passes.td @@ -77,4 +77,5 @@ def TritonAMDGPULowerInstructionSchedHints : Pass<"triton-amdgpu-lower-insert-in ]; } + #endif diff --git a/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h b/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h index 6763de2eba22..721dcc232ef8 100644 --- a/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h +++ b/third_party/amd/include/TritonAMDGPUToLLVM/PatternTritonAMDGPUToLLVM.h @@ -12,6 +12,10 @@ void populateExtractSliceOpToLLVMPatterns( void populateInThreadTransposeOpToTTGPatterns(mlir::RewritePatternSet &patterns, mlir::PatternBenefit benefit); +void populateConcatOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, + mlir::RewritePatternSet &patterns, + mlir::PatternBenefit benefit); + } // namespace mlir::triton::AMD #endif // TRITON_THIRD_PARTY_AMD_INCLUDE_TRITONAMDGPUTOLLVM_PATTERNTRITONAMDGPUTOLLVM_H_ diff --git a/third_party/amd/include/TritonAMDGPUTransforms/DotTiling.h b/third_party/amd/include/TritonAMDGPUTransforms/DotTiling.h new file mode 100644 index 000000000000..f892a52e5683 --- /dev/null +++ b/third_party/amd/include/TritonAMDGPUTransforms/DotTiling.h @@ -0,0 +1,334 @@ +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Pass/Pass.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h" +#include "third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +#undef DEBUG_TYPE +#define DEBUG_TYPE "tritonamdgpu-refine-ops" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +namespace { +/* + TODO - this needs to be MUCH more official. +*/ +unsigned getCyclesPerMfma(DotOp dotOp) { + // Get mfma op type. + auto mfmaLayout = cast( + cast(dotOp.getResult().getType()).getEncoding()); + MLIRContext *ctx = dotOp->getContext(); + Value a = dotOp.getA(); + Value b = dotOp.getB(); + auto aTensorTy = cast(a.getType()); + auto bTensorTy = cast(b.getType()); + auto elemTyA = aTensorTy.getElementType(); + auto elemTyB = bTensorTy.getElementType(); + auto mDim = mfmaLayout.getMDim(); + auto nDim = mfmaLayout.getNDim(); + const auto kDimOperandSize = aTensorTy.getShape().back(); + // auto kDim = mfmaLayout.getKDim(); + auto mfmaVersion = mfmaLayout.getVersionMajor(); + bool allowXF32 = + dotOp.getInputPrecision() == InputPrecision::TF32 && mfmaVersion == 3; + + FailureOr maybeMfmaInsn = MfmaIntrinsic::selectFor( + mfmaVersion, mDim, nDim, kDimOperandSize, elemTyA, elemTyB, + /*withScale=*/false, allowXF32); + + if (failed(maybeMfmaInsn)) + llvm::report_fatal_error("No match found in MFMA database\n"); + // Estimate rate of mfma op type. + unsigned maxBitWidth = + std::max(maybeMfmaInsn->aElementType.getIntOrFloatBitWidth(), + maybeMfmaInsn->bElementType.getIntOrFloatBitWidth()); + // Estimate throughput as fma's per cycle. + unsigned opsPerCycle; + if (maxBitWidth <= 8) { // fp8, bf8, i8 + opsPerCycle = 512; + } else if (maxBitWidth <= 16) { // fp16, bf16 + opsPerCycle = 256; + } else if (maxBitWidth <= 32) { // fp32 + opsPerCycle = 128; + } else { + opsPerCycle = 64; // fp64 + } + // total floating point mfmas + int64_t totalOps = + maybeMfmaInsn->mDim * maybeMfmaInsn->nDim * maybeMfmaInsn->kDim; + unsigned cyclesPerMfma = static_cast(totalOps / opsPerCycle); + LDBG(maybeMfmaInsn->name << " = " << cyclesPerMfma << " cycles\n"); + return cyclesPerMfma; +} + +/* +Calculate how many mfmas are in a rep, e.g. 1x1x2. +// TODO(dtanner) Is there a more direct method for this? +*/ +SmallVector getMfmasPerRep(const ArrayRef &ctaTile, + const ArrayRef &warpsPerCta, + const ArrayRef &numReps, + const ArrayRef &mfmaShape) { + LDBG("ctaTile: " << ctaTile[0] << "x" << ctaTile[1] << "x" << ctaTile[2]); + LDBG("warpsPerCtaTile: " << warpsPerCta[0] << "x" << warpsPerCta[1]); + LDBG("numReps: " << numReps[0] << "x" << numReps[1] << "x" << numReps[2]); + LDBG("mfmaShape: " << mfmaShape[0] << "x" << mfmaShape[1] << "x" + << mfmaShape[2]); + // Tile shape per warp. + SmallVector warpTile = { + ctaTile[0] / warpsPerCta[0], + ctaTile[1] / warpsPerCta[1], + ctaTile[2], + }; + LDBG("warpTile: " << warpTile[0] << "x" << warpTile[1] << "x" << warpTile[2]); + // Tile shape per rep. + SmallVector repTile = { + warpTile[0] / numReps[0], + warpTile[1] / numReps[1], + warpTile[2] / numReps[2], + }; + LDBG("repTile: " << repTile[0] << "x" << repTile[1] << "x" << repTile[2]); + SmallVector mfmasPerRep = { + static_cast(repTile[0] / mfmaShape[0]), + static_cast(repTile[1] / mfmaShape[1]), + static_cast(repTile[2] / mfmaShape[2])}; + LDBG("mfmasPerRep: " << mfmasPerRep[0] << "x" << mfmasPerRep[1] << "x" + << mfmasPerRep[2]); + if (mfmasPerRep[0] < 1 || mfmasPerRep[1] < 1 || mfmasPerRep[2] < 1) { + llvm::errs() << "DotTiling::getMfmasPerRep() - Invalid combination of " + "ctaTile, warpsPerCta and mfmaShape.\n"; + return SmallVector({1, 1, 1}); + } + return mfmasPerRep; +} + +/* + Returns the ideal dot-tile shape (in number of reps, not number of mfmas). + + The dot-tile shape is chosen such that: + (1) A dot-tile's worth of mfma cycles hides the local_load data latency. + If this criterial leads the dot tile shape to be larger than the + dot itself, this means local loads need to be issued more + than one dot tile in advance. + (2) A dot-tile's worth of local_load A,B can + can be issued during a dot-tile's worth of mfmas + without overruning the hardware queues; + this is called issue *rate* below. + Note, this ensures the dot-tile will not be LDS bandwidth bound. + If this criteria leads the dot-tile shape to be larger than the + dot itself, this means the dot will be LDS bandwidth bound. + (3) A dot-tile's worth of local_load_a and local_load_b can + can be issued and have their issue cycles hidden by the mfmas. + If this criterial leads the dot-tile shape to be larger than the + dot itself, this means the dot will be bottlenecked + by issuing local loads; this this case need to increase + ds_read_u16 to ds_read_b32 for example. + (4) The dot-tile is as small and square as possible. + + Typical shapes when mfmasPerRep = 1x1x2 for b128 and localLoadRate=32 + for b128 + - 2x2 for fp16 (128 mfma cycles per tile) and + - 4x4 for fp8 (256 mfma cycles per tile). + + Args: + - mfmasPerRep - shape of number of mfmas in decomposed dot, e.g. 1x1x2. + - preferLargerM - prefer M > N if dot-tile cannot be square. + - cyclesPerMfma - how many cycles does mfma take in total. + - localLoadRateA,B - cycles between issuing consecutive ds_reads to not + overrun hardware queues. This is estimated to be + b128 -> 32 cycles, + b64 -> 16 cycles, + b32 -> 8 cycles, + b16 -> 4 cycles. + Default is 32 cycles, which assumes all ds_read_b128. + - numLoadsLoadsPerMfmaA - Note, this does not yet + handle the case for not enough cycles to issue all the loads, e.g. + mfma_16x16x16 // 16 cycles to compute - 4 cycles to issue = 12 cycles of + hiding ds_read_u16 // hidden ds_read_u16 // hidden ds_read_u16 // hidden + ds_read_u16 + ds_read_u16 + ds_read_u16 + ds_read_u16 + ds_read_u16 + - numLocalLoadsPerMfmaA,B - number of local loads required to load the + A,B operands of a single mfma. + - localLoadDataLatency - cycles between issuing ds_read and waiting for + data; rounded up to pow2. + + Notes: + - The intended scheduling of dot-tiles is: + + -------------------------------- + local_load_a[n] // Hide A,B load issue cycles and rate. + local_load_b[n] + DotTile[n-2] + -------------------------------- + DotTile[n-1] // Hide load A,B data latency. + -------------------------------- + DotTile[n] // A,B data is ready by the first mfma of tile. + -------------------------------- + + - Dot-tile shapes can be further refined if the data latency becomes much + larger than the issue rate; in this case we can remove the condition that one + tile hides all the data latency (which could make the tiles huge and waste + registers), and intead local load issue rate is the only criteria and we + retroactively calculate how many tiles are needed to hide the data latency. + - Dot-tile shapes can be further refined so that a dot-tile only needs to + load a or b, and not both a and b. + - At this time it is assumed that dot-tile-shape[K] = 1 since K's don't + interact with eachother. + // TODO(dtanner) - This should be simplifiable to not require so many + low-level details of the device. + For example, just a ratio of flops/byte for the given mfma precision. + And maybe info regarding transpose and how many loads of what precision + to know if we'll have +*/ +using DotTileShapeType = SmallVector; +DotTileShapeType +calcDotTileShape(const SmallVector + &mfmasPerRep, // = 16x16x64 / 16x16x32 = 1x1x2 + bool preferLargerM, unsigned cyclesPerMfma = 8, + unsigned localLoadRateA = 32, unsigned localLoadRateB = 32, + unsigned numLocalLoadsPerRepA = 1, + unsigned numLocalLoadsPerRepB = 1, + unsigned localLoadDataLatency = 128) { + DotTileShapeType tileShape = {1, 1, 1}; + assert(mfmasPerRep[0] >= 1); + assert(mfmasPerRep[1] >= 1); + assert(mfmasPerRep[2] >= 1); + assert(cyclesPerMfma >= 1); + + bool localLoadDataLatencyExposed = true; + bool localLoadRateExposed = true; + bool localLoadIssueExposed = true; + + // Try a finite number of times to increase the TileShape meet performance + // criteria. + // TODO - possibly need to refactor loop bounds before production ready. + constexpr int maxTries = 12; // sufficient to create 64x64 tile. + for (int i = 0; i < maxTries; ++i) { + // If TileShape meets performance criteria, return. + if (!(localLoadDataLatencyExposed || localLoadRateExposed || + localLoadIssueExposed)) { + return tileShape; + } + // TileShape doesn't meet performance criteria, increase it's size. + // Enforce criteria #4 - small square. + if ((tileShape[0] * mfmasPerRep[0] < tileShape[1] * mfmasPerRep[1]) || + ((tileShape[0] * mfmasPerRep[0] == tileShape[1] * mfmasPerRep[1]) && + preferLargerM)) { + tileShape[0] *= 2; + } else { + tileShape[1] *= 2; + } + // Check criteria #1 - local load data latency. + int64_t numMfmas = tileShape[0] * tileShape[1] * mfmasPerRep[0] * + mfmasPerRep[1] * mfmasPerRep[2]; + int64_t mfmaCycles = numMfmas * cyclesPerMfma; + localLoadDataLatencyExposed = mfmaCycles < localLoadDataLatency; + // Check criteria #2 - local load rate. + int64_t loadRateCycles = + tileShape[0] * mfmasPerRep[0] * numLocalLoadsPerRepA * localLoadRateA + + tileShape[1] * mfmasPerRep[1] * numLocalLoadsPerRepB * localLoadRateB; + localLoadRateExposed = mfmaCycles < loadRateCycles; + // Check criteria #3 - issue cycles. + constexpr unsigned mfmaIssueCycles = 4; // num cycles to issue mfma + constexpr unsigned loadIssueCycles = 4; // num cycles to issue local load + int64_t totalLoadIssueCycles = + tileShape[0] * mfmasPerRep[0] * numLocalLoadsPerRepA * loadIssueCycles + + tileShape[1] * mfmasPerRep[1] * numLocalLoadsPerRepB * loadIssueCycles; + int64_t totalMfmaIssueCycles = numMfmas * mfmaIssueCycles; + localLoadIssueExposed = + (mfmaCycles - totalMfmaIssueCycles) < totalLoadIssueCycles; + } + // Fallback to 2x2x1 tile shape. + return DotTileShapeType({2, 2, 1}); +} + +/* + DotTiling creates tiles of mfmas while they are decomposed from a dot + operation. A tile of mfmas is a set of mfmas that will be co-scheduled because + they use the same A,B operands; co-scheduling mfmas with same operands allows + finer control over prefetching from LDS and register usage for these operands. + Args: + - inputNumRepM - total number of [decomposed] dot ops along m. + - inputNumRepN - total number of [decomposed] dot ops along n. + - inputTileShapeM - number of [decomposed] dot ops along m per tile. + - inputTileShapeN - number of [decomposed] dot ops along n per tile. + - inputOuterLoopM - should be set to (warpTileM >= warpTileN). True means m + should be outer loop of mfma ops so that inner loop is smaller dimension which + leads to smallest number of registers carrying A,B operands. E.g. numRep = + 8x4, tileShape=2x2. +*/ +class DotTileOrder { + const int numRepM; + const int numRepN; + const int tileShapeM; + const int tileShapeN; + const int numTilesM; + const int numTilesN; + bool outerTileM; + int tileShapeOuter; + int tileShapeInner; + int numTilesOuter; + int numTilesInner; + +public: + explicit DotTileOrder(int inputNumRepM, int inputNumRepN, int inputTileShapeM, + int inputTileShapeN, bool inputOuterLoopM) + : numRepM(inputNumRepM), numRepN(inputNumRepN), + tileShapeM(inputTileShapeM), tileShapeN(inputTileShapeN), + numTilesM(numRepM / tileShapeM), numTilesN(numRepN / tileShapeN), + outerTileM(inputOuterLoopM) { + // Num mfmas must evenly divide into tiles. + assert(numTilesM * tileShapeM == numRepM); + assert(numTilesN * tileShapeN == numRepN); + // Assign M and N to be outer vs inner tile loop. + if (outerTileM) { + // M is tile of outer loop. + tileShapeOuter = tileShapeM; + tileShapeInner = tileShapeN; + numTilesOuter = numTilesM; + numTilesInner = numTilesN; + } else { + // N is tile of outer loop. + tileShapeOuter = tileShapeN; + tileShapeInner = tileShapeM; + numTilesOuter = numTilesN; + numTilesInner = numTilesM; + } + } + int getTileShapeM() const { return tileShapeM; } + int getTileShapeN() const { return tileShapeN; } + int getNumTilesOuter() const { return numTilesOuter; } + int getNumTilesInner() const { return numTilesInner; } + int getTileStartM(int tileOuterIdx, int tileInnerIdx) const { + if (outerTileM) { + return tileOuterIdx * tileShapeOuter; // M is outer tile loop. + } else { + return tileInnerIdx * tileShapeInner; // M is inner tile loop. + } + } + int getTileStartN(int tileOuterIdx, int tileInnerIdx) const { + if (outerTileM) { + return tileInnerIdx * tileShapeInner; // N is inner tile loop. + } else { + return tileOuterIdx * tileShapeOuter; // N is outer tile loop. + } + } + int getNumTilesM() const { return numTilesM; } + int getNumTilesN() const { return numTilesN; } + int getOuterTileM() const { return outerTileM; } +}; + +} // namespace diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h index fccb65d061ab..ee4a536468a4 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.h +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.h @@ -46,6 +46,14 @@ createTritonAMDGPUUpdateAsyncWaitCountPass(std::string archGenName = {}); std::unique_ptr createTritonAMDGPUFoldTrueCmpIPass(); +std::unique_ptr createTritonAMDGPUMembarAnalysisPass(); + +std::unique_ptr> +createTritonAMDGPURefineOpsPass(StringRef targetArch); + +std::unique_ptr> +createTritonAMDGPURescheduleOpsPass(StringRef targetArch); + /// Generate the code for registering passes. #define GEN_PASS_REGISTRATION #include "TritonAMDGPUTransforms/Passes.h.inc" diff --git a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td index 91bd40000222..be47b8d9ba35 100644 --- a/third_party/amd/include/TritonAMDGPUTransforms/Passes.td +++ b/third_party/amd/include/TritonAMDGPUTransforms/Passes.td @@ -279,4 +279,52 @@ def TritonAMDFoldTrueCmpI: Pass<"tritonamdgpu-fold-true-cmpi", "mlir::ModuleOp"> } +def TritonAMDGPUMembarAnalysis : Pass<"triton-amdgpu-membar-analysis", "mlir::ModuleOp"> { + let summary = "Perform the memory-barrier analysis"; + let description = [{ + This pass allocates shared memory and set barriers. + }]; + + let constructor = "mlir::createTritonAMDGPUMembarAnalysisPass()"; +} + + +def TritonAMDGPURefineOps : Pass<"triton-amdgpu-refine-ops", "mlir::triton::FuncOp"> { + let summary = "Convert Triton Ops to fine-grain ones"; + let constructor = "mlir::createTritonAMDGPURefineOpsPass(\"\")"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::math::MathDialect", + "mlir::gpu::GPUDialect", + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::amdgpu::TritonAMDGPUDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::ROCDL::ROCDLDialect"]; + + let options = [ + Option<"arch", "arch", "std::string", /*default*/"\"\"", + "gfx target device architecture, e.g., gfx942"> + ]; +} + +def TritonAMDGPURescheduleOps : Pass<"triton-amdgpu-reschedule-ops", "mlir::ModuleOp"> { + let summary = "Reschedule fine-grain triton ops ones"; + let constructor = "mlir::createTritonAMDGPURescheduleOpsPass(\"\")"; + + let dependentDialects = ["mlir::arith::ArithDialect", + "mlir::math::MathDialect", + "mlir::gpu::GPUDialect", + "mlir::scf::SCFDialect", + "mlir::triton::TritonDialect", + "mlir::triton::amdgpu::TritonAMDGPUDialect", + "mlir::triton::gpu::TritonGPUDialect", + "mlir::ROCDL::ROCDLDialect"]; + + let options = [ + Option<"arch", "arch", "std::string", /*default*/"\"\"", + "gfx target device architecture, e.g., gfx942"> + ]; +} + #endif diff --git a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp index 7543805fc084..407285345c69 100644 --- a/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp +++ b/third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp @@ -76,23 +76,20 @@ LogicalResult ExtractSliceOp::verify() { if (srcTy.getRank() != resultTy.getRank()) { return emitError("result rank must be equal to source rank"); } - if (srcTy.getRank() != 2) { - return emitError("currently only 2D tensors are supported"); - } + int64_t rank = srcTy.getRank(); auto srcShape = srcTy.getShape(); + auto shapePerCTATile = mlir::triton::gpu::getShapePerCTATile(srcTy); // ExtractSlice only supports slicing where offsets and sizes are multiples of // shapePerCTATile. This condition ensures that slice has the same layout as // the original tensor. auto offsets = getStaticOffsets(); - if (offsets.size() != 2) { - return emitError("invalid offset shape ") << offsets; + if (offsets.size() != rank) { + return emitError("offsets rank must equal source rank ") << offsets; } - - SmallVector sizes; - for (auto i = 0; i < 2; ++i) { + for (auto i = 0; i < rank; ++i) { auto resultDimSize = resultTy.getDimSize(i); auto srcDimSize = srcTy.getDimSize(i); if (resultDimSize == 0) { @@ -110,29 +107,50 @@ LogicalResult ExtractSliceOp::verify() { return emitError("invalid offset ") << offsets[i] << " at dimension " << i; } - sizes.push_back(resultDimSize); - } + int64_t size = resultDimSize; + + int64_t dimSizePerCTATile = + std::min(static_cast(srcShape[i]), shapePerCTATile[i]); + if (size % dimSizePerCTATile != 0) { + return emitError() << "size [" << size + << "] must be a multiple of shapePerCTATile [" + << dimSizePerCTATile << "]"; + } - auto shapePerCTATile = mlir::triton::gpu::getShapePerCTATile(srcTy); - shapePerCTATile[0] = - std::min(static_cast(srcShape[0]), shapePerCTATile[0]); - shapePerCTATile[1] = - std::min(static_cast(srcShape[1]), shapePerCTATile[1]); - if (sizes[0] % shapePerCTATile[0] != 0 || - sizes[1] % shapePerCTATile[1] != 0) { - return emitError() << "sizes [" << sizes - << "] must be a multiple of shapePerCTATile [" - << shapePerCTATile << "]"; + if (offsets[i] % dimSizePerCTATile != 0) { + return emitError() << "offset [" << offsets + << "] must be a multiple of shapePerCTATile [" + << dimSizePerCTATile << "]"; + } } - if (offsets[0] % shapePerCTATile[0] != 0 || - offsets[1] % shapePerCTATile[1] != 0) { - return emitError() << "offset [" << offsets - << "] must be a multiple of shapePerCTATile [" - << shapePerCTATile << "]"; + return success(); +} + +struct CanonicalizeExtractSliceOp + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(amdgpu::ExtractSliceOp op, + PatternRewriter &rewriter) const override { + auto result = op.getResult(); + auto resultType = cast(result.getType()); + auto source = op.getSource(); + auto sourceType = cast(source.getType()); + auto offsets = op.getStaticOffsets(); + + if (resultType == sourceType) { + result.replaceAllUsesWith(source); + return success(); + } + return failure(); } +}; - return success(); +void ExtractSliceOp::getCanonicalizationPatterns( + mlir::RewritePatternSet &patterns, mlir::MLIRContext *context) { + patterns.add(context); } LogicalResult UpcastMXFPOp::verify() { @@ -302,4 +320,188 @@ InThreadTransposeOp::deduceOutputLayout(ArrayRef shape, return transposedLL; } +LogicalResult ConcatOp::verify() { + auto sources = getSources(); + auto coords = getCoords(); + + auto expectedNumSources = product(coords); + if (sources.size() != expectedNumSources) { + return emitError() << "dims spec [" << coords + << "] does not match the number of provided sources [" + << sources.size() << "]"; + } + + auto srcType = dyn_cast(sources.front().getType()); + if (!srcType) + return emitError() << "expected source type is `RankedTensorType`"; + + for (auto source : sources) { + auto currType = dyn_cast(source.getType()); + if (srcType != currType) + return emitError() << "sources are expected to have the same type"; + } + + auto result = getResult(); + auto dstType = dyn_cast(result.getType()); + if (dstType.getElementType() != srcType.getElementType()) + return emitError() << "sources and the destination are expected to have " + "the same element type"; + + auto dstShape = dstType.getShape(); + auto srcShape = srcType.getShape(); + if (dstShape.size() != srcShape.size()) + return emitError() + << "sources and the destination must have the same shape size"; + + if (dstShape.size() != coords.size()) + return emitError() << "shape size of the destination and concat. coords " + "must be the same"; + + for (auto [idx, coordValue] : llvm::enumerate(coords)) { + auto scaledSrcDim = srcShape[idx] * coordValue; + if (dstShape[idx] != scaledSrcDim) { + return emitError() << "mismatch along dim [" << idx + << "]. Expected size `" << dstShape[idx] << "`; give `" + << scaledSrcDim << "` after concatenation"; + } + } + + return success(); +} + +struct CanonicalizeConcatOpFromExtractSlice + : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(amdgpu::ExtractSliceOp op, + PatternRewriter &rewriter) const override { + auto concatOp = op.getSource().getDefiningOp(); + if (!concatOp) + return failure(); + + auto offset = op.getStaticOffsets(); + auto coords = concatOp.getCoords(); + if (coords.size() != offset.size()) + return failure(); + + auto sliceResult = op.getResult(); + auto sliceResultType = sliceResult.getType(); + auto sliceResultShape = sliceResultType.getShape(); + + auto concatItem = concatOp.getSources().front(); + auto concatItemType = dyn_cast(concatItem.getType()); + if (!concatItemType) + return failure(); + + if (sliceResultType != concatItemType) + return failure(); + + auto concatItemShape = concatItemType.getShape(); + SmallVector dimScales(concatItemShape.size(), 1); + int64_t concatItemIndex = 0; + std::exclusive_scan(coords.rbegin(), coords.rend(), dimScales.rbegin(), 1, + std::multiplies<>()); + for (auto [idx, itemDimSize] : llvm::enumerate(concatItemShape)) { + if ((offset[idx] % itemDimSize) != 0) + return failure(); + const auto sliceCoords = offset[idx] / itemDimSize; + concatItemIndex += sliceCoords * dimScales[idx]; + } + assert(concatItemIndex < concatOp->getNumOperands() && + "concat index must be in bounds"); + Value concreteConcatItem = concatOp->getOperand(concatItemIndex); + rewriter.replaceOp(op, concreteConcatItem); + + return success(); + } +}; + +struct CanonicalizeConcatOp : public mlir::OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(amdgpu::ConcatOp op, + PatternRewriter &rewriter) const override { + + auto result = op.getResult(); + auto sources = op.getSources(); + auto offsets = op.getCoords(); + if (sources.size() == 1) { + assert(product(offsets) == 1); + auto source = sources.front(); + result.replaceAllUsesWith(source); + return success(); + } + + return failure(); + } +}; + +struct DotOpPropagateAttrPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + using AttrType = amdgpu::DotTileAttr; + + void propagate(Operation *op, AttrType targetAttr) const { + Block *currBlock = op->getBlock(); + for (auto operand : op->getOperands()) { + if (mlir::isa(operand)) + continue; + + auto definingOp = operand.getDefiningOp(); + if (definingOp->getBlock() != currBlock) + continue; + + bool allowedDialects = + mlir::isa(definingOp->getDialect()); + allowedDialects |= + mlir::isa(definingOp->getDialect()); + if (!allowedDialects) + continue; + + auto operandArrayAttr = + definingOp->getAttrOfType(attrName); + if (!operandArrayAttr) { + operandArrayAttr = mlir::ArrayAttr::get(op->getContext(), {}); + } + + SmallVector updatedAttrs(operandArrayAttr.getValue()); + auto result = operandArrayAttr.walk([&targetAttr](AttrType itemAttr) { + return itemAttr == targetAttr ? WalkResult::interrupt() + : WalkResult::advance(); + }); + + if (!result.wasInterrupted()) { + updatedAttrs.push_back(targetAttr); + } + + if (!updatedAttrs.empty()) + definingOp->setAttr( + attrName, mlir::ArrayAttr::get(op->getContext(), updatedAttrs)); + + propagate(definingOp, targetAttr); + } + } + + LogicalResult matchAndRewrite(triton::DotOp dotOp, + PatternRewriter &rewriter) const override { + amdgpu::DotTileAttr opAttr = dotOp->getAttrOfType(attrName); + if (!opAttr) + return failure(); + + propagate(dotOp, opAttr); + + return success(); + } + +private: + static const inline StringLiteral attrName = AttrType::getMnemonic(); +}; + +void ConcatOp::getCanonicalizationPatterns(mlir::RewritePatternSet &patterns, + mlir::MLIRContext *context) { + patterns.add(context); + patterns.add(context); + patterns.add(context); +} } // namespace mlir::triton::amdgpu diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt index 693bd41bc55a..35310b86eecd 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/CMakeLists.txt @@ -2,6 +2,7 @@ add_triton_library(TritonAMDGPUDialectToLLVM TritonAMDGPUToLLVMPatterns.cpp ExtractSliceOpToLLVM.cpp InThreadTransposeOpToTTG.cpp + ConcatOpToLLVM.cpp DEPENDS TritonAMDGPUIR diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ConcatOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ConcatOpToLLVM.cpp new file mode 100644 index 000000000000..12c151502255 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ConcatOpToLLVM.cpp @@ -0,0 +1,70 @@ +#include "Dialect/TritonAMDGPU/IR/Dialect.h" +#include "TritonAMDGPUToLLVM/GCNAsmFormat.h" +#include "mlir/Conversion/LLVMCommon/Pattern.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/Utility/CommonUtils.h" +#include "triton/Analysis/Utility.h" +#include "triton/Conversion/MLIRTypes.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" + +using namespace mlir; +using namespace mlir::triton; + +namespace { + +inline size_t getSourceSize(Value source) { + ArrayRef types = cast(source.getType()).getBody(); + return types.size(); +} + +struct ConcatOpConversion : public ConvertOpToLLVMPattern { + explicit ConcatOpConversion(LLVMTypeConverter &typeConverter, + PatternBenefit benefit = 1) + : ConvertOpToLLVMPattern(typeConverter, benefit) {} + + LogicalResult + matchAndRewrite(amdgpu::ConcatOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + auto resultTy = cast(op.getResult().getType()); + auto dims = op.getCoords(); + auto order = op.getLoweringOrder(); + + auto sources = adaptor.getSources(); + auto elemPerSource = getSourceSize(sources.front()); + llvm::SmallVector resultVals; + + auto coords = + mlir::triton::AMD::CoordinateMapper::cartesian(dims.vec(), order.vec()); + std::vector strides(dims.size(), 1); + std::exclusive_scan(dims.rbegin(), dims.rend(), strides.rbegin(), 1, + std::multiplies<>()); + + for (const auto &vec : coords) { + int linearIndex = 0; + for (size_t i = 0; i < vec.size(); ++i) { + linearIndex += vec[i] * strides[i]; + } + + auto elements = unpackLLElements(loc, sources[linearIndex], rewriter); + for (auto elem : elements) { + resultVals.push_back(elem); + } + } + + Value ret = packLLElements(loc, this->getTypeConverter(), resultVals, + rewriter, resultTy); + + rewriter.replaceOp(op, ret); + return llvm::success(); + } +}; +} // namespace + +namespace mlir::triton::AMD { +void populateConcatOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter, + mlir::RewritePatternSet &patterns, + mlir::PatternBenefit benefit) { + patterns.add(typeConverter, benefit); +} +} // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp index 07cf91870fed..a5f6bc63f2df 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp @@ -65,50 +65,60 @@ struct ExtractSliceOpConversion auto resultTy = cast(op.getType()); auto vals = unpackLLElements(loc, adaptor.getSource(), rewriter); auto elemsPerThread = triton::gpu::getElemsPerThread(srcTy); - auto contigPerThread = triton::gpu::getContigPerThread(srcTy); + auto contigPerThread = + triton::gpu::toLinearEncoding(srcTy).getSizePerThread(); + auto totalContigPerThread = product(contigPerThread); auto order = triton::gpu::getOrder(srcTy); // Calculate valid total number of workers in each dimension auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcTy); - shapePerCTATile[0] = - std::min(static_cast(srcShape[0]), shapePerCTATile[0]); - shapePerCTATile[1] = - std::min(static_cast(srcShape[1]), shapePerCTATile[1]); - - // Rank == 2 checked in the verifier - SmallVector sizes; - for (auto i = 0; i < 2; ++i) { - sizes.push_back(resultTy.getDimSize(i)); + for (auto i = 0; i < shapePerCTATile.size(); ++i) { + shapePerCTATile[i] = + std::min(static_cast(srcShape[i]), shapePerCTATile[i]); } auto offsets = op.getStaticOffsets(); // Calculate offsets and sizes in terms of CTA units. - std::array CTAOffsets{offsets[0] / shapePerCTATile[0], - offsets[1] / shapePerCTATile[1]}; - std::array CTASizes{sizes[0] / shapePerCTATile[0], - sizes[1] / shapePerCTATile[1]}; - std::array CTAPerShape{srcShape[0] / shapePerCTATile[0], - srcShape[1] / shapePerCTATile[1]}; - - // The diagram above illustrates the graphical representation of the - // skipElems, tensorStride, and lastIdx variables. - auto skipElems = CTAOffsets[order[1]] * (elemsPerThread[order[0]] * - contigPerThread[order[1]]) + - CTAOffsets[order[0]] * totalContigPerThread; - auto tensorStride = - (CTAPerShape[order[0]] - CTASizes[order[0]]) * totalContigPerThread; - auto lastIdx = - (CTAOffsets[order[1]] + CTASizes[order[1]] - 1) * - elemsPerThread[order[0]] * contigPerThread[order[1]] + - (CTAOffsets[order[0]] + CTASizes[order[0]]) * totalContigPerThread; + SmallVector sizes; + SmallVector CTAOffsets; + SmallVector CTASizes; + SmallVector CTAPerShape; + for (auto i = 0; i < resultTy.getRank(); ++i) { + sizes.push_back(resultTy.getDimSize(i)); + CTAOffsets.push_back(offsets[i] / shapePerCTATile[i]); + CTASizes.push_back(sizes[i] / shapePerCTATile[i]); + CTAPerShape.push_back(srcShape[i] / shapePerCTATile[i]); + } + + // SliceLayout uses 1d offsets. + auto skipElems = CTAOffsets[0] * totalContigPerThread; + auto tensorStride = (CTAPerShape[0] - CTASizes[0]) * totalContigPerThread; + auto lastIdx = (CTAOffsets[0] + CTASizes[0]) * totalContigPerThread; + auto numElemsPerVec = totalContigPerThread * CTASizes[0]; + + auto sliceLayout = mlir::dyn_cast(srcLayout); + if (!sliceLayout) { + // The diagram above illustrates the graphical representation of the + // skipElems, tensorStride, and lastIdx variables. + skipElems = CTAOffsets[order[1]] * + (elemsPerThread[order[0]] * contigPerThread[order[1]]) + + CTAOffsets[order[0]] * totalContigPerThread; + tensorStride = + (CTAPerShape[order[0]] - CTASizes[order[0]]) * totalContigPerThread; + lastIdx = + (CTAOffsets[order[1]] + CTASizes[order[1]] - 1) * + elemsPerThread[order[0]] * contigPerThread[order[1]] + + (CTAOffsets[order[0]] + CTASizes[order[0]]) * totalContigPerThread; + numElemsPerVec = totalContigPerThread * CTASizes[order[0]]; + } assert(lastIdx <= vals.size()); SmallVector resultVals; for (int i = skipElems; i < lastIdx; i += tensorStride) { - for (int j = 0; j < totalContigPerThread * CTASizes[order[0]]; ++j, ++i) { + for (int j = 0; j < numElemsPerVec; ++j, ++i) { assert(i < lastIdx); resultVals.push_back(vals[i]); } @@ -124,9 +134,16 @@ struct ExtractSliceOpConversion matchAndRewrite(amdgpu::ExtractSliceOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { auto srcTy = op.getSource().getType(); - if (isa( - op.getSource().getType().getEncoding())) { + auto encoding = srcTy.getEncoding(); + if (isa( + encoding)) { return processLayout(op, adaptor, rewriter); + } else if (auto sliceLayout = mlir::dyn_cast(encoding)) { + auto parent = sliceLayout.getParent(); + if (isa( + parent)) { + return processLayout(op, adaptor, rewriter); + } } return failure(); } diff --git a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp index a84d84b2819d..c0cf0fb5fefa 100644 --- a/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp +++ b/third_party/amd/lib/TritonAMDGPUDialectToLLVM/TritonAMDGPUToLLVMPatterns.cpp @@ -7,5 +7,6 @@ void populateTritonAMDGPUToLLVMPatterns(LLVMTypeConverter &typeConverter, PatternBenefit benefit) { populateExtractSliceOpToLLVMPatterns(typeConverter, patterns, benefit); populateInThreadTransposeOpToTTGPatterns(patterns, benefit); + populateConcatOpToLLVMPatterns(typeConverter, patterns, benefit); } } // namespace mlir::triton::AMD diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp index 849a73443e6e..0f707317f38c 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM/SharedToDotOperandMFMA.cpp @@ -138,17 +138,25 @@ bool hasSwizzleEnabled(const SwizzledSharedEncodingAttr &srcEncoding) { /// \param warpsPerBlock number of warps per horizontal axis /// \param numOfElems number of elements accessed by threads per repetition /// \param reps number of instructions repretition to fully cover dot operand -/// \param cSwizzleOffset +/// \param smemStrides shared memory strides +/// \param smemOffsets shared memory offsets llvm::SmallVector fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, const ArrayRef &elemsPerInstr, Value warpId, Value laneId, int warpsPerBlock, int numOfElems, - ArrayRef reps, Value cSwizzleOffset) { + ArrayRef reps, ArrayRef smemStrides, + ArrayRef smemOffsets) { auto b = TritonLLVMOpBuilder(loc, rewriter); auto numK = reps[1]; auto numN = reps[2]; SmallVector offsets(numK * numN * numOfElems); + auto smemStrideRow = smemStrides[0]; + auto smemStrideCol = smemStrides[1]; + + auto smemOffsetRow = smemOffsets[0]; + auto smmemOffsetCol = smemOffsets[1]; + auto iKDim = elemsPerInstr[0]; auto iNonKDim = elemsPerInstr[1]; int lineSize = warpsPerBlock * iNonKDim * numN; @@ -156,33 +164,38 @@ fastPathComputeOffsets(ConversionPatternRewriter &rewriter, Location loc, Value warpOffset = b.mul(warpId, b.i32_val(iNonKDim)); Value colOffset = b.urem(laneId, _nonKDim); + // halfOffset is an offset related to wrapping of warp in the tile. + // for example, mfma 32 case (mapping of tensor elements to lane ids in + // warp): + // + // 0 1 2 3 ... 31 + // 0 1 2 3 ... 31 + // 0 1 2 3 ... 31 + // 0 1 2 3 ... 31 + // 32 33 34 35 ... 63 <- at this point warp is wrapping + // 32 33 34 35 ... 63 + // 32 33 34 35 ... 63 + // 32 33 34 35 ... 63 + Value halfOffset; + if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) + halfOffset = b.i32_val(0); + else + halfOffset = b.mul(b.udiv(laneId, _nonKDim), b.i32_val(numOfElems)); + for (int block = 0; block < numN; ++block) { Value blockOffset = b.i32_val(block * iNonKDim * warpsPerBlock); for (int tile = 0; tile < numK; ++tile) { - Value tileOffset = b.i32_val(tile * iKDim * lineSize); + Value tileOffset = b.i32_val(tile * iKDim); for (int elem = 0; elem < numOfElems; ++elem) { - // halfOffset is an offset related to wrapping of warp in the tile. - // for example, mfma 32 case (mapping of tensor elements to lane ids in - // warp): - // - // 0 1 2 3 ... 31 - // 0 1 2 3 ... 31 - // 0 1 2 3 ... 31 - // 0 1 2 3 ... 31 - // 32 33 34 35 ... 63 <- at this point warp is wrapping - // 32 33 34 35 ... 63 - // 32 33 34 35 ... 63 - // 32 33 34 35 ... 63 - Value halfOffset; - if ((iKDim == 1 || iKDim == 4) && iNonKDim == 4) - halfOffset = b.i32_val(0); - else - halfOffset = - b.mul(b.udiv(laneId, _nonKDim), b.i32_val(numOfElems * lineSize)); - Value rowOffset = b.add(b.i32_val(elem * lineSize), halfOffset); - Value elemOffset = b.add(rowOffset, colOffset); - Value offset = b.add(b.add(b.add(warpOffset, blockOffset), tileOffset), - elemOffset); + + Value rowOffset = b.add(b.i32_val(elem), halfOffset); + Value row = b.add(b.add(rowOffset, tileOffset), smemOffsetRow); + Value col = b.add(b.add(blockOffset, b.add(warpOffset, colOffset)), + smmemOffsetCol); + + Value offset = + b.add(b.mul(row, smemStrideRow), b.mul(col, smemStrideCol)); + offsets[numK * numOfElems * block + numOfElems * tile + elem] = offset; } } @@ -302,6 +315,7 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, SmallVector offsets; Value smemBase; auto smemStrides = smemObj.getStrides(aTensorTy, loc, rewriter); + auto smemOffsets = smemObj.getOffsets(); bool isFastPath = !AMD::isKContig(order, opIdx) && !hasSwizzleEnabled(sharedLayout); if (isFastPath) { @@ -309,14 +323,15 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, // disabled, in which case offsets computation can be simplified // TODO (zhanglx): later when we enable vector access to LDS for non k-major // tensors, we'll refactor the scope of fast and normal path - Value cSwizzleOffset = smemObj.getCSwizzleOffset(order[0]); if (opIdx == 0) { if (isColMajor(order)) { SmallVector elemsPerInstr{mfmaInstrK, mfmaInstrNonK}; SmallVector reps{numReps[0], numReps[2], numReps[1]}; + SmallVector tStrides = {smemStrides[1], smemStrides[0]}; + SmallVector tOffsets = {smemOffsets[1], smemOffsets[0]}; offsets = fastPathComputeOffsets(rewriter, loc, elemsPerInstr, spatialWarpId, lane, warpsPerBlockNonK, - numOfElems, reps, cSwizzleOffset); + numOfElems, reps, tStrides, tOffsets); } else { llvm_unreachable( "row major operand A should be handled in the normal path"); @@ -326,12 +341,11 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, llvm_unreachable( "col major operand B should be handled in the normal path"); } else { - offsets = fastPathComputeOffsets(rewriter, loc, elemsPerInstr, - spatialWarpId, lane, warpsPerBlockNonK, - numOfElems, numReps, cSwizzleOffset); + offsets = fastPathComputeOffsets( + rewriter, loc, elemsPerInstr, spatialWarpId, lane, + warpsPerBlockNonK, numOfElems, numReps, smemStrides, smemOffsets); } } - smemBase = smemObj.getBaseBeforeSlice(order[0], loc, rewriter); } else { // normal path // Normal path handles tensors that fall into either of the following three // cases: @@ -351,8 +365,8 @@ Value convertLayout(int opIdx, ConversionPatternRewriter &rewriter, spatialWarpId, lane, warpsPerBlockNonK, numOfElems, numReps, smemObj, smemStrides, sharedLayout, nDim, mfmaInstrK); } - smemBase = AMD::computeBasePtr(rewriter, loc, smemObj, smemStrides); } + smemBase = AMD::computeBasePtr(rewriter, loc, smemObj, smemStrides); Type resElemTy = typeConverter->convertType(elemTy); Type smemPtrTy = ptr_ty(rewriter.getContext(), 3); diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp index f77abbf66771..936b44596d7b 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp @@ -2,7 +2,6 @@ #include "Utility.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "triton/Analysis/Allocation.h" #include "triton/Conversion/TritonGPUToLLVM/ElementwiseOpToLLVMBase.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" @@ -1890,8 +1889,8 @@ struct PreciseSqrtOpConversion namespace mlir::triton::AMD { void populateElementwiseOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, bool ftz, - ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, - const TargetInfo &targetInfo, PatternBenefit benefit) { + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfo &targetInfo, + PatternBenefit benefit) { // fmin (return NaN if either op is NaN) patterns.add>( diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h index 0c4915cf547f..79fef0b537f8 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/PatternTritonGPUOpToLLVM.h @@ -23,8 +23,8 @@ void populateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter, PatternBenefit benefit); void populateElementwiseOpToLLVMPatterns( LLVMTypeConverter &typeConverter, RewritePatternSet &patterns, bool ftz, - ModuleAxisInfoAnalysis &axisInfoAnalysis, ModuleAllocation &allocation, - const TargetInfo &targetInfo, PatternBenefit benefit); + ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfo &targetInfo, + PatternBenefit benefit); void populateLoadStoreOpToLLVMPatterns(LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo, RewritePatternSet &patterns, diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp index e4a6816c7ca5..b8633a13dace 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/SchedInstructions.cpp @@ -6,6 +6,7 @@ #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Pass/Pass.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/Utility/CommonUtils.h" #include "triton/Conversion/TritonGPUToLLVM/Utility.h" namespace mlir::triton { @@ -522,6 +523,19 @@ struct TritonAMDGPUInsertInstructionSchedHints return; } + if (schedHint == mlir::triton::amdgpu::SchedHint::refine_ops) { + mod->walk([&](triton::FuncOp funcOp) { + auto forOps = AMD::getLeafForOps(funcOp); + for (auto forOp : forOps) { + OpBuilder rewriter(ctx); + rewriter.setInsertionPointToStart(forOp.getBody()); + rewriter.create(forOp->getLoc(), + schedHint); + } + }); + return; + } + switch (schedHint) { case mlir::triton::amdgpu::SchedHint::local_prefetch: mod.walk([&](scf::ForOp forOp) { diff --git a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp index 6f87ecb6796d..619e827dbbd0 100644 --- a/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp +++ b/third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp @@ -15,7 +15,6 @@ #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" #include "mlir/Pass/Pass.h" #include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" -#include "triton/Analysis/Allocation.h" #include "triton/Analysis/AxisInfo.h" #include "triton/Analysis/Membar.h" #include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h" @@ -57,6 +56,7 @@ class TritonLLVMConversionTarget : public ConversionTarget { addIllegalDialect(); addIllegalDialect(); addIllegalDialect(); + addIllegalDialect(); addLegalOp(); addLegalOp(); } @@ -95,11 +95,6 @@ struct ConvertTritonAMDGPUToLLVM int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod); int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod); - // Allocate shared memory and set barrier - ModuleAllocation allocation(mod); - ModuleMembarAnalysis membarPass(&allocation); - membarPass.run(); - // Lower functions { TritonLLVMFunctionConversionTarget funcTarget(*context); @@ -142,8 +137,7 @@ struct ConvertTritonAMDGPUToLLVM // patterns int AMDBenefit = commonBenefit + 1; auto populatePatterns1 = [&](auto populateFunc, int benefit) { - populateFunc(typeConverter, patterns, axisInfoAnalysis, allocation, - benefit); + populateFunc(typeConverter, patterns, axisInfoAnalysis, benefit); }; auto populatePatterns5 = [&](auto populateFunc, int benefit) { @@ -151,8 +145,8 @@ struct ConvertTritonAMDGPUToLLVM }; auto populatePatterns6 = [&](auto populateFunc, int benefit) { - populateFunc(typeConverter, patterns, axisInfoAnalysis, allocation, - targetInfo, benefit); + populateFunc(typeConverter, patterns, axisInfoAnalysis, targetInfo, + benefit); }; auto populatePatterns7 = [&](auto populateFunc, int benefit) { @@ -165,9 +159,8 @@ struct ConvertTritonAMDGPUToLLVM typeConverter, targetInfo, patterns, commonBenefit); AMD::populateDotOpToLLVMPatterns(typeConverter, patterns, axisInfoAnalysis, AMDBenefit); - AMD::populateElementwiseOpToLLVMPatterns(typeConverter, patterns, ftz, - axisInfoAnalysis, allocation, - targetInfo, AMDBenefit); + AMD::populateElementwiseOpToLLVMPatterns( + typeConverter, patterns, ftz, axisInfoAnalysis, targetInfo, AMDBenefit); AMD::populateLoadStoreOpToLLVMPatterns(typeConverter, targetInfo, patterns, axisInfoAnalysis, AMDBenefit); populatePatterns7(mlir::triton::populateReduceOpToLLVMPatterns, diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt index 836720b43901..45d9600eb8e1 100644 --- a/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt +++ b/third_party/amd/lib/TritonAMDGPUTransforms/CMakeLists.txt @@ -13,6 +13,9 @@ add_triton_library(TritonAMDGPUTransforms FoldTrueCmpIOp.cpp UpdateAsyncWaitCount.cpp Utility.cpp + MembarAnalysis.cpp + RefineOps.cpp + RescheduleOps.cpp DEPENDS TritonAMDGPUIR diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/MembarAnalysis.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/MembarAnalysis.cpp new file mode 100644 index 000000000000..bda73f274406 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUTransforms/MembarAnalysis.cpp @@ -0,0 +1,28 @@ +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Pass/Pass.h" +#include "triton/Analysis/Allocation.h" +#include "triton/Analysis/Membar.h" + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h" + +namespace { +struct TritonAMDGPUMembarAnalysis + : public mlir::TritonAMDGPUMembarAnalysisBase { + + void runOnOperation() override { + mlir::ModuleOp mod = getOperation(); + + // Allocate shared memory and set barrier + mlir::ModuleAllocation allocation(mod); + mlir::ModuleMembarAnalysis membarPass(&allocation); + membarPass.run(); + } +}; +} // namespace + +namespace mlir { +std::unique_ptr createTritonAMDGPUMembarAnalysisPass() { + return std::make_unique(); +} +} // namespace mlir diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/RefineOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/RefineOps.cpp new file mode 100644 index 000000000000..d2166fd10ec1 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUTransforms/RefineOps.cpp @@ -0,0 +1,1293 @@ +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/Utility/CommonUtils.h" +#include "third_party/amd/include/TritonAMDGPUTransforms/DotTiling.h" +#include "third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h" +#include "third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h" + +#undef DEBUG_TYPE +#define DEBUG_TYPE "tritonamdgpu-refine-ops" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +using namespace mlir; +namespace tt = mlir::triton; +namespace ttg = mlir::triton::gpu; + +template +llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, ArrayRef vec) { + for (size_t i = 0; i < vec.size(); ++i) { + const char delim = (i != vec.size() - 1) ? ',' : '\n'; + stream << vec[i] << delim; + } + return stream; +} + +template +llvm::raw_ostream &operator<<(llvm::raw_ostream &stream, SmallVector vec) { + stream << ArrayRef(vec); + return stream; +} + +namespace { +SmallVector createOffset(llvm::ArrayRef valueOffset, + llvm::ArrayRef intOffset, + OpBuilder &rewriter, Location loc) { + SmallVector values; + for (auto item : valueOffset) { + values.push_back(item); + } + + for (auto item : intOffset) { + Value value = rewriter.create(loc, item, 32); + values.push_back(value); + } + return values; +} + +inline bool isRowMajor(::llvm::ArrayRef order) { + auto rank = order.size(); + return order[rank - 1] == 0; +} + +inline RankedTensorType rankedTType(Value tensor) { + return cast(tensor.getType()); +} + +SmallVector getRefinedShapePerCTATile(Type type) { + auto tensorType = cast(type); + return mlir::triton::gpu::getShapePerCTATile(tensorType); +} + +struct RefinedBlock { + RefinedBlock(ArrayRef shape, Type elemType, + BlockedEncodingAttr encoding) + : encoding(encoding), elemType(elemType) { + auto ctaOrder = encoding.getCTAOrder(); + auto warpsPerCTA = encoding.getWarpsPerCTA(); + auto threadsPerWarp = encoding.getThreadsPerWarp(); + auto sizePerThread = encoding.getSizePerThread(); + + numDims = warpsPerCTA.size(); + elementsPerWorkGroup.resize(numDims); + numPerDims.resize(numDims); + refinedShape.resize(numDims); + numSubTiles = 1; + for (size_t dim = 0; dim < numDims; ++dim) { + elementsPerWorkGroup[dim] = + sizePerThread[dim] * threadsPerWarp[dim] * warpsPerCTA[dim]; + numPerDims[dim] = shape[dim] / elementsPerWorkGroup[dim]; + refinedShape[dim] = shape[dim] / numPerDims[dim]; + numSubTiles *= numPerDims[dim]; + } + + tensorType = + RankedTensorType::get(elementsPerWorkGroup, elemType, encoding); + } + + BlockedEncodingAttr encoding; + Type elemType; + SmallVector elementsPerWorkGroup; + SmallVector numPerDims; + SmallVector refinedShape; + size_t numDims; + size_t numSubTiles; + RankedTensorType tensorType; +}; + +template +struct RefineRewritePattern : public OpRewritePattern { + RefineRewritePattern(MLIRContext *context, PatternBenefit benefit = 1) + : OpRewritePattern(context, benefit) {} + + virtual LogicalResult apply(OpTy op, PatternRewriter &rewriter) const = 0; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const final { + if (!isRefinable(op)) + return failure(); + return apply(op, rewriter); + } + +private: + bool isRefinable(Operation *op) const { + mlir::Block *block = op->getBlock(); + while (block) { + for (auto &op : block->getOperations()) { + if (auto hint = dyn_cast(op)) { + if (hint.getVariant() == triton::amdgpu::SchedHint::refine_ops) { + return true; + } + } + } + block = block->getParentOp()->getBlock(); + } + return false; + } +}; + +struct DotOpMFMAConverter { + AMDMfmaEncodingAttr mfmaLayout; + PatternRewriter &rewriter; + Location loc; + MLIRContext *ctx{}; + + explicit DotOpMFMAConverter(AMDMfmaEncodingAttr mfmaLayout, + PatternRewriter &rewriter, Location loc) + : mfmaLayout(mfmaLayout), rewriter(rewriter), loc(loc), + ctx(mfmaLayout.getContext()) {} + + LogicalResult convert(DotOp dotOp, DotOpAdaptor adaptor) const { + InputPrecisionAttr precisionAttr = dotOp.getInputPrecisionAttr(); + auto warpsPerCTA = mfmaLayout.getWarpsPerCTA(); + auto mDim = mfmaLayout.getMDim(); + auto nDim = mfmaLayout.getNDim(); + + Value a = dotOp.getA(); + Value b = dotOp.getB(); + Value c = dotOp.getC(); + Value d = dotOp.getD(); + + auto aTensorTy = cast(a.getType()); + auto bTensorTy = cast(b.getType()); + auto cTensorTy = cast(c.getType()); + auto dTensorTy = cast(d.getType()); + + auto elemTyA = aTensorTy.getElementType(); + auto elemTyB = bTensorTy.getElementType(); + auto elemTyC = cTensorTy.getElementType(); + auto elemTyD = dTensorTy.getElementType(); + + auto encodeA = cast(aTensorTy.getEncoding()); + auto encodeB = cast(bTensorTy.getEncoding()); + auto encodeC = cast(cTensorTy.getEncoding()); + auto encodeD = cast(dTensorTy.getEncoding()); + + auto shapeA = aTensorTy.getShape(); + auto shapeB = bTensorTy.getShape(); + auto shapeC = cTensorTy.getShape(); + auto shapeD = dTensorTy.getShape(); + + const auto kDimOperandSize = aTensorTy.getShape().back(); + + int kWidth = encodeA.getKWidth(); + auto repA = mfmaLayout.getRepForOperand(aTensorTy.getShape(), kWidth, 0); + auto repB = mfmaLayout.getRepForOperand(bTensorTy.getShape(), kWidth, 1); + assert(repA[2] == repB[1]); + + Value loadedA = adaptor.getA(); + Value loadedB = adaptor.getB(); + Value loadedC = adaptor.getC(); + + const auto numRepM = repA[1]; + const auto numRepN = repB[2]; + + // TODO(dtanner) This is a temporary workaround so that local_load and dot + // are decomposed the same and the intervening extract_slice and concat can + // be canonicalized away. Re-enable slicing dots along K when we know we can + // slice local_load along K too. + const auto numRepK = repA[2]; + // const int numRepK = 1; + const auto numRepB = repA[0]; + SmallVector numRepShape = {numRepM, numRepN, numRepK}; + LDBG("totalReps: " << numRepShape[0] << "x" << numRepShape[1] << "x" + << numRepShape[2]); + SmallVector refinedShapeA = {shapeA[0] / numRepM, + shapeA[1] / numRepK}; + SmallVector refinedShapeB = {shapeB[0] / numRepK, + shapeB[1] / numRepN}; + SmallVector refinedShapeCD = {shapeC[0] / numRepM, + shapeC[1] / numRepN}; + + // Calculate mfmas per rep. + SmallVector ctaTile = {shapeC[0], shapeC[1], shapeA[1]}; + SmallVector warpTile = { + shapeC[0] / warpsPerCTA[0], + shapeC[1] / warpsPerCTA[1], + shapeA[1], + }; + auto mfmaVersion = mfmaLayout.getVersionMajor(); + bool allowXF32 = + dotOp.getInputPrecision() == InputPrecision::TF32 && mfmaVersion == 3; + + FailureOr maybeMfmaInsn = MfmaIntrinsic::selectFor( + mfmaVersion, mDim, nDim, kDimOperandSize, elemTyA, elemTyB, + /*withScale=*/false, allowXF32); + + SmallVector mfmaShape = {16, 16, 16}; + if (failed(maybeMfmaInsn)) { + llvm::errs() << "No match found in MFMA database\n"; + } else { + mfmaShape[0] = maybeMfmaInsn->mDim; + mfmaShape[1] = maybeMfmaInsn->nDim; + mfmaShape[2] = maybeMfmaInsn->kDim; + } + + auto mfmasPerRep = + getMfmasPerRep(ctaTile, warpsPerCTA, numRepShape, mfmaShape); + + // Calculate Dot-Tiling. + unsigned cyclesPerMfma = getCyclesPerMfma(dotOp); + // Prefer tile to be skinny along inner loop dimension to minimize + // registers. + const bool preferOuterLoopM = + (warpTile[0] >= warpTile[1]); // true: row-major when tall warp-tile + const bool preferTileLargerM = + !preferOuterLoopM; // true: tall tiles when wide warp-tile + // Calculate dot-tile shape (in reps per dot-tile). + DotTileShapeType tileShape = + calcDotTileShape(mfmasPerRep, preferTileLargerM, cyclesPerMfma); + + tileShape[0] = std::min(tileShape[0], static_cast(numRepM)); + tileShape[1] = std::min(tileShape[1], static_cast(numRepN)); + tileShape[2] = std::min(tileShape[2], static_cast(numRepK)); + + LDBG("repsPerDotTile: " << tileShape[0] << "x" << tileShape[1] << "x" + << tileShape[2]); + const int tileShapeM = tileShape[0]; + const int tileShapeN = tileShape[1]; + const int tileShapeK = tileShape[2]; + const DotTileOrder dotTileOrder(numRepM, numRepN, tileShapeM, tileShapeN, + preferOuterLoopM); + + // Extract slices for A operands. + int64_t elementsPerSliceM = refinedShapeCD[0]; + int64_t elementsPerSliceN = refinedShapeCD[1]; + int64_t elementsPerSliceK = refinedShapeA[1]; + auto extractSliceTypeA = + RankedTensorType::get(refinedShapeA, elemTyA, encodeA); + rewriter.setInsertionPointAfter(dotOp); + SmallVector> subtilesA; + unsigned tileIdx = 0; + for (int32_t k = 0; k < numRepK; ++k) { + SmallVector subtilesK; + for (int32_t i = 0; i < numRepM; ++i) { + LDBG("local_load_a[" << i << "][" << k << "] extract_slice"); + int32_t shiftM = i * elementsPerSliceM; + int32_t shiftK = k * elementsPerSliceK; + auto extract = rewriter.create( + loc, Type{extractSliceTypeA}, Value{a}, + DenseI64ArrayAttr::get(ctx, {shiftM, shiftK})); + // Add dot-tile info to local_load's slice; + // this specifies which dot-tile this load is needed for. + int32_t tileM = i / tileShapeM; + int32_t tileN = -1; + int32_t tileK = k / tileShapeK; + int32_t tileSerial = dotTileOrder.getOuterTileM() + ? tileM * dotTileOrder.getNumTilesN() + : tileM; + tileSerial += + k * dotTileOrder.getNumTilesM() * dotTileOrder.getNumTilesN(); + int32_t elementM = i % tileShapeM; // dots are n-major within tile + int32_t elementN = -1; + int32_t elementK = k % tileShapeK; + int32_t elementSerial = + elementM * tileShapeN; // dots are n-major within tile + auto dotTileAttr = triton::amdgpu::DotTileAttr::get( + ctx, tileM, tileN, tileK, tileSerial, elementM, elementN, elementK, + elementSerial); + extract->setAttr(triton::amdgpu::DotTileAttr::getMnemonic(), + dotTileAttr); + subtilesK.push_back(extract); + } + subtilesA.push_back(subtilesK); + } + + // Extract slices for B operands. + auto extractSliceTypeB = + RankedTensorType::get(refinedShapeB, elemTyB, encodeB); + SmallVector> subtilesB; + tileIdx = 0; + for (int32_t k = 0; k < numRepK; ++k) { + SmallVector subtilesK; + for (int32_t j = 0; j < numRepN; ++j) { + LDBG("local_load_b[" << k << "][" << j << "] extact_slice"); + int32_t shiftN = j * elementsPerSliceN; + int32_t shiftK = k * elementsPerSliceK; + auto extract = rewriter.create( + loc, Type{extractSliceTypeB}, Value{b}, + DenseI64ArrayAttr::get(ctx, {shiftK, shiftN})); + // Add dot-tile info to local_load's slice; + // this specifies which dot-tile this load is needed for. + int32_t tileM = -1; + int32_t tileN = j / tileShapeN; + int32_t tileK = k / tileShapeK; + int32_t tileSerial = dotTileOrder.getOuterTileM() + ? tileN + : tileN * dotTileOrder.getNumTilesM(); + tileSerial += + k * dotTileOrder.getNumTilesM() * dotTileOrder.getNumTilesN(); + int32_t elementM = -1; + int32_t elementN = j % tileShapeN; // dots are n-major within tile + int32_t elementK = k % tileShapeK; + int32_t elementSerial = elementN; // dots are n-major within tile + auto dotTileAttr = triton::amdgpu::DotTileAttr::get( + ctx, tileM, tileN, tileK, tileSerial, elementM, elementN, elementK, + elementSerial); + extract->setAttr(triton::amdgpu::DotTileAttr::getMnemonic(), + dotTileAttr); + subtilesK.push_back(extract); + } + subtilesB.push_back(subtilesK); + } + + auto refinedTensorTypeC = + RankedTensorType::get(refinedShapeCD, elemTyC, encodeC); + auto refinedTensorTypeD = + RankedTensorType::get(refinedShapeCD, elemTyD, encodeD); + SmallVector refinedDotValues; + // Extract slices for refined C operands for first slice of K. + // Create these in same order that concat wants them. + for (int m = 0; m < numRepM; ++m) { + for (int n = 0; n < numRepN; ++n) { + SmallVector offset = {m * elementsPerSliceM, + n * elementsPerSliceN}; + auto refinedTensorC = rewriter.create( + loc, Type{refinedTensorTypeC}, Value{c}, offset); + refinedDotValues.push_back(refinedTensorC); + } + } + auto dotAttrs = dotOp->getAttrs(); + int32_t tileSerial = 0; + // Iterate over dot-tiles. + for (int32_t tileIdxK = 0; tileIdxK < numRepK / tileShapeK; ++tileIdxK) { + for (int tileOuterIdx = 0; tileOuterIdx < dotTileOrder.getNumTilesOuter(); + ++tileOuterIdx) { + for (int tileInnerIdx = 0; + tileInnerIdx < dotTileOrder.getNumTilesInner(); ++tileInnerIdx) { + const int tileStartM = + dotTileOrder.getTileStartM(tileOuterIdx, tileInnerIdx); + const int tileStartN = + dotTileOrder.getTileStartN(tileOuterIdx, tileInnerIdx); + for (int k = tileIdxK * tileShapeK; k < (tileIdxK + 1) * tileShapeK; + ++k) { + int32_t elementSerial = 0; + LDBG("dot-tile[" << tileSerial << "]"); + // Iterate over dots within dot-tile. + for (int m = tileStartM; m < tileStartM + tileShapeM; ++m) { + for (int n = tileStartN; n < tileStartN + tileShapeN; ++n) { + LDBG(" dot[" << m << "][" << n << "][" << k << "]"); + auto refinedTensorA = subtilesA[k][m]; + auto refinedTensorB = subtilesB[k][n]; + auto dotOp = rewriter.create( + loc, refinedTensorTypeD, + ValueRange{refinedTensorA, refinedTensorB, + refinedDotValues[int32_t(m * numRepN + n)]}, + dotAttrs); + // Add dot-tile info to dot. + int32_t tileM = tileStartM / tileShapeM; + int32_t tileN = tileStartN / tileShapeN; + int32_t tileK = k; + int32_t elementM = m - tileStartM; + int32_t elementN = n - tileStartN; + int32_t elementK = 0; + auto dotTileAttr = triton::amdgpu::DotTileAttr::get( + ctx, tileM, tileN, tileK, tileSerial, elementM, elementN, + elementK, elementSerial); + dotOp->setAttr(triton::amdgpu::DotTileAttr::getMnemonic(), + dotTileAttr); + refinedDotValues[int32_t(m * numRepN + n)] = dotOp; + elementSerial++; + } + } + } + tileSerial++; + } + } + } + + auto concatDims = DenseI64ArrayAttr::get(ctx, {numRepM, numRepN}); + auto joinedDotsResult = rewriter.create( + loc, dTensorTy, refinedDotValues, concatDims); + + d.replaceAllUsesWith(joinedDotsResult); + + // Note: dangling localLoadA or/and localLoadB (if exist) + // should be removed by the dead code elimination pass + rewriter.eraseOp(dotOp); + return success(); + } +}; + +LogicalResult rewriteMFMA(PatternRewriter &rewriter, triton::DotOp op) { + if (!(isa(rankedTType(op.getA()).getEncoding()) && + isa(rankedTType(op.getB()).getEncoding()))) { + LDBG("Both $a and %b should be DotOperand layout"); + return failure(); + } + + auto cTensorTy = rankedTType(op.getC()); + auto dTensorTy = rankedTType(op.getD()); + if (!isa(cTensorTy.getEncoding())) { + LDBG("Currently, we only support $c with a mfma layout"); + return failure(); + } + + if (!(cTensorTy.getShape()[0] == dTensorTy.getShape()[0] && + cTensorTy.getShape()[1] == dTensorTy.getShape()[1])) { + LDBG("DotOp's $c operand should pass the same number of values as $d"); + return failure(); + } + + auto loc = op.getLoc(); + auto mfmaLayout = cast( + cast(op.getResult().getType()).getEncoding()); + + DotOpMFMAConverter converter(mfmaLayout, rewriter, loc); + return converter.convert(op, DotOpAdaptor(op)); +} + +struct DotOpPattern : public RefineRewritePattern { + DotOpPattern(MLIRContext *context, PatternBenefit benefit = 1) + : RefineRewritePattern(context, benefit) {} + + LogicalResult apply(triton::DotOp op, + PatternRewriter &rewriter) const override { + auto result = rewriteMFMA(rewriter, op); + if (failed(result)) { + LDBG("failed to refine tt.Dot: " << *op); + } + return result; + } +}; + +struct LocalLoadOpPattern + : public RefineRewritePattern { + LocalLoadOpPattern(MLIRContext *context, PatternBenefit benefit = 1) + : RefineRewritePattern(context, benefit) {} + + LogicalResult apply(triton::gpu::LocalLoadOp op, + PatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1) { + return failure(); + } + + auto *ctx = op->getContext(); + auto loc = op->getLoc(); + + auto resultType = cast(op.getType()); + auto resultElementType = resultType.getElementType(); + auto resultEncode = cast(resultType.getEncoding()); + auto resultShape = resultType.getShape(); + + const auto rank = resultShape.size(); + assert(rank == 2); + + auto opIdx = resultEncode.getOpIdx(); + const int kDimIdx = opIdx == 0 ? rank - 1 : rank - 2; + const int nonKDimIdx = opIdx == 0 ? rank - 2 : rank - 1; + + auto mfmaLayout = cast(resultEncode.getParent()); + int kWidth = resultEncode.getKWidth(); + auto numReps = mfmaLayout.getRepForOperand(resultShape, kWidth, opIdx); + + // indices into 3D numReps + int kRepsIdx = opIdx == 0 ? 2 : 1; + int nonKRepsIdx = opIdx == 0 ? 1 : 2; + int bRepsIdx = 0; + + // 2D shape which drops batch dimension. + SmallVector numReps2D = {numReps[1], numReps[2]}; + + auto numRepsNonK = numReps[nonKRepsIdx]; + auto numRepsK = numReps[kRepsIdx]; + auto numRepsB = numReps[bRepsIdx]; + + auto memDesc = op->getOperand(0); + auto memDescType = cast(memDesc.getType()); + auto memDescEncoding = memDescType.getEncoding(); + SmallVector order; + if (auto enc = dyn_cast( + memDescEncoding)) { + order = decltype(order)(enc.getOrder()); + } + if (auto enc = dyn_cast( + memDescEncoding)) { + order = decltype(order)(enc.getOrder()); + } + assert(!order.empty()); + + SmallVector refinedShape = {resultShape[0] / numReps2D[0], + resultShape[1] / numReps2D[1]}; + LDBG("refinedShape: " << refinedShape[0] << "x" << refinedShape[1]); + + auto refinedTensorType = + RankedTensorType::get(refinedShape, resultElementType, resultEncode); + + constexpr bool mutableMemory = true; + auto sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(ctx); + auto subviewType = ttg::MemDescType::get( + refinedShape, memDescType.getElementType(), memDescType.getEncoding(), + sharedMemorySpace, mutableMemory, memDescType.getAllocShape()); + + rewriter.setInsertionPointAfter(op); + SmallVector subtiles; + for (int32_t i = 0; i < numReps2D[0]; ++i) { + for (int32_t j = 0; j < numReps2D[1]; ++j) { + int32_t offset0 = i * refinedShape[0]; + int32_t offset1 = j * refinedShape[1]; + auto offset = createOffset({}, {offset0, offset1}, rewriter, loc); + auto refinedView = rewriter.create( + loc, subviewType, memDesc, offset); + LDBG("RefinedLocalLoadSubvew: " << *refinedView); + + auto refinedLoad = rewriter.create( + loc, refinedTensorType, refinedView); + subtiles.push_back(refinedLoad); + } + } + + // concat dims is correct shape 8x1 vs 1x8, else gives wrong output shape. + std::vector loweringOrder(numReps2D.size()); + int64_t counter = 0; + auto increment = [&counter](int64_t &val) { val = counter++; }; + if (opIdx == 0) + std::for_each(loweringOrder.rbegin(), loweringOrder.rend(), increment); + else + std::for_each(loweringOrder.begin(), loweringOrder.end(), increment); + + auto joinedResult = rewriter.create( + loc, resultType, subtiles, numReps2D, loweringOrder); + LDBG("ConcatOp: " << *joinedResult); + + rewriter.replaceOp(op, joinedResult); + return success(); + } +}; + +struct LoadOpPattern : public RefineRewritePattern { + LoadOpPattern(MLIRContext *context, PatternBenefit benefit = 1) + : RefineRewritePattern(context, benefit) {} + + LogicalResult apply(triton::LoadOp op, + PatternRewriter &rewriter) const override { + if (op->getNumOperands() != 1) { + return failure(); + } + + auto ctx = op->getContext(); + auto loc = op.getLoc(); + + Value origSrc = op->getOperand(0); + Value origResult = op.getResult(); + Type origResultType = op.getResult().getType(); + auto origPtrs = rankedTType(origSrc); + auto origShape = origPtrs.getShape(); + auto elemType = origPtrs.getElementType(); + auto encoding = dyn_cast(origPtrs.getEncoding()); + if (encoding == nullptr) + return failure(); + + RefinedBlock refinedBlock(origShape, elemType, encoding); + + rewriter.setInsertionPointAfter(op); + SmallVector refinedTensors; + + Value mask = op.getMask(); + Value other = op.getOther(); + auto boundaryCheck = op.getBoundaryCheck(); + auto padding = op.getPadding(); + auto cache = op.getCache(); + auto evict = op.getEvict(); + auto isVolatile = op.getIsVolatile(); + + AMD::CoordinateMapper coordsMapper(refinedBlock.numPerDims); + for (size_t linearIdx = 0; linearIdx < refinedBlock.numSubTiles; + ++linearIdx) { + auto coords = coordsMapper.map(linearIdx); + SmallVector offset(refinedBlock.numDims, 0); + for (auto [dim, coord] : llvm::enumerate(coords)) { + offset[dim] = coord * refinedBlock.elementsPerWorkGroup[dim]; + } + + auto slice = rewriter.create( + loc, Type{refinedBlock.tensorType}, Value{origSrc}, offset); + + auto refinedTensor = rewriter.create( + loc, slice, mask, other, boundaryCheck, padding, cache, evict, + isVolatile); + refinedTensors.push_back(refinedTensor); + } + + auto concatDims = DenseI64ArrayAttr::get(ctx, refinedBlock.numPerDims); + auto joinedResult = rewriter.create( + loc, origResultType, refinedTensors, concatDims); + + origResult.replaceAllUsesWith(joinedResult); + return success(); + } +}; + +struct LocalStoreOpPattern + : public RefineRewritePattern { + LocalStoreOpPattern(MLIRContext *context, PatternBenefit benefit = 1) + : RefineRewritePattern(context, benefit) {} + + LogicalResult apply(triton::gpu::LocalStoreOp op, + PatternRewriter &rewriter) const override { + if (op->getNumOperands() != 2) { + return failure(); + } + + auto ctx = op->getContext(); + auto loc = op.getLoc(); + + Value origSrc = op->getOperand(0); + auto origMemViewOp = + cast(op->getOperand(1).getDefiningOp()); + Value origMemView = origMemViewOp->getOperand(0); + Value selectValue = origMemViewOp.getOffsets().front(); + + auto origSrcType = rankedTType(origSrc); + auto blockEncoding = + dyn_cast(origSrcType.getEncoding()); + if (blockEncoding == nullptr) + return failure(); + + auto origMemViewType = cast(origMemView.getType()); + auto sharedEncoding = cast( + origMemViewType.getEncoding()); + if (sharedEncoding == nullptr) + return failure(); + + RefinedBlock refinedBlock(origSrcType.getShape(), + origSrcType.getElementType(), blockEncoding); + + constexpr bool mutableMemory = true; + auto sharedMemorySpace = triton::gpu::SharedMemorySpaceAttr::get(ctx); + + auto subviewType = ttg::MemDescType::get( + refinedBlock.refinedShape, refinedBlock.elemType, sharedEncoding, + sharedMemorySpace, mutableMemory, origMemViewType.getAllocShape()); + + rewriter.setInsertionPointAfter(op); + AMD::CoordinateMapper coordsMapper(refinedBlock.numPerDims); + for (size_t linearIdx = 0; linearIdx < refinedBlock.numSubTiles; + ++linearIdx) { + auto coords = coordsMapper.map(linearIdx); + SmallVector offset(refinedBlock.numDims, 0); + for (auto [dim, coord] : llvm::enumerate(coords)) { + offset[dim] = coord * refinedBlock.elementsPerWorkGroup[dim]; + } + auto offsetValues = createOffset({selectValue}, offset, rewriter, loc); + auto slicedSharedMemView = rewriter.create( + loc, subviewType, origMemView, offsetValues); + + auto slice = rewriter.create( + loc, Type{refinedBlock.tensorType}, Value{origSrc}, offset); + + rewriter.create(loc, slice, slicedSharedMemView); + } + + rewriter.eraseOp(op); + return success(); + } +}; + +struct LocalAllocOpPattern + : public RefineRewritePattern { + LocalAllocOpPattern(MLIRContext *context, PatternBenefit benefit = 1) + : RefineRewritePattern(context, benefit) {} + + // Refines non-mutable memory `LocalAllocOp` ops. The non-mutable variant + // is used as a not-pipelined version of the op. To be able to refine the op, + // we replace the non-mutable variant with the mutable one that requires + // `LocalDeallocOp` after the last user of the result of `LocalAllocOp`. + // The `LocalStoreOp` is used to move data from registers to the LDS. + // The refinement of the resulting `LocalStoreOp` is left to the dedicated + // rewrite pattern. + LogicalResult apply(triton::gpu::LocalAllocOp op, + PatternRewriter &rewriter) const override { + auto ctx = op->getContext(); + auto loc = op.getLoc(); + auto alignment = op.getAlignment(); + + if (op->getNumOperands() == 0) + return failure(); + + auto allocType = cast(op.getResult().getType()); + auto origShape = allocType.getShape(); + SmallVector newShape(origShape); + SmallVector newAllocShape(allocType.getAllocShape()); + + if (newShape.size() == 2) { + newShape.insert(newShape.begin(), 1); + } + assert(newShape.size() == 3); + + if (newAllocShape.size() == 2) { + newAllocShape.insert(newAllocShape.begin(), 1); + } + assert(newAllocShape.size() == 3); + + auto newAllocType = triton::gpu::MemDescType::get( + ctx, newShape, allocType.getElementType(), allocType.getEncoding(), + allocType.getMemorySpace(), + /*mutableMemory=*/true, newAllocShape); + + rewriter.setInsertionPointAfter(op); + auto newAlloc = + rewriter.create(loc, newAllocType); + newAlloc->setAttrs(op->getAttrs()); + + auto newSubviewType = triton::gpu::MemDescType::get( + ctx, origShape, allocType.getElementType(), allocType.getEncoding(), + allocType.getMemorySpace(), + /*mutableMemory=*/true, newAllocShape); + + auto offset = createOffset({}, {0, 0, 0}, rewriter, loc); + auto newSubview = rewriter.create( + loc, newSubviewType, newAlloc, offset); + rewriter.create(loc, op.getOperand(0), newSubview); + + mlir::Operation *lastUser = nullptr; + for (auto *user : op.getResult().getUsers()) { + if (!lastUser || user->isBeforeInBlock(lastUser) == false) { + lastUser = user; + } + } + + Operation &lastOpInBlock = op->getBlock()->back(); + const bool noUsers = lastUser == nullptr; + const bool isLastInstr = noUsers + ? false + : mlir::OperationEquivalence::isEquivalentTo( + &lastOpInBlock, lastUser, + mlir::OperationEquivalence::Flags::None); + ; + if (noUsers || isLastInstr) { + rewriter.setInsertionPoint(&lastOpInBlock); + } else { + rewriter.setInsertionPointAfter(lastUser); + } + + rewriter.create(loc, newAlloc.getResult()); + + op.replaceAllUsesWith(newSubview.getResult()); + rewriter.eraseOp(op); + + return success(); + } +}; + +struct ReduceOpPattern : public RefineRewritePattern { + ReduceOpPattern(MLIRContext *context, PatternBenefit benefit = 1) + : RefineRewritePattern(context, benefit) {} + + // Reduce ops have different intput and output shapes and produce + // sliced layouts. + // This currently only supports 2d inputs. + LogicalResult apply(triton::ReduceOp op, + PatternRewriter &rewriter) const override { + auto ctx = op->getContext(); + auto loc = op.getLoc(); + uint32_t axisReduce = op.getAxis(); + uint32_t axisNonReduce = (axisReduce + 1) % 2; + if (op.getNumOperands() != 1) + return failure(); + + // Calculate refined shape. + auto src = op->getOperand(0); + auto srcType = rankedTType(src); + if (srcType.getRank() != 2) + return failure(); + auto srcShape = srcType.getShape(); + auto srcEncoding = srcType.getEncoding(); + auto srcShapePerCtaTile = triton::gpu::getShapePerCTATile(srcType); + SmallVector repShape = {srcShape[0] / srcShapePerCtaTile[0], + srcShape[1] / srcShapePerCtaTile[1]}; + int numReps = repShape[axisNonReduce]; + SmallVector refinedSrcShape = {srcShape[0], srcShape[1]}; + refinedSrcShape[axisNonReduce] /= numReps; + int64_t elementsPerRep = refinedSrcShape[axisNonReduce]; + auto elemTy = srcType.getElementType(); + auto refinedTensorType = + RankedTensorType::get(refinedSrcShape, elemTy, srcEncoding); + + // Create refined ops. + rewriter.setInsertionPointAfter(op); + SmallVector refinedReduces; + for (int i = 0; i < numReps; ++i) { + SmallVector offset(refinedSrcShape.size(), 0); + offset[axisReduce] = 0; + offset[axisNonReduce] = i * elementsPerRep; + auto sliceOp = rewriter.create( + loc, Type{refinedTensorType}, Value{src}, offset); + auto reduceOp = rewriter.create( + loc, ValueRange{sliceOp}, axisReduce); + IRMapping mapping; + mapping.map(reduceOp.getOperand(0), sliceOp); + op.getCombineOp().cloneInto(&reduceOp->getRegion(0), mapping); + refinedReduces.push_back(reduceOp->getResult(0)); + } + + // Concat reduce slices. + auto reduceResultType = op.getResultTypes()[0]; + SmallVector concatDimShape = {numReps}; + auto concatDims = DenseI64ArrayAttr::get(ctx, concatDimShape); + auto concatOp = rewriter.create( + loc, reduceResultType, refinedReduces, concatDims); + auto origOpResult = op.getResult(); + origOpResult.replaceAllUsesWith(concatOp); + rewriter.eraseOp(op); + return success(); + } +}; + +template +struct ElementWiseOpPattern : public RefineRewritePattern { + ElementWiseOpPattern(MLIRContext *context, PatternBenefit benefit = 1) + : RefineRewritePattern(context, benefit) {} + + // Refine ops with distributed layouts. + // Assumes same layout for operands. + LogicalResult rewriteElementWiseOp(PatternRewriter &rewriter, OpTy op) const { + // Verify opd[0] is valid. + int numOperands = op->getNumOperands(); + if (op->getNumOperands() < 1) + return failure(); + auto src = op->getOperand(0); + if (!isa(src.getType())) + return failure(); + auto srcType = rankedTType(src); + auto rank = srcType.getRank(); + if (rank != 2) { // TODO(dtanner) remove me + return failure(); + } + + auto srcShape = srcType.getShape(); + auto srcEncoding = srcType.getEncoding(); + auto srcLL = ttg::toLinearEncoding(srcType); + auto srcShapePerCtaTile = getRefinedShapePerCTATile(srcType); + + // Verify subsequent operands match opd[0]. + for (int i = 1; i < numOperands; ++i) { + if (!isa(op->getOperand(i).getType())) + return failure(); + if (rankedTType(op->getOperand(i)).getRank() != rank) + return failure(); + if (getRefinedShapePerCTATile(op->getOperand(i).getType()) != + srcShapePerCtaTile) + return failure(); + } + + // Result tensor. + auto res = op->getResult(0); + if (!isa(res.getType())) + return failure(); + auto resType = rankedTType(res); + auto resShape = resType.getShape(); + if (resShape != srcShape) + return failure(); + + LDBG("rewriteElementWiseOp(): " << op); + + // DEBUG check if concat op results in correct linear layout + auto leRes = ttg::toLinearEncoding(resType); + auto llRes = leRes.getLinearLayout(); + + auto resEncoding = resType.getEncoding(); + auto resShapePerCtaTile = getRefinedShapePerCTATile(resType); + + // Calculate refined shapes. + SmallVector refinedShape; + SmallVector numReps; + for (int i = 0; i < rank; ++i) { + // src and res can have different refineable shapes if different layouts. + const auto refinedDim = + std::max(srcShapePerCtaTile[i], resShapePerCtaTile[i]); + refinedShape.push_back(refinedDim); + numReps.push_back(srcShape[i] / refinedDim); + } + + if (product(numReps) == 1) + return success(); + + // Create refined ops. + auto refinedTensorTypeSrc = RankedTensorType::get( + refinedShape, srcType.getElementType(), srcEncoding); + auto refinedTensorTypeRes = RankedTensorType::get( + refinedShape, resType.getElementType(), resEncoding); + + rewriter.setInsertionPointAfter(op); + SmallVector refinedOps; + SmallVector offset(rank, 0); + int outerIdx = 0; // rank-1; + int innerIdx = 1; // rank-2; + + auto sliceOperation = [&]() { + SmallVector slicedOperands; + for (int opdIdx = 0; opdIdx < numOperands; ++opdIdx) { + auto slicedOperand = rewriter.create( + op.getLoc(), Type{refinedTensorTypeSrc}, + Value{op->getOperand(opdIdx)}, offset); + slicedOperands.push_back(slicedOperand); + } + auto refinedOp = rewriter.create(op.getLoc(), refinedTensorTypeRes, + slicedOperands); + refinedOps.push_back(refinedOp->getResult(0)); + }; + + for (int i = 0; i < numReps[outerIdx]; ++i) { + offset[outerIdx] = i * refinedShape[outerIdx]; + + if (rank == 2) { + for (int j = 0; j < numReps[innerIdx]; ++j) { + offset[innerIdx] = j * refinedShape[innerIdx]; + sliceOperation(); + } + } else { + assert(rank == 1 && "rank is expected to be `1`"); + sliceOperation(); + } + } + + // Concat slices. + auto resultType = op->getResultTypes()[0]; + auto concatDims = DenseI64ArrayAttr::get(op->getContext(), numReps); + auto concatOp = rewriter.create( + op.getLoc(), resultType, refinedOps, concatDims); + + auto origOpResult = op.getResult(); + origOpResult.replaceAllUsesWith(concatOp); + LDBG("rewriteElementWiseOp() - SUCCESS " << op); + rewriter.replaceOp(op, concatOp); + + return success(); + } + + LogicalResult apply(OpTy op, PatternRewriter &rewriter) const override { + auto result = rewriteElementWiseOp(rewriter, op); + if (failed(result)) { + LDBG("failed to refine elementwise op: " << *op); + } + return result; + } +}; + +struct ExpandDimsOpPattern : public RefineRewritePattern { + ExpandDimsOpPattern(MLIRContext *context, PatternBenefit benefit = 1) + : RefineRewritePattern(context, benefit) {} + + // Refine ExpandDims ops. + // Since expanding dims increases tensor rank, + // this refinement multipe intermediate shapes, + // ExSl ExpD Conct + // -> -> -> . + // TODO(dtanner) only need to support 1D sliceLayout input, same as + // ViewOpToLLVM.cpp ? + LogicalResult apply(triton::ExpandDimsOp op, + PatternRewriter &rewriter) const override { + int numOperands = op->getNumOperands(); + if (op->getNumOperands() != 1) + return failure(); + auto src = op->getOperand(0); + if (!isa(src.getType())) + return failure(); + auto srcType = rankedTType(src); + if (srcType.getElementTypeBitWidth() == 1) + return failure(); + + auto rank = srcType.getRank(); + auto srcShape = srcType.getShape(); + auto srcEncoding = srcType.getEncoding(); + auto srcShapePerCtaTile = getRefinedShapePerCTATile(srcType); + + auto ll = triton::gpu::toLinearEncoding(srcType); + + // Calculate refined shape. + SmallVector refinedSrcShape; + SmallVector numReps; + for (int i = 0; i < rank; ++i) { + refinedSrcShape.push_back(srcShapePerCtaTile[i]); + numReps.push_back(srcShape[i] / srcShapePerCtaTile[i]); + } + + if (product(numReps) == 1) + return success(); + + auto refinedResultShape = refinedSrcShape; + refinedResultShape.insert(refinedResultShape.begin() + op.getAxis(), 1); + auto refinedSrcTensorType = RankedTensorType::get( + refinedSrcShape, srcType.getElementType(), srcEncoding); + + // Create refined ops. + rewriter.setInsertionPointAfter(op); + SmallVector refinedReduces; + SmallVector offset(rank, 0); + + auto sliceOperation = [&]() { + auto slicedOp = rewriter.create( + op.getLoc(), Type{refinedSrcTensorType}, Value{op->getOperand(0)}, + offset); + + auto sliceRes = + ::llvm::cast<::mlir::TypedValue<::mlir::RankedTensorType>>( + slicedOp->getResult(0)); + + auto sliceResTy = sliceRes.getType(); + Attribute refinedResultEncoding; + + if (auto refinedSrcEncoding = sliceResTy.getEncoding()) { + if (cast(&srcEncoding.getDialect()) + ->inferExpandDimsOpEncoding(refinedSrcEncoding, op.getAxis(), + refinedResultEncoding, op.getLoc()) + .failed()) { + return emitOptionalError(op.getLoc(), + "Failed to infer layout for ExpandDimsOp"); + } + } + + auto sliceResTensorType = + RankedTensorType::get(refinedResultShape, sliceResTy.getElementType(), + refinedResultEncoding); + + auto refinedOp = rewriter.create( + op.getLoc(), sliceResTensorType, sliceRes, op.getAxis()); + + refinedReduces.push_back(refinedOp->getResult(0)); + return success(); + }; + + for (int i = 0; i < numReps[rank - 1]; ++i) { + offset[rank - 1] = i * refinedSrcShape[rank - 1]; + + // TODO(dtanner) how to iterate over Nd array? + if (rank == 2) { + for (int j = 0; j < numReps[rank - 2]; ++j) { + offset[rank - 2] = j * refinedSrcShape[rank - 2]; + if (llvm::failed(sliceOperation())) + return failure(); + } + } else { + assert(rank == 1 && "rank is expected to be `1`"); + if (llvm::failed(sliceOperation())) + return failure(); + } + } + + // Concat refined ops. + auto reduceResultType = op->getResultTypes()[0]; + // Expand dims of numReps also before concat. + numReps.insert(numReps.begin() + op.getAxis(), 1); + auto concatDims = DenseI64ArrayAttr::get(op->getContext(), numReps); + auto concatOp = rewriter.create( + op.getLoc(), reduceResultType, refinedReduces, concatDims); + auto origOpResult = op.getResult(); + + auto checkLL = triton::gpu::toLinearEncoding( + cast(refinedReduces.front().getType())); + + origOpResult.replaceAllUsesWith(concatOp); + rewriter.eraseOp(op); + return success(); + } +}; + +struct BroadcastOpPattern : public RefineRewritePattern { + BroadcastOpPattern(MLIRContext *context, PatternBenefit benefit = 1) + : RefineRewritePattern(context, benefit) {} + + // Refine Broadcast ops. + // Since inputs are roughtly 1D and outputs are roughly 2D, + // Then the op and outputs are sliced more than the inputs. + // In the below example, shapePerCtaTile is 64x32, + // so the input can be cut in half, while the BroadcastOp + // can be cut into fourths, and the Concat will have dims=2x2. + // Presumably this means the 1st and 3rd Broadcasts are redundant, + // the 2nd and 4th Broadcasts are redundant, and some will be + // eliminated by CSE in the backend compiler. + // Example: + // ExSl Brdcst Concat + //<128x1> -> <64x1> -> <64x32> -> 128x64 + // \ <64x32> / + // -> <64x1> -> <64x32> / + // <64x32> / + LogicalResult apply(triton::BroadcastOp op, + PatternRewriter &rewriter) const override { + // src tensor e.g. <128x1>. + int numOperands = op->getNumOperands(); + if (op->getNumOperands() != 1) + return failure(); + auto src = op->getOperand(0); + if (!isa(src.getType())) + return failure(); + auto srcType = rankedTType(src); + auto rank = srcType.getRank(); + if (rank != 2) + return failure(); + if (srcType.getElementTypeBitWidth() == 1) + return failure(); + auto srcShape = srcType.getShape(); + auto srcEncoding = srcType.getEncoding(); + auto srcShapePerCtaTile = getRefinedShapePerCTATile(srcType); + + // Result tensor e.g. <128x64>. + auto res = op->getResult(0); + if (!isa(res.getType())) + return failure(); + auto resType = rankedTType(res); + auto resShape = resType.getShape(); + auto resEncoding = resType.getEncoding(); + auto resShapePerCtaTile = getRefinedShapePerCTATile(resType); + + // numReps + SmallVector refinedSrcShape; + SmallVector refinedResShape; + SmallVector numReps; + for (int i = 0; i < rank; ++i) { + refinedSrcShape.push_back(srcShapePerCtaTile[i]); + refinedResShape.push_back(resShapePerCtaTile[i]); + numReps.push_back(resShape[i] / resShapePerCtaTile[i]); + } + + if (product(numReps) == 1) + return success(); + + // Determine indices and values of reps. + // numRepsSrc is the non-one size, because the src can be sliced. + // numRepsRes is the one size, because the result will be repeated. + unsigned numRepsSrcIdx = 0; // <*128x 1> + unsigned numRepsResIdx = 1; // < 128x*1> + if (refinedSrcShape[numRepsSrcIdx] == 1) { + numRepsSrcIdx = 1; // < 1x*64> + numRepsResIdx = 0; // <*1x 64> + } + unsigned numRepsSrc = numReps[numRepsSrcIdx]; + unsigned numRepsRes = numReps[numRepsResIdx]; + + // Refined src/result tensor types. + auto refinedSrcTensorType = RankedTensorType::get( + refinedSrcShape, srcType.getElementType(), srcEncoding); + auto refinedResTensorType = RankedTensorType::get( + refinedResShape, srcType.getElementType(), srcEncoding); + + // Create refined ops. + rewriter.setInsertionPointAfter(op); + SmallVector refinedBroadcasts; + SmallVector offset(rank, 0); + for (int i = 0; i < numRepsSrc; ++i) { + offset[numRepsSrcIdx] = i * refinedSrcShape[numRepsSrcIdx]; + // Create slice. + auto slicedOp = rewriter.create( + op.getLoc(), Type{refinedSrcTensorType}, Value{op->getOperand(0)}, + offset); + auto sliceRes = + ::llvm::cast<::mlir::TypedValue<::mlir::RankedTensorType>>( + slicedOp->getResult(0)); + auto sliceResTensorType = RankedTensorType::get( + refinedResShape, srcType.getElementType(), resEncoding); + for (int j = 0; j < numRepsRes; ++j) { + // Create broadcast. + auto broadcastOp = rewriter.create( + op.getLoc(), sliceResTensorType, sliceRes); + refinedBroadcasts.push_back(broadcastOp->getResult(0)); + } + } + + // Concat refined ops. + auto reduceResultType = op->getResultTypes()[0]; + auto concatDims = DenseI64ArrayAttr::get(op->getContext(), numReps); + auto concatOp = rewriter.create( + op.getLoc(), reduceResultType, refinedBroadcasts, concatDims); + + auto origOpResult = op.getResult(); + origOpResult.replaceAllUsesWith(concatOp); + rewriter.eraseOp(op); + return success(); + } +}; + +struct TritonAMDGPURefineOps + : public TritonAMDGPURefineOpsBase { + explicit TritonAMDGPURefineOps(StringRef targetArch) { + this->arch = targetArch.str(); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + triton::FuncOp func = getOperation(); + mlir::triton::AMD::TargetInfo targetInfo(this->arch.getValue()); + if (targetInfo.getISAFamily() == mlir::triton::AMD::ISAFamily::Unknown) { + func.emitError("unsupported target: '") << this->arch.getValue() << "'"; + return signalPassFailure(); + } + + RewritePatternSet primaryPatterns(context); + primaryPatterns.add(context, /*benefit=*/1); + walkAndApplyPatterns(func, std::move(primaryPatterns)); + + RewritePatternSet patterns(context); + patterns.add(context, /*benefit=*/1); + patterns.add(context, /*benefit=*/1); + patterns.add(context, /*benefit=*/1); + patterns.add(context, /*benefit=*/1); + patterns.add(context, /*benefit=*/1); + patterns.add(context, /*benefit=*/1); + patterns.add(context, /*benefit=*/1); + + // Elementwise patterns +#define REFINE_ELEMENTWISE_OP(OP_TYPE) \ + patterns.add>(context, /*benefit=*/1); + + REFINE_ELEMENTWISE_OP(math::RsqrtOp) + REFINE_ELEMENTWISE_OP(math::Exp2Op) + REFINE_ELEMENTWISE_OP(arith::TruncFOp) + REFINE_ELEMENTWISE_OP(arith::ExtFOp) + REFINE_ELEMENTWISE_OP(arith::FPToSIOp) + REFINE_ELEMENTWISE_OP(arith::SIToFPOp) + REFINE_ELEMENTWISE_OP(triton::FpToFpOp) + REFINE_ELEMENTWISE_OP(triton::PreciseSqrtOp) + REFINE_ELEMENTWISE_OP(math::SqrtOp) + REFINE_ELEMENTWISE_OP(math::ExpOp) + REFINE_ELEMENTWISE_OP(arith::SubIOp) + REFINE_ELEMENTWISE_OP(arith::AddIOp) + REFINE_ELEMENTWISE_OP(arith::MulIOp) + REFINE_ELEMENTWISE_OP(arith::DivSIOp) + REFINE_ELEMENTWISE_OP(arith::DivUIOp) + REFINE_ELEMENTWISE_OP(arith::RemFOp) + REFINE_ELEMENTWISE_OP(arith::RemSIOp) + REFINE_ELEMENTWISE_OP(arith::RemUIOp) + REFINE_ELEMENTWISE_OP(arith::AndIOp) + REFINE_ELEMENTWISE_OP(arith::OrIOp) + REFINE_ELEMENTWISE_OP(arith::XOrIOp) + REFINE_ELEMENTWISE_OP(arith::ShLIOp) + REFINE_ELEMENTWISE_OP(arith::ShRSIOp) + REFINE_ELEMENTWISE_OP(arith::ShRUIOp) + REFINE_ELEMENTWISE_OP(arith::MinNumFOp) + REFINE_ELEMENTWISE_OP(arith::MaxNumFOp) + REFINE_ELEMENTWISE_OP(arith::MinSIOp) + REFINE_ELEMENTWISE_OP(arith::MaxSIOp) + REFINE_ELEMENTWISE_OP(arith::MinUIOp) + REFINE_ELEMENTWISE_OP(arith::MaxUIOp) + REFINE_ELEMENTWISE_OP(arith::AddFOp) + REFINE_ELEMENTWISE_OP(arith::SubFOp) + REFINE_ELEMENTWISE_OP(arith::MulFOp) + REFINE_ELEMENTWISE_OP(arith::DivFOp) + REFINE_ELEMENTWISE_OP(arith::MaximumFOp) + REFINE_ELEMENTWISE_OP(arith::MinimumFOp) + REFINE_ELEMENTWISE_OP(triton::gpu::ConvertLayoutOp) + +#undef REFINE_ELEMENTWISE_OP + walkAndApplyPatterns(func, std::move(patterns)); + } +}; + +} // namespace + +namespace mlir { + +std::unique_ptr> +createTritonAMDGPURefineOpsPass(StringRef targetArch) { + return std::make_unique(targetArch); +} + +} // namespace mlir diff --git a/third_party/amd/lib/TritonAMDGPUTransforms/RescheduleOps.cpp b/third_party/amd/lib/TritonAMDGPUTransforms/RescheduleOps.cpp new file mode 100644 index 000000000000..85e6a8542338 --- /dev/null +++ b/third_party/amd/lib/TritonAMDGPUTransforms/RescheduleOps.cpp @@ -0,0 +1,615 @@ +#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h" +#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h" +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Pass/Pass.h" +#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h" +#include "third_party/amd/include/TritonAMDGPUTransforms/MfmaGroup.h" +#include "third_party/amd/lib/TritonAMDGPUToLLVM/TargetInfo.h" +#include "triton/Conversion/TritonGPUToLLVM/Utility.h" +#include "triton/Dialect/Triton/IR/Dialect.h" +#include "triton/Dialect/TritonGPU/IR/Dialect.h" +#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h" + +#define GEN_PASS_CLASSES +#include "TritonAMDGPUTransforms/Passes.h" + +#undef DEBUG_TYPE +#define DEBUG_TYPE "tritonamdgpu-reschedule-ops" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") + +namespace { + +// TODO (ravil): Note, took function from `SchedInstructions.cpp`. +// we need to combine these two implementations +Operation *createSchedBarrier(OpBuilder &rewriter, Location loc, + mlir::amdgpu::sched_barrier_opt_enum maskValue) { + IntegerAttr mask = + rewriter.getI32IntegerAttr(static_cast(maskValue)); + return rewriter.create(loc, mask); +} + +struct Node { + Node(Operation *op, int32_t weight = 0) : op(op), weight(weight) {} + enum class ChildType { Real, Artificials }; + + Operation *getOp() { return op; } + int32_t getWeight() { return weight; } + void setWeight(int32_t priority) { this->weight = priority; } + + template void add(Node *node) { + llvm::SetVector *children = + Type == ChildType::Real ? &realChildren : &artificialChildren; + children->insert(node); + } + void addParent(Node *node) { parents.insert(node); } + + size_t getNumParents() { return parents.size(); } + bool hasChildren() { + return !(realChildren.empty() && artificialChildren.empty()); + } + bool hasNoChildren() { return !hasChildren(); } + + const llvm::SetVector &getRealChildren() { return realChildren; } + const llvm::SetVector &getArtificialChildren() { + return artificialChildren; + } + const llvm::SetVector &getParents() { return parents; } + + void removeChild(Node *node) { + if (realChildren.contains(node)) { + realChildren.remove(node); + } + if (artificialChildren.contains(node)) { + artificialChildren.remove(node); + } + } + + void removeParent(Node *node) { + if (parents.contains(node)) { + parents.remove(node); + } + } + + void drainChildren() { + realChildren.clear(); + artificialChildren.clear(); + } + +private: + Operation *op; + int32_t weight; + llvm::SetVector realChildren; + llvm::SetVector artificialChildren; + llvm::SetVector parents; +}; + +struct NodeWeightStrategy { + virtual void set(llvm::SmallVector> &nodes) = 0; +}; + +struct BasicDotWeightStrategy : public NodeWeightStrategy { + void set(llvm::SmallVector> &nodes) override { + SmallVector dots; + SmallVector loads; + for (auto &node : nodes) { + if (dyn_cast(node->getOp())) + dots.push_back(node.get()); + if (dyn_cast(node->getOp())) + loads.push_back(node.get()); + } + + // Make sure that prio is never equal to `0` for dotOps + constexpr int32_t extraDefaultWeight = 10; + int32_t dotWeight = dots.size() + extraDefaultWeight; + const int32_t loadWeight = dotWeight + extraDefaultWeight; + for (auto loadNode : loads) { + propagate(loadNode, loadWeight); + } + + for (auto dotNode : dots) { + propagate(dotNode, dotWeight--); + } + } + +private: + void propagate(Node *node, int32_t weight) { + if (visited.contains(node)) + return; + + node->setWeight(weight); + visited.insert({node}); + + if (node->hasChildren()) { + for (auto child : node->getRealChildren()) { + propagate(child, weight); + } + for (auto child : node->getArtificialChildren()) { + propagate(child, weight); + } + } + } + + SetVector visited{}; +}; + +struct Graph { +public: + Graph(Block *mlirBlock) { + createNodes(mlirBlock); + createEdges(); + } + + Graph(const Graph &other) { + using iteratorType = decltype(other.nodes.begin()); + DenseMap map; + for (auto it = other.nodes.begin(); it != other.nodes.end(); ++it) { + auto newNode = + std::make_unique(it->get()->getOp(), it->get()->getWeight()); + map.insert({it->get(), it}); + lookup.insert({newNode->getOp(), newNode.get()}); + nodes.push_back(std::move(newNode)); + } + + for (auto [idx, otherNode] : llvm::enumerate(other.nodes)) { + auto &currNode = nodes[idx]; + for (auto otherChild : otherNode->getRealChildren()) { + auto otherChildIt = map.find(otherChild)->second; + auto childIdx = std::distance(other.nodes.begin(), otherChildIt); + currNode->add(nodes[childIdx].get()); + } + + for (auto otherChild : otherNode->getArtificialChildren()) { + auto otherChildIt = map.find(otherChild)->second; + auto childIdx = std::distance(other.nodes.begin(), otherChildIt); + currNode->add(nodes[childIdx].get()); + } + + for (auto otherParent : otherNode->getParents()) { + auto otherParentIt = map.find(otherParent)->second; + auto parentIdx = std::distance(other.nodes.begin(), otherParentIt); + currNode->addParent(nodes[parentIdx].get()); + } + nodes.push_back(std::make_unique(otherNode->getOp())); + } + } + + SmallVector getNodes() { + SmallVector copy(nodes.size(), nullptr); + for (auto [idx, node] : llvm::enumerate(nodes)) { + copy[idx] = node.get(); + } + return copy; + } + + void setNodesWeights(NodeWeightStrategy &&strategy) { strategy.set(nodes); } + +private: + void createNodes(Block *mlirBlock) { + for (auto it = mlirBlock->begin(); it != mlirBlock->end(); ++it) { + Operation *op = &(*it); + std::unique_ptr node = std::make_unique(op); + lookup.insert({op, node.get()}); + nodes.push_back(std::move(node)); + } + } + + enum class Traversal { Topdown, Bottomup }; + template void insertGPUBarrierEdges() { + + auto fwIt = nodes.begin(); + auto bkIt = nodes.rbegin(); + auto next = [&]() -> Node * { + if constexpr (Direction == Traversal::Topdown) { + if (fwIt == nodes.end()) + return nullptr; + return (fwIt++)->get(); + } + if constexpr (Direction == Traversal::Bottomup) { + if (bkIt == nodes.rend()) + return nullptr; + return (bkIt++)->get(); + } + return nullptr; + }; + + llvm::SmallVector ldsOpsNodes; + while (Node *node = next()) { + auto localLoad = dyn_cast(node->getOp()); + auto localStore = dyn_cast(node->getOp()); + auto localAlloc = dyn_cast(node->getOp()); + if (localLoad || localStore || localAlloc) { + ldsOpsNodes.push_back(node); + } + auto gpuBarrier = dyn_cast(node->getOp()); + if (gpuBarrier) { + Node *barrierNode = node; + for (auto ldsOpNode : ldsOpsNodes) { + if constexpr (Direction == Traversal::Topdown) { + barrierNode->add(ldsOpNode); + ldsOpNode->addParent(barrierNode); + } + if constexpr (Direction == Traversal::Bottomup) { + barrierNode->addParent(ldsOpNode); + ldsOpNode->add(barrierNode); + } + } + ldsOpsNodes.clear(); + } + } + } + + void createEdges() { + // insert edges imposed by def-use chains + for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) { + auto &node = *it; + for (auto operandValue : node->getOp()->getOperands()) { + auto operandDefOp = operandValue.getDefiningOp(); + if (!lookup.contains(operandDefOp)) + continue; + Node *childNode = lookup.find(operandDefOp)->second; + node->add(childNode); + childNode->addParent(node.get()); + } + } + + // gpu.Barrier ops are orphans. Add edges to + // respect data dependencies in the block + insertGPUBarrierEdges(); + insertGPUBarrierEdges(); + + // connect orphans with the last op in the block + auto &lastNode = *(nodes.rbegin()); + for (auto it = std::next(nodes.rbegin()); it != nodes.rend(); ++it) { + auto &node = *it; + if (node->getNumParents() == 0) { + node->addParent(lastNode.get()); + lastNode->add(node.get()); + } + } + } + + llvm::SmallVector> nodes; + llvm::MapVector lookup; +}; + +llvm::raw_ostream &operator<<(llvm::raw_ostream &out, Graph &graph) { + out << "digraph \"dep-graph\" {\n"; + out << "rankdir=\"LR\"\n"; + for (auto [idx, node] : llvm::enumerate(graph.getNodes())) { + std::string name = std::to_string(reinterpret_cast(node)); + out << name << "\t[label=\"" << node->getOp()->getName() << " (" + << node->getWeight() << ") \"]\n"; + } + for (auto [idx, node] : llvm::enumerate(graph.getNodes())) { + std::string name = std::to_string(reinterpret_cast(node)); + for (auto child : node->getRealChildren()) { + std::string childName = std::to_string(reinterpret_cast(child)); + out << "\t" << childName << " -> " << name << ";\n"; + } + for (auto child : node->getArtificialChildren()) { + std::string childName = std::to_string(reinterpret_cast(child)); + out << "\t" << childName << " -> " << name + << " [style=\"dashed\", color=\"blue\"];\n"; + } + } + out << "}"; + return out; +} + +struct GraphManager { + GraphManager(Graph &graph) : graph(graph) { + for (auto [idx, node] : llvm::enumerate(graph.getNodes())) { + nodesIndices.insert({node, idx}); + if (node->hasNoChildren()) + leafs.insert(node); + } + } + + bool finished() { return leafs.empty(); } + + void removeLeaf(Node *node) { + assert(node->hasNoChildren()); + leafs.remove(node); + for (auto parent : node->getParents()) { + parent->removeChild(node); + if (parent->hasNoChildren()) { + leafs.insert(parent); + } + } + } + + size_t getNodeSourceCodeIndex(Node *node) { + assert(nodesIndices.contains(node)); + return nodesIndices[node]; + } + + const SetVector &getCurrentLeafs() { return leafs; } + +private: + Graph graph; + SetVector leafs; + DenseMap nodesIndices; +}; + +struct MachineModel { + struct Result { + Node *selectedNode{nullptr}; + SmallVector normPriorityNodes{}; + SmallVector lowPriorityNodes{}; + void set(Node *node) { + if (!selectedNode) { + selectedNode = node; + } else { + if (node->getWeight() > selectedNode->getWeight()) { + selectedNode = node; + } + } + } + }; + + Result select(const SetVector &readyNodes) { + Result result; + SmallVector ldsNodes; + SmallVector vmemNodes; + SmallVector mfmaNodes; + SmallVector barrierNodes; + + for (auto *node : readyNodes) { + Operation *op = node->getOp(); + if (dyn_cast(op) || + dyn_cast(op) || + dyn_cast(op)) { + ldsNodes.push_back(node); + continue; + } + if (dyn_cast(op) || dyn_cast(op) || + dyn_cast(op) || + dyn_cast(op)) { + vmemNodes.push_back(node); + continue; + } + if (dyn_cast(op)) { + mfmaNodes.push_back(node); + continue; + } + if (dyn_cast(op)) { + barrierNodes.push_back(node); + continue; + } + result.normPriorityNodes.push_back(node); + } + + if (!vmemNodes.empty()) { + if (MachineModel::maxLoadStoreIssues > issuedLoadStoreCounter) { + for (auto *node : vmemNodes) { + result.set(node); + } + ++issuedLoadStoreCounter; + return result; + } else { + result.lowPriorityNodes.append(vmemNodes); + } + } + + if (!ldsNodes.empty()) { + if (MachineModel::maxLocalLoadStoreIssues > issuedLocalStoreLoadCounter) { + for (auto *node : ldsNodes) { + result.set(node); + } + if (!(dyn_cast( + result.selectedNode->getOp()))) + ++issuedLocalStoreLoadCounter; + return result; + } else { + result.lowPriorityNodes.append(ldsNodes); + } + } + + if (!barrierNodes.empty()) { + for (auto *node : barrierNodes) { + result.set(node); + } + return result; + } + + if (!mfmaNodes.empty()) { + issuedLocalStoreLoadCounter = + std::max(0, issuedLocalStoreLoadCounter - 1); + issuedLoadStoreCounter = std::max(0, issuedLoadStoreCounter - 1); + for (auto *node : mfmaNodes) { + result.set(node); + } + return result; + } + + return result; + } + + void printState(llvm::raw_ostream &stream) { + stream << "issuedLoadStoreCounter: " << issuedLoadStoreCounter << "; " + << "issuedLocalStoreLoadCounter: " << issuedLocalStoreLoadCounter + << '\n'; + } + +private: + const inline static int32_t maxLoadStoreIssues{2}; + const inline static int32_t maxLocalLoadStoreIssues{4}; + int32_t issuedLoadStoreCounter{0}; + int32_t issuedLocalStoreLoadCounter{0}; +}; + +struct TritonAMDGPURescheduleOps + : public TritonAMDGPURescheduleOpsBase { + explicit TritonAMDGPURescheduleOps(StringRef targetArch) { + this->arch = targetArch.str(); + } + + LogicalResult verify(Block *mlirBlock) { + // make sure that a block gets terminated with `cf::BranchOp` + if (!dyn_cast(&(mlirBlock->back()))) { + return failure(); + } + + // do't schedule if there is not enough operations in a block + if (mlirBlock->getOperations().size() < 3) + return failure(); + return success(); + } + /* + reschedule() is the top-level scheduling pass for a single block, + whose purpose is to improve performance and regalloc of backend compilers. + Before this pass, mfmas and local_loads (belonging to dots) + were already annotated with their dot-tile info. + The order of re-scheduling is: + - Place order dependencies on dots according to dot-tiling. + - Place order dependencies on local_loads according to dot-tiling. + - Determine min-register vs max-latency-hiding preference. + - Determine memory op order and co-scheduling. + - Place order dependencies between memory ops. + - Determine memory ops' early/late preference. + - Determine memory ops' preferred issue rate. + - Determine memory ops' supported issue rate. + - Place performance and anti-dependencies between memory ops and dots. + - Run scheduler with new dependencies in place. + Note that rescheduling can be run after any new dependencies are created to + visualize graph. + */ + void reschedule(Block *mlirBlock) { + + Graph graph(mlirBlock); + graph.setNodesWeights(BasicDotWeightStrategy()); + LDBG("Dependency graph in dot-format:\n" << graph); + + GraphManager manager(graph); + MachineModel machineModel; + SmallVector rescheduledOps; + + auto defaultSelector = [&](const SmallVector readyNodes) { + size_t minSourceCodeNodeIndex = std::numeric_limits::max(); + Node *earliestNodeToRun = nullptr; + for (auto node : readyNodes) { + const auto sourceCodeIndex = manager.getNodeSourceCodeIndex(node); + if (minSourceCodeNodeIndex > sourceCodeIndex) { + minSourceCodeNodeIndex = sourceCodeIndex; + earliestNodeToRun = node; + } + } + return earliestNodeToRun; + }; + + auto nodeWeightsSelector = [&](const SmallVector readyNodes) { + int32_t maxWeightValue = std::numeric_limits::min(); + Node *selectedNode = nullptr; + for (auto node : readyNodes) { + if (node->getWeight() > maxWeightValue) { + maxWeightValue = node->getWeight(); + selectedNode = node; + } + } + return selectedNode; + }; + + const bool verbose = false; + std::string dbgStr; + llvm::raw_string_ostream dbgStream(dbgStr); + while (!manager.finished()) { + const auto &readyNodes = manager.getCurrentLeafs(); + MachineModel::Result selectionResult = machineModel.select(readyNodes); + auto selectedNode = selectionResult.selectedNode; + bool selectedFromMachineModel = selectedNode ? true : false; + + if (!selectedNode) { + selectedNode = nodeWeightsSelector(selectionResult.normPriorityNodes); + } + + bool selectedFromNormPrioQueue = false; + if (!selectedNode) { + selectedNode = defaultSelector(selectionResult.normPriorityNodes); + selectedFromNormPrioQueue = true; + } + + bool selectedFromLowPrioqueue = false; + if (!selectedNode) { + selectedNode = defaultSelector(selectionResult.lowPriorityNodes); + selectedFromLowPrioqueue = true; + } + + assert(selectedNode != nullptr); + + if (verbose) { + dbgStream << std::string(80, '+') << "\n"; + for (auto n : selectionResult.normPriorityNodes) { + n->getOp()->print(dbgStream); + dbgStream << '\n'; + } + dbgStream << "\n\n\nSelected\n"; + selectedNode->getOp()->print(dbgStream); + dbgStream << '\n'; + machineModel.printState(dbgStream); + dbgStream << "selectedFromMachineModel: " << selectedFromMachineModel + << "; " + << "selectedFromNormPrioQueue: " << selectedFromNormPrioQueue + << "; " + << "selectedFromLowPrioqueue: " << selectedFromLowPrioqueue + << '\n'; + } + + manager.removeLeaf(selectedNode); + rescheduledOps.push_back(selectedNode->getOp()); + } + + if (verbose) + llvm::outs() << dbgStream.str() << '\n'; + + std::string outStr; + llvm::raw_string_ostream outStream(outStr); + outStream << "\n\n\n...." << std::string(80, '-') << '\n'; + for (auto op : rescheduledOps) { + op->print(outStream); + outStream << "\n"; + } + + // re-order instruction based on the new schedule + // move instruction from the tail to the begining of the current BB + // one-by-one + for (auto it = rescheduledOps.rbegin(); it != rescheduledOps.rend(); ++it) { + (*it)->moveBefore(mlirBlock, mlirBlock->begin()); + } + + OpBuilder builder(&(mlirBlock->front())); + for (auto &op : mlirBlock->getOperations()) { + if (dyn_cast(&op)) { + auto barrier = createSchedBarrier( + builder, op.getLoc(), mlir::amdgpu::sched_barrier_opt_enum::none); + barrier->moveAfter(&op); + } + } + } + + void runOnOperation() override { + ModuleOp mod = getOperation(); + llvm::SmallVector blocks; + mod.walk([&](triton::amdgpu::InstructionSchedHint hint) { + if (hint.getVariant() == triton::amdgpu::SchedHint::refine_ops) { + blocks.push_back(hint->getBlock()); + hint->erase(); + } + }); + + for (auto block : blocks) { + if (succeeded(verify(block))) { + reschedule(block); + } + } + } +}; +} // namespace + +namespace mlir { +std::unique_ptr> +createTritonAMDGPURescheduleOpsPass(StringRef targetArch) { + return std::make_unique(targetArch); +} +} // namespace mlir diff --git a/third_party/amd/python/triton_amd.cc b/third_party/amd/python/triton_amd.cc index e0e870127cb5..442d72634538 100644 --- a/third_party/amd/python/triton_amd.cc +++ b/third_party/amd/python/triton_amd.cc @@ -37,6 +37,18 @@ const char *const amdTargetTriple = "amdgcn-amd-amdhsa"; void init_triton_amd_passes_ttgpuir(py::module &&m) { using namespace mlir::triton; + m.def("add_membar_analysis", [](mlir::PassManager &pm) { + pm.addPass(mlir::createTritonAMDGPUMembarAnalysisPass()); + }); + m.def("add_refine_amdgpu_ops", + [](mlir::PassManager &pm, const std::string &arch) { + pm.addNestedPass( + mlir::createTritonAMDGPURefineOpsPass(arch)); + }); + m.def("add_reschedule_amdgpu_ops", + [](mlir::PassManager &pm, const std::string &arch) { + pm.addPass(mlir::createTritonAMDGPURescheduleOpsPass(arch)); + }); m.def("add_to_llvmir", [](mlir::PassManager &pm, const std::string &arch, bool ftz) { pm.addPass(createConvertTritonAMDGPUToLLVMPass(arch, ftz));