@@ -77,6 +77,47 @@ bool isValidLayoutForUnbroadcast(const LinearLayout &linearLayout,
7777 linearLayout, numWorkGroupPos, rewriter);
7878}
7979
80+ // / Generic checks for the operation not looking at the tensor type.
81+ bool isCandidateOp (Operation *op) {
82+ // Rely on this for a simpler pass.
83+ if (!op->hasTrait <OpTrait::SameOperandsAndResultType>() ||
84+ op->getNumResults () != 1 )
85+ return false ;
86+
87+ // Skip complex operations.
88+ if (op->hasSuccessors () || op->getNumRegions () != 0 )
89+ return false ;
90+
91+ return true ;
92+ }
93+
94+ bool optimizationDoesNotWorsenRegisterPressure (
95+ Value value, RankedTensorType newType, SmallPtrSetImpl<Value> &visited) {
96+ if (!visited.insert (value).second )
97+ return true ;
98+ // All users must be operations we will optimize too or layout conversions we
99+ // will introduce later.
100+ return llvm::all_of (value.getUses (), [&visited, newType](OpOperand &operand) {
101+ Operation *owner = operand.getOwner ();
102+
103+ // We will be introducing just this operation later.
104+ if (auto convertLayout = dyn_cast<ConvertLayoutOp>(owner))
105+ return convertLayout.getResult ().getType () == newType;
106+
107+ // Only allow candidates. Check only operation constraints. We do not have
108+ // to check the type as we did already.
109+ if (!owner->hasTrait <OpTrait::Elementwise>() || !isCandidateOp (owner))
110+ return false ;
111+
112+ // Check other operands fit the constraints.
113+ return llvm::all_of (owner->getOperands (),
114+ [&visited, newType](Value operand) {
115+ return optimizationDoesNotWorsenRegisterPressure (
116+ operand, newType, visited);
117+ });
118+ });
119+ }
120+
80121// / Get optimized unbroadcasted tensor type.
81122// /
82123// / Get optimized ranked tensor type after unbroadcasting. As we only support 1D
@@ -110,13 +151,10 @@ struct ElementwiseOptPattern final
110151
111152 LogicalResult matchAndRewrite (Operation *op,
112153 PatternRewriter &rewriter) const final {
113- // Rely on this for a simpler pass.
114- if (!op->hasTrait <OpTrait::SameOperandsAndResultType>() ||
115- op->getNumResults () != 1 )
116- return failure ();
154+ LLVM_DEBUG (llvm::dbgs () << " Checking operation:\n " << *op << " \n " );
117155
118- // Skip complex operations .
119- if (op-> hasSuccessors () || op-> getNumRegions () != 0 )
156+ // Rely on this for a simpler pass .
157+ if (! isCandidateOp (op) )
120158 return failure ();
121159
122160 // Layout optimizations only apply to tensors.
@@ -132,19 +170,30 @@ struct ElementwiseOptPattern final
132170 return failure ();
133171 std::optional<LinearLayout> linearLayout =
134172 toLinearLayout (type.getShape (), layout);
135- if (!linearLayout || !isValidLayoutForUnbroadcast (*linearLayout, rewriter))
136- return failure ();
137173
138- // Check the operands are not used by other operations. This will prevent
139- // register pressure increase:
140- if (! llvm::all_of (op-> getOperands (),
141- [](Value val) { return val. hasOneUse (); } ))
174+ LLVM_DEBUG ( llvm::dbgs () << " Checking linear layout: \n "
175+ << linearLayout << " \n " );
176+
177+ if (!linearLayout || ! isValidLayoutForUnbroadcast (*linearLayout, rewriter ))
142178 return failure ();
143179
144180 // As we are dealing with 1D tensors, we can do a simple transform to obtain
145181 // a more optimized operation.
146182 Location loc = op->getLoc ();
147183 RankedTensorType newType = getOptimizedType (type, *linearLayout, rewriter);
184+
185+ LLVM_DEBUG (llvm::dbgs () << " Would convert to type:\n " << newType << " \n " );
186+
187+ // Check the operands are not used by other operations. This will prevent
188+ // register pressure increase:
189+ if (SmallPtrSet<Value, 2 > visited;
190+ !llvm::all_of (op->getOperands (), [&visited, newType](Value operand) {
191+ return optimizationDoesNotWorsenRegisterPressure (operand, newType,
192+ visited);
193+ }))
194+ return failure ();
195+
196+ // Obtain converted operands.
148197 SmallVector<Value> newOperands (op->getNumOperands ());
149198 llvm::transform (op->getOperands (), std::begin (newOperands),
150199 [&rewriter, loc, newType](Value operand) {
@@ -164,6 +213,8 @@ struct ElementwiseOptPattern final
164213 Value newValue = newElementwiseOp->getResult (0 );
165214 rewriter.replaceOpWithNewOp <ConvertLayoutOp>(op, type, newValue);
166215
216+ LLVM_DEBUG (llvm::dbgs () << " Conversion took place.\n " );
217+
167218 return success ();
168219 }
169220};
0 commit comments