@@ -23,21 +23,20 @@ using ttn::OperandsAndConstraints;
2323
2424namespace {
2525
26- const std::string Wgmma_Fence_Op = " wgmma.fence.sync.aligned;" ;
27- const std::string Wgmma_Commit_Group_Op = " wgmma.commit_group.sync.aligned;" ;
28- const std::string Cluster_Wait_Op = " barrier.cluster.wait.aligned;" ;
29- const std::string Fence_Mbarrier_Init_Op =
30- " fence.mbarrier_init.release.cluster;" ;
31- const std::string Cluster_Cta_Id_Op = " {\n "
32- " .reg .u32 a<5>; \n "
33- " mov.u32 a0, %cluster_ctaid.x;\n " // x
34- " mov.u32 a1, %cluster_ctaid.y;\n " // y
35- " mov.u32 a2, %cluster_ctaid.z;\n " // z
36- " mov.u32 a3, %cluster_nctaid.x;\n " // nx
37- " mov.u32 a4, %cluster_nctaid.y;\n " // ny
38- " mad.lo.u32 a1, a2, a4, a1; \n "
39- " mad.lo.u32 $0, a1, a3, a0; \n "
40- " }" ;
26+ const std::string kWgmmaFenceOp = " wgmma.fence.sync.aligned;" ;
27+ const std::string kWgmmaCommitGroupOp = " wgmma.commit_group.sync.aligned;" ;
28+ const std::string kClusterWaitOp = " barrier.cluster.wait.aligned;" ;
29+ const std::string kFenceMbarrierInitOp = " fence.mbarrier_init.release.cluster;" ;
30+ const std::string kClusterCtaIdOp = " {\n "
31+ " .reg .u32 a<5>; \n "
32+ " mov.u32 a0, %cluster_ctaid.x;\n " // x
33+ " mov.u32 a1, %cluster_ctaid.y;\n " // y
34+ " mov.u32 a2, %cluster_ctaid.z;\n " // z
35+ " mov.u32 a3, %cluster_nctaid.x;\n " // nx
36+ " mov.u32 a4, %cluster_nctaid.y;\n " // ny
37+ " mad.lo.u32 a1, a2, a4, a1; \n "
38+ " mad.lo.u32 $0, a1, a3, a0; \n "
39+ " }" ;
4140
4241bool isNumber (const std::string &s) {
4342 return !s.empty () && std::find_if (s.begin (), s.end (), [](unsigned char c) {
@@ -235,46 +234,141 @@ class ClusterArriveOpPattern : public OpRewritePattern<ttn::ClusterArriveOp> {
235234 }
236235};
237236
238- class StoreMatrixOpPattern : public OpRewritePattern <ttn::StoreMatrixOp> {
237+ // Base class for Matrix Operation Patterns
238+ template <typename MatrixOpType, typename ConcreteMatrixOpPattern>
239+ class MatrixOpPattern : public OpRewritePattern <MatrixOpType> {
239240public:
240- using OpRewritePattern<ttn::StoreMatrixOp >::OpRewritePattern;
241+ using OpRewritePattern<MatrixOpType >::OpRewritePattern;
241242
242- LogicalResult matchAndRewrite (ttn::StoreMatrixOp op,
243+ LogicalResult matchAndRewrite (MatrixOpType op,
243244 PatternRewriter &rewriter) const override {
244- return rewriteAsPtxAsm (op, rewriter, getPtxAsm (op),
245- getOperandsAndConstraints (op));
246- }
247-
248- OperandsAndConstraints
249- getOperandsAndConstraints (ttn::StoreMatrixOp op) const {
250- OperandsAndConstraints operandsAndTypes;
251- auto addr = op.getAddr ();
252- auto datas = op.getDatas ();
253- operandsAndTypes.push_back ({addr, " r" });
254- for (unsigned i = 0 ; i < datas.size (); i++) {
255- operandsAndTypes.push_back ({datas[i], " r" });
256- }
257- return operandsAndTypes;
245+ unsigned vecSize = getVectorSize (op);
246+ bool trans = op->hasAttr (" trans" )
247+ ? op->template getAttrOfType <BoolAttr>(" trans" ).getValue ()
248+ : false ;
249+
250+ // Template method for PTX assembly generation
251+ std::string ptxAsm =
252+ (llvm::Twine (ConcreteMatrixOpPattern::kOpCode ) +
253+ getPtxModifiers (vecSize, trans) + " " + getOperands (op, vecSize) + " ;" )
254+ .str ();
255+
256+ OperandsAndConstraints operandAndConstraints =
257+ getOperandsAndConstraints (op, vecSize);
258+ Constraints outputConstraints = getOutputConstraints (op, vecSize);
259+
260+ return rewriteAsPtxAsm (op, rewriter, ptxAsm, operandAndConstraints,
261+ outputConstraints);
258262 }
259263
260- std::string getPtxAsm (ttn::StoreMatrixOp op) const {
261- auto datas = op.getDatas ();
262- std::string ptxAsm;
263- switch (datas.size ()) {
264+ protected:
265+ // Shared helper methods
266+ std::string getPtxModifiers (unsigned vecSize, bool trans) const {
267+ auto ptxAsmBase = llvm::Twine (" .sync.aligned.m8n8" );
268+ const std::string suffix = trans ? " .trans.shared.b16" : " .shared.b16" ;
269+ switch (vecSize) {
264270 case 1 :
265- ptxAsm = " stmatrix.sync.aligned.m8n8.x1.shared.b16 [$0], {$1};" ;
266- break ;
271+ return (ptxAsmBase + " .x1" + suffix).str ();
267272 case 2 :
268- ptxAsm = " stmatrix.sync.aligned.m8n8.x2.shared.b16 [$0], {$1, $2};" ;
269- break ;
273+ return (ptxAsmBase + " .x2" + suffix).str ();
270274 case 4 :
271- ptxAsm =
272- " stmatrix.sync.aligned.m8n8.x4.shared.b16 [$0], {$1, $2, $3, $4};" ;
273- break ;
275+ return (ptxAsmBase + " .x4" + suffix).str ();
274276 default :
275- assert (false && " Invalid size" );
277+ assert (false && " Invalid vector size" );
276278 }
277- return ptxAsm;
279+ }
280+
281+ std::string getPtxRegOperands (unsigned startIdx, unsigned count) const {
282+ llvm::SmallString<20 > regOperands;
283+ llvm::raw_svector_ostream stream (regOperands);
284+ stream << " {" ;
285+ for (unsigned i = 0 ; i < count; i++) {
286+ stream << " $" + llvm::utostr (startIdx + i);
287+ if (i != count - 1 )
288+ stream << " , " ;
289+ }
290+ stream << " }" ;
291+ return std::string (regOperands.str ());
292+ }
293+
294+ std::string getPtxAddrOperand (unsigned idx) const {
295+ return (llvm::Twine (" [$" ) + llvm::utostr (idx) + " ]" ).str ();
296+ }
297+
298+ virtual std::string getOperands (MatrixOpType op, unsigned vecSize) const = 0;
299+ virtual OperandsAndConstraints
300+ getOperandsAndConstraints (MatrixOpType op, unsigned vecSize) const = 0 ;
301+ virtual Constraints getOutputConstraints (MatrixOpType op,
302+ unsigned vecSize) const = 0;
303+ virtual unsigned getVectorSize (MatrixOpType op) const = 0;
304+ };
305+
306+ // StoreMatrixOp Pattern
307+ class StoreMatrixOpPattern
308+ : public MatrixOpPattern<ttn::StoreMatrixOp, StoreMatrixOpPattern> {
309+ public:
310+ using MatrixOpPattern<ttn::StoreMatrixOp,
311+ StoreMatrixOpPattern>::MatrixOpPattern;
312+ static constexpr const char *kOpCode = " stmatrix" ;
313+
314+ protected:
315+ unsigned getVectorSize (ttn::StoreMatrixOp op) const override {
316+ return op.getVals ().size ();
317+ }
318+
319+ std::string getOperands (ttn::StoreMatrixOp op,
320+ unsigned vecSize) const override {
321+ return (llvm::Twine (getPtxAddrOperand (0 )) + " , " +
322+ getPtxRegOperands (1 , vecSize))
323+ .str ();
324+ }
325+
326+ OperandsAndConstraints
327+ getOperandsAndConstraints (ttn::StoreMatrixOp op,
328+ unsigned vecSize) const override {
329+ OperandsAndConstraints constraints = {{op.getAddr (), " r" }};
330+ for (unsigned i = 0 ; i < vecSize; i++) {
331+ constraints.push_back ({op.getVals ()[i], " r" });
332+ }
333+ return constraints;
334+ }
335+
336+ Constraints getOutputConstraints (ttn::StoreMatrixOp op,
337+ unsigned vecSize) const override {
338+ return {}; // No output constraints for StoreMatrixOp
339+ }
340+ };
341+
342+ // LoadMatrixOp Pattern
343+ class LoadMatrixOpPattern
344+ : public MatrixOpPattern<ttn::LoadMatrixOp, LoadMatrixOpPattern> {
345+ public:
346+ using MatrixOpPattern<ttn::LoadMatrixOp,
347+ LoadMatrixOpPattern>::MatrixOpPattern;
348+ static constexpr const char *kOpCode = " ldmatrix" ;
349+
350+ protected:
351+ unsigned getVectorSize (ttn::LoadMatrixOp op) const override {
352+ auto resultType = cast<LLVM::LLVMStructType>(op.getType ());
353+ return resultType.getBody ().size ();
354+ }
355+
356+ std::string getOperands (ttn::LoadMatrixOp op,
357+ unsigned vecSize) const override {
358+ return (llvm::Twine (getPtxRegOperands (0 , vecSize)) + " , " +
359+ getPtxAddrOperand (vecSize))
360+ .str ();
361+ }
362+
363+ OperandsAndConstraints
364+ getOperandsAndConstraints (ttn::LoadMatrixOp op,
365+ unsigned vecSize) const override {
366+ return {{op.getAddr (), " r" }};
367+ }
368+
369+ Constraints getOutputConstraints (ttn::LoadMatrixOp op,
370+ unsigned vecSize) const override {
371+ return Constraints (vecSize, " =r" );
278372 }
279373};
280374
@@ -507,17 +601,16 @@ class ConvertNVGPUToLLVM : public ConvertNVGPUToLLVMBase<ConvertNVGPUToLLVM> {
507601#define POPULATE_NVGPU_OP (SRC_OP, ASM ) \
508602 patterns.add <NVGPUOpGenericPattern<SRC_OP>>(context, ASM, Constraints (), \
509603 Constraints ());
510- POPULATE_NVGPU_OP (ttn::WGMMAFenceOp, Wgmma_Fence_Op )
511- POPULATE_NVGPU_OP (ttn::WGMMACommitGroupOp, Wgmma_Commit_Group_Op )
512- POPULATE_NVGPU_OP (ttn::ClusterWaitOp, Cluster_Wait_Op )
604+ POPULATE_NVGPU_OP (ttn::WGMMAFenceOp, kWgmmaFenceOp )
605+ POPULATE_NVGPU_OP (ttn::WGMMACommitGroupOp, kWgmmaCommitGroupOp )
606+ POPULATE_NVGPU_OP (ttn::ClusterWaitOp, kClusterWaitOp )
513607#undef POPULATE_NVGPU_OP
514608 patterns.add <NVGPUOpGenericPattern<ttn::ClusterCTAIdOp>>(
515- context, Cluster_Cta_Id_Op , Constraints ({" =r" }), Constraints ());
609+ context, kClusterCtaIdOp , Constraints ({" =r" }), Constraints ());
516610
517- patterns
518- .add <FenceAsyncSharedOpPattern, StoreMatrixOpPattern,
519- ClusterArriveOpPattern, WGMMAOpPattern, WGMMAWaitGroupOpPattern>(
520- context);
611+ patterns.add <FenceAsyncSharedOpPattern, LoadMatrixOpPattern,
612+ StoreMatrixOpPattern, ClusterArriveOpPattern, WGMMAOpPattern,
613+ WGMMAWaitGroupOpPattern>(context);
521614
522615 if (applyPatternsAndFoldGreedily (mod, std::move (patterns)).failed ())
523616 signalPassFailure ();
0 commit comments