Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,17 @@ def ApplyVectorContractToPackedTypeDotProductPatternsOp : Op<Transform_Dialect,
let assemblyFormat = "attr-dict";
}

def ApplySinkVectorProducerOpsPatternsOp : Op<Transform_Dialect,
"apply_patterns.x86vector.sink_vector_producer_ops",
[DeclareOpInterfaceMethods<PatternDescriptorOpInterface>]> {
let description = [{
Collect patterns to sink vector producer operations forward in a block to
place them immediately before their first use.
}];

let assemblyFormat = "attr-dict";
}


#endif // X86VECTOR_TRANSFORM_OPS

4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/X86Vector/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,10 @@ void populateVectorContractToFMAPatterns(RewritePatternSet &patterns);
void populateVectorContractToPackedTypeDotProductPatterns(
RewritePatternSet &patterns);

// Performs forward scheduling of vector producer ops to minimize their live
// range by placing them at their earliest legal use site
void populateSinkVectorProducerOpsPatterns(RewritePatternSet &patterns);

//===----------------------------------------------------------------------===//
/// Helpers extracted from:
/// - clang/lib/Headers/avxintrin.h
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ void mlir::transform::ApplyVectorContractToPackedTypeDotProductPatternsOp::
x86vector::populateVectorContractToPackedTypeDotProductPatterns(patterns);
}

void mlir::transform::ApplySinkVectorProducerOpsPatternsOp::populatePatterns(
RewritePatternSet &patterns) {
x86vector::populateSinkVectorProducerOpsPatterns(patterns);
}

//===----------------------------------------------------------------------===//
// Transform op registration
//===----------------------------------------------------------------------===//
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/X86Vector/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ add_mlir_dialect_library(MLIRX86VectorTransforms
LegalizeForLLVMExport.cpp
VectorContractToFMA.cpp
VectorContractToPackedTypeDotProduct.cpp
SinkVectorProducerOps.cpp

LINK_LIBS PUBLIC
MLIRArithDialect
Expand Down
111 changes: 111 additions & 0 deletions mlir/lib/Dialect/X86Vector/Transforms/SinkVectorProducerOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
//===- SinkVectorProducerOps.cpp ------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/Dialect/Vector/Utils/VectorUtils.h"
#include "mlir/Dialect/X86Vector/Transforms.h"
#include "mlir/Dialect/X86Vector/X86VectorDialect.h"

#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/PatternMatch.h"

#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

using namespace mlir;
using namespace mlir::vector;
using namespace mlir::x86vector;

/// Sink vector producers forward to reduce live ranges.
/// This pattern applies to ops such as vector.load and vector.transfer_read.
template <typename producerOp>
struct SinkVectorProducerOps final : public OpRewritePattern<producerOp> {
using OpRewritePattern<producerOp>::OpRewritePattern;

LogicalResult matchAndRewrite(producerOp op,
PatternRewriter &rewriter) const override {

// Collect all users of the producer op.
llvm::SmallVector<Operation *> opUsers;
for (OpResult result : op->getResults())
for (Operation *user : result.getUsers())
opUsers.push_back(user);

// If there are no users, nothing to sink.
if (opUsers.empty())
return failure();

// If the next op is already a user, do not move.
Operation *nextOp = op->getNextNode();
if (llvm::is_contained(opUsers, nextOp))
return failure();

// Prevent pathological looping:
// If two producers are used by same consumer, will end in looping of
// moving the producers.
// For example:
// %1 = prod1
// %2 = prod2
// %3 = op %1, %2
llvm::SmallVector<Operation *> nextOpUsers;
for (OpResult result : nextOp->getResults())
for (Operation *user : result.getUsers())
nextOpUsers.push_back(user);

// Both producers have one same users.
if (opUsers.size() == 1 && nextOpUsers.size() != 1 &&
llvm::is_contained(opUsers, nextOpUsers.front()))
return failure();

// Get the first user of both the current and next operation.
Operation *opFirstUser = op->getNextNode();
Operation *nextOpFirstUser = op->getNextNode();

while (opFirstUser) {
if (llvm::is_contained(opUsers, opFirstUser))
break;

opFirstUser = opFirstUser->getNextNode();
}

while (nextOpFirstUser) {
if (llvm::is_contained(nextOpUsers, nextOpFirstUser))
break;

nextOpFirstUser = nextOpFirstUser->getNextNode();
}

if (!opFirstUser)
return failure();

// The Op first user and next Op first user are same. Break here to
// to avoid the shift cycle looping.
if (opFirstUser == nextOpFirstUser)
return failure();

// Both ops must be in the same block to safely move.
if (op->getBlock() != opFirstUser->getBlock())
return failure();

// Move producer immediately before its first user.
op->moveBefore(opFirstUser);

// Move the nextOp to its first user
if (nextOpFirstUser && (nextOpFirstUser->getBlock() == nextOp->getBlock()))
nextOp->moveBefore(nextOpFirstUser);

return success();
}
};

void x86vector::populateSinkVectorProducerOpsPatterns(
RewritePatternSet &patterns) {
patterns.add<SinkVectorProducerOps<vector::TransferReadOp>,
SinkVectorProducerOps<vector::LoadOp>>(patterns.getContext());
}
199 changes: 199 additions & 0 deletions mlir/test/Dialect/X86Vector/sink-vector-producer-ops.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
// RUN: mlir-opt %s -transform-interpreter -cse -split-input-file | FileCheck %s

func.func @sink_vector_loads(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%0 = vector.load %arg0[%c0, %c0] : memref<16x16xf32>, vector<8xf32>
%1 = vector.load %arg0[%c0, %c8] : memref<16x16xf32>, vector<8xf32>
%2 = vector.load %arg0[%c8, %c0] : memref<16x16xf32>, vector<8xf32>
%3 = vector.load %arg0[%c8, %c8] : memref<16x16xf32>, vector<8xf32>
%4 = vector.fma %0, %1, %arg1 : vector<8xf32>
%5 = vector.fma %2, %3, %4 : vector<8xf32>
return %5 : vector<8xf32>
}

// CHECK-LABEL: @sink_vector_loads
// CHECK: vector.load
// CHECK-NEXT: vector.load
// CHECK-NEXT: vector.fma
// CHECK-NEXT: vector.load
// CHECK-NEXT: vector.load
// CHECK-NEXT: vector.fma

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops
} : !transform.any_op
transform.yield
}
}

