@@ -2295,17 +2295,216 @@ class DecomposeAtenTraceOp : public OpRewritePattern<AtenTraceOp> {
22952295};
22962296} // namespace
22972297
2298+ static Value getSoftmaxResult(Operation *op, Value self, Value dim,
2299+ Type resultType, Type accumulatorType,
2300+ PatternRewriter &rewriter);
2301+
2302+ namespace {
2303+ // Decompose scaled dot product attention into matmul/softmax pipeline when
2304+ // there is no masking, dropout, causal, or GQA behaviour.
2305+ class DecomposeAtenScaledDotProductAttentionOp
2306+ : public OpRewritePattern<AtenScaledDotProductAttentionOp> {
2307+ public:
2308+ using OpRewritePattern::OpRewritePattern;
2309+ LogicalResult matchAndRewrite(AtenScaledDotProductAttentionOp op,
2310+ PatternRewriter &rewriter) const override {
2311+ Location loc = op.getLoc();
2312+
2313+ if (!isa<Torch::NoneType>(op.getAttnMask().getType()))
2314+ return rewriter.notifyMatchFailure(
2315+ op, "attention mask decomposition not implemented");
2316+
2317+ double dropoutP;
2318+ if (!matchPattern(op.getDropoutP(), m_TorchConstantFloat(&dropoutP)) ||
2319+ dropoutP != 0.0)
2320+ return rewriter.notifyMatchFailure(
2321+ op, "expected dropout_p to be the constant 0.0");
2322+
2323+ bool isCausal;
2324+ if (!matchPattern(op.getIsCausal(), m_TorchConstantBool(&isCausal)) ||
2325+ isCausal)
2326+ return rewriter.notifyMatchFailure(op,
2327+ "causal attention not supported yet");
2328+
2329+ bool enableGqa;
2330+ if (!matchPattern(op.getEnableGqa(), m_TorchConstantBool(&enableGqa)) ||
2331+ enableGqa)
2332+ return rewriter.notifyMatchFailure(op,
2333+ "grouped-query attention unsupported");
2334+
2335+ Value query = op.getQuery();
2336+ Value key = op.getKey();
2337+ Value value = op.getValue();
2338+
2339+ auto queryValueTensorType = dyn_cast<ValueTensorType>(query.getType());
2340+ auto keyValueTensorType = dyn_cast<ValueTensorType>(key.getType());
2341+ auto valueValueTensorType = dyn_cast<ValueTensorType>(value.getType());
2342+ if (!queryValueTensorType || !keyValueTensorType || !valueValueTensorType)
2343+ return rewriter.notifyMatchFailure(op, "expected value tensor semantics");
2344+ if (!queryValueTensorType.hasSizes() || !keyValueTensorType.hasSizes() ||
2345+ !valueValueTensorType.hasSizes())
2346+ return rewriter.notifyMatchFailure(
2347+ op, "expected tensor inputs to have known shapes");
2348+ if (!queryValueTensorType.hasDtype() || !keyValueTensorType.hasDtype() ||
2349+ !valueValueTensorType.hasDtype())
2350+ return rewriter.notifyMatchFailure(
2351+ op, "expected tensor inputs to have dtypes");
2352+ Type queryDtype = queryValueTensorType.getDtype();
2353+ Type keyDtype = keyValueTensorType.getDtype();
2354+ Type valueDtype = valueValueTensorType.getDtype();
2355+ if (queryDtype != keyDtype || queryDtype != valueDtype)
2356+ return rewriter.notifyMatchFailure(
2357+ op, "expected query, key, and value to share dtype");
2358+
2359+ ArrayRef<int64_t> querySizes = queryValueTensorType.getSizes();
2360+ int64_t queryRank = querySizes.size();
2361+ if (queryRank < 3 || queryRank > 4)
2362+ return rewriter.notifyMatchFailure(
2363+ op, "expected query tensor rank to be 3 or 4");
2364+ ArrayRef<int64_t> keySizes = keyValueTensorType.getSizes();
2365+ ArrayRef<int64_t> valueSizes = valueValueTensorType.getSizes();
2366+ if (static_cast<int64_t>(keySizes.size()) != queryRank ||
2367+ static_cast<int64_t>(valueSizes.size()) != queryRank)
2368+ return rewriter.notifyMatchFailure(
2369+ op, "expected query, key, and value to share rank");
2370+ bool hasExplicitHeadDim = queryRank == 4;
2371+ Value oneInt =
2372+ ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(1));
2373+ Value zeroInt =
2374+ ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(0));
2375+ Value rank = AtenDimOp::create(rewriter, loc, query);
2376+ Value lastDim = AtenSubIntOp::create(rewriter, loc, rank, oneInt);
2377+ Value headDim = AtenSizeIntOp::create(rewriter, loc, query, lastDim);
2378+ Value seqDimIndex = AtenSubIntOp::create(rewriter, loc, lastDim, oneInt);
2379+ Value seqLen = AtenSizeIntOp::create(rewriter, loc, query, seqDimIndex);
2380+ Value keySeqLen = AtenSizeIntOp::create(rewriter, loc, key, seqDimIndex);
2381+ Value numHeadsSize =
2382+ hasExplicitHeadDim
2383+ ? (Value)AtenSizeIntOp::create(rewriter, loc, query, oneInt)
2384+ : oneInt;
2385+ Value batchSize = AtenSizeIntOp::create(rewriter, loc, query, zeroInt);
2386+ auto listIntType =
2387+ Torch::ListType::get(Torch::IntType::get(rewriter.getContext()));
2388+
2389+ auto getDimValue = [&](int64_t staticDim, Value fallback) -> Value {
2390+ if (staticDim != Torch::kUnknownSize)
2391+ return ConstantIntOp::create(rewriter, loc,
2392+ rewriter.getI64IntegerAttr(staticDim));
2393+ return fallback;
2394+ };
2395+
2396+ Value scaleFloat;
2397+ if (isa<Torch::NoneType>(op.getScale().getType())) {
2398+ Value sqrtHeadDim = AtenSqrtIntOp::create(rewriter, loc, headDim);
2399+ Value oneFloat =
2400+ ConstantFloatOp::create(rewriter, loc, rewriter.getF64FloatAttr(1.0));
2401+ scaleFloat = AtenDivFloatOp::create(rewriter, loc, oneFloat, sqrtHeadDim);
2402+ } else {
2403+ scaleFloat = op.getScale();
2404+ }
2405+
2406+ Value negTwo =
2407+ ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-2));
2408+ Value negOne =
2409+ ConstantIntOp::create(rewriter, loc, rewriter.getI64IntegerAttr(-1));
2410+
2411+ SmallVector<int64_t> keyTransposedSizes(keySizes.begin(), keySizes.end());
2412+ std::swap(keyTransposedSizes[keyTransposedSizes.size() - 1],
2413+ keyTransposedSizes[keyTransposedSizes.size() - 2]);
2414+ ArrayRef<int64_t> keyTransposedRef(keyTransposedSizes);
2415+ std::optional<ArrayRef<int64_t>> keyTransposedOpt(keyTransposedRef);
2416+ Type keyTransposedType = keyValueTensorType.getWithSizesAndDtypeAndSparsity(
2417+ keyTransposedSizes, keyValueTensorType.getOptionalDtype(),
2418+ keyValueTensorType.getOptionalSparsity());
2419+ Value keyTransposed = AtenTransposeIntOp::create(
2420+ rewriter, loc, keyTransposedType, key, negTwo, negOne);
2421+ SmallVector<Value> keyDims;
2422+ auto getOrFallback = [&](ArrayRef<int64_t> staticDims, unsigned idx,
2423+ Value fallback) -> Value {
2424+ return getDimValue(idx < staticDims.size() ? staticDims[idx]
2425+ : Torch::kUnknownSize,
2426+ fallback);
2427+ };
2428+ keyDims.push_back(getOrFallback(keyTransposedSizes, 0, batchSize));
2429+ if (hasExplicitHeadDim) {
2430+ keyDims.push_back(getOrFallback(keyTransposedSizes, 1, numHeadsSize));
2431+ keyDims.push_back(getOrFallback(keyTransposedSizes, 2, headDim));
2432+ keyDims.push_back(getOrFallback(keyTransposedSizes, 3, keySeqLen));
2433+ } else {
2434+ keyDims.push_back(getOrFallback(keyTransposedSizes, 1, headDim));
2435+ keyDims.push_back(getOrFallback(keyTransposedSizes, 2, keySeqLen));
2436+ }
2437+ Value keyTransposeShapeList =
2438+ PrimListConstructOp::create(rewriter, loc, listIntType, keyDims);
2439+ keyTransposed = AtenViewOp::create(rewriter, loc, keyTransposedType,
2440+ keyTransposed, keyTransposeShapeList);
2441+
2442+ auto getStaticDim = [](ArrayRef<int64_t> sizes, int64_t index) {
2443+ if (index < 0)
2444+ index += sizes.size();
2445+ if (index < 0 || index >= static_cast<int64_t>(sizes.size()))
2446+ return Torch::kUnknownSize;
2447+ return sizes[index];
2448+ };
2449+ int64_t queryBatchStatic = getStaticDim(querySizes, 0);
2450+ int64_t querySeqStatic = getStaticDim(querySizes, -2);
2451+ int64_t keySeqStatic = getStaticDim(keySizes, -2);
2452+ int64_t queryHeadsStatic =
2453+ hasExplicitHeadDim ? getStaticDim(querySizes, 1) : 1;
2454+ SmallVector<int64_t, 4> scoresSizes;
2455+ if (hasExplicitHeadDim)
2456+ scoresSizes.assign(
2457+ {queryBatchStatic, queryHeadsStatic, querySeqStatic, keySeqStatic});
2458+ else
2459+ scoresSizes.assign({queryBatchStatic, querySeqStatic, keySeqStatic});
2460+ Type scoresType = ValueTensorType::get(
2461+ op->getContext(),
2462+ ArrayRef<int64_t>(scoresSizes.begin(), scoresSizes.end()),
2463+ queryValueTensorType.getOptionalDtype(),
2464+ queryValueTensorType.getOptionalSparsity());
2465+ Value scores =
2466+ AtenMatmulOp::create(rewriter, loc, scoresType, query, keyTransposed);
2467+ SmallVector<Value> scoresDims;
2468+ scoresDims.push_back(getDimValue(scoresSizes[0], batchSize));
2469+ unsigned seqIndex = 1;
2470+ if (hasExplicitHeadDim) {
2471+ scoresDims.push_back(getDimValue(scoresSizes[1], numHeadsSize));
2472+ seqIndex = 2;
2473+ }
2474+ scoresDims.push_back(getDimValue(scoresSizes[seqIndex], seqLen));
2475+ scoresDims.push_back(getDimValue(scoresSizes.back(), keySeqLen));
2476+ Value scoresShapeList =
2477+ PrimListConstructOp::create(rewriter, loc, listIntType, scoresDims);
2478+ scores =
2479+ AtenViewOp::create(rewriter, loc, scoresType, scores, scoresShapeList);
2480+ Value scaledScores =
2481+ AtenMulScalarOp::create(rewriter, loc, scoresType, scores, scaleFloat);
2482+
2483+ Value softmax = getSoftmaxResult(op.getOperation(), scaledScores, negOne,
2484+ scoresType, scoresType, rewriter);
2485+ if (!softmax)
2486+ return rewriter.notifyMatchFailure(op,
2487+ "failed to compute softmax scores");
2488+
2489+ Value output =
2490+ AtenMatmulOp::create(rewriter, loc, op.getType(), softmax, value);
2491+
2492+ rewriter.replaceOp(op, output);
2493+ return success();
2494+ }
2495+ };
2496+ } // namespace
2497+
22982498// Calculates the softmax function on the given `input` tensor. Softmax(x) =
22992499// exp(x)/sum(exp(x)).
23002500// To avoid overflow we use the following decomposition rule:
23012501// x_max = max(input, dim, keepdim = True)
23022502// unnorm = aten.exp(input - x_max)
23032503// softmax = unnorm / sum(unnorm, dim, keepdim = True)
2304- template <typename OpTy>
2305- static Value getSoftmaxResult(OpTy op, Value self, Type resultType,
2306- Type accumulatorType, PatternRewriter &rewriter) {
2307- Location loc = op.getLoc();
2308- Value dim = op.getDim();
2504+ static Value getSoftmaxResult(Operation *op, Value self, Value dim,
2505+ Type resultType, Type accumulatorType,
2506+ PatternRewriter &rewriter) {
2507+ Location loc = op->getLoc();
23092508 if (resultType != accumulatorType)
23102509 self = convertTensorToDtype(rewriter, loc, self, accumulatorType);
23112510 Value xMax =
@@ -2362,8 +2561,9 @@ class DecomposeAtenSoftmaxIntOp : public OpRewritePattern<AtenSoftmaxIntOp> {
23622561
23632562 Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype);
23642563
2365- Value result = getSoftmaxResult(op, self, resultTensorType,
2366- accumulatorTensorType, rewriter);
2564+ Value result =
2565+ getSoftmaxResult(op.getOperation(), self, op.getDim(), resultTensorType,
2566+ accumulatorTensorType, rewriter);
23672567 if (!result)
23682568 return failure();
23692569 rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, op.getType(),
@@ -2411,8 +2611,9 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern<Aten_SoftmaxOp> {
24112611
24122612 Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype);
24132613
2414- Value result = getSoftmaxResult(op, self, resultTensorType,
2415- accumulatorTensorType, rewriter);
2614+ Value result =
2615+ getSoftmaxResult(op.getOperation(), self, op.getDim(), resultTensorType,
2616+ accumulatorTensorType, rewriter);
24162617 if (!result)
24172618 return op.emitError("failed to get softmax result");
24182619 rewriter.replaceOpWithNewOp<TensorStaticInfoCastOp>(op, resultTensorType,
@@ -13103,6 +13304,9 @@ class DecomposeComplexOpsPass
1310313304 legalOpsSet.clear();
1310413305 legalOpsSet.insert(legalOps.begin(), legalOps.end());
1310513306
13307+ addPatternIfTargetOpIsIllegal<DecomposeAtenScaledDotProductAttentionOp>(
13308+ patterns);
13309+
1310613310 addPatternIfTargetOpIsIllegal<DecomposeAten_WeightNormInterfaceOp>(
1310713311 patterns);
1310813312 addPatternIfTargetOpIsIllegal<DecomposeAtenSoftmaxIntOp>(patterns);
0 commit comments