@@ -119,6 +119,10 @@ struct LoadStoreConversionBase {
119119 return axisAnalysisPass.getMaskAlignment (mask);
120120 }
121121
122+ unsigned getPtrAlignment (Value ptr) const {
123+ return axisAnalysisPass.getPtrAlignment (ptr);
124+ }
125+
122126protected:
123127 const AMD::TargetInfo &targetInfo;
124128 ModuleAxisInfoAnalysis &axisAnalysisPass;
@@ -193,7 +197,9 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
193197 // vectorized iteration through all the pointer/mask/other elements
194198 const int valueElemNBits =
195199 std::max (8u , valueElemTy.getIntOrFloatBitWidth ());
200+ const size_t valueElemNBytes = valueElemNBits / 8 ;
196201 const int numVecs = numElems / vec;
202+ int64_t ptrAlignmentBytes = getPtrAlignment (ptr) * valueElemNBytes;
197203
198204 auto cacheMod = op.getCache ();
199205 SmallVector<Value> loadedVals;
@@ -230,8 +236,8 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
230236 falseVal = v;
231237 }
232238
233- auto loadVal =
234- llLoad (rewriter, loc, ptr, vecTy, pred, falseVal , cacheMod);
239+ Value loadVal = llLoad (rewriter, loc, ptr, vecTy, pred, falseVal,
240+ ptrAlignmentBytes , cacheMod);
235241 for (size_t ii = 0 ; ii < vec; ++ii) {
236242 Value vecIdx = createIndexAttrConstant (
237243 rewriter, loc, this ->getTypeConverter ()->getIndexType (), ii % vec);
@@ -294,9 +300,10 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
294300 vec = std::min (vec, maskAlign);
295301 }
296302
297- const size_t dtsize =
298- std::max<int >(1 , valueElemTy.getIntOrFloatBitWidth () / 8 );
299- const size_t valueElemNBits = dtsize * 8 ;
303+ const size_t valueElemNBits =
304+ std::max<int >(8 , valueElemTy.getIntOrFloatBitWidth ());
305+ const size_t valueElemNBytes = valueElemNBits / 8 ;
306+ int64_t ptrAlignmentBytes = getPtrAlignment (ptr) * valueElemNBytes;
300307
301308 auto cacheMod = op.getCache ();
302309 const int numVecs = elemsPerThread / vec;
@@ -328,7 +335,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
328335 rewriter, loc, this ->getTypeConverter ()->getIndexType (), s);
329336 storeVal = insert_element (vecTy, storeVal, otherElem, indexVal);
330337 }
331- llStore (rewriter, loc, ptr, storeVal, pred, cacheMod);
338+ llStore (rewriter, loc, ptr, storeVal, pred, ptrAlignmentBytes, cacheMod);
332339 } // end vec
333340 rewriter.eraseOp (op);
334341 return success ();
0 commit comments