@@ -39,6 +39,8 @@ static int __builtin_ctz(unsigned x) {
3939
4040namespace {
4141
42+ static constexpr int kBlockIOPitchSpecId = 123 ;
43+
4244Value maybeAnd (RewriterBase &rewriter, Location loc, Value a, Value b) {
4345 auto tb = TritonLLVMOpBuilder (loc, rewriter);
4446 if (a && b) {
@@ -338,6 +340,11 @@ struct LoadStoreConversionBase {
338340 triton::tools::getBoolEnv (" TRITON_INTEL_PREDICATED" );
339341};
340342
343+ static Value emitGenericLoad (triton::LoadOp op, Value llPtr, Value llMask,
344+ Value llOther, ConversionPatternRewriter &rewriter,
345+ const LLVMTypeConverter *typeConverter,
346+ const LoadStoreConversionBase &base);
347+
341348struct BlockIOConversionBase : public LoadStoreConversionBase {
342349 explicit BlockIOConversionBase (
343350 const triton::intel::TargetInfo &targetInfo,
@@ -1659,9 +1666,13 @@ struct LoadOpToBlockIOConversion
16591666 std::swap (baseWidth, baseHeight);
16601667 }
16611668 // HW requires the pitch to be at least 64 bytes.
1669+ bool needRuntimePitchCheck = false ;
1670+
16621671 if (auto pitchConst = mlir::triton::intel::getFoldedConstantValue (pitch)) {
16631672 if ((*pitchConst * elemSizeInBits / 8 ) < 64 )
16641673 return failure ();
1674+ } else {
1675+ needRuntimePitchCheck = true ;
16651676 }
16661677
16671678 baseWidth = b.trunc (i32_ty, baseWidth);
@@ -1889,10 +1900,72 @@ struct LoadOpToBlockIOConversion
18891900 }
18901901
18911902 Type llvmResultStructTy = typeConverter->convertType (op.getType ());
1892- Value resultStruct = packLLElements (loc, typeConverter, unpackedLoadedVals,
1893- rewriter, llvmResultStructTy);
1894- rewriter.replaceOp (op, {resultStruct});
18951903
1904+ Value blockIOResult = packLLElements (loc, typeConverter, unpackedLoadedVals,
1905+ rewriter, llvmResultStructTy);
1906+
1907+ Value finalResult;
1908+ if (!needRuntimePitchCheck) {
1909+ finalResult = blockIOResult;
1910+ } else {
1911+ MLIRContext *ctx = rewriter.getContext ();
1912+ ModuleOp module = op->getParentOfType <ModuleOp>();
1913+
1914+ auto i32Ty = IntegerType::get (ctx, 32 );
1915+ auto fnTy = LLVM::LLVMFunctionType::get (
1916+ i32Ty, ArrayRef<Type>{i32Ty, i32Ty}, /* isVarArg=*/ false );
1917+
1918+ LLVM::LLVMFuncOp specFn =
1919+ module .lookupSymbol <LLVM::LLVMFuncOp>(" __spirv_SpecConstant" );
1920+ if (!specFn) {
1921+ PatternRewriter::InsertionGuard guard (rewriter);
1922+ rewriter.setInsertionPointToStart (module .getBody ());
1923+
1924+ ImplicitLocOpBuilder ib (loc, rewriter);
1925+ specFn = LLVM::LLVMFuncOp::create (ib, " __spirv_SpecConstant" , fnTy);
1926+ // default linkage is External
1927+ }
1928+
1929+ // Default value (in bytes) if host doesn't specialize this ID.
1930+ // Using 0 means "disable block-IO by default".
1931+ Value specIdVal = LLVM::ConstantOp::create (
1932+ rewriter, loc, i32Ty,
1933+ rewriter.getI32IntegerAttr (kBlockIOPitchSpecId ));
1934+
1935+ Value defaultPitchBytes = LLVM::ConstantOp::create (
1936+ rewriter, loc, i32Ty, rewriter.getI32IntegerAttr (0 ));
1937+
1938+ // llvm.call @__spirv_SpecConstant(i32 specId, i32 default) -> i32
1939+ auto call = LLVM::CallOp::create (
1940+ rewriter, loc, TypeRange{i32Ty}, SymbolRefAttr::get (specFn),
1941+ ValueRange{specIdVal, defaultPitchBytes});
1942+
1943+ Value specPitchBytes = call.getResult ();
1944+
1945+ // cond = (specPitchBytes >= 64)
1946+ Value cond = b.icmp_sge (specPitchBytes, b.i32_val (64 ));
1947+
1948+ // Generic fallback lowering (gather load).
1949+ Value genericResult = emitGenericLoad (op,
1950+ adaptor.getPtr (), // llPtr
1951+ adaptor.getMask (), // llMask
1952+ adaptor.getOther (), // llOther
1953+ rewriter, typeConverter, *this );
1954+
1955+ auto createBlockIOResult = [&]() -> SmallVector<Value, 1 > {
1956+ return {blockIOResult};
1957+ };
1958+
1959+ Block &mergeBlock = LLVM::intel::createPredicatedBlock (
1960+ rewriter, loc,
1961+ cond, // true → block-IO
1962+ SmallVector<Value, 1 >{genericResult}, // false → generic
1963+ createBlockIOResult);
1964+
1965+ finalResult = mergeBlock.getArgument (0 );
1966+ }
1967+
1968+ rewriter.replaceOp (op, finalResult);
18961969 return success ();
18971970 }
18981971
@@ -2426,31 +2499,28 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
24262499 : ConvertOpToLLVMPattern<triton::LoadOp>(converter, benefit),
24272500 LoadStoreConversionBase (targetInfo, axisAnalysisPass) {}
24282501
2429- LogicalResult
2430- matchAndRewrite (triton::LoadOp op, OpAdaptor adaptor,
2431- ConversionPatternRewriter &rewriter) const override {
2502+ // / Generic lowering for triton::LoadOp → LLVM struct value.
2503+ static Value emitGenericLoadImpl (triton::LoadOp op, Value llPtr, Value llMask,
2504+ Value llOther,
2505+ ConversionPatternRewriter &rewriter,
2506+ const LLVMTypeConverter *typeConverter,
2507+ const LoadStoreConversionBase &base) {
24322508 Location loc = op->getLoc ();
24332509 auto b = TritonLLVMOpBuilder (loc, rewriter);
2434- auto typeConverter = getTypeConverter ();
24352510 MLIRContext *ctx = rewriter.getContext ();
24362511
24372512 // original values
24382513 Value ptr = op.getPtr ();
24392514 Value mask = op.getMask ();
24402515 Value other = op.getOther ();
24412516
2442- // adaptor values
2443- Value llPtr = adaptor.getPtr ();
2444- Value llMask = adaptor.getMask ();
2445- Value llOther = adaptor.getOther ();
2446-
24472517 // Determine the vectorization size
24482518 Type valueElemTy =
24492519 typeConverter->convertType (getElementTypeOrSelf (op.getType ()));
24502520 unsigned numElems = getTotalElemsPerThread (op.getType ());
2451- unsigned vec = getVectorSize (ptr);
2521+ unsigned vec = base. getVectorSize (ptr);
24522522 if (llMask)
2453- vec = std::min<size_t >(vec, getMaskAlignment (mask));
2523+ vec = std::min<std:: size_t >(vec, base. getMaskAlignment (mask));
24542524
24552525 SmallVector<Value> ptrElems, maskElems, otherElems;
24562526 bool otherIsSplatConstInt = false ;
@@ -2459,9 +2529,10 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
24592529 if (isTensorPointerType (ptr.getType ())) {
24602530 // fallback to gather load.
24612531 auto tensorType = cast<RankedTensorType>(op.getType ());
2462- std::tie (ptrElems, maskElems, otherElems) = convertBlockPtrToTensorOfPtr (
2463- loc, llPtr, tensorType, valueElemTy, rewriter, op.getBoundaryCheck (),
2464- op.getPadding ());
2532+ std::tie (ptrElems, maskElems, otherElems) =
2533+ base.convertBlockPtrToTensorOfPtr (loc, llPtr, tensorType, valueElemTy,
2534+ rewriter, op.getBoundaryCheck (),
2535+ op.getPadding ());
24652536 } else {
24662537 // Get the LLVM values for pointers
24672538 ptrElems = unpackLLElements (loc, llPtr, rewriter);
@@ -2503,19 +2574,19 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
25032574 if (unsigned canonicalVecStart = getCanonicalIndex (vecStart, regMask);
25042575 vecStart != canonicalVecStart) {
25052576 // For redundant registers, refer back to the canonical load
2506- for (int iVec = 0 ; iVec < vec; ++iVec)
2577+ for (int iVec = 0 ; iVec < static_cast < int >( vec) ; ++iVec)
25072578 loadedVals.push_back (loadedVals[canonicalVecStart + iVec]);
2508-
25092579 continue ;
25102580 }
25112581
25122582 // TODO: optimization when ptr is GEP with constant offset
2513- const size_t maxWordWidth = std::max<size_t >(32 , valueElemNBits);
2583+ const size_t maxWordWidth = std::max<std:: size_t >(32 , valueElemNBits);
25142584 const size_t totalWidth = valueElemNBits * vec;
25152585 const size_t width = std::min (totalWidth, maxWordWidth);
2516- const size_t nWords = std::max<size_t >(1 , totalWidth / width);
2586+ const size_t nWords = std::max<std:: size_t >(1 , totalWidth / width);
25172587 const size_t wordNElems = width / valueElemNBits;
25182588 const size_t movWidth = width < 16 ? 16 : width;
2589+ (void )movWidth; // keep variable but silence unused warning
25192590 assert (wordNElems * nWords * numVecs == numElems);
25202591
25212592 Value pred = maskElems.size () ? maskElems[vecStart] : Value{};
@@ -2554,9 +2625,7 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
25542625 retTy, other_, v,
25552626 createIndexAttrConstant (
25562627 rewriter, loc, typeConverter->getIndexType (), ii))
2557- :
2558-
2559- v;
2628+ : v;
25602629 }
25612630 }
25622631 assert (other_ && " Expecting a valid value" );
@@ -2566,13 +2635,13 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
25662635 auto createLoadWithAttrs = [&]() {
25672636 return SmallVector<Value>{b.load (retTy, addrElem, alignment,
25682637 op.getIsVolatile (),
2569- getNonTemporalFlag (op))};
2638+ base. getNonTemporalFlag (op))};
25702639 };
25712640
25722641 Value ret;
25732642 if (!pred)
25742643 ret = createLoadWithAttrs ()[0 ];
2575- else if (canUsePredicatedInstructions (op))
2644+ else if (base. canUsePredicatedInstructions (op))
25762645 ret = TritonGEN::PredicatedLoadOp::create (
25772646 rewriter, loc, retTy, addrElem, b.i64_val (alignment), pred, other_);
25782647 else {
@@ -2604,13 +2673,29 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
26042673 } // end vec
26052674
26062675 Type llvmResultStructTy = typeConverter->convertType (op.getType ());
2607- Value resultStruct = packLLElements (loc, typeConverter, loadedVals,
2608- rewriter, llvmResultStructTy);
2676+ return packLLElements (loc, typeConverter, loadedVals, rewriter,
2677+ llvmResultStructTy);
2678+ }
2679+
2680+ LogicalResult
2681+ matchAndRewrite (triton::LoadOp op, OpAdaptor adaptor,
2682+ ConversionPatternRewriter &rewriter) const override {
2683+ Value resultStruct = emitGenericLoadImpl (
2684+ op, adaptor.getPtr (), adaptor.getMask (), adaptor.getOther (), rewriter,
2685+ getTypeConverter (), *this );
26092686 rewriter.replaceOp (op, {resultStruct});
26102687 return success ();
26112688 }
26122689};
26132690
2691+ static Value emitGenericLoad (triton::LoadOp op, Value llPtr, Value llMask,
2692+ Value llOther, ConversionPatternRewriter &rewriter,
2693+ const LLVMTypeConverter *typeConverter,
2694+ const LoadStoreConversionBase &base) {
2695+ return LoadOpConversion::emitGenericLoadImpl (op, llPtr, llMask, llOther,
2696+ rewriter, typeConverter, base);
2697+ }
2698+
26142699struct StoreOpToBlockIOConversion
26152700 : public ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>,
26162701 public BlockIOConversionBase {
0 commit comments