diff --git a/test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir b/test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir index 26a77b8113..b27480f2f1 100644 --- a/test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir +++ b/test/TritonIntelGPU/backward_combine_dpas_dot_layout.mlir @@ -345,3 +345,57 @@ module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 32 : i32} tt.return } } + +// ----- + +// CHECK: #[[BLOCKED:.+]] = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}> +// CHECK: #[[DPAS:.+]] = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> +#blocked = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [2, 2], order = [1, 0]}> +#blocked1 = #ttg.blocked<{sizePerThread = [1, 1], threadsPerWarp = [1, 16], warpsPerCTA = [1, 4], order = [1, 0]}> +#dpas = #ttig.dpas<{repeatCount = 8, systolicDepth = 8, executionSize = 16, opsPerChan = 2, threadsPerWarp = 16, warpsPerCTA = [1, 4], repCluster = [1, 1], A = [8, 16], B = [16, 16], C = [8, 16]}> +#dot0 = #ttg.dot_op<{opIdx = 0, parent = #dpas, kWidth=1}> +#dot1 = #ttg.dot_op<{opIdx = 1, parent = #dpas, kWidth=2}> +module attributes {"ttg.num-warps" = 4 : i32, "ttg.threads-per-warp" = 16 : i32, "ttig.support_sg_2d_block"} { + tt.func public @reduce_loop_carried_values(%arg0: !tt.ptr, %arg1: !tt.ptr, %arg2: !tt.ptr, %arg5: i32) { + %c1_i64 = arith.constant 1 : i64 + %c0_i32 = arith.constant 0 : i32 + %c0_i64 = arith.constant 0 : i64 + %c32_i32 = arith.constant 32 : i32 + %cst = arith.constant dense<0.000000e+00> : tensor<64x256xf32, #dpas> + %18 = tt.make_tensor_ptr %arg0, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + %22 = tt.make_tensor_ptr %arg1, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + %23:3 = scf.for %arg9 = %c0_i32 to %arg5 step %c32_i32 iter_args(%arg10 = %cst, %arg11 = %18, %arg12 = %22) -> (tensor<64x256xf32, #dpas>, !tt.ptr>, !tt.ptr>) : i32 { + // COM: Ensure there are only 3 loop results and not layout conversion in the loop. + // CHECK: [[LOOP_RES:%.*]]:3 = scf.for + // CHECK-NOT: ttg.convert_layout + // CHECK: scf.yield + %28 = tt.load %arg11 {boundaryCheck = array, ttig.block_io = "row_major" } : !tt.ptr> + %29 = tt.load %arg12 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + %30 = ttg.convert_layout %28 : tensor<64x32xf16, #blocked> -> tensor<64x32xf16, #dot0> + %31 = ttg.convert_layout %29 : tensor<32x256xf16, #blocked1> -> tensor<32x256xf16, #dot1> + %32 = tt.dot %30, %31, %arg10, inputPrecision = tf32 : tensor<64x32xf16, #dot0> * tensor<32x256xf16, #dot1> -> tensor<64x256xf32, #dpas> + %33 = tt.advance %arg11, [%c0_i32, %c32_i32] : > + %34 = tt.advance %arg12, [%c32_i32, %c0_i32] : > + scf.yield %32, %33, %34 : tensor<64x256xf32, #dpas>, !tt.ptr>, !tt.ptr> + } + %24 = arith.truncf %23#0 : tensor<64x256xf32, #dpas> to tensor<64x256xf16, #dpas> + %27 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + tt.store %27, %24 {boundaryCheck = array} : !tt.ptr> + + // CHECK: [[LOAD1:%.*]] = tt.load [[LOOP_RES]]#1 {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr>> + // CHECK: [[CONV1:%.*]] = ttg.convert_layout [[LOAD1]] : tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>> -> tensor<64x32xf16, #[[BLOCKED]]> + // CHECK: [[PTR:%.*]] = tt.make_tensor_ptr {{.*}}, {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}], {{\[}}{{.*}}, {{.*}}] {order = array} : > + // CHECK: tt.store [[PTR]], [[CONV1]] {boundaryCheck = array} : !tt.ptr> + %28 = tt.load %23#1 {boundaryCheck = array, ttig.block_io = "row_major" } : !tt.ptr> + %29 = tt.make_tensor_ptr %arg2, [%c0_i64, %c0_i64], [%c0_i64, %c1_i64], [%c0_i32, %c0_i32] {order = array} : > + tt.store %29, %28 {boundaryCheck = array} : !tt.ptr> + + // CHECK: [[LOAD2:%.*]] = tt.load [[PTR]] {boundaryCheck = array, ttig.block_io = "row_major"} : !tt.ptr> + // CHECK: [[CONV2:%.*]] = ttg.convert_layout [[LOAD2]] : tensor<64x32xf16, #[[BLOCKED]]> -> tensor<64x32xf16, #ttg.dot_op<{opIdx = 0, parent = #[[DPAS]], kWidth = 1}>> + // CHECK: tt.store [[LOOP_RES]]#1, [[CONV2]] {boundaryCheck = array} : !tt.ptr>> + %30 = tt.load %29 {boundaryCheck = array, ttig.block_io = "row_major" } : !tt.ptr> + tt.store %23#1, %30 {boundaryCheck = array} : !tt.ptr> + + tt.return + } +} diff --git a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp index 5dde69dcc9..6e4bc21960 100644 --- a/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp +++ b/third_party/intel/lib/TritonIntelGPUTransforms/RemoveLayoutConversions.cpp @@ -25,6 +25,7 @@ #include "triton/Dialect/TritonGPU/IR/Dialect.h" #include "triton/Dialect/TritonGPU/Transforms/TritonGPUConversion.h" #include "triton/Dialect/TritonGPU/Transforms/Utility.h" +#include "llvm/ADT/TypeSwitch.h" #include namespace mlir::triton::gpu::intel { @@ -1267,6 +1268,78 @@ void LayoutRematerialization::backwardRematerialization() { convertOp.getResult()); } } + + // Reduce loop carried values if the value can be removed by using another + // loop yielded value plus a convert layout operation. + for (auto [pair, val] : rematMapping) { + if (!isa(pair.first)) + continue; + + auto arg = cast(pair.first); + if (!isTensorPointerType(arg.getType())) + continue; + + if (auto loopOp = + dyn_cast(arg.getOwner()->getParentOp())) { + // Loop arguments that corresponds to a loop result which is not used are + // not interesting. + OpResult loopRes = loopOp.getTiedLoopResult(arg); + if (loopRes.getNumUses() == 0) + continue; + + // Replace the loop result corresponding to the argument with an + // equivalent loop result. + auto rematArg = cast(val); + OpResult rematRes = loopOp.getTiedLoopResult(rematArg); + + for (OpOperand &use : loopRes.getUses()) { + Operation *user = use.getOwner(); + Location loc = user->getLoc(); + OpBuilder rewriter(user); + + TypeSwitch(user) + .Case([&](auto loadOp) { + auto newLoadOp = + rewriter.create(loc, rematRes, loadOp->getAttrs()); + auto convOp = rewriter.create( + loc, loadOp.getType(), newLoadOp.getResult()); + loadOp->replaceAllUsesWith(convOp); + opToDelete.insert(loadOp); + LLVM_DEBUG({ + DBGS() << "Replaced:\n\t" << *loadOp << "\n"; + DBGS() << "with:\n\t" << *newLoadOp << "\n" + << "\t" << *convOp << "\n"; + }); + }) + .Case([&](auto storeOp) { + Value data = storeOp.getOperand(1); + auto dataType = cast(data.getType()); + auto newPtrType = cast(rematRes.getType()); + Attribute encoding = + cast(newPtrType.getPointeeType()) + .getEncoding(); + RankedTensorType newDataType = + dataType.cloneWithEncoding(encoding); + auto convOp = + rewriter.create(loc, newDataType, data); + auto newStoreOp = rewriter.create( + loc, rematRes, convOp, storeOp.getBoundaryCheck(), + storeOp.getCache(), storeOp.getEvict()); + opToDelete.insert(storeOp); + LLVM_DEBUG({ + DBGS() << "Replaced:\n\t" << *storeOp << "\n"; + DBGS() << "with:\n\t" << *convOp << "\n" + << "\t" << *newStoreOp << "\n"; + }); + }) + .Default([](auto op) { + llvm::report_fatal_error(llvm::Twine( + "Unsupported operation in backward rematerialization: '" + + op->getName().getStringRef() + "'")); + }); + } + } + } } void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {