@@ -206,12 +206,11 @@ struct LoadStoreConversionBase {
206206 // }
207207 // All the values are decomposed by `unpackLLElements` into a vector.
208208 // Defines the indices for the block pointer struct.
209- unsigned blockOffset = 0 , blockShape = 1 * rank, blockStride = 2 * rank,
210- blockBase = 3 * rank;
209+ const unsigned blockOffset = 0 , blockShape = 1 * rank,
210+ blockStride = 2 * rank, blockBase = 3 * rank;
211211 const SmallVector<Value> &blockPtr =
212212 unpackLLElements (loc, blockPointerStruct, rewriter);
213-
214- unsigned numElems = getTotalElemsPerThread (tensorType);
213+ const unsigned numElems = getTotalElemsPerThread (tensorType);
215214
216215 // Get the LLVM values for indices in block
217216 auto indices = emitIndices (loc, rewriter, targetInfo,
@@ -303,6 +302,34 @@ struct LoadStoreConversionBase {
303302 return std::make_tuple (ptrElems, maskElems, otherElems);
304303 }
305304
305+ // Ensure the operation doesn't have attributes that the IGC predicated
306+ // instruction cannot handle.
307+ template <typename OpType, typename = std::enable_if_t <llvm::is_one_of<
308+ OpType, LoadOp, StoreOp>::value>>
309+ bool canUsePredicatedInstructions (OpType op) const {
310+ if (!usePredicatedInstructions)
311+ return false ;
312+
313+ if constexpr (std::is_same_v<OpType, LoadOp>)
314+ return !op.getIsVolatile () && op.getCache () == CacheModifier::NONE;
315+
316+ return op.getCache () == CacheModifier::NONE;
317+ }
318+
319+ template <typename OpType, typename = std::enable_if_t <llvm::is_one_of<
320+ OpType, LoadOp, StoreOp>::value>>
321+ bool getNonTemporalFlag (OpType op) const {
322+ switch (op.getCache ()) {
323+ case triton::CacheModifier::CG:
324+ case triton::CacheModifier::CS:
325+ case triton::CacheModifier::CV:
326+ return true ;
327+ case triton::CacheModifier::CA:
328+ default :
329+ return false ;
330+ }
331+ }
332+
306333protected:
307334 const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass;
308335 const triton::intel::TargetInfo &targetInfo;
@@ -2438,11 +2465,13 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
24382465 : retTys[0 ];
24392466
24402467 Value other_ = b.undef (retTy);
2441- if (otherElems.size ()) {
2468+ if (otherElems.empty ()) {
2469+ other_ = rewriter.create <LLVM::ConstantOp>(loc, retTy,
2470+ rewriter.getZeroAttr (retTy));
2471+ } else {
24422472 for (size_t ii = 0 ; ii < nWords; ++ii) {
24432473 size_t size = width / valueElemNBits;
2444-
2445- auto vecTy = vec_ty (valueElemTy, size);
2474+ VectorType vecTy = vec_ty (valueElemTy, size);
24462475 Value v = b.undef (vecTy);
24472476 for (size_t s = 0 ; s < size; ++s) {
24482477 Value falseVal = otherElems[vecStart + ii * size + s];
@@ -2468,36 +2497,21 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
24682497
24692498 v;
24702499 }
2471- } else {
2472- other_ = rewriter.create <LLVM::ConstantOp>(loc, retTy,
2473- rewriter.getZeroAttr (retTy));
24742500 }
2501+ assert (other_ && " Expecting a valid value" );
24752502
24762503 Value addrElem = b.bitcast (ptrElems[vecStart], ptr_ty (ctx, 1 /* global*/ ));
24772504 uint32_t alignment = nWords * width / 8 ;
2478- auto createLoadWithAttrs = [&]() -> SmallVector<Value, 1 > {
2479- auto getNonTemporalFlag = [](triton::LoadOp loadOp) {
2480- switch (loadOp.getCache ()) {
2481- case triton::CacheModifier::CG:
2482- case triton::CacheModifier::CS:
2483- case triton::CacheModifier::CV:
2484- return true ;
2485- case triton::CacheModifier::CA:
2486- default :
2487- return false ;
2488- }
2489- };
2490-
2491- Value ret = b.load (retTy, addrElem, alignment, op.getIsVolatile (),
2492- getNonTemporalFlag (op));
2493- return {ret};
2505+ auto createLoadWithAttrs = [&]() {
2506+ return SmallVector<Value>{b.load (retTy, addrElem, alignment,
2507+ op.getIsVolatile (),
2508+ getNonTemporalFlag (op))};
24942509 };
24952510
24962511 Value ret;
2497-
24982512 if (!pred)
24992513 ret = createLoadWithAttrs ()[0 ];
2500- else if (usePredicatedInstructions )
2514+ else if (canUsePredicatedInstructions (op) )
25012515 ret = rewriter.create <TritonGEN::PredicatedLoadOp>(
25022516 loc, retTy, addrElem, b.i64_val (alignment), pred, other_);
25032517 else {
@@ -2519,6 +2533,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
25192533 curr, LLVM::getVectorType (valueElemTy, width / valueElemNBits));
25202534 rets.push_back (curr);
25212535 }
2536+
25222537 int tmp = width / valueElemNBits;
25232538 for (size_t ii = 0 ; ii < vec; ++ii) {
25242539 Value loaded =
@@ -2931,18 +2946,21 @@ struct StoreOpConversion
29312946
29322947 Value addrElem = b.bitcast (ptrElems[vecStart], ptr_ty (ctx, 1 /* global*/ ));
29332948 uint32_t alignment = nWords * width / 8 ;
2934- auto createStore = [&]() -> ArrayRef<Value> {
2935- b.store (vecWord, addrElem, alignment);
2949+ auto createStoreWithAttrs = [&]() {
2950+ bool isVolatile = false ;
2951+ b.store (vecWord, addrElem, alignment, isVolatile,
2952+ getNonTemporalFlag (op));
29362953 return ArrayRef<Value>();
29372954 };
29382955
29392956 if (!maskVal)
2940- auto _ = createStore ();
2941- else if (usePredicatedInstructions )
2957+ auto _ = createStoreWithAttrs ();
2958+ else if (canUsePredicatedInstructions (op) )
29422959 rewriter.create <TritonGEN::PredicatedStoreOp>(
29432960 loc, addrElem, vecWord, b.i64_val (alignment), maskVal);
29442961 else
2945- LLVM::intel::createPredicatedBlock (rewriter, loc, maskVal, createStore);
2962+ LLVM::intel::createPredicatedBlock (rewriter, loc, maskVal,
2963+ createStoreWithAttrs);
29462964 }
29472965
29482966 rewriter.eraseOp (op);
0 commit comments