@@ -116,9 +116,20 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
116116 RankedTensorType dstTy = op.getType ();
117117 Attribute srcLayout = srcTy.getEncoding ();
118118 Attribute dstLayout = dstTy.getEncoding ();
119+ // FIXME [Dot LL]
120+ // Do for all DotOperandEncodingAttr once we have LLs for all of them
121+ auto isAmpereLargeKWidth = [](Attribute layout) {
122+ if (auto dot = dyn_cast<DotOperandEncodingAttr>(layout)) {
123+ if (auto mma = dyn_cast<NvidiaMmaEncodingAttr>(dot.getParent ())) {
124+ return mma.isAmpere () && dot.getKWidth () == 8 ;
125+ }
126+ }
127+ return false ;
128+ };
119129 if (isa<SharedEncodingAttr>(srcLayout) &&
120- isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
121- dstLayout)) {
130+ (isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
131+ dstLayout) ||
132+ isAmpereLargeKWidth (dstLayout))) {
122133 return lowerSharedToDistributed (op, adaptor, getTypeConverter (),
123134 rewriter);
124135 }
@@ -170,6 +181,37 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
170181 SmallVector<Value> outVals = loadSharedToDistributed (
171182 dstTy, srcTy, elemLlvmTy, smemObj, loc, rewriter, targetInfo);
172183
184+ // FIXME [Dot LL]
185+ // Ampere case
186+ // In this case, we need to pack the outputs into i32
187+ if (isa<DotOperandEncodingAttr>(dstTy.getEncoding ())) {
188+ if (elemLlvmTy.isInteger (8 )) {
189+ auto concat = [&](Value a1, Value a2, Value a3, Value a4) {
190+ return or_ (or_ (zext (i32_ty, a1), shl (zext (i32_ty, a2), i32_val (8 ))),
191+ or_ (shl (zext (i32_ty, a3), i32_val (16 )),
192+ shl (zext (i32_ty, a4), i32_val (24 ))));
193+ };
194+ SmallVector<Value> outVals32 (outVals.size () / 4 );
195+ for (int i = 0 ; i < outVals32.size (); ++i) {
196+ outVals32[i] = concat (outVals[4 * i], outVals[4 * i + 1 ],
197+ outVals[4 * i + 2 ], outVals[4 * i + 3 ]);
198+ }
199+ outVals = outVals32;
200+ } else {
201+ assert (elemLlvmTy.isBF16 () && " Unexpected element type" );
202+ auto concat = [&](Value a, Value b) {
203+ return or_ (zext (i32_ty, bitcast (a, i16_ty)),
204+ shl (zext (i32_ty, bitcast (b, i16_ty)), i32_val (16 )));
205+ };
206+
207+ SmallVector<Value> outVals32 (outVals.size () / 2 );
208+ for (int i = 0 ; i < outVals32.size (); ++i) {
209+ outVals32[i] = concat (outVals[2 * i], outVals[2 * i + 1 ]);
210+ }
211+ outVals = outVals32;
212+ }
213+ }
214+
173215 Value result = packLLElements (loc, typeConverter, outVals, rewriter, dstTy);
174216 rewriter.replaceOp (op, result);
175217
0 commit comments