@@ -25,6 +25,7 @@ namespace {
2525using namespace mlir ;
2626using namespace mlir ::triton;
2727using namespace mlir ::triton::gpu;
28+ namespace ttng = triton::nvidia_gpu;
2829
2930// pass named attrs (e.g., tt.contiguity) from Triton to Triton
3031static void addNamedAttrs (Operation *op, DictionaryAttr dictAttrs) {
@@ -466,6 +467,72 @@ struct GatherScatterOpPattern : public OpConversionPattern<OpT> {
466467 }
467468};
468469
470+ // Given a tensor and its representation in tensor memory, determine its
471+ // distributed layout.
472+ static RankedTensorType getTMEMTensorLayout (const TypeConverter *tc,
473+ RankedTensorType type,
474+ MemDescType memdesc,
475+ unsigned numWarps) {
476+ Attribute encoding;
477+ type = cast<RankedTensorType>(tc->convertType (type));
478+ if (isa<ttng::TensorMemoryScalesEncodingAttr>(memdesc.getEncoding ())) {
479+ encoding = LinearEncodingAttr::get (
480+ type.getContext (), getScaleTMEMStoreLinearLayout (type, numWarps));
481+ } else {
482+ auto tmemEnc = cast<ttng::TensorMemoryEncodingAttr>(memdesc.getEncoding ());
483+ encoding = ttng::getTmemCompatibleLayout (
484+ tmemEnc.getBlockM (), tmemEnc.getBlockN (), type, numWarps);
485+ }
486+ return RankedTensorType::get (type.getShape (), type.getElementType (),
487+ encoding);
488+ }
489+
490+ struct TMEMLoadOpPattern : public OpConversionPattern <ttng::TMEMLoadOp> {
491+ using OpConversionPattern::OpConversionPattern;
492+
493+ LogicalResult
494+ matchAndRewrite (ttng::TMEMLoadOp op, OpAdaptor adaptor,
495+ ConversionPatternRewriter &rewriter) const override {
496+ RankedTensorType type = getTMEMTensorLayout (
497+ typeConverter, op.getType (), op.getSrc ().getType (), lookupNumWarps (op));
498+ rewriter.modifyOpInPlace (op, [&] { op.getResult ().setType (type); });
499+ return success ();
500+ }
501+ };
502+
503+ struct TMEMStoreOpPattern : public OpConversionPattern <ttng::TMEMStoreOp> {
504+ using OpConversionPattern::OpConversionPattern;
505+
506+ LogicalResult
507+ matchAndRewrite (ttng::TMEMStoreOp op, OpAdaptor adaptor,
508+ ConversionPatternRewriter &rewriter) const override {
509+ RankedTensorType type =
510+ getTMEMTensorLayout (typeConverter, op.getSrc ().getType (),
511+ op.getDst ().getType (), lookupNumWarps (op));
512+ Value src =
513+ rewriter.create <ConvertLayoutOp>(op.getLoc (), type, adaptor.getSrc ());
514+ rewriter.modifyOpInPlace (op, [&] { op.getSrcMutable ().assign (src); });
515+ return success ();
516+ }
517+ };
518+
519+ struct TMEMAllocOpPattern : public OpConversionPattern <ttng::TMEMAllocOp> {
520+ using OpConversionPattern::OpConversionPattern;
521+
522+ LogicalResult
523+ matchAndRewrite (ttng::TMEMAllocOp op, OpAdaptor adaptor,
524+ ConversionPatternRewriter &rewriter) const override {
525+ if (!op.getSrc ())
526+ return success ();
527+ RankedTensorType type = getTMEMTensorLayout (
528+ typeConverter, op.getSrc ().getType (), op.getType (), lookupNumWarps (op));
529+ Value src =
530+ rewriter.create <ConvertLayoutOp>(op.getLoc (), type, adaptor.getSrc ());
531+ rewriter.modifyOpInPlace (op, [&] { op.getSrcMutable ().assign (src); });
532+ return success ();
533+ }
534+ };
535+
469536struct TritonTransPattern : public OpConversionPattern <TransOp> {
470537 using OpConversionPattern::OpConversionPattern;
471538
@@ -592,40 +659,61 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
592659 MLIRContext *context = patterns.getContext ();
593660 patterns.insert < // TODO: view should have custom pattern that views the
594661 // layout
662+ // clang-format off
595663 GenericOpPattern<triton::AdvanceOp>,
596664 GenericOpPattern<triton::MakeTensorPtrOp>,
597- GenericOpPattern<triton::ReshapeOp>, GenericOpPattern<triton::BitcastOp>,
598- GenericOpPattern<triton::FpToFpOp>, GenericOpPattern<triton::IntToPtrOp>,
599- GenericOpPattern<triton::PtrToIntOp>, GenericOpPattern<triton::SplatOp>,
600- TritonBroadcastPattern, GenericOpPattern<triton::AddPtrOp>,
601- TritonCatPattern, TritonJoinOpPattern, TritonSplitOpPattern,
665+ GenericOpPattern<triton::ReshapeOp>,
666+ GenericOpPattern<triton::BitcastOp>,
667+ GenericOpPattern<triton::FpToFpOp>,
668+ GenericOpPattern<triton::IntToPtrOp>,
669+ GenericOpPattern<triton::PtrToIntOp>,
670+ GenericOpPattern<triton::SplatOp>,
671+ GenericOpPattern<triton::AddPtrOp>,
672+ TritonBroadcastPattern,
673+ TritonCatPattern,
674+ TritonJoinOpPattern,
675+ TritonSplitOpPattern,
602676 GenericOpPattern<triton::ClampFOp>,
603677 GenericOpPattern<triton::PreciseSqrtOp>,
604678 GenericOpPattern<triton::PreciseDivFOp>,
605679 GenericOpPattern<triton::MulhiUIOp>,
606- GenericOpPattern<triton::ElementwiseInlineAsmOp>, TritonReducePattern,
607- GenericOpPattern<triton::ReduceReturnOp>, TritonScanPattern,
680+ GenericOpPattern<triton::ElementwiseInlineAsmOp>,
681+ TritonReducePattern,
682+ GenericOpPattern<triton::ReduceReturnOp>,
683+ TritonScanPattern,
608684 GenericOpPattern<triton::ScanReturnOp>,
609- GenericOpPattern<triton::MakeRangeOp>, TritonExpandDimsPattern,
610- TritonTransPattern, TritonDotPattern,
685+ GenericOpPattern<triton::MakeRangeOp>,
686+ TritonExpandDimsPattern,
687+ TritonTransPattern,
688+ TritonDotPattern,
611689 GatherScatterOpPattern<DescriptorGatherOp>,
612690 GatherScatterOpPattern<DescriptorScatterOp>,
613- GatherScatterOpPattern<triton::nvidia_gpu::AsyncTMAGatherOp>,
614- GatherScatterOpPattern<triton::nvidia_gpu::AsyncTMAScatterOp>,
615- GenericOpPattern<triton::LoadOp>, GenericOpPattern<triton::StoreOp>,
616- GenericOpPattern<triton::HistogramOp>, GenericOpPattern<triton::GatherOp>,
691+ GatherScatterOpPattern<ttng::AsyncTMAGatherOp>,
692+ GatherScatterOpPattern<ttng::AsyncTMAScatterOp>,
693+ TMEMLoadOpPattern,
694+ TMEMStoreOpPattern,
695+ TMEMAllocOpPattern,
696+ GenericOpPattern<triton::LoadOp>,
697+ GenericOpPattern<triton::StoreOp>,
698+ GenericOpPattern<triton::HistogramOp>,
699+ GenericOpPattern<triton::GatherOp>,
617700 GenericOpPattern<triton::ExternElementwiseOp>,
618- GenericOpPattern<triton::PrintOp>, GenericOpPattern<triton::AssertOp>,
701+ GenericOpPattern<triton::PrintOp>,
702+ GenericOpPattern<triton::AssertOp>,
619703 GenericOpPattern<triton::AtomicCASOp>,
620- GenericOpPattern<triton::AtomicRMWOp>, GenericOpPattern<ReturnOp>,
704+ GenericOpPattern<triton::AtomicRMWOp>,
621705 GenericOpPattern<triton::DescriptorLoadOp>,
622706 GenericOpPattern<triton::DescriptorStoreOp>,
623707 GenericOpPattern<triton::DescriptorReduceOp>,
624708 GenericOpPattern<triton::ExperimentalTensormapCreateOp>,
625709 GenericOpPattern<triton::ExperimentalTensormapFenceproxyAcquireOp>,
626710 // this assumes the right layout will be set later for dot scaled.
627- GenericOpPattern<triton::DotScaledOp>, GenericOpPattern<triton::CallOp>,
628- TritonFuncOpPattern>(typeConverter, context);
711+ GenericOpPattern<triton::DotScaledOp>,
712+ GenericOpPattern<triton::CallOp>,
713+ GenericOpPattern<ReturnOp>,
714+ TritonFuncOpPattern
715+ // clang-format on
716+ >(typeConverter, context);
629717}
630718// Proton patterns
631719// NOTE: Because Proton's inputs are scalars and not tensors this conversion
0 commit comments