// -----

func.func @sink_vector_transfer_reads(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%0 = ub.poison : f32
%1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
%2 = vector.transfer_read %arg0[%c0, %c8], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
%3 = vector.transfer_read %arg0[%c8, %c0], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
%4 = vector.transfer_read %arg0[%c8, %c8], %0 {in_bounds = [true]} : memref<16x16xf32>, vector<8xf32>
%5 = vector.fma %1, %2, %arg1 : vector<8xf32>
%6 = vector.fma %3, %4, %5 : vector<8xf32>
return %6 : vector<8xf32>
}

// CHECK-LABEL: @sink_vector_transfer_reads
// CHECK: vector.transfer_read
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: vector.fma
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: vector.fma

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops
} : !transform.any_op
transform.yield
}
}

// -----

func.func @sink_vector_transfer_reads_tensor(%arg0: tensor<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%0 = ub.poison : f32
%1 = vector.transfer_read %arg0[%c0, %c0], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
%2 = vector.transfer_read %arg0[%c0, %c8], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
%3 = vector.transfer_read %arg0[%c8, %c0], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
%4 = vector.transfer_read %arg0[%c8, %c8], %0 {in_bounds = [true]} : tensor<16x16xf32>, vector<8xf32>
%5 = vector.fma %1, %2, %arg1 : vector<8xf32>
%6 = vector.fma %3, %4, %5 : vector<8xf32>
return %6 : vector<8xf32>
}

// CHECK-LABEL: @sink_vector_transfer_reads_tensor
// CHECK: vector.transfer_read
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: vector.fma
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: vector.fma

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops
} : !transform.any_op
transform.yield
}
}

// -----

