@@ -2373,10 +2373,100 @@ std::optional<SmallVector<int64_t, 4>> FMAOp::getShapeForUnroll() {
23732373 return llvm::to_vector<4 >(getVectorType ().getShape ());
23742374}
23752375
2376+ // ===----------------------------------------------------------------------===//
2377+ // ToElementsOp
2378+ // ===----------------------------------------------------------------------===//
2379+
2380+ // / Returns true if all the `operands` are defined by `defOp`.
2381+ // / Otherwise, returns false.
2382+ static bool haveSameDefiningOp (OperandRange operands, Operation *defOp) {
2383+ if (operands.empty ())
2384+ return false ;
2385+
2386+ return llvm::all_of (operands, [&](Value operand) {
2387+ Operation *currentDef = operand.getDefiningOp ();
2388+ return currentDef == defOp;
2389+ });
2390+ }
2391+
2392+ // / Folds vector.to_elements(vector.from_elements(%e0, %e1, ...)) into
2393+ // / (%e0, %e1, ...). For example:
2394+ // /
2395+ // / %0 = vector.from_elements %a, %b, %c : vector<3xf32>
2396+ // / %1:3 = vector.to_elements %0 : vector<3xf32>
2397+ // / user_op %1#0, %1#1, %1#2
2398+ // /
2399+ // / becomes:
2400+ // /
2401+ // / user_op %a, %b, %c
2402+ // /
2403+ static LogicalResult
2404+ foldToElementsFromElements (ToElementsOp toElementsOp,
2405+ SmallVectorImpl<OpFoldResult> &results) {
2406+ auto fromElementsOp = toElementsOp.getSource ().getDefiningOp <FromElementsOp>();
2407+ if (!fromElementsOp)
2408+ return failure ();
2409+
2410+ results.append (fromElementsOp.getElements ().begin (),
2411+ fromElementsOp.getElements ().end ());
2412+ return success ();
2413+ }
2414+
2415+ LogicalResult ToElementsOp::fold (FoldAdaptor adaptor,
2416+ SmallVectorImpl<OpFoldResult> &results) {
2417+ if (succeeded (foldToElementsFromElements (*this , results)))
2418+ return success ();
2419+ return failure ();
2420+ }
2421+
23762422// ===----------------------------------------------------------------------===//
23772423// FromElementsOp
23782424// ===----------------------------------------------------------------------===//
23792425
2426+ // / Folds vector.from_elements(vector.to_elements(%vector)) into %vector.
2427+ // /
2428+ // / Case #1: Input and output vectors are the same.
2429+ // /
2430+ // / %0:3 = vector.to_elements %a : vector<3xf32>
2431+ // / %1 = vector.from_elements %0#0, %0#1, %0#2 : vector<3xf32>
2432+ // / user_op %1
2433+ // /
2434+ // / becomes:
2435+ // /
2436+ // / user_op %a
2437+ // /
2438+ static OpFoldResult foldFromElementsToElements (FromElementsOp fromElementsOp) {
2439+ auto fromElemsOperands = fromElementsOp.getElements ();
2440+
2441+ if (fromElemsOperands.empty ())
2442+ return {};
2443+
2444+ auto toElementsOp = fromElemsOperands[0 ].getDefiningOp <ToElementsOp>();
2445+ if (!toElementsOp)
2446+ return {};
2447+
2448+ if (!haveSameDefiningOp (fromElemsOperands, toElementsOp))
2449+ return {};
2450+
2451+ // Case #1: Input and output vectors are the same. Forward the input vector.
2452+ Value toElementsInput = toElementsOp.getSource ();
2453+ if (fromElementsOp.getType () == toElementsInput.getType () &&
2454+ llvm::equal (fromElemsOperands, toElementsOp.getResults ())) {
2455+ return toElementsInput;
2456+ }
2457+
2458+ // TODO: Support cases with different input and output shapes and different
2459+ // number of elements.
2460+
2461+ return {};
2462+ }
2463+
2464+ OpFoldResult FromElementsOp::fold (FoldAdaptor adaptor) {
2465+ if (auto result = foldFromElementsToElements (*this ))
2466+ return result;
2467+ return {};
2468+ }
2469+
23802470// / Rewrite a vector.from_elements into a vector.splat if all elements are the
23812471// / same SSA value. E.g.:
23822472// /
0 commit comments