@@ -2295,6 +2295,218 @@ class DecomposeAtenTraceOp : public OpRewritePattern<AtenTraceOp> {
22952295};
22962296} // namespace
22972297
2298+ namespace {
2299+ // Decompose scaled dot product attention into matmul/softmax pipeline when
2300+ // there is no masking, dropout, causal, or GQA behaviour.
2301+ class DecomposeAtenScaledDotProductAttentionOp
2302+ : public OpRewritePattern<AtenScaledDotProductAttentionOp> {
2303+ public:
2304+ using OpRewritePattern::OpRewritePattern;
2305+ LogicalResult matchAndRewrite(AtenScaledDotProductAttentionOp op,
2306+ PatternRewriter &rewriter) const override {
2307+ Location loc = op.getLoc();
2308+
2309+ if (!isa<Torch::NoneType>(op.getAttnMask().getType()))
2310+ return rewriter.notifyMatchFailure(
2311+ op, "attention mask decomposition not implemented");
2312+
2313+ double dropoutP;
2314+ if (!matchPattern(op.getDropoutP(), m_TorchConstantFloat(&dropoutP)) ||
2315+ dropoutP != 0.0)
2316+ return rewriter.notifyMatchFailure(
2317+ op, "expected dropout_p to be the constant 0.0");
2318+
2319+ bool isCausal;
2320+ if (!matchPattern(op.getIsCausal(), m_TorchConstantBool(&isCausal)) ||
2321+ isCausal)
2322+ return rewriter.notifyMatchFailure(op,
2323+ "causal attention not supported yet");
2324+
2325+ bool enableGqa;
2326+ if (!matchPattern(op.getEnableGqa(), m_TorchConstantBool(&enableGqa)) ||
2327+ enableGqa)
2328+ return rewriter.notifyMatchFailure(op,
2329+ "grouped-query attention unsupported");
2330+
2331+ Value query = op.getQuery();
2332+ Value key = op.getKey();
2333+ Value value = op.getValue();
2334+
2335+ auto queryTensorType = dyn_cast<BaseTensorType>(query.getType());
2336+ auto keyTensorType = dyn_cast<BaseTensorType>(key.getType());
2337+ auto valueTensorType = dyn_cast<BaseTensorType>(value.getType());
2338+ if (!queryTensorType || !keyTensorType || !valueTensorType)
2339+ return rewriter.notifyMatchFailure(op, "expected tensor inputs");
2340+ if (!queryTensorType.hasSizes() || !keyTensorType.hasSizes() ||
2341+ !valueTensorType.hasSizes())
2342+ return rewriter.notifyMatchFailure(
2343+ op, "expected tensor inputs to have known shapes");
2344+ auto queryValueTensorType = dyn_cast<ValueTensorType>(queryTensorType);
2345+ auto keyValueTensorType = dyn_cast<ValueTensorType>(keyTensorType);
2346+ auto valueValueTensorType = dyn_cast<ValueTensorType>(valueTensorType);
2347+ if (!queryValueTensorType || !keyValueTensorType || !valueValueTensorType)
2348+ return rewriter.notifyMatchFailure(op, "expected value tensor semantics");
2349+
2350+ Value oneInt =
2351+ ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1));
2352+ Value zeroInt =
2353+ ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0));
2354+ Value rank = AtenDimOp::create(rewriter, loc, query);
2355+ Value lastDim = AtenSubIntOp::create(rewriter, loc, rank, oneInt);
2356+ Value headDim = AtenSizeIntOp::create(rewriter, loc, query, lastDim);
2357+ Value seqDimIndex = AtenSubIntOp::create(rewriter, loc, lastDim, oneInt);
2358+ Value seqLen = AtenSizeIntOp::create(rewriter, loc, query, seqDimIndex);
2359+ Value keySeqLen = AtenSizeIntOp::create(rewriter, loc, key, seqDimIndex);
2360+ ArrayRef<int64_t> querySizes = queryValueTensorType.getSizes();
2361+ bool hasExplicitHeadDim = querySizes.size() >= 4;
2362+ Value numHeadsSize =
2363+ hasExplicitHeadDim
2364+ ? (Value)AtenSizeIntOp::create(rewriter, loc, query, oneInt)
2365+ : oneInt;
2366+ Value batchSize = AtenSizeIntOp::create(rewriter, loc, query, zeroInt);
2367+ auto listIntType =
2368+ Torch::ListType::get(Torch::IntType::get(rewriter.getContext()));
2369+
2370+ auto getDimValue = [&](int64_t staticDim, Value fallback) -> Value {
2371+ if (staticDim != Torch::kUnknownSize)
2372+ return ConstantIntOp::create(rewriter, loc,
2373+ rewriter.getI64IntegerAttr(staticDim));
2374+ return fallback;
2375+ };
2376+
2377+ Value scaleFloat;
2378+ if (isa<Torch::NoneType>(op.getScale().getType())) {
2379+ Value sqrtHeadDim = AtenSqrtIntOp::create(rewriter, loc, headDim);
2380+ Value oneFloat =
2381+ ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0));
2382+ scaleFloat = AtenDivFloatOp::create(rewriter, loc, oneFloat, sqrtHeadDim);
2383+ } else {
2384+ scaleFloat = op.getScale();
2385+ }
2386+
2387+ Value negTwo =
2388+ ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-2));
2389+ Value negOne =
2390+ ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-1));
2391+
2392+ ArrayRef<int64_t> keySizes = keyValueTensorType.getSizes();
2393+ SmallVector<int64_t> keyTransposedSizes(keySizes.begin(), keySizes.end());
2394+ if (keyTransposedSizes.size() < 2)
2395+ return rewriter.notifyMatchFailure(
2396+ op, "expected key tensor rank >= 2 for transpose");
2397+ std::swap(keyTransposedSizes[keyTransposedSizes.size() - 1],
2398+ keyTransposedSizes[keyTransposedSizes.size() - 2]);
2399+ ArrayRef<int64_t> keyTransposedRef(keyTransposedSizes);
2400+ std::optional<ArrayRef<int64_t>> keyTransposedOpt(keyTransposedRef);
2401+ Type keyTransposedType = keyValueTensorType.getWithSizesAndDtypeAndSparsity(
2402+ keyTransposedSizes, keyValueTensorType.getOptionalDtype(),
2403+ keyValueTensorType.getOptionalSparsity());
2404+ Value keyTransposed = AtenTransposeIntOp::create(
2405+ rewriter, loc, keyTransposedType, key, negTwo, negOne);
2406+ SmallVector<Value> keyDims;
2407+ auto getOrFallback = [&](ArrayRef<int64_t> staticDims, unsigned idx,
2408+ Value fallback) -> Value {
2409+ return getDimValue(idx < staticDims.size() ? staticDims[idx]
2410+ : Torch::kUnknownSize,
2411+ fallback);
2412+ };
2413+ keyDims.push_back(getOrFallback(keyTransposedSizes, 0, batchSize));
2414+ if (keyTransposedSizes.size() == 4) {
2415+ keyDims.push_back(getOrFallback(keyTransposedSizes, 1, numHeadsSize));
2416+ keyDims.push_back(getOrFallback(keyTransposedSizes, 2, seqLen));
2417+ keyDims.push_back(getOrFallback(keyTransposedSizes, 3, keySeqLen));
2418+ } else {
2419+ keyDims.push_back(getOrFallback(keyTransposedSizes, 1, headDim));
2420+ keyDims.push_back(getOrFallback(keyTransposedSizes, 2, keySeqLen));
2421+ }
2422+ Value keyTransposeShapeList =
2423+ PrimListConstructOp::create(rewriter, loc, listIntType, keyDims);
2424+ keyTransposed = AtenViewOp::create(rewriter, loc, keyTransposedType,
2425+ keyTransposed, keyTransposeShapeList);
2426+
2427+ auto getStaticDim = [](ArrayRef<int64_t> sizes, int64_t index) {
2428+ if (index < 0)
2429+ index += sizes.size();
2430+ if (index < 0 || index >= static_cast<int64_t>(sizes.size()))
2431+ return Torch::kUnknownSize;
2432+ return sizes[index];
2433+ };
2434+ int64_t queryBatchStatic = getStaticDim(querySizes, 0);
2435+ int64_t querySeqStatic = getStaticDim(querySizes, -2);
2436+ int64_t keySeqStatic = getStaticDim(keySizes, -2);
2437+ int64_t queryHeadsStatic =
2438+ hasExplicitHeadDim ? getStaticDim(querySizes, 1) : 1;
2439+ SmallVector<int64_t, 4> scoresSizes;
2440+ if (hasExplicitHeadDim)
2441+ scoresSizes.assign(
2442+ {queryBatchStatic, queryHeadsStatic, querySeqStatic, keySeqStatic});
2443+ else
2444+ scoresSizes.assign({queryBatchStatic, querySeqStatic, keySeqStatic});
2445+ Type scoresType = ValueTensorType::get(
2446+ op->getContext(),
2447+ ArrayRef<int64_t>(scoresSizes.begin(), scoresSizes.end()),
2448+ queryValueTensorType.getOptionalDtype(),
2449+ queryValueTensorType.getOptionalSparsity());
2450+ Value scores =
2451+ AtenMatmulOp::create(rewriter, loc, scoresType, query, keyTransposed);
2452+ SmallVector<Value> scoresDims;
2453+ scoresDims.push_back(getDimValue(scoresSizes[0], batchSize));
2454+ unsigned seqIndex = 1;
2455+ if (hasExplicitHeadDim) {
2456+ scoresDims.push_back(getDimValue(scoresSizes[1], numHeadsSize));
2457+ seqIndex = 2;
2458+ }
2459+ scoresDims.push_back(getDimValue(scoresSizes[seqIndex], seqLen));
2460+ scoresDims.push_back(getDimValue(scoresSizes.back(), keySeqLen));
2461+ Value scoresShapeList =
2462+ PrimListConstructOp::create(rewriter, loc, listIntType, scoresDims);
2463+ scores =
2464+ AtenViewOp::create(rewriter, loc, scoresType, scores, scoresShapeList);
2465+ Value scaledScores =
2466+ AtenMulScalarOp::create(rewriter, loc, scoresType, scores, scaleFloat);
2467+
2468+ SmallVector<int64_t> reducedSizes(scoresSizes.begin(), scoresSizes.end());
2469+ reducedSizes.back() = 1;
2470+ ArrayRef<int64_t> reducedSizesRef(reducedSizes);
2471+ std::optional<ArrayRef<int64_t>> reducedSizesOpt(reducedSizesRef);
2472+ Type reducedValueType =
2473+ ValueTensorType::get(op->getContext(), reducedSizesOpt,
2474+ queryValueTensorType.getOptionalDtype());
2475+ Type reducedIndexType = ValueTensorType::get(
2476+ op->getContext(), reducedSizesOpt,
2477+ IntegerType::get(op->getContext(), 64, IntegerType::Signed));
2478+ Value keepDimTrue =
2479+ ConstantBoolOp::create(rewriter, loc, rewriter.getBoolAttr(true));
2480+ auto maxOp =
2481+ AtenMaxDimOp::create(rewriter, loc, reducedValueType, reducedIndexType,
2482+ scaledScores, negOne, keepDimTrue);
2483+ Value softmaxMax = TensorStaticInfoCastOp::create(
2484+ rewriter, loc, reducedValueType, maxOp.getValues());
2485+ Value centered =
2486+ createTensorSub(rewriter, loc, scoresType, scaledScores, softmaxMax);
2487+ Value unNormalizedExp =
2488+ AtenExpOp::create(rewriter, loc, scoresType, centered);
2489+ SmallVector<Value, 1> softmaxDims{negOne};
2490+ Value dimList =
2491+ PrimListConstructOp::create(rewriter, loc, listIntType, softmaxDims);
2492+ Value noneValue = ConstantNoneOp::create(rewriter, loc);
2493+ Value softmaxDenominator = AtenSumDimIntListOp::create(
2494+ rewriter, loc, reducedValueType, unNormalizedExp, dimList, keepDimTrue,
2495+ noneValue);
2496+ softmaxDenominator = TensorStaticInfoCastOp::create(
2497+ rewriter, loc, reducedValueType, softmaxDenominator);
2498+ Value softmax = AtenDivTensorOp::create(
2499+ rewriter, loc, scoresType, unNormalizedExp, softmaxDenominator);
2500+
2501+ Value output =
2502+ AtenMatmulOp::create(rewriter, loc, op.getType(), softmax, value);
2503+
2504+ rewriter.replaceOp(op, output);
2505+ return success();
2506+ }
2507+ };
2508+ } // namespace
2509+
22982510// Calculates the softmax function on the given `input` tensor. Softmax(x) =
22992511// exp(x)/sum(exp(x)).
23002512// To avoid overflow we use the following decomposition rule:
@@ -13084,6 +13296,8 @@ class DecomposeComplexOpsPass
1308413296 legalOpsSet.clear();
1308513297 legalOpsSet.insert(legalOps.begin(), legalOps.end());
1308613298
13299+ patterns.add<DecomposeAtenScaledDotProductAttentionOp>(context);
13300+
1308713301 addPatternIfTargetOpIsIllegal<DecomposeAten_WeightNormInterfaceOp>(
1308813302 patterns);
1308913303 addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);
0 commit comments