@@ -46,7 +46,7 @@ struct ConvertLayoutOpConversion
4646 Attribute srcLayout = srcTy.getEncoding ();
4747 Attribute dstLayout = dstTy.getEncoding ();
4848 if (isSupported (srcLayout, dstLayout)) {
49- return lowerDistributedToDistributed (op, adaptor, rewriter);
49+ return lowerDistributedToDistributed (op, adaptor, rewriter, targetInfo );
5050 }
5151 return failure ();
5252 }
@@ -115,10 +115,9 @@ struct ConvertLayoutOpConversion
115115 shapePerCTA);
116116 Value offset = linearize (rewriter, loc, multiDimOffsetWrapped,
117117 paddedRepShape, outOrd);
118- auto elemPtrTy = ptr_ty (rewriter. getContext (), 3 );
118+ auto elemPtrTy = smemBase. getType ( );
119119 Value ptr = gep (elemPtrTy, llvmElemTy, smemBase, offset);
120120 auto vecTy = vec_ty (llvmElemTy, vec);
121- ptr = bitcast (ptr, ptr_ty (rewriter.getContext (), 3 ));
122121 if (stNotRd) {
123122 Value valVec = undef (vecTy);
124123 for (unsigned v = 0 ; v < vec; ++v) {
@@ -150,7 +149,8 @@ struct ConvertLayoutOpConversion
150149 // Data padding in shared memory to avoid bank conflict.
151150 LogicalResult
152151 lowerDistributedToDistributed (ConvertLayoutOp op, OpAdaptor adaptor,
153- ConversionPatternRewriter &rewriter) const {
152+ ConversionPatternRewriter &rewriter,
153+ const TargetInfoBase &targetInfo) const {
154154 auto loc = op.getLoc ();
155155 auto typeConverter = getTypeConverter ();
156156 RankedTensorType srcTy = op.getSrc ().getType ();
@@ -168,9 +168,7 @@ struct ConvertLayoutOpConversion
168168 }
169169
170170 Value smemBase =
171- LLVM::getSharedMemoryBase (loc, rewriter, op.getOperation ());
172- auto elemPtrTy = ptr_ty (rewriter.getContext (), 3 );
173- smemBase = bitcast (smemBase, elemPtrTy);
171+ LLVM::getSharedMemoryBase (loc, rewriter, targetInfo, op.getOperation ());
174172 auto shape = dstTy.getShape ();
175173 unsigned rank = dstTy.getRank ();
176174 SmallVector<unsigned > numReplicates (rank);
@@ -447,8 +445,6 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
447445 MLIRContext *ctx = op.getContext ();
448446 auto loc = op.getLoc ();
449447
450- auto sharedPtrTy = ptr_ty (ctx, /* addressSpace=*/ 3 );
451-
452448 StringAttr kRegister = str_attr (" register" );
453449 StringAttr kLane = str_attr (" lane" );
454450 StringAttr kWarp = str_attr (" warp" );
@@ -508,7 +504,8 @@ struct ConvertLayoutOpUsingLinearLayoutsConversion
508504 collectRegsForIter (ctx, shmemLoadLayout);
509505
510506 Value smemBase =
511- LLVM::getSharedMemoryBase (loc, rewriter, op.getOperation ());
507+ LLVM::getSharedMemoryBase (loc, rewriter, targetInfo, op.getOperation ());
508+ auto sharedPtrTy = smemBase.getType ();
512509 Type elemTy = inVals[0 ].getType ();
513510 auto outSize = shmemLoadLayout.getInDimSize (kRegister );
514511 auto iterations = sharedLayout.getInDimSize (kIteration );
0 commit comments