diff --git a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp index 994e445e..bc78fe93 100644 --- a/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp +++ b/lib/gc/Transforms/GPU/LinalgToXeGPU.cpp @@ -1312,6 +1312,105 @@ static FailureOr findAndReplaceTranspose(const Value &matmulOperand, defOp, "No transpose operation producing the operand was found"); } +// Checks whether the given linalgOp operand is produced by a +// `linalg::BroadcastOp` that can be replaced by a simple subview +// (for example broadcast: 7x128 -> 1x1x7x128) and ensures that +// the broadcast result is only used by linalgOp in question. +// +// If a valid `linalg::BroadcastOp` is found, the function removes it +// and returns the operand of the `linalg::BroadcastOp` as the new +// linalgOp operand. Otherwise returns the original operand. +static Value findAndReplaceBroadcast(linalg::LinalgOp linalgOp, + size_t operandIdx, + PatternRewriter &rewriter) { + auto operand = linalgOp.getDpsInputs()[operandIdx]; + auto operandParent = operand; + + // walk over the 'Value' users and verify that it's only used by 'ops' + std::function &)> onlyUsedByOp = + [&onlyUsedByOp](Value value, SmallVector &ops) -> bool { + bool result = true; + for (auto user : value.getUsers()) { + if (auto linalgOpUser = dyn_cast(user)) + result &= std::find(ops.begin(), ops.end(), linalgOpUser) != + ops.end(); // linalgOpUser == op; + else if (isa(user)) + continue; // allow deallocs as users + else if (auto subview = dyn_cast(user)) + result &= onlyUsedByOp(subview.getResult(), ops); + else + return false; + } + return result; + }; + + linalg::BroadcastOp broadcastOp = nullptr; + while (auto defOp = operandParent.getDefiningOp()) { + for (auto x : defOp->getUsers()) { + if (!isa(x)) + continue; + + if (broadcastOp) { + rewriter.notifyMatchFailure(broadcastOp, + "Only one broadcast operation is allowed"); + return operand; + } + + broadcastOp = dyn_cast(x); + auto broadcastRes = broadcastOp.getDpsInits()[0]; + SmallVector ops({linalgOp, broadcastOp}); + + // verify that there are no other users of the broadcast result + // other than the linalgOp in question + if (!onlyUsedByOp(broadcastRes, ops)) { + rewriter.notifyMatchFailure( + broadcastOp, "Broadcast result is used by more than one operation"); + return operand; + } + break; + } + + if (defOp->getOperands().size() == 0) + break; + + operandParent = defOp->getOperand(0); + } + if (!broadcastOp) { + rewriter.notifyMatchFailure( + linalgOp, "No broadcast operation producing the operand was found"); + return operand; + } + + auto brInp = broadcastOp.getDpsInputs()[0]; + auto brOut = broadcastOp.getDpsInits()[0]; + + auto inpType = dyn_cast(brInp.getType()); + auto outType = dyn_cast(brOut.getType()); + if (!inpType || !outType) + return operand; + + auto inpShape = inpType.getShape(); + auto outShape = outType.getShape(); + + if (inpShape.size() < 2) { + rewriter.notifyMatchFailure(broadcastOp, "Only nD broadcast is supported"); + return operand; + } + + if (!utils::canSqueezeDims(inpShape) || !utils::canSqueezeDims(outShape)) { + rewriter.notifyMatchFailure(broadcastOp, + "Can't squeeze broadcast operands to 2D"); + return operand; + } + + auto res = utils::reduceMemrefDims(rewriter, broadcastOp.getLoc(), brInp); + if (failed(res)) + return operand; + + rewriter.eraseOp(broadcastOp); + return res.value(); +} + // Create XeGPU DPAS kernel out of GEMM-like operation. static LogicalResult createDPASKernel(linalg::LinalgOp linalgOp, ArrayRef dpasTile, int kTile, @@ -1690,7 +1789,9 @@ LogicalResult createEltwiseKernel(linalg::LinalgOp linalgOp, // Create descriptors and load values for all inputs. SmallVector> loadedInputs; + size_t operandIdx = 0; for (auto input : linalgOp.getDpsInputs()) { + input = findAndReplaceBroadcast(linalgOp, operandIdx, rewriter); SmallVector inputTiles = createDescriptorTiles(rewriter, loc, input, outputShape, tileShape); @@ -1699,6 +1800,7 @@ LogicalResult createEltwiseKernel(linalg::LinalgOp linalgOp, /*vnniConf=*/std::nullopt, /*transpose=*/nullptr, /*transpose_bit=*/nullptr); loadedInputs.push_back(loadedVals); + operandIdx++; } // Extract SIMD sized sub-tiles from loaded tiles. diff --git a/lib/gc/Transforms/GPU/Pipeline.cpp b/lib/gc/Transforms/GPU/Pipeline.cpp index 130b25a1..a1878df7 100644 --- a/lib/gc/Transforms/GPU/Pipeline.cpp +++ b/lib/gc/Transforms/GPU/Pipeline.cpp @@ -37,6 +37,7 @@ void populateGPUPipeline(OpPassManager &pm, pm.addPass(createDecomposeTensorOperation()); pm.addNestedPass(createGpuTilingAndFusion()); + pm.addPass(createCanonicalizerPass()); pm.addPass(bufferization::createEmptyTensorEliminationPass()); pm.addPass(bufferization::createEmptyTensorToAllocTensorPass()); diff --git a/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-broadcast-fold.mlir b/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-broadcast-fold.mlir new file mode 100644 index 00000000..bb0091ce --- /dev/null +++ b/test/mlir/test/gc/Transforms/GPU/linalg-to-xegpu-broadcast-fold.mlir @@ -0,0 +1,109 @@ +// RUN: gc-opt %s -linalg-to-xegpu="dpas-tile=8,16,16 k-tile=16" -canonicalize -split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @broadcast_eliminate_2d +func.func @broadcast_eliminate_2d() { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + + // CHECK: %[[MEMREF_0:.*]] = memref.alloc() : memref<7x128xf16> + %0 = memref.alloc() : memref<7x128xf16> + // CHECK: %[[MEMREF_2:.*]] = memref.alloc() : memref<1x1x7x128xf16> + %2 = memref.alloc() : memref<1x1x7x128xf16> + // CHECK: %[[MEMREF_3:.*]] = memref.alloc() : memref<1x1x7x128xf16> + %3 = memref.alloc() : memref<1x1x7x128xf16> + + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c1, %arg16 = %c1) { + // CHECK-NOT: memref.alloc() : memref<4x1x7x128xf16, 3> + %slm_base = memref.alloc() : memref<4x1x7x128xf16, 3> + %1 = memref.subview %slm_base[%arg6, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : memref<4x1x7x128xf16, 3> to memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3> + + // CHECK-NOT: linalg.broadcast + linalg.broadcast ins(%0 : memref<7x128xf16>) outs(%1 : memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>) dimensions = [0, 1] + // CHECK: xegpu.create_nd_tdesc %[[MEMREF_0]] + linalg.add ins(%1, %2 : memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>, memref<1x1x7x128xf16>) outs(%3 : memref<1x1x7x128xf16>) + gpu.terminator + } + return +} + +// ----- + +// CHECK-LABEL: func.func @broadcast_eliminate +func.func @broadcast_eliminate_3d() { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + + // CHECK: %[[MEMREF_0:.*]] = memref.alloc() : memref<1x7x128xf16> + %0 = memref.alloc() : memref<1x7x128xf16> + // CHECK: %[[MEMREF_2:.*]] = memref.alloc() : memref<1x1x7x128xf16> + %2 = memref.alloc() : memref<1x1x7x128xf16> + // CHECK: %[[MEMREF_3:.*]] = memref.alloc() : memref<1x1x7x128xf16> + %3 = memref.alloc() : memref<1x1x7x128xf16> + + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c1, %arg16 = %c1) { + // CHECK-NOT: memref.alloc() : memref<4x1x7x128xf16, 3> + %slm_base = memref.alloc() : memref<4x1x7x128xf16, 3> + %1 = memref.subview %slm_base[%arg6, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : memref<4x1x7x128xf16, 3> to memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3> + + // CHECK-NOT: linalg.broadcast + linalg.broadcast ins(%0 : memref<1x7x128xf16>) outs(%1 : memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>) dimensions = [0] + // Squeezing the %0 before passing to 'linalg.add' + // CHECK: %[[MEMREF0_SQUEEZ:.+]] = memref.subview %[[MEMREF_0]][0, 0, 0] [1, 7, 128] [1, 1, 1] : + // CHECK-SAME: memref<1x7x128xf16> to memref<7x128xf16, strided<[128, 1]>> + // CHECK: xegpu.create_nd_tdesc %[[MEMREF0_SQUEEZ]] + linalg.add ins(%1, %2 : memref<1x1x7x128xf16, strided<[896, 896, 128, 1], offset: ?>, 3>, memref<1x1x7x128xf16>) outs(%3 : memref<1x1x7x128xf16>) + gpu.terminator + } + return +} + +// ----- + +// CHECK-LABEL: func.func @complex_broadcast +func.func @complex_broadcast_3d() { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + + // CHECK: %[[MEMREF_0:.*]] = memref.alloc() : memref<7x128xf16> + %0 = memref.alloc() : memref<7x128xf16> + // CHECK: %[[MEMREF_1:.*]] = memref.alloc() : memref<7x7x128xf16> + %1 = memref.alloc() : memref<7x7x128xf16> + // CHECK: %[[MEMREF_2:.*]] = memref.alloc() : memref<7x7x128xf16> + %2 = memref.alloc() : memref<7x7x128xf16> + // CHECK: %[[MEMREF_3:.*]] = memref.alloc() : memref<7x7x128xf16> + %3 = memref.alloc() : memref<7x7x128xf16> + + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c1, %arg16 = %c1) { + // This broadcast can't be replaced by a single memref.subview. Can't remove it + // CHECK: linalg.broadcast + linalg.broadcast ins(%0 : memref<7x128xf16>) outs(%1 : memref<7x7x128xf16>) dimensions = [0] + linalg.add ins(%1, %2 : memref<7x7x128xf16>, memref<7x7x128xf16>) outs(%3 : memref<7x7x128xf16>) + gpu.terminator + } + return +} + +// ----- + +// CHECK-LABEL: func.func @single_broadcast +func.func @single_broadcast() { + %c1 = arith.constant 1 : index + %c2 = arith.constant 2 : index + %c4 = arith.constant 4 : index + + // CHECK: %[[MEMREF_0:.*]] = memref.alloc() : memref<7x128xf16> + %0 = memref.alloc() : memref<7x128xf16> + // CHECK: %[[MEMREF_1:.*]] = memref.alloc() : memref<1x1x7x128xf16> + %1 = memref.alloc() : memref<1x1x7x128xf16> + + gpu.launch blocks(%arg3, %arg4, %arg5) in (%arg11 = %c2, %arg12 = %c4, %arg13 = %c1) threads(%arg6, %arg7, %arg8) in (%arg14 = %c4, %arg15 = %c1, %arg16 = %c1) { + // broadcast result is not an input of any xegpu operation, we can't lower it + // CHECK: linalg.broadcast + linalg.broadcast ins(%0 : memref<7x128xf16>) outs(%1 : memref<1x1x7x128xf16>) dimensions = [0, 1] + gpu.terminator + } + return +} diff --git a/test/mlir/test/gc/gpu-runner/XeGPU/rope_sanity_test.mlir b/test/mlir/test/gc/gpu-runner/XeGPU/rope_sanity_test.mlir new file mode 100644 index 00000000..98a0535c --- /dev/null +++ b/test/mlir/test/gc/gpu-runner/XeGPU/rope_sanity_test.mlir @@ -0,0 +1,95 @@ +// RUN: gc-gpu-runner --shared-libs=%mlir_runner_utils,%mlir_c_runner_utils %s | FileCheck %s + +!dtype=f16 +!input_memref_type=memref<2x7x32x128x!dtype> +!input_tensor_type=tensor<2x7x32x128x!dtype> +!output_memref_type=memref<2x32x7x128x!dtype> +!output_tensor_type=tensor<2x32x7x128x!dtype> +!cos_sin_cache_memref_type=memref<1x1x2048x128x!dtype> +!cos_sin_cache_tensor_type=tensor<1x1x2048x128x!dtype> +!cos_sin_cache_tensor_shrink_type=tensor<1x1x7x128x!dtype> +!pos_ids_memref_type=memref<1x7xindex> +!pos_ids_tensor_type=tensor<1x7xindex> +#map = affine_map<(xi, yi, zi) -> ((xi * 3 * 4 + yi * 4 + zi) * 2)> +module @fragment_name { +memref.global "private" constant @_cos_cache : !cos_sin_cache_memref_type = dense<3.000000e+00> +memref.global "private" constant @_sin_cache : !cos_sin_cache_memref_type = dense<2.000000e+00> +memref.global "private" constant @_iinput_const : !input_memref_type = dense<3.000000e+00> +memref.global "private" constant @_ipos_ids_const : !pos_ids_memref_type = dense<1> +memref.global "private" constant @_ipos_id_end_const : memref<1xindex> = dense<1> +func.func @RoPE(%iinput: !input_memref_type, %ipos_ids: !pos_ids_memref_type, %ipos_id_end: memref<1xindex>, %out: !output_memref_type) { + %input = bufferization.to_tensor %iinput restrict : !input_memref_type + %cos_cache = memref.get_global @_cos_cache : !cos_sin_cache_memref_type + %sin_cache = memref.get_global @_sin_cache : !cos_sin_cache_memref_type + %cos_cache_tensor = bufferization.to_tensor %cos_cache restrict : !cos_sin_cache_memref_type + %sin_cache_tensor = bufferization.to_tensor %sin_cache restrict : !cos_sin_cache_memref_type + %pos_ids = bufferization.to_tensor %ipos_ids restrict : !pos_ids_memref_type + %pos_id_end = bufferization.to_tensor %ipos_id_end restrict : memref<1xindex> + %3 = tensor.empty(): !output_tensor_type + + %transpose_in = linalg.transpose ins(%input: !input_tensor_type) outs(%3:!output_tensor_type) permutation = [0, 2, 1, 3] + + %c0 = arith.constant 0 : index + %c3 = arith.constant 3 : index + %cos_cache_slice = tensor.extract_slice %cos_cache_tensor[0, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : !cos_sin_cache_tensor_type to !cos_sin_cache_tensor_shrink_type + %cos_cache_slice2 = tensor.collapse_shape %cos_cache_slice [[0, 1], [2],[3]] : tensor<1x1x7x128x!dtype> into tensor<1x7x128x!dtype> + %cos_cache_slice3 = tensor.collapse_shape %cos_cache_slice2 [[0, 1], [2]] : tensor<1x7x128x!dtype> into tensor<7x128x!dtype> + %pos_ids_index=tensor.expand_shape %pos_ids [[0],[1,2]] output_shape [1, 7, 1] : tensor<1x7xindex> into tensor<1x7x1xindex> + + %cos_cache_slice4 = tensor.gather %cos_cache_slice3[%pos_ids_index] gather_dims([0]) : (tensor<7x128x!dtype>, tensor<1x7x1xindex>) -> tensor<1x7x128x!dtype> + + %cos_cache_slice5 = tensor.expand_shape %cos_cache_slice4 [[0,1],[2],[3]] output_shape [1,1,7,128] : tensor<1x7x128x!dtype> into tensor<1x1x7x128x!dtype> + %cos_cache_slice6 = tensor.collapse_shape %cos_cache_slice5 [[0,1,2],[3]] : tensor<1x1x7x128x!dtype> into tensor<7x128x!dtype> + + %cos_cache_slice7 = linalg.broadcast ins(%cos_cache_slice6: tensor<7x128x!dtype>) outs(%3: !output_tensor_type) dimensions = [0, 1] + %input_apply_cos_cache = linalg.mul ins(%transpose_in, %cos_cache_slice7: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type + + %head_dim = tensor.dim %transpose_in, %c3 : !output_tensor_type + %c2 = arith.constant 2 : index + %half_head_dim = arith.floordivsi %head_dim, %c2 : index + %transpose_input_first_half = tensor.extract_slice %transpose_in[0, 0, 0, 0][2, 32, 7, 64][1,1,1,1] : !output_tensor_type to tensor<2x32x7x64x!dtype> + %transpose_input_second_half = tensor.extract_slice %transpose_in[0, 0, 0, %half_head_dim][2, 32, 7, 64][1,1,1,1] : !output_tensor_type to tensor<2x32x7x64x!dtype> + %cnegative1 = arith.constant dense<-1.000000e+00> : tensor<2x32x7x64x!dtype> + %empty_tensor = tensor.empty() : tensor<2x32x7x64x!dtype> + %transpose_input_second_half_opposite = linalg.mul ins(%transpose_input_second_half, %cnegative1: tensor<2x32x7x64x!dtype>, tensor<2x32x7x64x!dtype>) outs(%empty_tensor: tensor<2x32x7x64x!dtype>) -> tensor<2x32x7x64x!dtype> + + %transformed_input = tensor.concat dim(3) %transpose_input_second_half_opposite, %transpose_input_first_half : (tensor<2x32x7x64x!dtype>, tensor<2x32x7x64x!dtype>) -> !output_tensor_type + + %sin_cache_slice = tensor.extract_slice %sin_cache_tensor[0, 0, 0, 0] [1, 1, 7, 128] [1, 1, 1, 1] : !cos_sin_cache_tensor_type to !cos_sin_cache_tensor_shrink_type + %sin_cache_slice2 = tensor.collapse_shape %sin_cache_slice [[0, 1], [2],[3]] : tensor<1x1x7x128x!dtype> into tensor<1x7x128x!dtype> + %sin_cache_slice3 = tensor.collapse_shape %sin_cache_slice2 [[0, 1], [2]] : tensor<1x7x128x!dtype> into tensor<7x128x!dtype> + %sin_cache_slice4 = tensor.gather %sin_cache_slice3[%pos_ids_index] gather_dims([0]) : (tensor<7x128x!dtype>, tensor<1x7x1xindex>) -> tensor<1x7x128x!dtype> + + %sin_cache_slice5 = tensor.expand_shape %sin_cache_slice4 [[0,1],[2],[3]] output_shape [1,1,7,128] : tensor<1x7x128x!dtype> into tensor<1x1x7x128x!dtype> + %sin_cache_slice6 = tensor.collapse_shape %sin_cache_slice5 [[0,1,2],[3]] : tensor<1x1x7x128x!dtype> into tensor<7x128x!dtype> + + %sin_cache_slice7 = linalg.broadcast ins(%sin_cache_slice6: tensor<7x128x!dtype>) outs(%3: !output_tensor_type) dimensions = [0, 1] + %input_apply_sin_cache = linalg.mul ins(%transformed_input, %sin_cache_slice7: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type + + %result = linalg.add ins(%input_apply_cos_cache, %input_apply_sin_cache: !output_tensor_type, !output_tensor_type) outs(%3: !output_tensor_type) -> !output_tensor_type + bufferization.materialize_in_destination %result in restrict writable %out : (!output_tensor_type, !output_memref_type) -> () + return +} + +func.func @main() { + %inp = memref.get_global @_iinput_const : !input_memref_type + %ipos_ids = memref.get_global @_ipos_ids_const : !pos_ids_memref_type + %ipos_id_end = memref.get_global @_ipos_id_end_const : memref<1xindex> + + %out = memref.alloc() {alignment = 64 : i64} : !output_memref_type + + func.call @RoPE(%inp, %ipos_ids, %ipos_id_end, %out) : (!input_memref_type, !pos_ids_memref_type, memref<1xindex>, !output_memref_type) -> () + + %out_subview = memref.subview %out[0, 0, 0, 0] [2, 1, 1, 1] [1, 1, 1, 1] : !output_memref_type to memref<2xf16, strided<[28672]>> + %cast = memref.cast %out_subview : memref<2xf16, strided<[28672]>> to memref<*xf16> + call @printMemrefF16(%cast) : (memref<*xf16>) -> () + + return +} + +func.func private @printMemrefF16(%ptr : memref<*xf16>) +} + +// CHECK: Unranked Memref base@{{(0x)?[-0-9a-fA-F]*}} +// CHECK-SAME: rank = 1 offset = 0 sizes = [2] strides = [28672] data = +// CHECK-NEXT: [3, 3]