@@ -25,6 +25,47 @@ static bool isScalarOrTensorOfSizeOne(Type t) {
2525 return t.isIntOrIndexOrFloat ();
2626}
2727
28+ // / This function checks whether the `genericOp` has any external captures,
29+ // / i.e., whether it uses any values that are defined outside of its body.
30+ // / %10 = linalg.generic {indexing_maps = [#map, #map],
31+ // / iterator_types = ["parallel", "parallel"]}
32+ // / ins(%5 : tensor<4096x64xi64>) outs(%9 : tensor<4096x64xf16>) {
33+ // / ^bb0(%in: i64, %out: f16):
34+ // / %14 = linalg.index 0 : index
35+ // / %15 = arith.index_cast %in : i64 to index
36+ // / %extracted = tensor.extract %4[%14, %15] : tensor<4096x64xf16>
37+ // / linalg.yield %extracted : f16
38+ // / } -> tensor<4096x64xf16>
39+ // / Here %4 is an external capture used via tensor.extract inside
40+ // / linalg.generic hence the above `genericOp` has an external capture.
41+ static bool hasExternalCapture (linalg::GenericOp genericOp) {
42+ Block &body = genericOp.getRegion ().front ();
43+ for (Operation &op : body.getOperations ()) {
44+ for (Value operand : op.getOperands ()) {
45+ if (auto bArg = dyn_cast<BlockArgument>(operand)) {
46+ // Check whether the operand lies in the same block.
47+ if (bArg.getOwner () == &body) {
48+ continue ;
49+ }
50+ return true ;
51+ }
52+ Operation *defOp = operand.getDefiningOp ();
53+ // Scalar constant is allowed.
54+ if (defOp && defOp->hasTrait <mlir::OpTrait::ConstantLike>()) {
55+ Type type = operand.getType ();
56+ if (type.isIntOrFloat () || type.isIndex ()) {
57+ continue ;
58+ }
59+ }
60+ // If defining op is not inside the block, it’s an external value.
61+ if (!defOp || defOp->getBlock () != &body) {
62+ return true ;
63+ }
64+ }
65+ }
66+ return false ; // All operands are locally defined or block arguments.
67+ }
68+
2869// / Rematerialize all parallel elementwise operations into its users within a
2970// / `flow.dispatch.region`.
3071struct RematerializeParallelOpsPattern
@@ -44,9 +85,13 @@ struct RematerializeParallelOpsPattern
4485
4586 // Find the first operand that is defined by another generic op on tensors.
4687 for (OpOperand &opOperand : genericOp->getOpOperands ()) {
47- if (!linalg::areElementwiseOpsFusable (&opOperand))
88+ if (!linalg::areElementwiseOpsFusable (&opOperand)) {
4889 continue ;
49-
90+ }
91+ auto producer = opOperand.get ().getDefiningOp <linalg::GenericOp>();
92+ if (producer && hasExternalCapture (producer)) {
93+ continue ;
94+ }
5095 FailureOr<linalg::ElementwiseOpFusionResult> fusionResult =
5196 linalg::fuseElementwiseOps (rewriter, &opOperand);
5297 if (succeeded (fusionResult)) {
0 commit comments