@@ -1030,46 +1030,6 @@ StringRef getAMDArch(Operation *module) {
10301030 return ref.drop_front (4 ); // drop the "hip:"
10311031}
10321032
1033- // Rough utility for obtaining a SharedEnc for a LinearEncoding,
1034- // as we've replaced DotOpEnc with Linear in some cases
1035- // (specifically, fp4ToFp and similar unpack-upcast thru join)
1036- std::optional<ttg::SwizzledSharedEncodingAttr>
1037- getSharedForLinear (ttg::LinearEncodingAttr enc,
1038- ArrayRef<unsigned int > globalOrder, ArrayRef<int64_t > shape,
1039- unsigned elemBitWidth, ttg::CTALayoutAttr ctaLayout) {
1040- auto ctx = enc.getContext ();
1041- auto ll = enc.getLinearLayout ();
1042- auto rank = shape.size ();
1043-
1044- if (rank != 2 )
1045- return std::nullopt ;
1046-
1047- auto order = enc.getOrder ();
1048- assert (globalOrder.size () == rank);
1049- // TODO add memdesc_trans support for dot(trans(cvt(src) #linear) #dot_op)
1050- if (order != globalOrder)
1051- return std::nullopt ;
1052-
1053- auto innerDim = order[0 ];
1054- auto outerDim = order[1 ];
1055- auto contigPerWarp = enc.getContigPerWarp ();
1056-
1057- constexpr unsigned BANK_SIZE{128 };
1058- auto elemBytes = elemBitWidth / 8 ;
1059-
1060- auto vec = contigPerWarp[innerDim];
1061- auto rowSize = elemBytes * (unsigned )shape[innerDim];
1062- auto perPhase = std::max (BANK_SIZE / rowSize, 1u );
1063- auto maxPhase = std::max (contigPerWarp[outerDim] / perPhase, 1u );
1064-
1065- // cp.async does not support transfer size < 4B
1066- if (vec * elemBytes < 4 && perPhase < maxPhase)
1067- return std::nullopt ;
1068-
1069- return ttg::SwizzledSharedEncodingAttr::get (ctx, vec, perPhase, maxPhase,
1070- order, ctaLayout);
1071- }
1072-
10731033// If all the transitive uses of the given value have are used by a convert to
10741034// the same dot operand encoding, return the shared encoding that needs to be
10751035// used to be compatible with users' layouts. If there are incompatible shared
@@ -1096,28 +1056,18 @@ getSharedEncIfAllUsersAreDotEnc(Value val, bool &incompatible) {
10961056 } else {
10971057 if (!isa<ttg::LocalLoadOp, ttg::ConvertLayoutOp>(user))
10981058 return std::nullopt ;
1099- auto enc =
1059+ auto dotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(
11001060 cast<triton::gpu::TensorOrMemDesc>(user->getResult (0 ).getType ())
1101- .getEncoding ();
1061+ .getEncoding ());
1062+ if (!dotOpEnc)
1063+ return std::nullopt ;
11021064 auto srcTy = cast<triton::gpu::TensorOrMemDesc>(val.getType ());
1103- auto ctaLayout = ttg::getCTALayout (srcTy.getEncoding ());
1065+ auto CTALayout = ttg::getCTALayout (srcTy.getEncoding ());
11041066 auto order = ttg::getOrder (srcTy.getEncoding ());
11051067 unsigned bitWidth = srcTy.getElementType ().getIntOrFloatBitWidth ();
1106-
1107- if (auto dotOpEnc = dyn_cast<ttg::DotOperandEncodingAttr>(enc)) {
1108- tempAttr = ttg::SwizzledSharedEncodingAttr::get (
1109- val.getContext (), dotOpEnc, srcTy.getShape (), order, ctaLayout,
1110- bitWidth, /* needTrans=*/ false );
1111- } else if (auto linearEnc = dyn_cast<ttg::LinearEncodingAttr>(enc)) {
1112-
1113- auto attrOpt = getSharedForLinear (linearEnc, order, srcTy.getShape (),
1114- bitWidth, ctaLayout);
1115- if (!attrOpt)
1116- return std::nullopt ;
1117- tempAttr = *attrOpt;
1118- } else {
1119- return std::nullopt ;
1120- }
1068+ tempAttr = ttg::SwizzledSharedEncodingAttr::get (
1069+ val.getContext (), dotOpEnc, srcTy.getShape (), order, CTALayout,
1070+ bitWidth, /* needTrans=*/ false );
11211071 }
11221072 // Check that the shared encodings needed by the users are compatible.
11231073 if (attr != nullptr && attr != tempAttr) {
0 commit comments