@@ -77,7 +77,37 @@ bool linalg::isaCopyOpInterface(LinalgOp op) {
7777// ===----------------------------------------------------------------------===//
7878// FillOpInterface implementation
7979// ===----------------------------------------------------------------------===//
80- std::optional<Value> linalg::isaFillOpInterface (GenericOp op) {
80+ // / Detects if a linalg.generic operation represents a fill with an inlined
81+ // / constant. If so, returns the constant value. Otherwise, returns
82+ // / std::nullopt.
83+ static std::optional<Value> isaInlinedFillOp (GenericOp op) {
84+ if (!op.isAllParallelLoops () || op.getNumDpsInits () != 1 ||
85+ op.getNumDpsInputs () != 0 )
86+ return std::nullopt ;
87+
88+ // Init should not be referenced.
89+ if (op.payloadUsesValueFromOperand (op.getDpsInitOperand (0 )))
90+ return std::nullopt ;
91+
92+ Block *body = op.getBody ();
93+ if (body->getOperations ().size () != 1 )
94+ return std::nullopt ;
95+
96+ auto yieldOp = dyn_cast<linalg::YieldOp>(body->back ());
97+ if (!yieldOp || yieldOp.getNumOperands () != 1 )
98+ return std::nullopt ;
99+
100+ Value yieldOperand = yieldOp->getOperand (0 );
101+ if (!yieldOperand.getDefiningOp <arith::ConstantOp>() &&
102+ !yieldOperand.getDefiningOp <complex ::ConstantOp>())
103+ return std::nullopt ;
104+
105+ return yieldOperand;
106+ }
107+
108+ // / Detects if a linalg.generic operation represents an external scalar input.
109+ // / If so, returns the constant value. Otherwise, returns std::nullopt.
110+ static std::optional<Value> isaExternalFillOp (GenericOp op) {
81111 // Structural.
82112 if (!op.isAllParallelLoops () || !op.isSingleInputOutput () ||
83113 !op.isSingleYieldOp ())
@@ -94,6 +124,12 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp op) {
94124 return value->get ();
95125}
96126
127+ std::optional<Value> linalg::isaFillOpInterface (GenericOp op) {
128+ if (auto fillVal = isaInlinedFillOp (op))
129+ return fillVal;
130+ return isaExternalFillOp (op);
131+ }
132+
97133// ===----------------------------------------------------------------------===//
98134// BroadcastOpInterface implementation
99135// ===----------------------------------------------------------------------===//
0 commit comments