#map = affine_map<(d0, d1, d2, d3, d4) -> (d0, d2, d4, d1)>
#map1 = affine_map<(d0, d1, d2, d3, d4) -> (d0, d4, d3, d1)>
#map2 = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3)>

func.func @sink_vector_transfer_reads_bf16(%arg0: tensor<4x64x32x2xbf16>, %arg1: tensor<4x32x64x2xbf16>, %arg2: vector<1x16xf32>) -> vector<1x16xf32> {
%0 = ub.poison : bf16
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c16 = arith.constant 16 : index
%extracted_slice = tensor.extract_slice %arg0[%c0, %c0, %c0, 0] [1, 4, 1, 2] [1, 1, 1, 1] : tensor<4x64x32x2xbf16> to tensor<1x4x1x2xbf16>
%extracted_slice_0 = tensor.extract_slice %arg1[%c0, %c0, %c0, 0] [1, 1, 32, 2] [1, 1, 1, 1] : tensor<4x32x64x2xbf16> to tensor<1x1x32x2xbf16>
%1 = vector.transfer_read %extracted_slice[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16>
%2 = vector.transfer_read %extracted_slice[%c0, %c1, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x4x1x2xbf16>, vector<1x1x1x2xbf16>
%3 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c0, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16>
%4 = vector.transfer_read %extracted_slice_0[%c0, %c0, %c16, %c0], %0 {in_bounds = [true, true, true, true]} : tensor<1x1x32x2xbf16>, vector<1x1x16x2xbf16>
%5 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %3, %arg2 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
%6 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %1, %4, %5 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
%7 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %3, %6 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
%8 = vector.contract {indexing_maps = [#map, #map1, #map2], iterator_types = ["reduction", "reduction", "parallel", "parallel", "reduction"], kind = #vector.kind<add>} %2, %4, %7 {unroll_shape = array<i64: 1, 2, 1, 16, 1>} : vector<1x1x1x2xbf16>, vector<1x1x16x2xbf16> into vector<1x16xf32>
return %8 : vector<1x16xf32>
}

// CHECK-LABEL: @sink_vector_transfer_reads_bf16
// CHECK: vector.transfer_read
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: vector.contract
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: vector.contract
// CHECK-NEXT: vector.transfer_read
// CHECK-NEXT: vector.contract
// CHECK-NEXT: vector.contract

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops
} : !transform.any_op
transform.yield
}
}

// -----

func.func @negative_no_infinite_looping(%arg0: memref<16x16xf32>, %arg1: vector<8xf32>) -> vector<8xf32> {
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%0 = vector.load %arg0[%c0, %c0] : memref<16x16xf32>, vector<8xf32>
%1 = vector.load %arg0[%c0, %c8] : memref<16x16xf32>, vector<8xf32>
%2 = vector.fma %0, %1, %arg1 : vector<8xf32>
return %2: vector<8xf32>
}

// CHECK-LABEL: @negative_no_infinite_looping
// CHECK: vector.load
// CHECK-NEXT: vector.load
// CHECK-NEXT: vector.fma

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops
} : !transform.any_op
transform.yield
}
}

// -----

func.func @negative_no_sink_outside_block(%arg0: memref<8x16xf32>, %arg1: i1) -> vector<8xf32> {
%c0 = arith.constant 0 : index
%c8 = arith.constant 8 : index
%0 = vector.load %arg0[%c0, %c0] : memref<8x16xf32>, vector<8xf32>
%1 = vector.load %arg0[%c0, %c8] : memref<8x16xf32>, vector<8xf32>
%2 = scf.if %arg1 -> (vector<8xf32>) {
scf.yield %0 : vector<8xf32>
} else {
scf.yield %1 : vector<8xf32>
}
return %2 : vector<8xf32>
}

// CHECK-LABEL: @negative_no_sink_outside_block
// CHECK: vector.load
// CHECK-NEXT: vector.load
// CHECK-NEXT: scf.if
// CHECK-NEXT: scf.yield

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["func.func"]} in %arg0 : (!transform.any_op) -> !transform.any_op
transform.apply_patterns to %0 {
transform.apply_patterns.x86vector.sink_vector_producer_ops
} : !transform.any_op
transform.yield
}
}