Skip to content

Conversation

@AviadCo
Copy link
Contributor

@AviadCo AviadCo commented Feb 2, 2025

No description provided.

@llvmbot
Copy link
Member

llvmbot commented Feb 2, 2025

@llvm/pr-subscribers-mlir-linalg
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Aviad Cohen (AviadCo)

Changes

Full diff: https://github.com/llvm/llvm-project/pull/125401.diff

4 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h (+28)
  • (modified) mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt (+1)
  • (modified) mlir/test/lib/Dialect/Linalg/CMakeLists.txt (+1)
  • (modified) mlir/tools/mlir-opt/mlir-opt.cpp (+2)
diff --git a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
index 1dc42f71e10eff..4b325aaeab87ca 100644
--- a/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
+++ b/mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
@@ -1893,6 +1893,34 @@ void populateDecomposeWinogradOpsPatterns(RewritePatternSet &patterns);
 /// convert to a `linalg.dot`.
 void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);
 
+/// Add patterns to fuse a linalg fill operation with a linalg operation.
+/// Add patterns to fold linalg.fill into linalg.reduce by creating a fused
+/// linalg.generic operation.
+/// The fill operation is expected to happen only on the first index
+/// of the reduction dimension. Currently only one reduction dimension is
+/// supported. Given the pattern:
+///   %empty = tensor.empty() : tensor<i8>
+///   %filled = linalg.fill ins(%c0 : i8) outs(%empty : tensor<i8>) ->
+///   tensor<i8> %reduced = linalg.reduce ins(%0 : tensor<147456xi8>)
+///   outs(%filled : tensor<i8>) dimensions = [0]
+///     (%in: i8, %init: i8) {
+///       %3 = arith.addi %in, %init : i8
+///       linalg.yield %3 : i8
+///   }
+/// The pattern is rewritten into:
+///   %empty = tensor.empty() : tensor<i8>
+///   %reduced = linalg.generic ins(%0 : tensor<147456xi8>) outs(%empty :
+///   tensor<i8>) {
+///     ^bb0(%in: i8, %init: i8):
+///       %cst = arith.constant 0 : index
+///       %index = linalg.index %c0 : index
+///       %cmp = arith.cmpi eq, %cst, %index : i1
+///       %sum = arith.select %cmp, %c0, %init : i8
+///       %res = arith.addi %in, %sum : i8
+///       linalg.yield %res : i8
+///   }
+void populateFuseFillOpWithReduceOpPatterns(RewritePatternSet &patterns);
+
 } // namespace linalg
 } // namespace mlir
 
diff --git a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
index 3594b084138124..cace3dcb6cbfca 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
+++ b/mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
@@ -15,6 +15,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
   EraseUnusedOperandsAndResults.cpp
   FoldAddIntoDest.cpp
   FusePadOpWithLinalgProducer.cpp
+  FuseFillOpWithReduceOp.cpp
   Fusion.cpp
   Generalization.cpp
   Hoisting.cpp
