@@ -109,27 +109,30 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
109109 : ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
110110 }
111111
112+ // FIXME [Dot LL]
113+ // Do for all DotOperandEncodingAttr once we have LLs for all of them
114+ static bool isSupportedDotOpLayout (Attribute layout) {
115+ if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
116+ if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent ())) {
117+ return mma.isAmpere () && dot.getKWidth () == 8 ;
118+ }
119+ if (isa<AMDMfmaEncodingAttr>(dot.getParent ()))
120+ return true ;
121+ }
122+ return false ;
123+ };
124+
112125 LogicalResult
113126 matchAndRewrite (LocalLoadOp op, OpAdaptor adaptor,
114127 ConversionPatternRewriter &rewriter) const override {
115128 MemDescType srcTy = op.getSrc ().getType ();
116129 RankedTensorType dstTy = op.getType ();
117130 Attribute srcLayout = srcTy.getEncoding ();
118131 Attribute dstLayout = dstTy.getEncoding ();
119- // FIXME [Dot LL]
120- // Do for all DotOperandEncodingAttr once we have LLs for all of them
121- auto isAmpereLargeKWidth = [](Attribute layout) {
122- if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
123- if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent ())) {
124- return mma.isAmpere () && dot.getKWidth () == 8 ;
125- }
126- }
127- return false ;
128- };
129132 if (isa<SharedEncodingAttr>(srcLayout) &&
130133 (isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
131134 dstLayout) ||
132- isAmpereLargeKWidth (dstLayout))) {
135+ isSupportedDotOpLayout (dstLayout))) {
133136 return lowerSharedToDistributed (op, adaptor, getTypeConverter (),
134137 rewriter);
135138 }
@@ -167,10 +170,10 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
167170 auto srcTy = op.getSrc ().getType ();
168171 auto dstTy = op.getResult ().getType ();
169172 auto dstShape = dstTy.getShape ();
170- assert (dstShape.size () <= 2 &&
171- " Unexpected rank of ConvertLayout(shared->blocked)" );
172173 auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding ());
173174 auto dstLayout = dstTy.getEncoding ();
175+ assert ((dstShape.size () <= 2 || isSupportedDotOpLayout (dstLayout)) &&
176+ " Unexpected rank of ConvertLayout(shared->distributed)" );
174177 auto inOrd = getOrder (srcSharedLayout);
175178
176179 auto smemObj = LLVM::getSharedMemoryObjectFromStruct (
@@ -184,31 +187,36 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
184187 // FIXME [Dot LL]
185188 // Ampere case
186189 // In this case, we need to pack the outputs into i32
187- if (isa<DotOperandEncodingAttr>(dstTy.getEncoding ())) {
188- if (elemLlvmTy.isInteger (8 )) {
189- auto concat = [&](Value a1, Value a2, Value a3, Value a4) {
190- return or_ (or_ (zext (i32_ty, a1), shl (zext (i32_ty, a2), i32_val (8 ))),
191- or_ (shl (zext (i32_ty, a3), i32_val (16 )),
192- shl (zext (i32_ty, a4), i32_val (24 ))));
193- };
194- SmallVector<Value> outVals32 (outVals.size () / 4 );
195- for (int i = 0 ; i < outVals32.size (); ++i) {
196- outVals32[i] = concat (outVals[4 * i], outVals[4 * i + 1 ],
197- outVals[4 * i + 2 ], outVals[4 * i + 3 ]);
198- }
199- outVals = outVals32;
200- } else {
201- assert (elemLlvmTy.isBF16 () && " Unexpected element type" );
202- auto concat = [&](Value a, Value b) {
203- return or_ (zext (i32_ty, bitcast (a, i16_ty)),
204- shl (zext (i32_ty, bitcast (b, i16_ty)), i32_val (16 )));
205- };
190+ if (auto dotOp = dyn_cast<DotOperandEncodingAttr>(dstTy.getEncoding ())) {
191+ if (auto parent = dyn_cast<NvidiaMmaEncodingAttr>(dotOp.getParent ())) {
192+ if (parent.isAmpere ()) {
193+ if (elemLlvmTy.isInteger (8 )) {
194+ auto concat = [&](Value a1, Value a2, Value a3, Value a4) {
195+ return or_ (
196+ or_ (zext (i32_ty, a1), shl (zext (i32_ty, a2), i32_val (8 ))),
197+ or_ (shl (zext (i32_ty, a3), i32_val (16 )),
198+ shl (zext (i32_ty, a4), i32_val (24 ))));
199+ };
200+ SmallVector<Value> outVals32 (outVals.size () / 4 );
201+ for (int i = 0 ; i < outVals32.size (); ++i) {
202+ outVals32[i] = concat (outVals[4 * i], outVals[4 * i + 1 ],
203+ outVals[4 * i + 2 ], outVals[4 * i + 3 ]);
204+ }
205+ outVals = outVals32;
206+ } else {
207+ assert (elemLlvmTy.isBF16 () && " Unexpected element type" );
208+ auto concat = [&](Value a, Value b) {
209+ return or_ (zext (i32_ty, bitcast (a, i16_ty)),
210+ shl (zext (i32_ty, bitcast (b, i16_ty)), i32_val (16 )));
211+ };
206212
207- SmallVector<Value> outVals32 (outVals.size () / 2 );
208- for (int i = 0 ; i < outVals32.size (); ++i) {
209- outVals32[i] = concat (outVals[2 * i], outVals[2 * i + 1 ]);
213+ SmallVector<Value> outVals32 (outVals.size () / 2 );
214+ for (int i = 0 ; i < outVals32.size (); ++i) {
215+ outVals32[i] = concat (outVals[2 * i], outVals[2 * i + 1 ]);
216+ }
217+ outVals = outVals32;
218+ }
210219 }
211- outVals = outVals32;
212220 }
213221 }
214222
0 commit comments