@@ -131,20 +131,27 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
131131 LogicalResult
132132 matchAndRewrite (LocalLoadOp op, OpAdaptor adaptor,
133133 ConversionPatternRewriter &rewriter) const override {
134- MemDescType srcTy = op.getSrc ().getType ();
135134 RankedTensorType dstTy = op.getType ();
136- Attribute srcLayout = srcTy.getEncoding ();
137135 Attribute dstLayout = dstTy.getEncoding ();
138- if (isa<SharedEncodingAttr>(srcLayout) &&
139- isa<BlockedEncodingAttr, MmaEncodingTrait, SliceEncodingAttr>(
140- dstLayout)) {
141- return lowerSharedToDistributed (op, adaptor, getTypeConverter (),
142- rewriter);
143- }
144136 if (isa<DotOperandEncodingAttr>(dstLayout)) {
145- return lowerSharedToDotOperand (op, adaptor, getTypeConverter (), rewriter);
137+ auto dotLayout = cast<DotOperandEncodingAttr>(dstLayout);
138+ if (auto dpasLayout =
139+ dyn_cast_or_null<DpasEncodingAttr>(dotLayout.getParent ())) {
140+ auto sharedLayout =
141+ cast<SharedEncodingAttr>(op.getSrc ().getType ().getEncoding ());
142+ int K;
143+ if (dotLayout.getOpIdx () == 0 ) // $a
144+ K = op.getType ().getShape ()[sharedLayout.getOrder ()[0 ]];
145+ else // $b
146+ K = op.getType ().getShape ()[sharedLayout.getOrder ()[1 ]];
147+ bool isOuter = K == 1 ;
148+ rewriter.replaceOp (op, lowerSharedToDotOperandDPAS (
149+ op, adaptor, getTypeConverter (), rewriter,
150+ dpasLayout, dotLayout, isOuter));
151+ return success ();
152+ }
146153 }
147- return failure ( );
154+ return lowerSharedToDistributed (op, adaptor, getTypeConverter (), rewriter );
148155 }
149156
150157private:
@@ -174,53 +181,13 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
174181 return res;
175182 }
176183
177- LogicalResult
178- lowerSharedToDotOperand (LocalLoadOp op, LocalLoadOpAdaptor adaptor,
179- const LLVMTypeConverter *typeConverter,
180- ConversionPatternRewriter &rewriter) const {
181- auto loc = op.getLoc ();
182- RankedTensorType dstTy = op.getType ();
183- Attribute dstLayout = dstTy.getEncoding ();
184- auto dotLayout = cast<DotOperandEncodingAttr>(dstLayout);
185- auto sharedLayout =
186- cast<SharedEncodingAttr>(op.getSrc ().getType ().getEncoding ());
187-
188- int K;
189- if (dotLayout.getOpIdx () == 0 ) // $a
190- K = op.getType ().getShape ()[sharedLayout.getOrder ()[0 ]];
191- else // $b
192- K = op.getType ().getShape ()[sharedLayout.getOrder ()[1 ]];
193- bool isOuter = K == 1 ;
194-
195- Value res;
196- if (auto dpasLayout =
197- dyn_cast_or_null<DpasEncodingAttr>(dotLayout.getParent ())) {
198- res = lowerSharedToDotOperandDPAS (op, adaptor, typeConverter, rewriter,
199- dpasLayout, dotLayout, isOuter);
200- } else if (auto blockedLayout = dyn_cast_or_null<BlockedEncodingAttr>(
201- dotLayout.getParent ())) {
202- auto thread = getThreadId (rewriter, loc);
203- res = SharedToDotOperandFMA::convertLayout (
204- dotLayout.getOpIdx (), op.getSrc (), adaptor.getSrc (), blockedLayout,
205- thread, loc, getTypeConverter (), rewriter);
206- } else {
207- assert (false && " Unsupported dot operand layout found" );
208- }
209-
210- rewriter.replaceOp (op, res);
211- return success ();
212- }
213184 LogicalResult
214185 lowerSharedToDistributed (LocalLoadOp op, LocalLoadOpAdaptor adaptor,
215186 const LLVMTypeConverter *typeConverter,
216187 ConversionPatternRewriter &rewriter) const {
217188 auto loc = op.getLoc ();
218189 auto srcTy = op.getSrc ().getType ();
219190 auto dstTy = op.getResult ().getType ();
220- auto dstShape = dstTy.getShape ();
221- auto srcSharedLayout = cast<SharedEncodingAttr>(srcTy.getEncoding ());
222- assert (!isa<DotOperandEncodingAttr>(dstTy.getEncoding ()) &&
223- " Unexpected rank of ConvertLayout(shared->blocked)" );
224191
225192 auto smemObj = LLVM::getSharedMemoryObjectFromStruct (
226193 loc, adaptor.getSrc (),
0 commit comments