diff --git a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
index eb6f581252181a..2c2cef60428743 100644
--- a/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
+++ b/mlir/test/lib/Dialect/Linalg/CMakeLists.txt
@@ -6,6 +6,7 @@ add_mlir_library(MLIRLinalgTestPasses
   TestLinalgElementwiseFusion.cpp
   TestLinalgFusionTransforms.cpp
   TestLinalgRankReduceContractionOps.cpp
+  TestLinalgFuseFillOpWithReduceOp.cpp
   TestLinalgTransforms.cpp
   TestPadFusion.cpp
 
diff --git a/mlir/tools/mlir-opt/mlir-opt.cpp b/mlir/tools/mlir-opt/mlir-opt.cpp
index 74007d01347ae8..7e92095ff2fae7 100644
--- a/mlir/tools/mlir-opt/mlir-opt.cpp
+++ b/mlir/tools/mlir-opt/mlir-opt.cpp
@@ -111,6 +111,7 @@ void registerTestLinalgDropUnitDims();
 void registerTestLinalgElementwiseFusion();
 void registerTestLinalgGreedyFusion();
 void registerTestLinalgRankReduceContractionOps();
+void registerTestLinalgFuseFillOpWithReduceOp();
 void registerTestLinalgTransforms();
 void registerTestLivenessAnalysisPass();
 void registerTestLivenessPass();
@@ -251,6 +252,7 @@ void registerTestPasses() {
   mlir::test::registerTestLinalgElementwiseFusion();
   mlir::test::registerTestLinalgGreedyFusion();
   mlir::test::registerTestLinalgRankReduceContractionOps();
+  mlir::test::registerTestLinalgFuseFillOpWithReduceOp();
   mlir::test::registerTestLinalgTransforms();
   mlir::test::registerTestLivenessAnalysisPass();
   mlir::test::registerTestLivenessPass();

@AviadCo AviadCo force-pushed the linalg/fuseFillReduce branch from 5ebcb8c to 56afd0f Compare February 2, 2025 12:51
/// convert to a `linalg.dot`.
void populateContractionOpRankReducingPatterns(RewritePatternSet &patterns);

/// Add patterns to fuse a linalg fill operation with a linalg operation.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general this is not the prefered way of fusing a fill with a reduction. The preferred way is to use tile + fuse approach to fuse at a tile granularity (since fusing fill with its consumer reduction operations results in an inherently imperfectly nested loop computation). THe main issue here is this adds a conditional to the innermost loop computation which isnt what is generally performant.
But this seems still valid. Could you add some comments explaining the alternatives.

Copy link
Contributor

@MaheshRavishankar MaheshRavishankar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marking request changes for documentation request now. Will review the change itself in a bit.

@AviadCo
Copy link
Contributor Author

AviadCo commented Feb 3, 2025

Hey @MaheshRavishankar ,
We are using tile and fuse transform in addition to the regular linalg fusion.
When lowered to loops, I would like to acheive the following pseudo-code:

for (int i = 0; i < N; ++i) {
    int sum = 0;
    for (j = 0; j < M; ++j) {
        int x = load[i][j];
        sum += x;
    }
}

I can lower the linalg.generic of fill and reduce into loops and do it then but at that point the flow is much more complicated to identify the pattern of fill + reduce.
This pattern is useful for us although it causes none fully nested loops.
Moreover, our HW know how to handle such calculations although it is not fully nested loops.

I do have some patterns to optimize the none nested loops after the lowering but it is more HW specific.

I think that some other people might use this tranformation pattern as well as I do.

Comment on lines +1916 to +1917
/// %index = linalg.index %c0 : index
/// %cmp = arith.cmpi eq, %cst, %index : i1
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The iteration order of reduction and parallel iterators in a linalg operation is undefined. I don't think you can assume that the first iteration of the iterator is 0.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be clear, I don't think you can do this fusion while adding linalg.index to the body, because that would mean you are assuming the first iteration index to be something.

@Groverkss
Copy link
Member

Groverkss commented Feb 3, 2025

Hey @MaheshRavishankar , We are using tile and fuse transform in addition to the regular linalg fusion. When lowered to loops, I would like to acheive the following pseudo-code:

for (int i = 0; i < N; ++i) {
    int sum = 0;
    for (j = 0; j < M; ++j) {
        int x = load[i][j];
        sum += x;
    }
}

I can lower the linalg.generic of fill and reduce into loops and do it then but at that point the flow is much more complicated to identify the pattern of fill + reduce. This pattern is useful for us although it causes none fully nested loops. Moreover, our HW know how to handle such calculations although it is not fully nested loops.

I do have some patterns to optimize the none nested loops after the lowering but it is more HW specific.

I think that some other people might use this tranformation pattern as well as I do.

I don't think doing this at the linalg.generic level is correct, because you cannot assume a single iteration order. IIUC, what you are trying to do here, is tile the following example on parallel and reduction dims:

%empty = linalg.fill
linalg.generic ins(...) outs(%empty) { iterator_types = [parallel, reduction] }
  1. TileAndFuse along the reduction dimension (The fill will not fuse, because it doesnt have the reduction dimension):
%empty = linalg.fill
scf.for %j = 0 to ... init_args(%arg0 = %empty) {
  %out = linalg.generic ins(...) outs(%arg0)
  yield %out
}
  1. TileAndFuse along the reduction dimension (The fill will fuse, because it has the parallel dimension):
scf.for %i = 0 to ... {
   %empty = linalg.fill
   scf.for %j = 0 to ... init_args(%arg0 = %empty) {
    %out = linalg.generic ins(...) outs(%arg0)
    yield %out
  }
}
  1. This is the form you were looking for. Now, if you still want a perfectly nested loop form, you can write a loop sinking pass. You can sink any init_args (with value based semantics atleast, i.e. tensors in this case) with an if condition:
%em = tensor.empty()
scf.for %i = 0 to ... {
   scf.for %j = 0 to ... init_args(%arg0 = %em) {
    %filled = scf.if (%j == 0) init_args(%arg1 = %arg0) {
       %fill = linalg.fill outs(%arg1)
       yield %fill
    } else {
      yield %arg1
    }
    linalg.generic ins(...) outs(%filled)
  }
}

@MaheshRavishankar
Copy link
Contributor

To be clear, I don't think you can do this fusion while adding linalg.index to the body, because that would mean you are assuming the first iteration index to be something.

Thanks @Groverkss . That is correct. For a moment there I forgot this and went into a rabbit hole of "why do we not do this again".

btw,

for (int i = 0; i < N; ++i) {
int sum = 0;
for (j = 0; j < M; ++j) {
int x = load[i][j];
sum += x;
}
}

This example is kind of interesting... you are initializing sum on every i iteration, so you will only get the value for the N-1th iteration. I dont think this is what you meant. You probably meant


for (int i = 0; i < N; ++i) {
    int sum = 0;
    for (j = 0; j < M; ++j) {
        int x = load[i][j];
        sum += x;
    }
   output[i] = sum;
}

That is this sequence

%cst = arith.constant 0.0 : f32
%empty = tensor.empty(%N) : tensor<?xf32> 
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<?xf32>)
%generic = linalg.generic {
    iterator_types = ["parallel", "reduction"],
    indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>]}
    ins(%load : tensor<?x?xf32>) outs(%fill : tensor<?xf32>) {
  ^bb0(%b0 : f32, %b1 : f32):
    %0 = arith.addf %b0, %b1 : f32
    linalg.yield %0 : f32
} -> tensor<?xf32>

