11#include " PatternTritonGPUOpToLLVM.h"
22#include " TargetInfo.h"
33#include " Utility.h"
4+ #include " mlir/Conversion/LLVMCommon/TypeConverter.h"
5+ #include " mlir/Dialect/LLVMIR/LLVMDialect.h"
6+ #include " mlir/IR/PatternMatch.h"
7+ #include " mlir/Transforms/DialectConversion.h"
8+ #include " triton/Conversion/TritonGPUToLLVM/Utility.h"
49#include " triton/Dialect/TritonGPU/Transforms/Utility.h"
510
611using namespace mlir ;
@@ -86,6 +91,7 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
8691 }
8792 return mask;
8893}
94+
8995// Contains some helper functions for both Load and Store conversions.
9096struct LoadStoreConversionBase {
9197 explicit LoadStoreConversionBase (const AMD::TargetInfo &targetInfo,
@@ -192,7 +198,6 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
192198 auto cacheMod = op.getCache ();
193199 SmallVector<Value> loadedVals;
194200 for (size_t vecStart = 0 ; vecStart < numElems; vecStart += vec) {
195- // TODO: optimization when ptr is GEP with constant offset
196201 size_t in_off = 0 ;
197202
198203 const size_t maxWordWidth = std::max<size_t >(32 , valueElemNBits);
@@ -218,8 +223,8 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
218223 Value v = undef (vecTy);
219224 for (size_t s = 0 ; s < vec; ++s) {
220225 Value otherElem = otherElems[vecStart + s];
221- Value indexVal = createIndexAttrConstant (
222- rewriter, loc, this ->getTypeConverter ()-> getIndexType () , s);
226+ Value indexVal = LLVM::createIndexConstant (
227+ rewriter, loc, this ->getTypeConverter (), s);
223228 v = insert_element (vecTy, v, otherElem, indexVal);
224229 }
225230 falseVal = v;
@@ -259,6 +264,7 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
259264 ConversionPatternRewriter &rewriter) const override {
260265 Value ptr = op.getPtr ();
261266 Value value = op.getValue ();
267+ Value mask = op.getMask ();
262268
263269 Value llPtr = adaptor.getPtr ();
264270 Value llMask = adaptor.getMask ();
@@ -281,24 +287,24 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
281287 // Determine the vectorization size
282288 SmallVector<Value> maskElems;
283289 if (llMask) {
284- Value mask = op.getMask ();
285290 maskElems = unpackLLElements (loc, llMask, rewriter);
286291 assert (valueElems.size () == maskElems.size ());
287292
288293 unsigned maskAlign = getMaskAlignment (mask);
289294 vec = std::min (vec, maskAlign);
290295 }
291296
292- Value mask = redundantDataMask (valueTy, rewriter, loc, targetInfo);
293297 const size_t dtsize =
294298 std::max<int >(1 , valueElemTy.getIntOrFloatBitWidth () / 8 );
295299 const size_t valueElemNBits = dtsize * 8 ;
296300
297301 auto cacheMod = op.getCache ();
298302 const int numVecs = elemsPerThread / vec;
303+ Value rDataMask = redundantDataMask (valueTy, rewriter, loc, targetInfo);
299304 for (size_t vecStart = 0 ; vecStart < elemsPerThread; vecStart += vec) {
300- // TODO: optimization when ptr is AddPtr with constant offset
301305 size_t in_off = 0 ;
306+ Value pred = mask ? and_ (maskElems[vecStart], rDataMask) : rDataMask;
307+ auto vecTy = LLVM::getFixedVectorType (valueElemTy, vec);
302308
303309 const size_t maxWordWidth = std::max<size_t >(32 , valueElemNBits);
304310 const size_t totalWidth = valueElemNBits * vec;
@@ -307,33 +313,23 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
307313 const size_t wordNElems = width / valueElemNBits;
308314 assert (wordNElems * nWords * numVecs == elemsPerThread);
309315
310- // TODO(Superjomn) Add cache policy fields to StoreOp.
311- // TODO(Superjomn) Deal with cache policy here.
312-
313316 Type valArgTy = IntegerType::get (ctx, width);
314317 auto wordTy = vec_ty (valueElemTy, wordNElems);
315318
316319 SmallVector<std::pair<Value, std::string>> asmArgs;
317- for (size_t wordIdx = 0 ; wordIdx < nWords; ++wordIdx) {
318- // llWord is a width-len composition
319- Value llWord = undef (wordTy);
320- // Insert each value element to the composition
321- for (size_t elemIdx = 0 ; elemIdx < wordNElems; ++elemIdx) {
322- const size_t elemOffset = vecStart + wordIdx * wordNElems + elemIdx;
323- assert (elemOffset < valueElems.size ());
324- Value elem = valueElems[elemOffset];
325- if (elem.getType ().isInteger (1 ))
326- elem = sext (i8_ty, elem);
327- elem = bitcast (elem, valueElemTy);
328-
329- llWord = insert_element (wordTy, llWord, elem, i32_val (elemIdx));
330- }
331- llWord = bitcast (llWord, valArgTy);
332- Value maskVal = llMask ? and_ (mask, maskElems[vecStart]) : mask;
333- auto address = ptrElems[vecStart + wordIdx * wordNElems];
334- llStore (rewriter, loc, address, llWord, maskVal, cacheMod);
320+ Value elem = valueElems[vecStart];
321+ Value ptr = addrspacecast (ptr_ty (getContext ()), ptrElems[vecStart]);
322+
323+ // Create the store val
324+ Value storeVal = undef (vecTy);
325+ for (size_t s = 0 ; s < vec; ++s) {
326+ Value otherElem = valueElems[vecStart + s];
327+ Value indexVal = createIndexAttrConstant (
328+ rewriter, loc, this ->getTypeConverter ()->getIndexType (), s);
329+ storeVal = insert_element (vecTy, storeVal, otherElem, indexVal);
335330 }
336- }
331+ llStore (rewriter, loc, ptr, storeVal, pred, cacheMod);
332+ } // end vec
337333 rewriter.eraseOp (op);
338334 return success ();
339335 }
0 commit comments