@@ -52,9 +52,23 @@ template <class MMAOpTy> class LHSToTMem : public OpRewritePattern<MMAOpTy> {
5252 return failure ();
5353 Value src = localAllocOp.getSrc ();
5454 auto srcType = cast<RankedTensorType>(src.getType ());
55- auto srcLayout = cast<ttg::BlockedEncodingAttr>(srcType.getEncoding ());
55+ auto srcLayout = srcType.getEncoding ();
56+ auto accTMemEncoding = dyn_cast<ttng::TensorMemoryEncodingAttr>(
57+ tcGen5MMAOp.getD ().getType ().getEncoding ());
58+ ArrayRef<unsigned > CTASplitNum =
59+ triton::gpu::getCTALayout (srcLayout).getCTASplitNum ();
60+ // TMem encoding for A operand is the same as for D (Acc), but packed.
61+ auto aTMemEncoding = ttng::TensorMemoryEncodingAttr::get (
62+ context, accTMemEncoding.getBlockM (), lhs.getType ().getShape ()[1 ],
63+ /* unpacked=*/ false , CTASplitNum[0 ], CTASplitNum[1 ]);
64+ Attribute tensorMemorySpace =
65+ triton::nvidia_gpu::TensorMemorySpaceAttr::get (context);
66+ ttg::MemDescType lhsMemDescType = ttg::MemDescType::get (
67+ lhs.getType ().getShape (), lhs.getType ().getElementType (), aTMemEncoding,
68+ tensorMemorySpace,
69+ /* mutableMemory=*/ false );
5670 bool layoutTmemCompatible = ttng::isDistributedLayoutTMemCompatible (
57- tcGen5MMAOp, srcType, tcGen5MMAOp. getD (). getType () );
71+ tcGen5MMAOp, srcType, lhsMemDescType );
5872 Attribute newLayout = srcLayout;
5973 if (!layoutTmemCompatible) {
6074 if (triton::tools::getBoolEnv (" ALLOW_LHS_TMEM_LAYOUT_CONVERSION" )) {
@@ -70,19 +84,6 @@ template <class MMAOpTy> class LHSToTMem : public OpRewritePattern<MMAOpTy> {
7084 RankedTensorType::get (ty.getShape (), ty.getElementType (), newLayout);
7185 src = rewriter.create <ttg::ConvertLayoutOp>(loc, newTy, src);
7286 }
73- auto accTMemEncoding = dyn_cast<ttng::TensorMemoryEncodingAttr>(
74- tcGen5MMAOp.getD ().getType ().getEncoding ());
75- ArrayRef<unsigned > CTASplitNum = srcLayout.getCTALayout ().getCTASplitNum ();
76- // TMem encoding for A operand is the same as for D (Acc), but unpacked.
77- auto aTMemEncoding = ttng::TensorMemoryEncodingAttr::get (
78- context, accTMemEncoding.getBlockM (), lhs.getType ().getShape ()[1 ],
79- /* unpacked=*/ false , CTASplitNum[0 ], CTASplitNum[1 ]);
80- Attribute tensorMemorySpace =
81- triton::nvidia_gpu::TensorMemorySpaceAttr::get (context);
82- Type lhsMemDescType = triton::gpu::MemDescType::get (
83- lhs.getType ().getShape (), lhs.getType ().getElementType (), aTMemEncoding,
84- tensorMemorySpace,
85- /* mutableMemory=*/ false );
8687 Value tMemAlloc =
8788 rewriter.create <ttng::TMEMAllocOp>(loc, lhsMemDescType, src);
8889 tcGen5MMAOp.getAMutable ().assign (tMemAlloc);
@@ -100,8 +101,6 @@ class TritonNvidiaGPUPromoteLHSToTMemPass
100101 TritonNvidiaGPUPromoteLHSToTMemPassBase;
101102
102103 void runOnOperation () override {
103- if (!triton::tools::getBoolEnv (" ENABLE_LHS_TO_TMEM" ))
104- return ;
105104 MLIRContext *context = &getContext ();
106105 ModuleOp m = getOperation ();
107106
0 commit comments