So to get the final loop sequence, you

  1. Tile the outer loop by 1
%cst = arith.constant 0.0 : f32
%empty = tensor.empty(%N) : tensor<?xf32> 
%fill = linalg.fill ins(%cst : f32) outs(%empty : tensor<?xf32>) -> tensor<?xf32>
%result = scf.for %iv0 = 0 to %N step 1 outs(%init = %fill) {
  %slice = tensor.extract_slice %load[%iv0, 0][1, %N][1, 1]: tensor<?x?xf32> to tensor<1x?xf32>
  %outs = tensor.extract_slice %fill[%iv0][1][1]: tensor<?xf32> to tensor<1xf32>
  %generic = linalg.generic {
      iterator_types = ["parallel", "reduction"],
      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>]}
      ins(%slice : tensor<1x?xf32>) outs(%outs : tensor<1xf32>) {
    ^bb0(%b0 : f32, %b1 : f32):
      %0 = arith.addf %b0, %b1 : f32
      linalg.yield %0 : f32
  } -> tensor<1xf32>
  %inserted = tensor.insert_slice %generic into %init[%iv0][1][1] : tensor<1xf32> into tensor<?xf32>
  scf.yield %inserted : tensor<?xf32>
} -> tensor<?xf32>
  1. Fuse the fill in
%cst = arith.constant 0.0 : f32
%empty = tensor.empty(%N) : tensor<?xf32> 
%result = scf.for %iv0 = 0 to %N step 1 outs(%init = %empty) {
  %outs = tensor.extract_slice %init[%iv0][1][1]: tensor<?xf32> to tensor<1xf32>
  %fill = linalg.fill ins(%cst : f32) outs(%outs : tensor<?xf32>) -> tensor<?xf32>
  %slice = tensor.extract_slice %load[%iv0, 0][1, %N][1, 1]: tensor<?x?xf32> to tensor<1x?xf32>
  %generic = linalg.generic {
      iterator_types = ["parallel", "reduction"],
      indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>]}
      ins(%slice : tensor<1x?xf32>) outs(%fill : tensor<1xf32>) {
    ^bb0(%b0 : f32, %b1 : f32):
      %0 = arith.addf %b0, %b1 : f32
      linalg.yield %0 : f32
  } -> tensor<1xf32>
  %inserted = tensor.insert_slice %generic into %init[%iv0][1][1] : tensor<1xf32> into tensor<?xf32>
  scf.yield %inserted : tensor<?xf32>
} -> tensor<?xf32>
  1. Tile the linalg.generic reduction dimension by 1
