Skip to content

Commit 69b5786

Browse files
committed
Feedback
1 parent 5f7c8f3 commit 69b5786

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -339,8 +339,8 @@ struct WgToSgElementwiseOp : public ConversionPattern {
339339
LogicalResult
340340
matchAndRewrite(Operation *op, ArrayRef<ValueRange> operands,
341341
ConversionPatternRewriter &rewriter) const override {
342-
// Only match ops with elementwise trait
343-
if (!OpTrait::hasElementwiseMappableTraits(op))
342+
// Only match ops with elementwise trait and single result.
343+
if (!OpTrait::hasElementwiseMappableTraits(op) || op->getNumResults() != 1)
344344
return failure();
345345

346346
auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
@@ -353,9 +353,12 @@ struct WgToSgElementwiseOp : public ConversionPattern {
353353
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
354354

355355
size_t numVariants = operands.empty() ? 0 : operands.front().size();
356-
for (auto &operandVec : operands)
357-
if (operandVec.size() != numVariants)
358-
return failure();
356+
// Only VectorType operands are supported here.
357+
// TODO: Support other types.
358+
if (llvm::any_of(operands, [&](const ValueRange &operandVec) {
359+
return operandVec.size() != numVariants;
360+
}))
361+
return failure();
359362

360363
SmallVector<Value> newResults;
361364
VectorType newResultType =

0 commit comments

Comments
 (0)