Skip to content

Commit b593976

Browse files
authored
[Transform][Fusion] yield fused producer if necessary (#289)
1 parent 9326650 commit b593976

File tree

2 files changed

+106
-5
lines changed

2 files changed

+106
-5
lines changed

lib/gc/Transforms/TilingUsingInterfaceX.cpp

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
#include "mlir/IR/PatternMatch.h"
2222
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
2323
#include "mlir/Interfaces/TilingInterface.h"
24+
#include "mlir/Transforms/RegionUtils.h"
2425
#include "llvm/ADT/TypeSwitch.h"
2526
#include "llvm/Support/Debug.h"
2627
#include <optional>
@@ -255,6 +256,28 @@ SmallVector<LoopLikeOpInterface> mlir::scfX::getOuterNestLoopsWhile(
255256
return {nestLoops.rbegin(), nestLoops.rend()};
256257
}
257258

259+
/// A listener that watches which ops were erased.
260+
struct ErasedOpListener : public RewriterBase::Listener {
261+
private:
262+
/// Pointers to all erased operations and blocks.
263+
DenseSet<void *> erased;
264+
// Hook old listener.
265+
OpBuilder::Listener *oldListenerHook = nullptr;
266+
267+
public:
268+
ErasedOpListener() = default;
269+
ErasedOpListener(OpBuilder::Listener *oldListener)
270+
: oldListenerHook(oldListener) {}
271+
void notifyOperationErased(Operation *op) override {
272+
// Call old listener hook.
273+
if (auto *oldListener =
274+
dyn_cast_if_present<RewriterBase::Listener>(oldListenerHook))
275+
oldListener->notifyOperationErased(op);
276+
erased.insert(op);
277+
}
278+
bool isErased(Operation *op) { return erased.count(op); }
279+
};
280+
258281
/// Enhanced version of `tileAndFuseProducerOfSliceImpl`, which can deal with
259282
/// multi-level `extractSliceOp`. E.g.
260283
///
@@ -296,6 +319,55 @@ mlir::scfX::tileAndFuseProducerOfSlice(RewriterBase &rewriter,
296319
tileAndFuseProducerOfSliceImpl(rewriter, sliceOp, outerLoops);
297320
if (!fuseProducerResult)
298321
return std::nullopt;
322+
323+
// Cache old listener.
324+
OpBuilder::Listener *oldListener = rewriter.getListener();
325+
// Set new listener.
326+
ErasedOpListener newListener = ErasedOpListener(oldListener);
327+
rewriter.setListener(&newListener);
328+
329+
auto producerOp =
330+
cast<TilingInterface>(fuseProducerResult->origProducer.getDefiningOp());
331+
unsigned resultNumber = fuseProducerResult->origProducer.getResultNumber();
332+
// cache candidate slice
333+
auto extractSliceOp = cast<tensor::ExtractSliceOp>(candidateSliceOp);
334+
SmallVector<OpFoldResult> offsets = extractSliceOp.getMixedOffsets(),
335+
sizes = extractSliceOp.getMixedSizes(),
336+
strides = extractSliceOp.getMixedStrides();
337+
// Explicitly execute DCE.
338+
(void)mlir::simplifyRegions(rewriter, {*producerOp->getParentRegion()});
339+
// If fused producer has multiple users.
340+
bool yieldReplacement = !newListener.isErased(producerOp);
341+
// Reset to old listener.
342+
rewriter.setListener(oldListener);
343+
344+
if (yieldReplacement) {
345+
OpBuilder::InsertionGuard g(rewriter);
346+
// Set insertPoint right before tiled op.
347+
rewriter.setInsertionPoint(fuseProducerResult->tiledOps[0]);
348+
// Manually clone new candidate slice.
349+
auto clonedExtractSliceOp = rewriter.create<tensor::ExtractSliceOp>(
350+
producerOp->getLoc(), producerOp->getResult(resultNumber), offsets,
351+
sizes, strides);
352+
// Yield replacement for fused producer in avoid of repeated computation.
353+
if (failed(scf::yieldReplacementForFusedProducer(
354+
rewriter, clonedExtractSliceOp, fuseProducerResult.value(),
355+
outerLoops)))
356+
return std::nullopt;
357+
// Erase cloned candidate slice.
358+
rewriter.eraseOp(clonedExtractSliceOp);
359+
360+
unsigned loopNumResults = outerLoops.front()->getNumResults(),
361+
producerNumResults = producerOp->getNumResults();
362+
// Replace other users of fused producer with new loop results.
363+
for (auto &&[index, result] : llvm::enumerate(producerOp->getResults())) {
364+
rewriter.replaceAllUsesWith(
365+
result, outerLoops.front()->getResult(loopNumResults -
366+
producerNumResults + index));
367+
}
368+
// Erase fused producer op.
369+
rewriter.eraseOp(producerOp);
370+
}
299371
}
300372
return fuseProducerResult;
301373
}

test/mlir/test/gc/Transforms/iterative-tiling-and-fusion.mlir

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -381,11 +381,11 @@ module {
381381
// -----
382382

383383
module {
384-
// CHECK: func.func @fuse_generic_matmul(
385-
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
386-
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x16x16xf32>
387-
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<4x16x16xf32>
388-
func.func @fuse_generic_matmul(%arg0: tensor<32x32xf32>, %arg1: tensor<2x16x16xf32>, %arg2: tensor<4x16x16xf32>) -> tensor<32x64xf32> attributes {llvm.emit_c_interface} {
384+
/// CHECK-LABEL: @fuse_generic_matmul
385+
/// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<32x32xf32>
386+
/// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<2x16x16xf32>
387+
/// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<4x16x16xf32>
388+
func.func @fuse_generic_matmul(%arg0: tensor<32x32xf32>, %arg1: tensor<2x16x16xf32>, %arg2: tensor<4x16x16xf32>) -> tensor<32x64xf32> {
389389
/// CHECK: %[[EMPTY_OUT_0:.*]] = tensor.empty
390390
%0 = tensor.empty() : tensor<2x2x16x16xf32>
391391
%pack = tensor.pack %arg0 outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 16] into %0 : tensor<32x32xf32> -> tensor<2x2x16x16xf32>
@@ -429,4 +429,33 @@ module {
429429
/// CHECK: return %[[FINAL_RESULT]]#1
430430
return %unpack : tensor<32x64xf32>
431431
}
432+
}
433+
434+
// -----
435+
436+
module {
437+
/// CHECK-LABEL: @yield_fused_producer
438+
func.func @yield_fused_producer(%arg0: tensor<16x32x32xf32>) -> (tensor<16x32x32xf32>, tensor<16x32xf32>) {
439+
/// CHECK: arith.constant
440+
%cst_0 = arith.constant dense<2.000000e+00> : tensor<16x32x32xf32>
441+
/// CHECK-NEXT: tensor.empty
442+
%dest0 = tensor.empty() : tensor<16x32x32xf32>
443+
%0 = linalg.powf ins(%arg0, %cst_0 : tensor<16x32x32xf32>, tensor<16x32x32xf32>) outs(%dest0 : tensor<16x32x32xf32>) -> tensor<16x32x32xf32>
444+
/// CHECK-NEXT: tensor.empty
445+
%dest1 = tensor.empty() : tensor<16x32xf32>
446+
/// CHECK-NEXT: %[[FINAL_RESULT:.*]]:2 = scf.forall (%{{.*}}) in (16)
447+
/// CHECK-NEXT: tensor.extract_slice
448+
/// CHECK-NEXT: tensor.extract_slice
449+
/// CHECK-NEXT: tensor.extract_slice
450+
/// CHECK-NEXT: linalg.powf
451+
/// CHECK-NEXT: tensor.extract_slice
452+
/// CHECK-NEXT: linalg.reduce
453+
%1 = linalg.reduce { arith.addf } ins(%0 : tensor<16x32x32xf32>) outs(%dest1 : tensor<16x32xf32>) dimensions = [2]
454+
/// CHECK-NEXT: scf.forall.in_parallel
455+
/// CHECK-NEXT: tensor.parallel_insert_slice
456+
/// CHECK-NEXT: tensor.parallel_insert_slice
457+
/// CHECK-NEXT: }
458+
/// CHECK: return %[[FINAL_RESULT]]#1, %[[FINAL_RESULT]]#0
459+
return %0, %1 : tensor<16x32x32xf32>, tensor<16x32xf32>
460+
}
432461
}

0 commit comments

Comments
 (0)