%cst = arith.constant 0.0 : f32
%empty = tensor.empty(%N) : tensor<?xf32> 
%result = scf.for %iv0 = 0 to %N step 1 outs(%init = %empty) {
  %outs = tensor.extract_slice %init[%iv0][1][1]: tensor<?xf32> to tensor<1xf32>
  %fill = linalg.fill ins(%cst : f32) outs(%outs : tensor<?xf32>) -> tensor<?xf32>
  %slice = tensor.extract_slice %load[%iv0, 0][1, %N][1, 1]: tensor<?x?xf32> to tensor<1x?xf32>
  %reduction = scf.for %iv1 = 0 to %M step 1 outs(%init0 = %fill) {
    %slice0 = tensor.extract_slice %slice[0, %iv1][1, 1][1, 1] : tensor<1x?xf32> to tensor<1x1xf32>
    %generic = linalg.generic {
        iterator_types = ["parallel", "reduction"],
        indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>]}
        ins(%slice0 : tensor<1x1xf32>) outs(%init0 : tensor<1xf32>) {
      ^bb0(%b0 : f32, %b1 : f32):
        %0 = arith.addf %b0, %b1 : f32
        linalg.yield %0 : f32
    } -> tensor<1xf32>
    %inserted0 = tensor.insert_slice %generic into %init0[0][1][1] : tensor<1xf32> into tensor<1xf32>
    scf.yield %inserted0 : tensor<1xf32>
  }
  %inserted = tensor.insert_slice %generic into %init[%iv0][1][1] : tensor<1xf32> into tensor<?xf32>
  scf.yield %inserted : tensor<1xf32>
} -> tensor<?xf32>

Now bufferization + lowering to loops will give you what you expect.

@AviadCo
Copy link
Contributor Author

AviadCo commented Feb 4, 2025

@MaheshRavishankar @Groverkss
I really appreciate your deep review and suggestions. I agree that this transform as it is should not be merged.
I think that linalg could benefit from optinal "order" attribute (in our case we are not allowed to change order on floats for example and must use them serially).

%cst = arith.constant 0.0 : f32

%empty = tensor.empty(%N) : tensor<?xf32>
%result = scf.for %iv0 = 0 to %N step 1 outs(%init = %empty) {
%outs = tensor.extract_slice %init[%iv0][1][1]: tensor<?xf32> to tensor<1xf32>
%fill = linalg.fill ins(%cst : f32) outs(%outs : tensor<?xf32>) -> tensor<?xf32>
%slice = tensor.extract_slice %load[%iv0, 0][1, %N][1, 1]: tensor<?x?xf32> to tensor<1x?xf32>
%reduction = scf.for %iv1 = 0 to %M step 1 outs(%init0 = %fill) {
%slice0 = tensor.extract_slice %slice[0, %iv1][1, 1][1, 1] : tensor<1x?xf32> to tensor<1x1xf32>
%generic = linalg.generic {
iterator_types = ["parallel", "reduction"],
indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0)>]}
ins(%slice0 : tensor<1x1xf32>) outs(%init0 : tensor<1xf32>) {
^bb0(%b0 : f32, %b1 : f32):
%0 = arith.addf %b0, %b1 : f32
linalg.yield %0 : f32
} -> tensor<1xf32>
%inserted0 = tensor.insert_slice %generic into %init0[0][1][1] : tensor<1xf32> into tensor<1xf32>
scf.yield %inserted0 : tensor<1xf32>
}
%inserted = tensor.insert_slice %generic into %init[%iv0][1][1] : tensor<1xf32> into tensor<?xf32>
scf.yield %inserted : tensor<1xf32>
} -> tensor<?xf32>

Unfortunately, this flow makes the final linalg.generic too naive (works on one element) and our general flow depends on the fact the linalg.generic is the actual heavy compute.

We do use FuseAndTile pattern and we do co-tile for the linalg.fill and linalg.reduce , I will try to do the fusion down the road where those operations are already lowered to loops.

@AviadCo AviadCo closed this Feb 4, 2025
@MaheshRavishankar
Copy link
Contributor

Unfortunately, this flow makes the final linalg.generic too naive (works on one element) and our general flow depends on the fact the linalg.generic is the actual heavy compute.

This was an example. You can use any tile size you want. I wrote that just to show the loop structure that it would generate with tile and fuse.
You might want to do fusion of loops, but it is strictly harder. If you want more suggestions based on what you are trying to do, please start a discourse thread with your needs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:core MLIR Core Infrastructure mlir:linalg mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants